diff --git a/dih.py b/dih.py index a4cbb20..e2062df 100644 --- a/dih.py +++ b/dih.py @@ -1,93 +1,125 @@ +import time +import joblib import pandas as pd import numpy as np -from catboost import CatBoostClassifier +import random -class MCCPredictor: - def __init__(self, model_path='mcc_classifier.cbm'): - self.model = CatBoostClassifier() - self.model.load_model(model_path) - - # Порядок колонок должен быть В ТОЧНОСТИ как при обучении - self.feature_order = [ - 'terminal_name', 'terminal_description', 'terminal_city', 'items_text', 'text', - 'amount', 'items_count', 'items_total_price', 'items_max_price', 'items_min_price', - 'terminal_id' - ] +from sklearn.base import BaseEstimator, TransformerMixin - def _preprocess_json(self, data): - """Конвертирует входящий JSON в плоскую структуру для модели""" - - # 1. Агрегируем данные из списка items - items = data.get('items', []) - item_names = [str(i.get('name', '')) for i in items] - item_prices = [float(i.get('price', 0)) for i in items] - - items_text = " ".join(item_names) - items_count = len(items) - items_total_price = sum(item_prices) - items_max_price = max(item_prices) if item_prices else 0 - items_min_price = min(item_prices) if item_prices else 0 - - # 2. Формируем ту самую склеенную колонку 'text' - # Важно: используй тот же формат, что был в train.csv - combined_text = f"{data.get('terminal_name', '')} {data.get('terminal_description', '')} {data.get('city', '')} items {items_text}" - - # 3. Собираем финальный словарь - flat_data = { - 'terminal_name': str(data.get('terminal_name', '')), - 'terminal_description': str(data.get('terminal_description', '')), - 'terminal_city': str(data.get('city', '')), # city -> terminal_city - 'items_text': items_text, - 'text': combined_text.lower(), - 'amount': float(data.get('amount', 0)), - 'items_count': float(items_count), - 'items_total_price': float(items_total_price), - 'items_max_price': float(items_max_price), - 'items_min_price': float(items_min_price), - 'terminal_id': 'unknown' # В запросе нет ID, ставим заглушку - } - return flat_data - def predict(self, raw_json): - # Если пришла одна транзакция, оборачиваем в список - if isinstance(raw_json, dict): - raw_json = [raw_json] - - # Препроцессинг всех транзакций в списке - processed_data = [self._preprocess_json(t) for t in raw_json] - df = pd.DataFrame(processed_data) - - # Проверка порядка колонок - df = df[self.feature_order] +# ===== КРИТИЧЕСКИ ВАЖНО ===== +# Эти классы ДОЛЖНЫ существовать до joblib.load() - # Предсказание - mcc_codes = self.model.predict(df) - probs = self.model.predict_proba(df) - - results = [] - for i in range(len(raw_json)): - results.append({ - "transaction_id": raw_json[i].get('transaction_id'), - "mcc": int(mcc_codes[i][0]), - "confidence": round(float(np.max(probs[i])), 4) - }) - return results +class TextExtractor(BaseEstimator, TransformerMixin): + def fit(self, X, y=None): + return self -# --- ТЕСТ --- -predictor = MCCPredictor('mcc_classifier.cbm') + def transform(self, X): + return X['full_text'].fillna('') -request_data = { - "transaction_id": "TX00001116", - "terminal_name": "STORE001", - "terminal_description": "common common common thing", - "city": "NYC", - "amount": 272.80, - "items": [ - {"name": "basic loyalty", "price": 58.20}, - {"name": "Bringiong item lifes", "price": 28.99}, - {"name": "regular item basic item", "price": 56.91} - ] + +class NumberExtractor(BaseEstimator, TransformerMixin): + def fit(self, X, y=None): + return self + + def transform(self, X): + return X[['amount']].fillna(0) + + +# ============================ + +MODEL_PATH = "mcc_model.pkl" + +BASE_TX = { + "transaction_id": "TX00001116", + "terminal_name": "STORE001", + "terminal_description": "common common common thing", + "city": "NYC", + "amount": 272.80, + "items": [ + {"name": "basic loyalty", "price": 58.20}, + {"name": "Bringiong item lifes", "price": 28.99}, + {"name": "regular item basic item", "price": 56.91} + ] } -res = predictor.predict(request_data) -print(res) \ No newline at end of file + +def json_to_df(data_list): + rows = [] + for d in data_list: + items = d.get("items", []) + item_text = " ".join(str(i.get("name","")) for i in items) + full_text = f"{d.get('terminal_name','')} {d.get('terminal_description','')} {item_text}".lower() + + rows.append({ + "full_text": full_text, + "amount": float(d.get("amount", 0)) + }) + + return pd.DataFrame(rows) + + +def make_batch(n=100): + batch = [] + for i in range(n): + tx = BASE_TX.copy() + tx["transaction_id"] = f"TX{i:06d}" + tx["amount"] = round(random.uniform(3, 500), 2) + batch.append(tx) + return batch + + +# ============================ +# BENCH +# ============================ + +print("\n[1] LOADING MODEL (cold start)...") +t0 = time.perf_counter() +model = joblib.load(MODEL_PATH) +t1 = time.perf_counter() +print(f"Cold load time: {t1 - t0:.4f} sec") + +# warmup +print("\n[2] WARMUP...") +warm_df = json_to_df([BASE_TX]) +for _ in range(5): + model.predict(warm_df) + +# single inference benchmark +print("\n[3] SINGLE REQUEST BENCH (1000 runs)...") +times = [] + +for _ in range(1000): + df = json_to_df([BASE_TX]) + t0 = time.perf_counter() + model.predict_proba(df) + times.append(time.perf_counter() - t0) + +times = np.array(times) + +print(f"avg : {times.mean()*1000:.2f} ms") +print(f"p95 : {np.percentile(times,95)*1000:.2f} ms") +print(f"max : {times.max()*1000:.2f} ms") + +# batch benchmark +print("\n[4] BATCH 100 BENCH...") +batch = make_batch(100) +df_batch = json_to_df(batch) + +t0 = time.perf_counter() +model.predict_proba(df_batch) +dt = time.perf_counter() - t0 + +print(f"Batch 100 time: {dt:.3f} sec") + +# verdict +print("\n[5] VERDICT") +if np.percentile(times,95) < 0.2: + print("✅ /predict проходит по latency (<200ms p95)") +else: + print("❌ /predict НЕ проходит по latency") + +if dt < 5: + print("✅ /predict/batch проходит (<5s)") +else: + print("❌ /predict/batch НЕ проходит") diff --git a/solution/model/mcc_model.pkl b/solution/model/mcc_model.pkl index 39376e6..cb8397a 100644 Binary files a/solution/model/mcc_model.pkl and b/solution/model/mcc_model.pkl differ diff --git a/train.py b/train.py index cbd201e..b90b148 100644 --- a/train.py +++ b/train.py @@ -9,71 +9,80 @@ from sklearn.pipeline import Pipeline, FeatureUnion from sklearn.preprocessing import StandardScaler from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import accuracy_score, classification_report +import rich #