mlops/train.py

124 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
import joblib
import os
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import accuracy_score, classification_report
# --- 1. Кастомные трансформеры для пайплайна ---
class TextExtractor(BaseEstimator, TransformerMixin):
"""Извлекает текстовую колонку для TF-IDF"""
def fit(self, X, y=None): return self
def transform(self, X):
return X['full_text'].fillna('')
class NumberExtractor(BaseEstimator, TransformerMixin):
"""Извлекает числовую колонку 'amount'"""
def fit(self, X, y=None): return self
def transform(self, X):
# Возвращаем как DataFrame (2D массив) для StandardScaler
return X[['amount']].fillna(0)
def train_model():
print("Загрузка данных...")
# Пути к файлам (предполагаем, что скрипт запускается из корня, где есть папка data/)
try:
tx = pd.read_csv('data/transactions.csv')
terminals = pd.read_csv('data/terminals.csv')
receipts = pd.read_csv('data/receipts.csv')
except FileNotFoundError as e:
print(f"Ошибка: Не найдены файлы данных в папке data/. {e}")
return
# --- 2. Предобработка и сборка признаков ---
print("Предобработка...")
# Агрегируем названия товаров в одну строку для каждой транзакции
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
lambda x: ' '.join(str(i) for i in x)
).reset_index()
# Объединяем транзакции с данными терминалов и чеками
df = tx.merge(terminals[['terminal_id', 'terminal_name', 'terminal_description']], on='terminal_id', how='left')
df = df.merge(receipts_agg, on='transaction_id', how='left')
# Создаем единое текстовое поле (игнорируем transaction_id, чтобы не было утечки!)
df['full_text'] = (
df['terminal_name'].astype(str) + " " +
df['terminal_description'].astype(str) + " " +
df['item_name'].astype(str)
).str.lower()
X = df[['full_text', 'amount']]
y = df['true_mcc']
# --- 3. Создание пайплайна ---
pipeline = Pipeline([
('features', FeatureUnion([
# Ветка ТЕКСТА
('text_branch', Pipeline([
('extract', TextExtractor()),
('tfidf_union', FeatureUnion([
# Слова (смысл)
('word', TfidfVectorizer(
ngram_range=(1, 2),
analyzer='word',
stop_words='english',
max_features=5000
)),
# Символы (опечатки)
('char', TfidfVectorizer(
ngram_range=(2, 5),
analyzer='char_wb',
max_features=10000
))
]))
])),
# Ветка ЧИСЕЛ
('numeric_branch', Pipeline([
('extract', NumberExtractor()),
('scaler', StandardScaler())
]))
])),
# Классификатор
('clf', LogisticRegression(C=1.0, max_iter=1000))
])
# --- 4. Оценка качества (Валидация) ---
print("Оценка качества на валидационной выборке...")
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
probs = pipeline.predict_proba(X_test)
acc = accuracy_score(y_test, y_pred)
conf = np.mean(np.max(probs, axis=1))
print(f"\n[РЕЗУЛЬТАТЫ]")
print(f"Accuracy: {acc:.4f}")
print(f"Average Confidence: {conf:.4f}")
print("\nОтчет по категориям:")
print(classification_report(y_test, y_pred))
# --- 5. Финальное обучение и сохранение ---
print("\nФинальное обучение на всех данных...")
pipeline.fit(X, y)
os.makedirs('solution/model', exist_ok=True)
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
print("Модель успешно сохранена в solution/model/mcc_model.pkl")
if __name__ == "__main__":
train_model()