124 lines
4.9 KiB
Python
124 lines
4.9 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
import joblib
|
||
import os
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.linear_model import LogisticRegression
|
||
from sklearn.pipeline import Pipeline, FeatureUnion
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.base import BaseEstimator, TransformerMixin
|
||
from sklearn.metrics import accuracy_score, classification_report
|
||
|
||
# --- 1. Кастомные трансформеры для пайплайна ---
|
||
|
||
class TextExtractor(BaseEstimator, TransformerMixin):
|
||
"""Извлекает текстовую колонку для TF-IDF"""
|
||
def fit(self, X, y=None): return self
|
||
def transform(self, X):
|
||
return X['full_text'].fillna('')
|
||
|
||
class NumberExtractor(BaseEstimator, TransformerMixin):
|
||
"""Извлекает числовую колонку 'amount'"""
|
||
def fit(self, X, y=None): return self
|
||
def transform(self, X):
|
||
# Возвращаем как DataFrame (2D массив) для StandardScaler
|
||
return X[['amount']].fillna(0)
|
||
|
||
def train_model():
|
||
print("Загрузка данных...")
|
||
# Пути к файлам (предполагаем, что скрипт запускается из корня, где есть папка data/)
|
||
try:
|
||
tx = pd.read_csv('data/transactions.csv')
|
||
terminals = pd.read_csv('data/terminals.csv')
|
||
receipts = pd.read_csv('data/receipts.csv')
|
||
except FileNotFoundError as e:
|
||
print(f"Ошибка: Не найдены файлы данных в папке data/. {e}")
|
||
return
|
||
|
||
# --- 2. Предобработка и сборка признаков ---
|
||
|
||
print("Предобработка...")
|
||
# Агрегируем названия товаров в одну строку для каждой транзакции
|
||
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
|
||
lambda x: ' '.join(str(i) for i in x)
|
||
).reset_index()
|
||
|
||
# Объединяем транзакции с данными терминалов и чеками
|
||
df = tx.merge(terminals[['terminal_id', 'terminal_name', 'terminal_description']], on='terminal_id', how='left')
|
||
df = df.merge(receipts_agg, on='transaction_id', how='left')
|
||
|
||
# Создаем единое текстовое поле (игнорируем transaction_id, чтобы не было утечки!)
|
||
df['full_text'] = (
|
||
df['terminal_name'].astype(str) + " " +
|
||
df['terminal_description'].astype(str) + " " +
|
||
df['item_name'].astype(str)
|
||
).str.lower()
|
||
|
||
X = df[['full_text', 'amount']]
|
||
y = df['true_mcc']
|
||
|
||
# --- 3. Создание пайплайна ---
|
||
|
||
pipeline = Pipeline([
|
||
('features', FeatureUnion([
|
||
# Ветка ТЕКСТА
|
||
('text_branch', Pipeline([
|
||
('extract', TextExtractor()),
|
||
('tfidf_union', FeatureUnion([
|
||
# Слова (смысл)
|
||
('word', TfidfVectorizer(
|
||
ngram_range=(1, 2),
|
||
analyzer='word',
|
||
stop_words='english',
|
||
max_features=5000
|
||
)),
|
||
# Символы (опечатки)
|
||
('char', TfidfVectorizer(
|
||
ngram_range=(2, 5),
|
||
analyzer='char_wb',
|
||
max_features=10000
|
||
))
|
||
]))
|
||
])),
|
||
# Ветка ЧИСЕЛ
|
||
('numeric_branch', Pipeline([
|
||
('extract', NumberExtractor()),
|
||
('scaler', StandardScaler())
|
||
]))
|
||
])),
|
||
# Классификатор
|
||
('clf', LogisticRegression(C=1.0, max_iter=1000))
|
||
])
|
||
|
||
# --- 4. Оценка качества (Валидация) ---
|
||
|
||
print("Оценка качества на валидационной выборке...")
|
||
X_train, X_test, y_train, y_test = train_test_split(
|
||
X, y, test_size=0.2, random_state=42, stratify=y
|
||
)
|
||
|
||
pipeline.fit(X_train, y_train)
|
||
y_pred = pipeline.predict(X_test)
|
||
probs = pipeline.predict_proba(X_test)
|
||
|
||
acc = accuracy_score(y_test, y_pred)
|
||
conf = np.mean(np.max(probs, axis=1))
|
||
|
||
print(f"\n[РЕЗУЛЬТАТЫ]")
|
||
print(f"Accuracy: {acc:.4f}")
|
||
print(f"Average Confidence: {conf:.4f}")
|
||
print("\nОтчет по категориям:")
|
||
print(classification_report(y_test, y_pred))
|
||
|
||
# --- 5. Финальное обучение и сохранение ---
|
||
|
||
print("\nФинальное обучение на всех данных...")
|
||
pipeline.fit(X, y)
|
||
|
||
os.makedirs('solution/model', exist_ok=True)
|
||
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
|
||
print("Модель успешно сохранена в solution/model/mcc_model.pkl")
|
||
|
||
if __name__ == "__main__":
|
||
train_model() |