better output
This commit is contained in:
parent
74df71d82d
commit
4b655fad30
186
dih.py
186
dih.py
|
|
@ -1,93 +1,125 @@
|
||||||
|
import time
|
||||||
|
import joblib
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from catboost import CatBoostClassifier
|
import random
|
||||||
|
|
||||||
class MCCPredictor:
|
from sklearn.base import BaseEstimator, TransformerMixin
|
||||||
def __init__(self, model_path='mcc_classifier.cbm'):
|
|
||||||
self.model = CatBoostClassifier()
|
|
||||||
self.model.load_model(model_path)
|
|
||||||
|
|
||||||
# Порядок колонок должен быть В ТОЧНОСТИ как при обучении
|
|
||||||
self.feature_order = [
|
|
||||||
'terminal_name', 'terminal_description', 'terminal_city', 'items_text', 'text',
|
|
||||||
'amount', 'items_count', 'items_total_price', 'items_max_price', 'items_min_price',
|
|
||||||
'terminal_id'
|
|
||||||
]
|
|
||||||
|
|
||||||
def _preprocess_json(self, data):
|
# ===== КРИТИЧЕСКИ ВАЖНО =====
|
||||||
"""Конвертирует входящий JSON в плоскую структуру для модели"""
|
# Эти классы ДОЛЖНЫ существовать до joblib.load()
|
||||||
|
|
||||||
# 1. Агрегируем данные из списка items
|
class TextExtractor(BaseEstimator, TransformerMixin):
|
||||||
items = data.get('items', [])
|
def fit(self, X, y=None):
|
||||||
item_names = [str(i.get('name', '')) for i in items]
|
return self
|
||||||
item_prices = [float(i.get('price', 0)) for i in items]
|
|
||||||
|
|
||||||
items_text = " ".join(item_names)
|
def transform(self, X):
|
||||||
items_count = len(items)
|
return X['full_text'].fillna('')
|
||||||
items_total_price = sum(item_prices)
|
|
||||||
items_max_price = max(item_prices) if item_prices else 0
|
|
||||||
items_min_price = min(item_prices) if item_prices else 0
|
|
||||||
|
|
||||||
# 2. Формируем ту самую склеенную колонку 'text'
|
|
||||||
# Важно: используй тот же формат, что был в train.csv
|
|
||||||
combined_text = f"{data.get('terminal_name', '')} {data.get('terminal_description', '')} {data.get('city', '')} items {items_text}"
|
|
||||||
|
|
||||||
# 3. Собираем финальный словарь
|
class NumberExtractor(BaseEstimator, TransformerMixin):
|
||||||
flat_data = {
|
def fit(self, X, y=None):
|
||||||
'terminal_name': str(data.get('terminal_name', '')),
|
return self
|
||||||
'terminal_description': str(data.get('terminal_description', '')),
|
|
||||||
'terminal_city': str(data.get('city', '')), # city -> terminal_city
|
|
||||||
'items_text': items_text,
|
|
||||||
'text': combined_text.lower(),
|
|
||||||
'amount': float(data.get('amount', 0)),
|
|
||||||
'items_count': float(items_count),
|
|
||||||
'items_total_price': float(items_total_price),
|
|
||||||
'items_max_price': float(items_max_price),
|
|
||||||
'items_min_price': float(items_min_price),
|
|
||||||
'terminal_id': 'unknown' # В запросе нет ID, ставим заглушку
|
|
||||||
}
|
|
||||||
return flat_data
|
|
||||||
|
|
||||||
def predict(self, raw_json):
|
def transform(self, X):
|
||||||
# Если пришла одна транзакция, оборачиваем в список
|
return X[['amount']].fillna(0)
|
||||||
if isinstance(raw_json, dict):
|
|
||||||
raw_json = [raw_json]
|
|
||||||
|
|
||||||
# Препроцессинг всех транзакций в списке
|
|
||||||
processed_data = [self._preprocess_json(t) for t in raw_json]
|
|
||||||
df = pd.DataFrame(processed_data)
|
|
||||||
|
|
||||||
# Проверка порядка колонок
|
# ============================
|
||||||
df = df[self.feature_order]
|
|
||||||
|
|
||||||
# Предсказание
|
MODEL_PATH = "mcc_model.pkl"
|
||||||
mcc_codes = self.model.predict(df)
|
|
||||||
probs = self.model.predict_proba(df)
|
|
||||||
|
|
||||||
results = []
|
BASE_TX = {
|
||||||
for i in range(len(raw_json)):
|
"transaction_id": "TX00001116",
|
||||||
results.append({
|
"terminal_name": "STORE001",
|
||||||
"transaction_id": raw_json[i].get('transaction_id'),
|
"terminal_description": "common common common thing",
|
||||||
"mcc": int(mcc_codes[i][0]),
|
"city": "NYC",
|
||||||
"confidence": round(float(np.max(probs[i])), 4)
|
"amount": 272.80,
|
||||||
})
|
"items": [
|
||||||
return results
|
{"name": "basic loyalty", "price": 58.20},
|
||||||
|
{"name": "Bringiong item lifes", "price": 28.99},
|
||||||
# --- ТЕСТ ---
|
{"name": "regular item basic item", "price": 56.91}
|
||||||
predictor = MCCPredictor('mcc_classifier.cbm')
|
]
|
||||||
|
|
||||||
request_data = {
|
|
||||||
"transaction_id": "TX00001116",
|
|
||||||
"terminal_name": "STORE001",
|
|
||||||
"terminal_description": "common common common thing",
|
|
||||||
"city": "NYC",
|
|
||||||
"amount": 272.80,
|
|
||||||
"items": [
|
|
||||||
{"name": "basic loyalty", "price": 58.20},
|
|
||||||
{"name": "Bringiong item lifes", "price": 28.99},
|
|
||||||
{"name": "regular item basic item", "price": 56.91}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res = predictor.predict(request_data)
|
|
||||||
print(res)
|
def json_to_df(data_list):
|
||||||
|
rows = []
|
||||||
|
for d in data_list:
|
||||||
|
items = d.get("items", [])
|
||||||
|
item_text = " ".join(str(i.get("name","")) for i in items)
|
||||||
|
full_text = f"{d.get('terminal_name','')} {d.get('terminal_description','')} {item_text}".lower()
|
||||||
|
|
||||||
|
rows.append({
|
||||||
|
"full_text": full_text,
|
||||||
|
"amount": float(d.get("amount", 0))
|
||||||
|
})
|
||||||
|
|
||||||
|
return pd.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch(n=100):
|
||||||
|
batch = []
|
||||||
|
for i in range(n):
|
||||||
|
tx = BASE_TX.copy()
|
||||||
|
tx["transaction_id"] = f"TX{i:06d}"
|
||||||
|
tx["amount"] = round(random.uniform(3, 500), 2)
|
||||||
|
batch.append(tx)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
# ============================
|
||||||
|
# BENCH
|
||||||
|
# ============================
|
||||||
|
|
||||||
|
print("\n[1] LOADING MODEL (cold start)...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
model = joblib.load(MODEL_PATH)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
print(f"Cold load time: {t1 - t0:.4f} sec")
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
print("\n[2] WARMUP...")
|
||||||
|
warm_df = json_to_df([BASE_TX])
|
||||||
|
for _ in range(5):
|
||||||
|
model.predict(warm_df)
|
||||||
|
|
||||||
|
# single inference benchmark
|
||||||
|
print("\n[3] SINGLE REQUEST BENCH (1000 runs)...")
|
||||||
|
times = []
|
||||||
|
|
||||||
|
for _ in range(1000):
|
||||||
|
df = json_to_df([BASE_TX])
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
model.predict_proba(df)
|
||||||
|
times.append(time.perf_counter() - t0)
|
||||||
|
|
||||||
|
times = np.array(times)
|
||||||
|
|
||||||
|
print(f"avg : {times.mean()*1000:.2f} ms")
|
||||||
|
print(f"p95 : {np.percentile(times,95)*1000:.2f} ms")
|
||||||
|
print(f"max : {times.max()*1000:.2f} ms")
|
||||||
|
|
||||||
|
# batch benchmark
|
||||||
|
print("\n[4] BATCH 100 BENCH...")
|
||||||
|
batch = make_batch(100)
|
||||||
|
df_batch = json_to_df(batch)
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
model.predict_proba(df_batch)
|
||||||
|
dt = time.perf_counter() - t0
|
||||||
|
|
||||||
|
print(f"Batch 100 time: {dt:.3f} sec")
|
||||||
|
|
||||||
|
# verdict
|
||||||
|
print("\n[5] VERDICT")
|
||||||
|
if np.percentile(times,95) < 0.2:
|
||||||
|
print("✅ /predict проходит по latency (<200ms p95)")
|
||||||
|
else:
|
||||||
|
print("❌ /predict НЕ проходит по latency")
|
||||||
|
|
||||||
|
if dt < 5:
|
||||||
|
print("✅ /predict/batch проходит (<5s)")
|
||||||
|
else:
|
||||||
|
print("❌ /predict/batch НЕ проходит")
|
||||||
|
|
|
||||||
Binary file not shown.
81
train.py
81
train.py
|
|
@ -9,71 +9,80 @@ from sklearn.pipeline import Pipeline, FeatureUnion
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from sklearn.base import BaseEstimator, TransformerMixin
|
from sklearn.base import BaseEstimator, TransformerMixin
|
||||||
from sklearn.metrics import accuracy_score, classification_report
|
from sklearn.metrics import accuracy_score, classification_report
|
||||||
|
import rich # <!-- дебаг через rich - моя guilty pleasure, очень уж люблю на красивые выводы смотреть
|
||||||
|
from rich import print as rpint
|
||||||
|
from rich.console import Console
|
||||||
|
from rich import box
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
# --- 1. Кастомные трансформеры для пайплайна ---
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
# инициализируем специальные классы, чтобы раскидать данные по категориям
|
||||||
|
|
||||||
|
# Для преобразования TF-IDF в вектора
|
||||||
class TextExtractor(BaseEstimator, TransformerMixin):
|
class TextExtractor(BaseEstimator, TransformerMixin):
|
||||||
"""Извлекает текстовую колонку для TF-IDF"""
|
|
||||||
def fit(self, X, y=None): return self
|
def fit(self, X, y=None): return self
|
||||||
def transform(self, X):
|
def transform(self, X):
|
||||||
return X['full_text'].fillna('')
|
return X['full_text'].fillna('')
|
||||||
|
|
||||||
|
# Для StandardScaler
|
||||||
class NumberExtractor(BaseEstimator, TransformerMixin):
|
class NumberExtractor(BaseEstimator, TransformerMixin):
|
||||||
"""Извлекает числовую колонку 'amount'"""
|
|
||||||
def fit(self, X, y=None): return self
|
def fit(self, X, y=None): return self
|
||||||
def transform(self, X):
|
def transform(self, X):
|
||||||
# Возвращаем как DataFrame (2D массив) для StandardScaler
|
|
||||||
return X[['amount']].fillna(0)
|
return X[['amount']].fillna(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
print("Загрузка данных...")
|
console.log("[yellow]Грузим данные из data...[/yellow]")
|
||||||
# Пути к файлам (предполагаем, что скрипт запускается из корня, где есть папка data/)
|
|
||||||
try:
|
try:
|
||||||
tx = pd.read_csv('data/transactions.csv')
|
tx = pd.read_csv('data/transactions.csv')
|
||||||
terminals = pd.read_csv('data/terminals.csv')
|
terminals = pd.read_csv('data/terminals.csv')
|
||||||
receipts = pd.read_csv('data/receipts.csv')
|
receipts = pd.read_csv('data/receipts.csv')
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
print(f"Ошибка: Не найдены файлы данных в папке data/. {e}")
|
console.log(f"Файлы для обучения не найдены :( \n {e}", style="white on red")
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- 2. Предобработка и сборка признаков ---
|
|
||||||
|
|
||||||
print("Предобработка...")
|
console.log("[yellow]Предобрабатываем данные...[/yellow]")
|
||||||
# Агрегируем названия товаров в одну строку для каждой транзакции
|
# Приклеиваеем вместе имена товаров
|
||||||
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
|
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
|
||||||
lambda x: ' '.join(str(i) for i in x)
|
lambda x: ' '.join(str(i) for i in x)
|
||||||
).reset_index()
|
).reset_index()
|
||||||
|
|
||||||
# Объединяем транзакции с данными терминалов и чеками
|
# Делаем один большой датафрейм с которым будем работать
|
||||||
df = tx.merge(terminals[['terminal_id', 'terminal_name', 'terminal_description']], on='terminal_id', how='left')
|
df = tx.merge(terminals[['terminal_id', 'terminal_name', 'terminal_description']], on='terminal_id', how='left')
|
||||||
df = df.merge(receipts_agg, on='transaction_id', how='left')
|
df = df.merge(receipts_agg, on='transaction_id', how='left')
|
||||||
|
|
||||||
# Создаем единое текстовое поле (игнорируем transaction_id, чтобы не было утечки!)
|
# Делаем текстовое поле для TF-IDF
|
||||||
df['full_text'] = (
|
df['full_text'] = (
|
||||||
df['terminal_name'].astype(str) + " " +
|
df['terminal_name'].astype(str) + " " +
|
||||||
df['terminal_description'].astype(str) + " " +
|
df['terminal_description'].astype(str) + " " + # <!-- изначально я пробовал клеить id транзакции, однако модель слишком на ней зацикливалась
|
||||||
df['item_name'].astype(str)
|
df['item_name'].astype(str)
|
||||||
).str.lower()
|
).str.lower()
|
||||||
|
|
||||||
X = df[['full_text', 'amount']]
|
X = df[['full_text', 'amount']]
|
||||||
y = df['true_mcc']
|
y = df['true_mcc']
|
||||||
|
|
||||||
# --- 3. Создание пайплайна ---
|
# Пайплайн обучения
|
||||||
|
|
||||||
pipeline = Pipeline([
|
pipeline = Pipeline([
|
||||||
('features', FeatureUnion([
|
('features', FeatureUnion([
|
||||||
# Ветка ТЕКСТА
|
# Ветка для слов
|
||||||
('text_branch', Pipeline([
|
('text_branch', Pipeline([
|
||||||
('extract', TextExtractor()),
|
('extract', TextExtractor()),
|
||||||
('tfidf_union', FeatureUnion([
|
('tfidf_union', FeatureUnion([
|
||||||
# Слова (смысл)
|
# Векторизуем слова и удаляем лишние слова без смысла
|
||||||
('word', TfidfVectorizer(
|
('word', TfidfVectorizer(
|
||||||
ngram_range=(1, 2),
|
ngram_range=(1, 2),
|
||||||
analyzer='word',
|
analyzer='word',
|
||||||
stop_words='english',
|
stop_words='english',
|
||||||
max_features=5000
|
max_features=5000
|
||||||
)),
|
)),
|
||||||
# Символы (опечатки)
|
# Фиксим очепятки
|
||||||
('char', TfidfVectorizer(
|
('char', TfidfVectorizer(
|
||||||
ngram_range=(2, 5),
|
ngram_range=(2, 5),
|
||||||
analyzer='char_wb',
|
analyzer='char_wb',
|
||||||
|
|
@ -81,23 +90,25 @@ def train_model():
|
||||||
))
|
))
|
||||||
]))
|
]))
|
||||||
])),
|
])),
|
||||||
# Ветка ЧИСЕЛ
|
# Ветка для чисел
|
||||||
('numeric_branch', Pipeline([
|
('numeric_branch', Pipeline([
|
||||||
('extract', NumberExtractor()),
|
('extract', NumberExtractor()),
|
||||||
('scaler', StandardScaler())
|
('scaler', StandardScaler())
|
||||||
]))
|
]))
|
||||||
])),
|
])),
|
||||||
# Классификатор
|
# Для классификации юзаем логрег
|
||||||
('clf', LogisticRegression(C=1.0, max_iter=1000))
|
('clf', LogisticRegression(C=1.0, max_iter=1000)) # <!-- были разные коэффиценты, в итоге оставил что-то средненькое
|
||||||
])
|
])
|
||||||
|
|
||||||
# --- 4. Оценка качества (Валидация) ---
|
# Валидация
|
||||||
|
|
||||||
print("Оценка качества на валидационной выборке...")
|
console.log("[yellow]Оцениваем качество на валидационной выборке...[/yellow]")
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X, y, test_size=0.2, random_state=42, stratify=y
|
X, y, test_size=0.2, random_state=42, stratify=y
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pipeline.fit(X_train, y_train)
|
pipeline.fit(X_train, y_train)
|
||||||
y_pred = pipeline.predict(X_test)
|
y_pred = pipeline.predict(X_test)
|
||||||
probs = pipeline.predict_proba(X_test)
|
probs = pipeline.predict_proba(X_test)
|
||||||
|
|
@ -105,20 +116,28 @@ def train_model():
|
||||||
acc = accuracy_score(y_test, y_pred)
|
acc = accuracy_score(y_test, y_pred)
|
||||||
conf = np.mean(np.max(probs, axis=1))
|
conf = np.mean(np.max(probs, axis=1))
|
||||||
|
|
||||||
print(f"\n[РЕЗУЛЬТАТЫ]")
|
table = Table(box=box.ROUNDED, title="Отчёт")
|
||||||
print(f"Accuracy: {acc:.4f}")
|
|
||||||
print(f"Average Confidence: {conf:.4f}")
|
|
||||||
print("\nОтчет по категориям:")
|
|
||||||
print(classification_report(y_test, y_pred))
|
|
||||||
|
|
||||||
# --- 5. Финальное обучение и сохранение ---
|
|
||||||
|
|
||||||
print("\nФинальное обучение на всех данных...")
|
table.add_column("Метрика", justify="center", style="yellow")
|
||||||
|
table.add_column("Значение", justify="center", style="yellow")
|
||||||
|
|
||||||
|
table.add_row("Accuracy", f"{acc:.4f}")
|
||||||
|
table.add_row("Avg Confidence", f"{conf:.4f}")
|
||||||
|
console.print(table, justify="center")
|
||||||
|
|
||||||
|
console.print("[yellow]Репорт по классам[/yellow]", justify="center")
|
||||||
|
console.print(classification_report(y_test, y_pred), justify="center")
|
||||||
|
|
||||||
|
# Метрики норм, учимся на всех данных и сохранем данные
|
||||||
|
|
||||||
|
console.log("[yellow]Учимся на всем, что есть...[/yellow]")
|
||||||
pipeline.fit(X, y)
|
pipeline.fit(X, y)
|
||||||
|
|
||||||
os.makedirs('solution/model', exist_ok=True)
|
os.makedirs('solution/model', exist_ok=True)
|
||||||
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
|
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
|
||||||
print("Модель успешно сохранена в solution/model/mcc_model.pkl")
|
console.log(Markdown("Сохранили модель в **solution/model/mcc_model.pkl**"))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
with console.status("Учим модель..."):
|
||||||
train_model()
|
train_model()
|
||||||
|
console.print(Markdown("*Модель готова :)*"))
|
||||||
Loading…
Reference in New Issue