mlops/dih.py

93 lines
3.7 KiB
Python
Raw Normal View History

2026-01-21 21:08:24 +02:00
import pandas as pd
import numpy as np
from catboost import CatBoostClassifier
class MCCPredictor:
def __init__(self, model_path='mcc_classifier.cbm'):
self.model = CatBoostClassifier()
self.model.load_model(model_path)
# Порядок колонок должен быть В ТОЧНОСТИ как при обучении
self.feature_order = [
'terminal_name', 'terminal_description', 'terminal_city', 'items_text', 'text',
'amount', 'items_count', 'items_total_price', 'items_max_price', 'items_min_price',
'terminal_id'
]
def _preprocess_json(self, data):
"""Конвертирует входящий JSON в плоскую структуру для модели"""
# 1. Агрегируем данные из списка items
items = data.get('items', [])
item_names = [str(i.get('name', '')) for i in items]
item_prices = [float(i.get('price', 0)) for i in items]
items_text = " ".join(item_names)
items_count = len(items)
items_total_price = sum(item_prices)
items_max_price = max(item_prices) if item_prices else 0
items_min_price = min(item_prices) if item_prices else 0
# 2. Формируем ту самую склеенную колонку 'text'
# Важно: используй тот же формат, что был в train.csv
combined_text = f"{data.get('terminal_name', '')} {data.get('terminal_description', '')} {data.get('city', '')} items {items_text}"
# 3. Собираем финальный словарь
flat_data = {
'terminal_name': str(data.get('terminal_name', '')),
'terminal_description': str(data.get('terminal_description', '')),
'terminal_city': str(data.get('city', '')), # city -> terminal_city
'items_text': items_text,
'text': combined_text.lower(),
'amount': float(data.get('amount', 0)),
'items_count': float(items_count),
'items_total_price': float(items_total_price),
'items_max_price': float(items_max_price),
'items_min_price': float(items_min_price),
'terminal_id': 'unknown' # В запросе нет ID, ставим заглушку
}
return flat_data
def predict(self, raw_json):
# Если пришла одна транзакция, оборачиваем в список
if isinstance(raw_json, dict):
raw_json = [raw_json]
# Препроцессинг всех транзакций в списке
processed_data = [self._preprocess_json(t) for t in raw_json]
df = pd.DataFrame(processed_data)
# Проверка порядка колонок
df = df[self.feature_order]
# Предсказание
mcc_codes = self.model.predict(df)
probs = self.model.predict_proba(df)
results = []
for i in range(len(raw_json)):
results.append({
"transaction_id": raw_json[i].get('transaction_id'),
"mcc": int(mcc_codes[i][0]),
"confidence": round(float(np.max(probs[i])), 4)
})
return results
# --- ТЕСТ ---
predictor = MCCPredictor('mcc_classifier.cbm')
request_data = {
"transaction_id": "TX00001116",
"terminal_name": "STORE001",
"terminal_description": "common common common thing",
"city": "NYC",
"amount": 272.80,
"items": [
{"name": "basic loyalty", "price": 58.20},
{"name": "Bringiong item lifes", "price": 28.99},
{"name": "regular item basic item", "price": 56.91}
]
}
res = predictor.predict(request_data)
print(res)