132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
|
|
from fastapi import FastAPI, HTTPException
|
|||
|
|
from pydantic import BaseModel, Field, validator
|
|||
|
|
from typing import List, Optional
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
import joblib
|
|||
|
|
import re
|
|||
|
|
from rich.console import Console
|
|||
|
|
|
|||
|
|
console = Console()
|
|||
|
|
|
|||
|
|
MAX_TEXT_LEN = 10000 # максимальная длина текста для модели
|
|||
|
|
|
|||
|
|
# модели для валидации и отлова ошибок
|
|||
|
|
class Item(BaseModel):
|
|||
|
|
name: str = ""
|
|||
|
|
price: float = 0.0
|
|||
|
|
|
|||
|
|
@validator("price", pre=True)
|
|||
|
|
def fix_negative_price(cls, v):
|
|||
|
|
return max(0, float(v) if v is not None else 0)
|
|||
|
|
|
|||
|
|
class Transaction(BaseModel):
|
|||
|
|
transaction_id: str
|
|||
|
|
terminal_name: Optional[str] = "unknown"
|
|||
|
|
terminal_description: Optional[str] = "unknown"
|
|||
|
|
city: Optional[str] = "unknown"
|
|||
|
|
amount: float = 0.0
|
|||
|
|
items: Optional[List[Item]] = []
|
|||
|
|
|
|||
|
|
@validator("amount", pre=True)
|
|||
|
|
def fix_negative_amount(cls, v):
|
|||
|
|
return max(0, float(v) if v is not None else 0)
|
|||
|
|
|
|||
|
|
class BatchRequest(BaseModel):
|
|||
|
|
transactions: List[Transaction]
|
|||
|
|
|
|||
|
|
|
|||
|
|
app = FastAPI()
|
|||
|
|
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
model = joblib.load("model/mcc_model.pkl")
|
|||
|
|
console.log("[green]Успешно загрузили модель[/green]")
|
|||
|
|
MODEL_READY = True
|
|||
|
|
except Exception as e:
|
|||
|
|
console.log(f"[red]Ошибка загрузки модели: {e}[/red]")
|
|||
|
|
MODEL_READY = False
|
|||
|
|
|
|||
|
|
# препроцессим транзакции с кастомными типами, чтобы не валилось
|
|||
|
|
def clean_text(text: str) -> str:
|
|||
|
|
text = re.sub(r"[^\w\s\.,'-]", " ", str(text))
|
|||
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|||
|
|
return text[:MAX_TEXT_LEN]
|
|||
|
|
|
|||
|
|
def preprocess_transaction(tx: Transaction) -> pd.DataFrame:
|
|||
|
|
terminal_name = clean_text(tx.terminal_name)
|
|||
|
|
terminal_desc = clean_text(tx.terminal_description)
|
|||
|
|
city = clean_text(tx.city)
|
|||
|
|
|
|||
|
|
item_names = [clean_text(i.name) for i in tx.items] if tx.items else []
|
|||
|
|
items_text = " ".join(item_names)
|
|||
|
|
|
|||
|
|
combined_text = f"{terminal_name} {terminal_desc} {city} items {items_text}".lower()[:MAX_TEXT_LEN]
|
|||
|
|
|
|||
|
|
df = pd.DataFrame([{
|
|||
|
|
"full_text": combined_text,
|
|||
|
|
"amount": tx.amount,
|
|||
|
|
"items_text": items_text,
|
|||
|
|
"items_count": len(tx.items),
|
|||
|
|
"items_total_price": sum(i.price for i in tx.items) if tx.items else 0,
|
|||
|
|
"items_max_price": max((i.price for i in tx.items), default=0),
|
|||
|
|
"items_min_price": min((i.price for i in tx.items), default=0)
|
|||
|
|
}])
|
|||
|
|
return df
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/health")
|
|||
|
|
def health():
|
|||
|
|
if MODEL_READY:
|
|||
|
|
return {"status": "ok"}
|
|||
|
|
raise HTTPException(status_code=503, detail="Model not up")
|
|||
|
|
|
|||
|
|
@app.get("/model/info")
|
|||
|
|
def info():
|
|||
|
|
return {"model_name": "mcc_classifier",
|
|||
|
|
"model_version": "1.0.67"}
|
|||
|
|
|
|||
|
|
@app.post("/predict")
|
|||
|
|
def predict(tx: Transaction):
|
|||
|
|
if not MODEL_READY:
|
|||
|
|
raise HTTPException(status_code=503, detail="Model not up")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
df = preprocess_transaction(tx)
|
|||
|
|
pred = model.predict(df)[0]
|
|||
|
|
conf = model.predict_proba(df).max()
|
|||
|
|
console.log(f"[yellow]Предсказание для транзакции с id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
|||
|
|
return {
|
|||
|
|
"transaction_id": tx.transaction_id,
|
|||
|
|
"predicted_mcc": int(pred),
|
|||
|
|
"confidence": round(float(conf), 4)
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
|||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|||
|
|
|
|||
|
|
@app.post("/predict/batch")
|
|||
|
|
def predict_batch(batch: BatchRequest):
|
|||
|
|
if not MODEL_READY:
|
|||
|
|
raise HTTPException(status_code=503, detail="Model not up")
|
|||
|
|
results = []
|
|||
|
|
for tx in batch.transactions:
|
|||
|
|
try:
|
|||
|
|
df = preprocess_transaction(tx)
|
|||
|
|
pred = model.predict(df)[0]
|
|||
|
|
conf = model.predict_proba(df).max()
|
|||
|
|
results.append({
|
|||
|
|
"transaction_id": tx.transaction_id,
|
|||
|
|
"predicted_mcc": int(pred),
|
|||
|
|
"confidence": round(float(conf), 4)
|
|||
|
|
})
|
|||
|
|
console.log(f"[yellow]Предсказание для id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
|||
|
|
except Exception as e:
|
|||
|
|
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
|||
|
|
results.append({
|
|||
|
|
"transaction_id": tx.transaction_id,
|
|||
|
|
"predicted_mcc": None,
|
|||
|
|
"confidence": 0.0
|
|||
|
|
})
|
|||
|
|
return {"predictions": results}
|