mlops/train.py

143 lines
5.6 KiB
Python
Raw Normal View History

2026-01-21 21:08:24 +02:00
import pandas as pd
import numpy as np
import joblib
import os
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import accuracy_score, classification_report
2026-01-22 11:45:39 +02:00
import rich # <!-- дебаг через rich - моя guilty pleasure, очень уж люблю на красивые выводы смотреть
from rich import print as rpint
from rich.console import Console
from rich import box
from rich.table import Table
from rich.markdown import Markdown
2026-01-21 21:08:24 +02:00
2026-01-22 11:45:39 +02:00
console = Console()
2026-01-21 21:08:24 +02:00
2026-01-22 11:45:39 +02:00
# инициализируем специальные классы, чтобы раскидать данные по категориям
# Для преобразования TF-IDF в вектора
2026-01-21 21:08:24 +02:00
class TextExtractor(BaseEstimator, TransformerMixin):
def fit(self, X, y=None): return self
def transform(self, X):
return X['full_text'].fillna('')
2026-01-22 11:45:39 +02:00
# Для StandardScaler
2026-01-21 21:08:24 +02:00
class NumberExtractor(BaseEstimator, TransformerMixin):
def fit(self, X, y=None): return self
def transform(self, X):
return X[['amount']].fillna(0)
2026-01-22 11:45:39 +02:00
2026-01-21 21:08:24 +02:00
def train_model():
2026-01-22 11:45:39 +02:00
console.log("[yellow]Грузим данные из data...[/yellow]")
2026-01-21 21:08:24 +02:00
try:
tx = pd.read_csv('data/transactions.csv')
terminals = pd.read_csv('data/terminals.csv')
receipts = pd.read_csv('data/receipts.csv')
except FileNotFoundError as e:
2026-01-22 11:45:39 +02:00
console.log(f"Файлы для обучения не найдены :( \n {e}", style="white on red")
2026-01-21 21:08:24 +02:00
return
2026-01-22 11:45:39 +02:00
console.log("[yellow]Предобрабатываем данные...[/yellow]")
# Приклеиваеем вместе имена товаров
2026-01-21 21:08:24 +02:00
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
lambda x: ' '.join(str(i) for i in x)
).reset_index()
2026-01-22 11:45:39 +02:00
# Делаем один большой датафрейм с которым будем работать
2026-01-21 21:08:24 +02:00
df = tx.merge(terminals[['terminal_id', 'terminal_name', 'terminal_description']], on='terminal_id', how='left')
df = df.merge(receipts_agg, on='transaction_id', how='left')
2026-01-22 11:45:39 +02:00
# Делаем текстовое поле для TF-IDF
2026-01-21 21:08:24 +02:00
df['full_text'] = (
df['terminal_name'].astype(str) + " " +
2026-01-22 11:45:39 +02:00
df['terminal_description'].astype(str) + " " + # <!-- изначально я пробовал клеить id транзакции, однако модель слишком на ней зацикливалась
2026-01-21 21:08:24 +02:00
df['item_name'].astype(str)
).str.lower()
X = df[['full_text', 'amount']]
y = df['true_mcc']
2026-01-22 11:45:39 +02:00
# Пайплайн обучения
2026-01-21 21:08:24 +02:00
pipeline = Pipeline([
('features', FeatureUnion([
2026-01-22 11:45:39 +02:00
# Ветка для слов
2026-01-21 21:08:24 +02:00
('text_branch', Pipeline([
('extract', TextExtractor()),
('tfidf_union', FeatureUnion([
2026-01-22 11:45:39 +02:00
# Векторизуем слова и удаляем лишние слова без смысла
2026-01-21 21:08:24 +02:00
('word', TfidfVectorizer(
ngram_range=(1, 2),
analyzer='word',
stop_words='english',
max_features=5000
)),
2026-01-22 11:45:39 +02:00
# Фиксим очепятки
2026-01-21 21:08:24 +02:00
('char', TfidfVectorizer(
ngram_range=(2, 5),
analyzer='char_wb',
max_features=10000
))
]))
])),
2026-01-22 11:45:39 +02:00
# Ветка для чисел
2026-01-21 21:08:24 +02:00
('numeric_branch', Pipeline([
('extract', NumberExtractor()),
('scaler', StandardScaler())
]))
])),
2026-01-22 11:45:39 +02:00
# Для классификации юзаем логрег
('clf', LogisticRegression(C=1.0, max_iter=1000)) # <!-- были разные коэффиценты, в итоге оставил что-то средненькое
2026-01-21 21:08:24 +02:00
])
2026-01-22 11:45:39 +02:00
# Валидация
2026-01-21 21:08:24 +02:00
2026-01-22 11:45:39 +02:00
console.log("[yellow]Оцениваем качество на валидационной выборке...[/yellow]")
2026-01-21 21:08:24 +02:00
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
2026-01-22 11:45:39 +02:00
2026-01-21 21:08:24 +02:00
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
probs = pipeline.predict_proba(X_test)
acc = accuracy_score(y_test, y_pred)
conf = np.mean(np.max(probs, axis=1))
2026-01-22 11:45:39 +02:00
table = Table(box=box.ROUNDED, title="Отчёт")
table.add_column("Метрика", justify="center", style="yellow")
table.add_column("Значение", justify="center", style="yellow")
table.add_row("Accuracy", f"{acc:.4f}")
table.add_row("Avg Confidence", f"{conf:.4f}")
console.print(table, justify="center")
console.print("[yellow]Репорт по классам[/yellow]", justify="center")
console.print(classification_report(y_test, y_pred), justify="center")
2026-01-21 21:08:24 +02:00
2026-01-22 11:45:39 +02:00
# Метрики норм, учимся на всех данных и сохранем данные
2026-01-21 21:08:24 +02:00
2026-01-22 11:45:39 +02:00
console.log("[yellow]Учимся на всем, что есть...[/yellow]")
2026-01-21 21:08:24 +02:00
pipeline.fit(X, y)
os.makedirs('solution/model', exist_ok=True)
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
2026-01-22 11:45:39 +02:00
console.log(Markdown("Сохранили модель в **solution/model/mcc_model.pkl**"))
2026-01-21 21:08:24 +02:00
2026-01-22 11:45:39 +02:00
with console.status("Учим модель..."):
train_model()
console.print(Markdown("*Модель готова :)*"))