app v1
This commit is contained in:
parent
694975742c
commit
d9abed5c15
125
dih.py
125
dih.py
|
|
@ -1,125 +0,0 @@
|
|||
import time
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
|
||||
|
||||
# ===== КРИТИЧЕСКИ ВАЖНО =====
|
||||
# Эти классы ДОЛЖНЫ существовать до joblib.load()
|
||||
|
||||
class TextExtractor(BaseEstimator, TransformerMixin):
|
||||
def fit(self, X, y=None):
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
return X['full_text'].fillna('')
|
||||
|
||||
|
||||
class NumberExtractor(BaseEstimator, TransformerMixin):
|
||||
def fit(self, X, y=None):
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
return X[['amount']].fillna(0)
|
||||
|
||||
|
||||
# ============================
|
||||
|
||||
MODEL_PATH = "mcc_model.pkl"
|
||||
|
||||
BASE_TX = {
|
||||
"transaction_id": "TX00001116",
|
||||
"terminal_name": "STORE001",
|
||||
"terminal_description": "common common common thing",
|
||||
"city": "NYC",
|
||||
"amount": 272.80,
|
||||
"items": [
|
||||
{"name": "basic loyalty", "price": 58.20},
|
||||
{"name": "Bringiong item lifes", "price": 28.99},
|
||||
{"name": "regular item basic item", "price": 56.91}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def json_to_df(data_list):
|
||||
rows = []
|
||||
for d in data_list:
|
||||
items = d.get("items", [])
|
||||
item_text = " ".join(str(i.get("name","")) for i in items)
|
||||
full_text = f"{d.get('terminal_name','')} {d.get('terminal_description','')} {item_text}".lower()
|
||||
|
||||
rows.append({
|
||||
"full_text": full_text,
|
||||
"amount": float(d.get("amount", 0))
|
||||
})
|
||||
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def make_batch(n=100):
|
||||
batch = []
|
||||
for i in range(n):
|
||||
tx = BASE_TX.copy()
|
||||
tx["transaction_id"] = f"TX{i:06d}"
|
||||
tx["amount"] = round(random.uniform(3, 500), 2)
|
||||
batch.append(tx)
|
||||
return batch
|
||||
|
||||
|
||||
# ============================
|
||||
# BENCH
|
||||
# ============================
|
||||
|
||||
print("\n[1] LOADING MODEL (cold start)...")
|
||||
t0 = time.perf_counter()
|
||||
model = joblib.load(MODEL_PATH)
|
||||
t1 = time.perf_counter()
|
||||
print(f"Cold load time: {t1 - t0:.4f} sec")
|
||||
|
||||
# warmup
|
||||
print("\n[2] WARMUP...")
|
||||
warm_df = json_to_df([BASE_TX])
|
||||
for _ in range(5):
|
||||
model.predict(warm_df)
|
||||
|
||||
# single inference benchmark
|
||||
print("\n[3] SINGLE REQUEST BENCH (1000 runs)...")
|
||||
times = []
|
||||
|
||||
for _ in range(1000):
|
||||
df = json_to_df([BASE_TX])
|
||||
t0 = time.perf_counter()
|
||||
model.predict_proba(df)
|
||||
times.append(time.perf_counter() - t0)
|
||||
|
||||
times = np.array(times)
|
||||
|
||||
print(f"avg : {times.mean()*1000:.2f} ms")
|
||||
print(f"p95 : {np.percentile(times,95)*1000:.2f} ms")
|
||||
print(f"max : {times.max()*1000:.2f} ms")
|
||||
|
||||
# batch benchmark
|
||||
print("\n[4] BATCH 100 BENCH...")
|
||||
batch = make_batch(100)
|
||||
df_batch = json_to_df(batch)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
model.predict_proba(df_batch)
|
||||
dt = time.perf_counter() - t0
|
||||
|
||||
print(f"Batch 100 time: {dt:.3f} sec")
|
||||
|
||||
# verdict
|
||||
print("\n[5] VERDICT")
|
||||
if np.percentile(times,95) < 0.2:
|
||||
print("✅ /predict проходит по latency (<200ms p95)")
|
||||
else:
|
||||
print("❌ /predict НЕ проходит по latency")
|
||||
|
||||
if dt < 5:
|
||||
print("✅ /predict/batch проходит (<5s)")
|
||||
else:
|
||||
print("❌ /predict/batch НЕ проходит")
|
||||
Binary file not shown.
BIN
mcc_model.pkl
BIN
mcc_model.pkl
Binary file not shown.
100
prepare_data.py
100
prepare_data.py
|
|
@ -1,100 +0,0 @@
|
|||
import pandas as pd
|
||||
import re
|
||||
|
||||
DATA_DIR = "data"
|
||||
|
||||
# ---------- Текстовая очистка ----------
|
||||
def clean_text(s: str, max_len=1000):
|
||||
if not isinstance(s, str):
|
||||
return ""
|
||||
|
||||
s = s.lower()
|
||||
|
||||
# --- УБРАТЬ ПРЕФИКС t" или t' ---
|
||||
s = re.sub(r'^t["\']', '', s)
|
||||
|
||||
# --- УНИФИКАЦИЯ КАВЫЧЕК ---
|
||||
s = s.replace('"', ' ').replace("'", " ").replace('`', ' ')
|
||||
|
||||
# убрать не-ASCII (эмодзи, иероглифы и т.п.)
|
||||
s = re.sub(r'[^\x00-\x7F]+', ' ', s)
|
||||
|
||||
# оставить только буквы и цифры
|
||||
s = re.sub(r'[^a-z0-9\s]', ' ', s)
|
||||
|
||||
# удалить подряд идущие повторы слов
|
||||
words = s.split()
|
||||
dedup = []
|
||||
prev = None
|
||||
for w in words:
|
||||
if w != prev:
|
||||
dedup.append(w)
|
||||
prev = w
|
||||
s = " ".join(dedup)
|
||||
|
||||
# нормализовать пробелы и обрезать
|
||||
s = re.sub(r'\s+', ' ', s).strip()
|
||||
return s[:max_len]
|
||||
|
||||
|
||||
# ---------- LOAD ----------
|
||||
transactions = pd.read_csv(f"{DATA_DIR}/transactions.csv")
|
||||
terminals = pd.read_csv(f"{DATA_DIR}/terminals.csv")
|
||||
receipts = pd.read_csv(f"{DATA_DIR}/receipts.csv")
|
||||
|
||||
# ---------- CLEAN TEXT ----------
|
||||
for col in ["terminal_name", "terminal_description", "terminal_city"]:
|
||||
if col == "terminal_city":
|
||||
terminals[col] = terminals[col].astype(str).apply(clean_text)
|
||||
else:
|
||||
terminals[col] = terminals[col].apply(clean_text)
|
||||
|
||||
receipts["item_name"] = receipts["item_name"].apply(clean_text)
|
||||
|
||||
# ---------- AGGREGATE RECEIPTS ----------
|
||||
receipt_agg = receipts.groupby("transaction_id").agg(
|
||||
items_text=("item_name", lambda x: " ".join(x)),
|
||||
items_count=("item_name", "count"),
|
||||
items_total_price=("item_price", "sum"),
|
||||
items_max_price=("item_price", "max"),
|
||||
items_min_price=("item_price", "min"),
|
||||
).reset_index()
|
||||
|
||||
# ---------- MERGE WITH TRANSACTIONS ----------
|
||||
df = transactions[["transaction_id", "terminal_id", "amount", "true_mcc"]].merge(
|
||||
terminals[["terminal_id", "terminal_name", "terminal_description", "terminal_city"]],
|
||||
on="terminal_id",
|
||||
how="left"
|
||||
)
|
||||
|
||||
df = df.merge(receipt_agg, on="transaction_id", how="left")
|
||||
|
||||
# ---------- FILL NA ----------
|
||||
for col in ["items_text", "terminal_name", "terminal_description", "terminal_city"]:
|
||||
df[col] = df[col].fillna("")
|
||||
|
||||
for col in ["items_count", "items_total_price", "items_max_price", "items_min_price"]:
|
||||
df[col] = df[col].fillna(0)
|
||||
|
||||
# ---------- BUILD FINAL TEXT ----------
|
||||
df["text"] = (
|
||||
df["terminal_name"] + " " +
|
||||
df["terminal_description"] + " " +
|
||||
df["terminal_city"] + " " +
|
||||
" items " + df["items_text"] + " items " +
|
||||
df["items_text"]
|
||||
)
|
||||
|
||||
df["text"] = df["text"].apply(clean_text)
|
||||
|
||||
# ---------- FINAL CHECK ----------
|
||||
print("rows:", len(df))
|
||||
print("unique tx:", df["transaction_id"].nunique())
|
||||
print(df["true_mcc"].value_counts())
|
||||
|
||||
assert len(df) == df["transaction_id"].nunique()
|
||||
assert df["text"].str.len().min() > 0
|
||||
|
||||
# ---------- SAVE ----------
|
||||
df.to_csv("train.csv", index=False)
|
||||
print("saved train.csv")
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY solution/req.txt .
|
||||
RUN pip install --no-cache-dir -r req.txt
|
||||
|
||||
COPY ./solution .
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080", "--reload"]
|
||||
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import joblib
|
||||
import re
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
MAX_TEXT_LEN = 10000 # максимальная длина текста для модели
|
||||
|
||||
# модели для валидации и отлова ошибок
|
||||
class Item(BaseModel):
|
||||
name: str = ""
|
||||
price: float = 0.0
|
||||
|
||||
@validator("price", pre=True)
|
||||
def fix_negative_price(cls, v):
|
||||
return max(0, float(v) if v is not None else 0)
|
||||
|
||||
class Transaction(BaseModel):
|
||||
transaction_id: str
|
||||
terminal_name: Optional[str] = "unknown"
|
||||
terminal_description: Optional[str] = "unknown"
|
||||
city: Optional[str] = "unknown"
|
||||
amount: float = 0.0
|
||||
items: Optional[List[Item]] = []
|
||||
|
||||
@validator("amount", pre=True)
|
||||
def fix_negative_amount(cls, v):
|
||||
return max(0, float(v) if v is not None else 0)
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
transactions: List[Transaction]
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
try:
|
||||
model = joblib.load("model/mcc_model.pkl")
|
||||
console.log("[green]Успешно загрузили модель[/green]")
|
||||
MODEL_READY = True
|
||||
except Exception as e:
|
||||
console.log(f"[red]Ошибка загрузки модели: {e}[/red]")
|
||||
MODEL_READY = False
|
||||
|
||||
# препроцессим транзакции с кастомными типами, чтобы не валилось
|
||||
def clean_text(text: str) -> str:
|
||||
text = re.sub(r"[^\w\s\.,'-]", " ", str(text))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text[:MAX_TEXT_LEN]
|
||||
|
||||
def preprocess_transaction(tx: Transaction) -> pd.DataFrame:
|
||||
terminal_name = clean_text(tx.terminal_name)
|
||||
terminal_desc = clean_text(tx.terminal_description)
|
||||
city = clean_text(tx.city)
|
||||
|
||||
item_names = [clean_text(i.name) for i in tx.items] if tx.items else []
|
||||
items_text = " ".join(item_names)
|
||||
|
||||
combined_text = f"{terminal_name} {terminal_desc} {city} items {items_text}".lower()[:MAX_TEXT_LEN]
|
||||
|
||||
df = pd.DataFrame([{
|
||||
"full_text": combined_text,
|
||||
"amount": tx.amount,
|
||||
"items_text": items_text,
|
||||
"items_count": len(tx.items),
|
||||
"items_total_price": sum(i.price for i in tx.items) if tx.items else 0,
|
||||
"items_max_price": max((i.price for i in tx.items), default=0),
|
||||
"items_min_price": min((i.price for i in tx.items), default=0)
|
||||
}])
|
||||
return df
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
if MODEL_READY:
|
||||
return {"status": "ok"}
|
||||
raise HTTPException(status_code=503, detail="Model not up")
|
||||
|
||||
@app.get("/model/info")
|
||||
def info():
|
||||
return {"model_name": "mcc_classifier",
|
||||
"model_version": "1.0.67"}
|
||||
|
||||
@app.post("/predict")
|
||||
def predict(tx: Transaction):
|
||||
if not MODEL_READY:
|
||||
raise HTTPException(status_code=503, detail="Model not up")
|
||||
|
||||
try:
|
||||
df = preprocess_transaction(tx)
|
||||
pred = model.predict(df)[0]
|
||||
conf = model.predict_proba(df).max()
|
||||
console.log(f"[yellow]Предсказание для транзакции с id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
||||
return {
|
||||
"transaction_id": tx.transaction_id,
|
||||
"predicted_mcc": int(pred),
|
||||
"confidence": round(float(conf), 4)
|
||||
}
|
||||
except Exception as e:
|
||||
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@app.post("/predict/batch")
|
||||
def predict_batch(batch: BatchRequest):
|
||||
if not MODEL_READY:
|
||||
raise HTTPException(status_code=503, detail="Model not up")
|
||||
results = []
|
||||
for tx in batch.transactions:
|
||||
try:
|
||||
df = preprocess_transaction(tx)
|
||||
pred = model.predict(df)[0]
|
||||
conf = model.predict_proba(df).max()
|
||||
results.append({
|
||||
"transaction_id": tx.transaction_id,
|
||||
"predicted_mcc": int(pred),
|
||||
"confidence": round(float(conf), 4)
|
||||
})
|
||||
console.log(f"[yellow]Предсказание для id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
||||
except Exception as e:
|
||||
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
||||
results.append({
|
||||
"transaction_id": tx.transaction_id,
|
||||
"predicted_mcc": None,
|
||||
"confidence": 0.0
|
||||
})
|
||||
return {"predictions": results}
|
||||
Binary file not shown.
|
|
@ -16,8 +16,10 @@ joblib==1.5.3
|
|||
jupyter-client==8.8.0
|
||||
jupyter-core==5.9.1
|
||||
kiwisolver==1.4.9
|
||||
markdown-it-py==4.0.0
|
||||
matplotlib==3.10.8
|
||||
matplotlib-inline==0.2.1
|
||||
mdurl==0.1.2
|
||||
narwhals==2.15.0
|
||||
nest-asyncio==1.6.0
|
||||
numpy==2.4.1
|
||||
|
|
@ -38,6 +40,7 @@ pyparsing==3.3.2
|
|||
python-dateutil==2.9.0.post0
|
||||
pytz==2025.2
|
||||
pyzmq==27.1.0
|
||||
rich==14.2.0
|
||||
scikit-learn==1.8.0
|
||||
scipy==1.17.0
|
||||
six==1.17.0
|
||||
|
|
@ -24,18 +24,18 @@ console = Console()
|
|||
# Для преобразования TF-IDF в вектора
|
||||
class TextExtractor(BaseEstimator, TransformerMixin):
|
||||
def fit(self, X, y=None): return self
|
||||
|
||||
def transform(self, X):
|
||||
return X['full_text'].fillna('')
|
||||
|
||||
# Для StandardScaler
|
||||
class NumberExtractor(BaseEstimator, TransformerMixin):
|
||||
def fit(self, X, y=None): return self
|
||||
|
||||
def transform(self, X):
|
||||
return X[['amount']].fillna(0)
|
||||
|
||||
|
||||
|
||||
|
||||
def train_model():
|
||||
console.log("[yellow]Грузим данные из data...[/yellow]")
|
||||
try:
|
||||
|
|
@ -43,10 +43,10 @@ def train_model():
|
|||
terminals = pd.read_csv('data/terminals.csv')
|
||||
receipts = pd.read_csv('data/receipts.csv')
|
||||
except FileNotFoundError as e:
|
||||
console.log(f"Файлы для обучения не найдены :( \n {e}", style="white on red")
|
||||
console.log(
|
||||
f"Файлы для обучения не найдены :( \n {e}", style="white on red")
|
||||
return
|
||||
|
||||
|
||||
console.log("[yellow]Предобрабатываем данные...[/yellow]")
|
||||
# Приклеиваеем вместе имена товаров
|
||||
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
|
||||
|
|
@ -54,12 +54,14 @@ def train_model():
|
|||
).reset_index()
|
||||
|
||||
# Делаем один большой датафрейм с которым будем работать
|
||||
df = tx.merge(terminals[['terminal_id', 'terminal_name', 'terminal_description']], on='terminal_id', how='left')
|
||||
df = tx.merge(terminals[['terminal_id', 'terminal_name',
|
||||
'terminal_description']], on='terminal_id', how='left')
|
||||
df = df.merge(receipts_agg, on='transaction_id', how='left')
|
||||
|
||||
# Делаем текстовое поле для TF-IDF
|
||||
df['full_text'] = (
|
||||
df['terminal_name'].astype(str) + " " +
|
||||
|
||||
df['terminal_description'].astype(str) + " " + # <!-- изначально я пробовал клеить id транзакции, однако модель слишком на ней зацикливалась
|
||||
df['item_name'].astype(str)
|
||||
).str.lower()
|
||||
|
|
@ -102,13 +104,12 @@ def train_model():
|
|||
|
||||
# Валидация
|
||||
|
||||
console.log("[yellow]Оцениваем качество на валидационной выборке...[/yellow]")
|
||||
console.log(
|
||||
"[yellow]Оцениваем качество на валидационной выборке...[/yellow]")
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
|
||||
|
||||
pipeline.fit(X_train, y_train)
|
||||
y_pred = pipeline.predict(X_test)
|
||||
probs = pipeline.predict_proba(X_test)
|
||||
|
|
@ -118,7 +119,6 @@ def train_model():
|
|||
|
||||
table = Table(box=box.ROUNDED, title="Отчёт")
|
||||
|
||||
|
||||
table.add_column("Метрика", justify="center", style="yellow")
|
||||
table.add_column("Значение", justify="center", style="yellow")
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ def train_model():
|
|||
console.print("[yellow]Репорт по классам[/yellow]", justify="center")
|
||||
console.print(classification_report(y_test, y_pred), justify="center")
|
||||
|
||||
# Метрики норм, учимся на всех данных и сохранем данные
|
||||
# Метрики норм, учимся на всех данных и сохранем модель
|
||||
|
||||
console.log("[yellow]Учимся на всем, что есть...[/yellow]")
|
||||
pipeline.fit(X, y)
|
||||
|
|
@ -138,6 +138,8 @@ def train_model():
|
|||
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
|
||||
console.log(Markdown("Сохранили модель в **solution/model/mcc_model.pkl**"))
|
||||
|
||||
with console.status("Учим модель..."):
|
||||
|
||||
if __name__ == "__main__":
|
||||
with console.status("Учим модель..."):
|
||||
train_model()
|
||||
console.print(Markdown("*Модель готова :)*"))
|
||||
console.print(Markdown("*Модель готова :)*"))
|
||||
Loading…
Reference in New Issue