93 lines
3.7 KiB
Python
93 lines
3.7 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
from catboost import CatBoostClassifier
|
||
|
||
class MCCPredictor:
|
||
def __init__(self, model_path='mcc_classifier.cbm'):
|
||
self.model = CatBoostClassifier()
|
||
self.model.load_model(model_path)
|
||
|
||
# Порядок колонок должен быть В ТОЧНОСТИ как при обучении
|
||
self.feature_order = [
|
||
'terminal_name', 'terminal_description', 'terminal_city', 'items_text', 'text',
|
||
'amount', 'items_count', 'items_total_price', 'items_max_price', 'items_min_price',
|
||
'terminal_id'
|
||
]
|
||
|
||
def _preprocess_json(self, data):
|
||
"""Конвертирует входящий JSON в плоскую структуру для модели"""
|
||
|
||
# 1. Агрегируем данные из списка items
|
||
items = data.get('items', [])
|
||
item_names = [str(i.get('name', '')) for i in items]
|
||
item_prices = [float(i.get('price', 0)) for i in items]
|
||
|
||
items_text = " ".join(item_names)
|
||
items_count = len(items)
|
||
items_total_price = sum(item_prices)
|
||
items_max_price = max(item_prices) if item_prices else 0
|
||
items_min_price = min(item_prices) if item_prices else 0
|
||
|
||
# 2. Формируем ту самую склеенную колонку 'text'
|
||
# Важно: используй тот же формат, что был в train.csv
|
||
combined_text = f"{data.get('terminal_name', '')} {data.get('terminal_description', '')} {data.get('city', '')} items {items_text}"
|
||
|
||
# 3. Собираем финальный словарь
|
||
flat_data = {
|
||
'terminal_name': str(data.get('terminal_name', '')),
|
||
'terminal_description': str(data.get('terminal_description', '')),
|
||
'terminal_city': str(data.get('city', '')), # city -> terminal_city
|
||
'items_text': items_text,
|
||
'text': combined_text.lower(),
|
||
'amount': float(data.get('amount', 0)),
|
||
'items_count': float(items_count),
|
||
'items_total_price': float(items_total_price),
|
||
'items_max_price': float(items_max_price),
|
||
'items_min_price': float(items_min_price),
|
||
'terminal_id': 'unknown' # В запросе нет ID, ставим заглушку
|
||
}
|
||
return flat_data
|
||
|
||
def predict(self, raw_json):
|
||
# Если пришла одна транзакция, оборачиваем в список
|
||
if isinstance(raw_json, dict):
|
||
raw_json = [raw_json]
|
||
|
||
# Препроцессинг всех транзакций в списке
|
||
processed_data = [self._preprocess_json(t) for t in raw_json]
|
||
df = pd.DataFrame(processed_data)
|
||
|
||
# Проверка порядка колонок
|
||
df = df[self.feature_order]
|
||
|
||
# Предсказание
|
||
mcc_codes = self.model.predict(df)
|
||
probs = self.model.predict_proba(df)
|
||
|
||
results = []
|
||
for i in range(len(raw_json)):
|
||
results.append({
|
||
"transaction_id": raw_json[i].get('transaction_id'),
|
||
"mcc": int(mcc_codes[i][0]),
|
||
"confidence": round(float(np.max(probs[i])), 4)
|
||
})
|
||
return results
|
||
|
||
# --- ТЕСТ ---
|
||
predictor = MCCPredictor('mcc_classifier.cbm')
|
||
|
||
request_data = {
|
||
"transaction_id": "TX00001116",
|
||
"terminal_name": "STORE001",
|
||
"terminal_description": "common common common thing",
|
||
"city": "NYC",
|
||
"amount": 272.80,
|
||
"items": [
|
||
{"name": "basic loyalty", "price": 58.20},
|
||
{"name": "Bringiong item lifes", "price": 28.99},
|
||
{"name": "regular item basic item", "price": 56.91}
|
||
]
|
||
}
|
||
|
||
res = predictor.predict(request_data)
|
||
print(res) |