mlops/solution/app.py

132 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, validator
from typing import List, Optional
import pandas as pd
import numpy as np
import joblib
import re
from rich.console import Console
console = Console()
MAX_TEXT_LEN = 10000 # максимальная длина текста для модели
# модели для валидации и отлова ошибок
class Item(BaseModel):
name: str = ""
price: float = 0.0
@validator("price", pre=True)
def fix_negative_price(cls, v):
return max(0, float(v) if v is not None else 0)
class Transaction(BaseModel):
transaction_id: str
terminal_name: Optional[str] = "unknown"
terminal_description: Optional[str] = "unknown"
city: Optional[str] = "unknown"
amount: float = 0.0
items: Optional[List[Item]] = []
@validator("amount", pre=True)
def fix_negative_amount(cls, v):
return max(0, float(v) if v is not None else 0)
class BatchRequest(BaseModel):
transactions: List[Transaction]
app = FastAPI()
try:
model = joblib.load("model/mcc_model.pkl")
console.log("[green]Успешно загрузили модель[/green]")
MODEL_READY = True
except Exception as e:
console.log(f"[red]Ошибка загрузки модели: {e}[/red]")
MODEL_READY = False
# препроцессим транзакции с кастомными типами, чтобы не валилось
def clean_text(text: str) -> str:
text = re.sub(r"[^\w\s\.,'-]", " ", str(text))
text = re.sub(r"\s+", " ", text).strip()
return text[:MAX_TEXT_LEN]
def preprocess_transaction(tx: Transaction) -> pd.DataFrame:
terminal_name = clean_text(tx.terminal_name)
terminal_desc = clean_text(tx.terminal_description)
city = clean_text(tx.city)
item_names = [clean_text(i.name) for i in tx.items] if tx.items else []
items_text = " ".join(item_names)
combined_text = f"{terminal_name} {terminal_desc} {city} items {items_text}".lower()[:MAX_TEXT_LEN]
df = pd.DataFrame([{
"full_text": combined_text,
"amount": tx.amount,
"items_text": items_text,
"items_count": len(tx.items),
"items_total_price": sum(i.price for i in tx.items) if tx.items else 0,
"items_max_price": max((i.price for i in tx.items), default=0),
"items_min_price": min((i.price for i in tx.items), default=0)
}])
return df
@app.get("/health")
def health():
if MODEL_READY:
return {"status": "ok"}
raise HTTPException(status_code=503, detail="Model not up")
@app.get("/model/info")
def info():
return {"model_name": "mcc_classifier",
"model_version": "1.0.67"}
@app.post("/predict")
def predict(tx: Transaction):
if not MODEL_READY:
raise HTTPException(status_code=503, detail="Model not up")
try:
df = preprocess_transaction(tx)
pred = model.predict(df)[0]
conf = model.predict_proba(df).max()
console.log(f"[yellow]Предсказание для транзакции с id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
return {
"transaction_id": tx.transaction_id,
"predicted_mcc": int(pred),
"confidence": round(float(conf), 4)
}
except Exception as e:
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/predict/batch")
def predict_batch(batch: BatchRequest):
if not MODEL_READY:
raise HTTPException(status_code=503, detail="Model not up")
results = []
for tx in batch.transactions:
try:
df = preprocess_transaction(tx)
pred = model.predict(df)[0]
conf = model.predict_proba(df).max()
results.append({
"transaction_id": tx.transaction_id,
"predicted_mcc": int(pred),
"confidence": round(float(conf), 4)
})
console.log(f"[yellow]Предсказание для id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
except Exception as e:
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
results.append({
"transaction_id": tx.transaction_id,
"predicted_mcc": None,
"confidence": 0.0
})
return {"predictions": results}