132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel, Field, validator
|
||
from typing import List, Optional
|
||
import pandas as pd
|
||
import numpy as np
|
||
import joblib
|
||
import re
|
||
from rich.console import Console
|
||
|
||
console = Console()
|
||
|
||
MAX_TEXT_LEN = 10000 # максимальная длина текста для модели
|
||
|
||
# модели для валидации и отлова ошибок
|
||
class Item(BaseModel):
|
||
name: str = ""
|
||
price: float = 0.0
|
||
|
||
@validator("price", pre=True)
|
||
def fix_negative_price(cls, v):
|
||
return max(0, float(v) if v is not None else 0)
|
||
|
||
class Transaction(BaseModel):
|
||
transaction_id: str
|
||
terminal_name: Optional[str] = "unknown"
|
||
terminal_description: Optional[str] = "unknown"
|
||
city: Optional[str] = "unknown"
|
||
amount: float = 0.0
|
||
items: Optional[List[Item]] = []
|
||
|
||
@validator("amount", pre=True)
|
||
def fix_negative_amount(cls, v):
|
||
return max(0, float(v) if v is not None else 0)
|
||
|
||
class BatchRequest(BaseModel):
|
||
transactions: List[Transaction]
|
||
|
||
|
||
app = FastAPI()
|
||
|
||
|
||
try:
|
||
model = joblib.load("model/mcc_model.pkl")
|
||
console.log("[green]Успешно загрузили модель[/green]")
|
||
MODEL_READY = True
|
||
except Exception as e:
|
||
console.log(f"[red]Ошибка загрузки модели: {e}[/red]")
|
||
MODEL_READY = False
|
||
|
||
# препроцессим транзакции с кастомными типами, чтобы не валилось
|
||
def clean_text(text: str) -> str:
|
||
text = re.sub(r"[^\w\s\.,'-]", " ", str(text))
|
||
text = re.sub(r"\s+", " ", text).strip()
|
||
return text[:MAX_TEXT_LEN]
|
||
|
||
def preprocess_transaction(tx: Transaction) -> pd.DataFrame:
|
||
terminal_name = clean_text(tx.terminal_name)
|
||
terminal_desc = clean_text(tx.terminal_description)
|
||
city = clean_text(tx.city)
|
||
|
||
item_names = [clean_text(i.name) for i in tx.items] if tx.items else []
|
||
items_text = " ".join(item_names)
|
||
|
||
combined_text = f"{terminal_name} {terminal_desc} {city} items {items_text}".lower()[:MAX_TEXT_LEN]
|
||
|
||
df = pd.DataFrame([{
|
||
"full_text": combined_text,
|
||
"amount": tx.amount,
|
||
"items_text": items_text,
|
||
"items_count": len(tx.items),
|
||
"items_total_price": sum(i.price for i in tx.items) if tx.items else 0,
|
||
"items_max_price": max((i.price for i in tx.items), default=0),
|
||
"items_min_price": min((i.price for i in tx.items), default=0)
|
||
}])
|
||
return df
|
||
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
if MODEL_READY:
|
||
return {"status": "ok"}
|
||
raise HTTPException(status_code=503, detail="Model not up")
|
||
|
||
@app.get("/model/info")
|
||
def info():
|
||
return {"model_name": "mcc_classifier",
|
||
"model_version": "1.0.67"}
|
||
|
||
@app.post("/predict")
|
||
def predict(tx: Transaction):
|
||
if not MODEL_READY:
|
||
raise HTTPException(status_code=503, detail="Model not up")
|
||
|
||
try:
|
||
df = preprocess_transaction(tx)
|
||
pred = model.predict(df)[0]
|
||
conf = model.predict_proba(df).max()
|
||
console.log(f"[yellow]Предсказание для транзакции с id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
||
return {
|
||
"transaction_id": tx.transaction_id,
|
||
"predicted_mcc": int(pred),
|
||
"confidence": round(float(conf), 4)
|
||
}
|
||
except Exception as e:
|
||
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
|
||
@app.post("/predict/batch")
|
||
def predict_batch(batch: BatchRequest):
|
||
if not MODEL_READY:
|
||
raise HTTPException(status_code=503, detail="Model not up")
|
||
results = []
|
||
for tx in batch.transactions:
|
||
try:
|
||
df = preprocess_transaction(tx)
|
||
pred = model.predict(df)[0]
|
||
conf = model.predict_proba(df).max()
|
||
results.append({
|
||
"transaction_id": tx.transaction_id,
|
||
"predicted_mcc": int(pred),
|
||
"confidence": round(float(conf), 4)
|
||
})
|
||
console.log(f"[yellow]Предсказание для id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
||
except Exception as e:
|
||
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
||
results.append({
|
||
"transaction_id": tx.transaction_id,
|
||
"predicted_mcc": None,
|
||
"confidence": 0.0
|
||
})
|
||
return {"predictions": results}
|