mlops/dih.py

93 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
from catboost import CatBoostClassifier
class MCCPredictor:
def __init__(self, model_path='mcc_classifier.cbm'):
self.model = CatBoostClassifier()
self.model.load_model(model_path)
# Порядок колонок должен быть В ТОЧНОСТИ как при обучении
self.feature_order = [
'terminal_name', 'terminal_description', 'terminal_city', 'items_text', 'text',
'amount', 'items_count', 'items_total_price', 'items_max_price', 'items_min_price',
'terminal_id'
]
def _preprocess_json(self, data):
"""Конвертирует входящий JSON в плоскую структуру для модели"""
# 1. Агрегируем данные из списка items
items = data.get('items', [])
item_names = [str(i.get('name', '')) for i in items]
item_prices = [float(i.get('price', 0)) for i in items]
items_text = " ".join(item_names)
items_count = len(items)
items_total_price = sum(item_prices)
items_max_price = max(item_prices) if item_prices else 0
items_min_price = min(item_prices) if item_prices else 0
# 2. Формируем ту самую склеенную колонку 'text'
# Важно: используй тот же формат, что был в train.csv
combined_text = f"{data.get('terminal_name', '')} {data.get('terminal_description', '')} {data.get('city', '')} items {items_text}"
# 3. Собираем финальный словарь
flat_data = {
'terminal_name': str(data.get('terminal_name', '')),
'terminal_description': str(data.get('terminal_description', '')),
'terminal_city': str(data.get('city', '')), # city -> terminal_city
'items_text': items_text,
'text': combined_text.lower(),
'amount': float(data.get('amount', 0)),
'items_count': float(items_count),
'items_total_price': float(items_total_price),
'items_max_price': float(items_max_price),
'items_min_price': float(items_min_price),
'terminal_id': 'unknown' # В запросе нет ID, ставим заглушку
}
return flat_data
def predict(self, raw_json):
# Если пришла одна транзакция, оборачиваем в список
if isinstance(raw_json, dict):
raw_json = [raw_json]
# Препроцессинг всех транзакций в списке
processed_data = [self._preprocess_json(t) for t in raw_json]
df = pd.DataFrame(processed_data)
# Проверка порядка колонок
df = df[self.feature_order]
# Предсказание
mcc_codes = self.model.predict(df)
probs = self.model.predict_proba(df)
results = []
for i in range(len(raw_json)):
results.append({
"transaction_id": raw_json[i].get('transaction_id'),
"mcc": int(mcc_codes[i][0]),
"confidence": round(float(np.max(probs[i])), 4)
})
return results
# --- ТЕСТ ---
predictor = MCCPredictor('mcc_classifier.cbm')
request_data = {
"transaction_id": "TX00001116",
"terminal_name": "STORE001",
"terminal_description": "common common common thing",
"city": "NYC",
"amount": 272.80,
"items": [
{"name": "basic loyalty", "price": 58.20},
{"name": "Bringiong item lifes", "price": 28.99},
{"name": "regular item basic item", "price": 56.91}
]
}
res = predictor.predict(request_data)
print(res)