app v1
This commit is contained in:
parent
694975742c
commit
d9abed5c15
125
dih.py
125
dih.py
|
|
@ -1,125 +0,0 @@
|
||||||
import time
|
|
||||||
import joblib
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
|
|
||||||
from sklearn.base import BaseEstimator, TransformerMixin
|
|
||||||
|
|
||||||
|
|
||||||
# ===== КРИТИЧЕСКИ ВАЖНО =====
|
|
||||||
# Эти классы ДОЛЖНЫ существовать до joblib.load()
|
|
||||||
|
|
||||||
class TextExtractor(BaseEstimator, TransformerMixin):
|
|
||||||
def fit(self, X, y=None):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def transform(self, X):
|
|
||||||
return X['full_text'].fillna('')
|
|
||||||
|
|
||||||
|
|
||||||
class NumberExtractor(BaseEstimator, TransformerMixin):
|
|
||||||
def fit(self, X, y=None):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def transform(self, X):
|
|
||||||
return X[['amount']].fillna(0)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================
|
|
||||||
|
|
||||||
MODEL_PATH = "mcc_model.pkl"
|
|
||||||
|
|
||||||
BASE_TX = {
|
|
||||||
"transaction_id": "TX00001116",
|
|
||||||
"terminal_name": "STORE001",
|
|
||||||
"terminal_description": "common common common thing",
|
|
||||||
"city": "NYC",
|
|
||||||
"amount": 272.80,
|
|
||||||
"items": [
|
|
||||||
{"name": "basic loyalty", "price": 58.20},
|
|
||||||
{"name": "Bringiong item lifes", "price": 28.99},
|
|
||||||
{"name": "regular item basic item", "price": 56.91}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def json_to_df(data_list):
|
|
||||||
rows = []
|
|
||||||
for d in data_list:
|
|
||||||
items = d.get("items", [])
|
|
||||||
item_text = " ".join(str(i.get("name","")) for i in items)
|
|
||||||
full_text = f"{d.get('terminal_name','')} {d.get('terminal_description','')} {item_text}".lower()
|
|
||||||
|
|
||||||
rows.append({
|
|
||||||
"full_text": full_text,
|
|
||||||
"amount": float(d.get("amount", 0))
|
|
||||||
})
|
|
||||||
|
|
||||||
return pd.DataFrame(rows)
|
|
||||||
|
|
||||||
|
|
||||||
def make_batch(n=100):
|
|
||||||
batch = []
|
|
||||||
for i in range(n):
|
|
||||||
tx = BASE_TX.copy()
|
|
||||||
tx["transaction_id"] = f"TX{i:06d}"
|
|
||||||
tx["amount"] = round(random.uniform(3, 500), 2)
|
|
||||||
batch.append(tx)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
# ============================
|
|
||||||
# BENCH
|
|
||||||
# ============================
|
|
||||||
|
|
||||||
print("\n[1] LOADING MODEL (cold start)...")
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
model = joblib.load(MODEL_PATH)
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
print(f"Cold load time: {t1 - t0:.4f} sec")
|
|
||||||
|
|
||||||
# warmup
|
|
||||||
print("\n[2] WARMUP...")
|
|
||||||
warm_df = json_to_df([BASE_TX])
|
|
||||||
for _ in range(5):
|
|
||||||
model.predict(warm_df)
|
|
||||||
|
|
||||||
# single inference benchmark
|
|
||||||
print("\n[3] SINGLE REQUEST BENCH (1000 runs)...")
|
|
||||||
times = []
|
|
||||||
|
|
||||||
for _ in range(1000):
|
|
||||||
df = json_to_df([BASE_TX])
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
model.predict_proba(df)
|
|
||||||
times.append(time.perf_counter() - t0)
|
|
||||||
|
|
||||||
times = np.array(times)
|
|
||||||
|
|
||||||
print(f"avg : {times.mean()*1000:.2f} ms")
|
|
||||||
print(f"p95 : {np.percentile(times,95)*1000:.2f} ms")
|
|
||||||
print(f"max : {times.max()*1000:.2f} ms")
|
|
||||||
|
|
||||||
# batch benchmark
|
|
||||||
print("\n[4] BATCH 100 BENCH...")
|
|
||||||
batch = make_batch(100)
|
|
||||||
df_batch = json_to_df(batch)
|
|
||||||
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
model.predict_proba(df_batch)
|
|
||||||
dt = time.perf_counter() - t0
|
|
||||||
|
|
||||||
print(f"Batch 100 time: {dt:.3f} sec")
|
|
||||||
|
|
||||||
# verdict
|
|
||||||
print("\n[5] VERDICT")
|
|
||||||
if np.percentile(times,95) < 0.2:
|
|
||||||
print("✅ /predict проходит по latency (<200ms p95)")
|
|
||||||
else:
|
|
||||||
print("❌ /predict НЕ проходит по latency")
|
|
||||||
|
|
||||||
if dt < 5:
|
|
||||||
print("✅ /predict/batch проходит (<5s)")
|
|
||||||
else:
|
|
||||||
print("❌ /predict/batch НЕ проходит")
|
|
||||||
Binary file not shown.
BIN
mcc_model.pkl
BIN
mcc_model.pkl
Binary file not shown.
100
prepare_data.py
100
prepare_data.py
|
|
@ -1,100 +0,0 @@
|
||||||
import pandas as pd
|
|
||||||
import re
|
|
||||||
|
|
||||||
DATA_DIR = "data"
|
|
||||||
|
|
||||||
# ---------- Текстовая очистка ----------
|
|
||||||
def clean_text(s: str, max_len=1000):
|
|
||||||
if not isinstance(s, str):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
s = s.lower()
|
|
||||||
|
|
||||||
# --- УБРАТЬ ПРЕФИКС t" или t' ---
|
|
||||||
s = re.sub(r'^t["\']', '', s)
|
|
||||||
|
|
||||||
# --- УНИФИКАЦИЯ КАВЫЧЕК ---
|
|
||||||
s = s.replace('"', ' ').replace("'", " ").replace('`', ' ')
|
|
||||||
|
|
||||||
# убрать не-ASCII (эмодзи, иероглифы и т.п.)
|
|
||||||
s = re.sub(r'[^\x00-\x7F]+', ' ', s)
|
|
||||||
|
|
||||||
# оставить только буквы и цифры
|
|
||||||
s = re.sub(r'[^a-z0-9\s]', ' ', s)
|
|
||||||
|
|
||||||
# удалить подряд идущие повторы слов
|
|
||||||
words = s.split()
|
|
||||||
dedup = []
|
|
||||||
prev = None
|
|
||||||
for w in words:
|
|
||||||
if w != prev:
|
|
||||||
dedup.append(w)
|
|
||||||
prev = w
|
|
||||||
s = " ".join(dedup)
|
|
||||||
|
|
||||||
# нормализовать пробелы и обрезать
|
|
||||||
s = re.sub(r'\s+', ' ', s).strip()
|
|
||||||
return s[:max_len]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------- LOAD ----------
|
|
||||||
transactions = pd.read_csv(f"{DATA_DIR}/transactions.csv")
|
|
||||||
terminals = pd.read_csv(f"{DATA_DIR}/terminals.csv")
|
|
||||||
receipts = pd.read_csv(f"{DATA_DIR}/receipts.csv")
|
|
||||||
|
|
||||||
# ---------- CLEAN TEXT ----------
|
|
||||||
for col in ["terminal_name", "terminal_description", "terminal_city"]:
|
|
||||||
if col == "terminal_city":
|
|
||||||
terminals[col] = terminals[col].astype(str).apply(clean_text)
|
|
||||||
else:
|
|
||||||
terminals[col] = terminals[col].apply(clean_text)
|
|
||||||
|
|
||||||
receipts["item_name"] = receipts["item_name"].apply(clean_text)
|
|
||||||
|
|
||||||
# ---------- AGGREGATE RECEIPTS ----------
|
|
||||||
receipt_agg = receipts.groupby("transaction_id").agg(
|
|
||||||
items_text=("item_name", lambda x: " ".join(x)),
|
|
||||||
items_count=("item_name", "count"),
|
|
||||||
items_total_price=("item_price", "sum"),
|
|
||||||
items_max_price=("item_price", "max"),
|
|
||||||
items_min_price=("item_price", "min"),
|
|
||||||
).reset_index()
|
|
||||||
|
|
||||||
# ---------- MERGE WITH TRANSACTIONS ----------
|
|
||||||
df = transactions[["transaction_id", "terminal_id", "amount", "true_mcc"]].merge(
|
|
||||||
terminals[["terminal_id", "terminal_name", "terminal_description", "terminal_city"]],
|
|
||||||
on="terminal_id",
|
|
||||||
how="left"
|
|
||||||
)
|
|
||||||
|
|
||||||
df = df.merge(receipt_agg, on="transaction_id", how="left")
|
|
||||||
|
|
||||||
# ---------- FILL NA ----------
|
|
||||||
for col in ["items_text", "terminal_name", "terminal_description", "terminal_city"]:
|
|
||||||
df[col] = df[col].fillna("")
|
|
||||||
|
|
||||||
for col in ["items_count", "items_total_price", "items_max_price", "items_min_price"]:
|
|
||||||
df[col] = df[col].fillna(0)
|
|
||||||
|
|
||||||
# ---------- BUILD FINAL TEXT ----------
|
|
||||||
df["text"] = (
|
|
||||||
df["terminal_name"] + " " +
|
|
||||||
df["terminal_description"] + " " +
|
|
||||||
df["terminal_city"] + " " +
|
|
||||||
" items " + df["items_text"] + " items " +
|
|
||||||
df["items_text"]
|
|
||||||
)
|
|
||||||
|
|
||||||
df["text"] = df["text"].apply(clean_text)
|
|
||||||
|
|
||||||
# ---------- FINAL CHECK ----------
|
|
||||||
print("rows:", len(df))
|
|
||||||
print("unique tx:", df["transaction_id"].nunique())
|
|
||||||
print(df["true_mcc"].value_counts())
|
|
||||||
|
|
||||||
assert len(df) == df["transaction_id"].nunique()
|
|
||||||
assert df["text"].str.len().min() > 0
|
|
||||||
|
|
||||||
# ---------- SAVE ----------
|
|
||||||
df.to_csv("train.csv", index=False)
|
|
||||||
print("saved train.csv")
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY solution/req.txt .
|
||||||
|
RUN pip install --no-cache-dir -r req.txt
|
||||||
|
|
||||||
|
COPY ./solution .
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080", "--reload"]
|
||||||
|
|
||||||
|
|
@ -0,0 +1,131 @@
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from typing import List, Optional
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import joblib
|
||||||
|
import re
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
MAX_TEXT_LEN = 10000 # максимальная длина текста для модели
|
||||||
|
|
||||||
|
# модели для валидации и отлова ошибок
|
||||||
|
class Item(BaseModel):
|
||||||
|
name: str = ""
|
||||||
|
price: float = 0.0
|
||||||
|
|
||||||
|
@validator("price", pre=True)
|
||||||
|
def fix_negative_price(cls, v):
|
||||||
|
return max(0, float(v) if v is not None else 0)
|
||||||
|
|
||||||
|
class Transaction(BaseModel):
|
||||||
|
transaction_id: str
|
||||||
|
terminal_name: Optional[str] = "unknown"
|
||||||
|
terminal_description: Optional[str] = "unknown"
|
||||||
|
city: Optional[str] = "unknown"
|
||||||
|
amount: float = 0.0
|
||||||
|
items: Optional[List[Item]] = []
|
||||||
|
|
||||||
|
@validator("amount", pre=True)
|
||||||
|
def fix_negative_amount(cls, v):
|
||||||
|
return max(0, float(v) if v is not None else 0)
|
||||||
|
|
||||||
|
class BatchRequest(BaseModel):
|
||||||
|
transactions: List[Transaction]
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = joblib.load("model/mcc_model.pkl")
|
||||||
|
console.log("[green]Успешно загрузили модель[/green]")
|
||||||
|
MODEL_READY = True
|
||||||
|
except Exception as e:
|
||||||
|
console.log(f"[red]Ошибка загрузки модели: {e}[/red]")
|
||||||
|
MODEL_READY = False
|
||||||
|
|
||||||
|
# препроцессим транзакции с кастомными типами, чтобы не валилось
|
||||||
|
def clean_text(text: str) -> str:
|
||||||
|
text = re.sub(r"[^\w\s\.,'-]", " ", str(text))
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
return text[:MAX_TEXT_LEN]
|
||||||
|
|
||||||
|
def preprocess_transaction(tx: Transaction) -> pd.DataFrame:
|
||||||
|
terminal_name = clean_text(tx.terminal_name)
|
||||||
|
terminal_desc = clean_text(tx.terminal_description)
|
||||||
|
city = clean_text(tx.city)
|
||||||
|
|
||||||
|
item_names = [clean_text(i.name) for i in tx.items] if tx.items else []
|
||||||
|
items_text = " ".join(item_names)
|
||||||
|
|
||||||
|
combined_text = f"{terminal_name} {terminal_desc} {city} items {items_text}".lower()[:MAX_TEXT_LEN]
|
||||||
|
|
||||||
|
df = pd.DataFrame([{
|
||||||
|
"full_text": combined_text,
|
||||||
|
"amount": tx.amount,
|
||||||
|
"items_text": items_text,
|
||||||
|
"items_count": len(tx.items),
|
||||||
|
"items_total_price": sum(i.price for i in tx.items) if tx.items else 0,
|
||||||
|
"items_max_price": max((i.price for i in tx.items), default=0),
|
||||||
|
"items_min_price": min((i.price for i in tx.items), default=0)
|
||||||
|
}])
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health():
|
||||||
|
if MODEL_READY:
|
||||||
|
return {"status": "ok"}
|
||||||
|
raise HTTPException(status_code=503, detail="Model not up")
|
||||||
|
|
||||||
|
@app.get("/model/info")
|
||||||
|
def info():
|
||||||
|
return {"model_name": "mcc_classifier",
|
||||||
|
"model_version": "1.0.67"}
|
||||||
|
|
||||||
|
@app.post("/predict")
|
||||||
|
def predict(tx: Transaction):
|
||||||
|
if not MODEL_READY:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not up")
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = preprocess_transaction(tx)
|
||||||
|
pred = model.predict(df)[0]
|
||||||
|
conf = model.predict_proba(df).max()
|
||||||
|
console.log(f"[yellow]Предсказание для транзакции с id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
||||||
|
return {
|
||||||
|
"transaction_id": tx.transaction_id,
|
||||||
|
"predicted_mcc": int(pred),
|
||||||
|
"confidence": round(float(conf), 4)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
@app.post("/predict/batch")
|
||||||
|
def predict_batch(batch: BatchRequest):
|
||||||
|
if not MODEL_READY:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not up")
|
||||||
|
results = []
|
||||||
|
for tx in batch.transactions:
|
||||||
|
try:
|
||||||
|
df = preprocess_transaction(tx)
|
||||||
|
pred = model.predict(df)[0]
|
||||||
|
conf = model.predict_proba(df).max()
|
||||||
|
results.append({
|
||||||
|
"transaction_id": tx.transaction_id,
|
||||||
|
"predicted_mcc": int(pred),
|
||||||
|
"confidence": round(float(conf), 4)
|
||||||
|
})
|
||||||
|
console.log(f"[yellow]Предсказание для id {tx.transaction_id}: {pred} (conf={conf:.4f})[/yellow]")
|
||||||
|
except Exception as e:
|
||||||
|
console.log(f"[red]Ошибка предсказания для id {tx.transaction_id}: {e}[/red]")
|
||||||
|
results.append({
|
||||||
|
"transaction_id": tx.transaction_id,
|
||||||
|
"predicted_mcc": None,
|
||||||
|
"confidence": 0.0
|
||||||
|
})
|
||||||
|
return {"predictions": results}
|
||||||
Binary file not shown.
|
|
@ -16,8 +16,10 @@ joblib==1.5.3
|
||||||
jupyter-client==8.8.0
|
jupyter-client==8.8.0
|
||||||
jupyter-core==5.9.1
|
jupyter-core==5.9.1
|
||||||
kiwisolver==1.4.9
|
kiwisolver==1.4.9
|
||||||
|
markdown-it-py==4.0.0
|
||||||
matplotlib==3.10.8
|
matplotlib==3.10.8
|
||||||
matplotlib-inline==0.2.1
|
matplotlib-inline==0.2.1
|
||||||
|
mdurl==0.1.2
|
||||||
narwhals==2.15.0
|
narwhals==2.15.0
|
||||||
nest-asyncio==1.6.0
|
nest-asyncio==1.6.0
|
||||||
numpy==2.4.1
|
numpy==2.4.1
|
||||||
|
|
@ -38,6 +40,7 @@ pyparsing==3.3.2
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
pytz==2025.2
|
pytz==2025.2
|
||||||
pyzmq==27.1.0
|
pyzmq==27.1.0
|
||||||
|
rich==14.2.0
|
||||||
scikit-learn==1.8.0
|
scikit-learn==1.8.0
|
||||||
scipy==1.17.0
|
scipy==1.17.0
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
|
|
@ -9,7 +9,7 @@ from sklearn.pipeline import Pipeline, FeatureUnion
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from sklearn.base import BaseEstimator, TransformerMixin
|
from sklearn.base import BaseEstimator, TransformerMixin
|
||||||
from sklearn.metrics import accuracy_score, classification_report
|
from sklearn.metrics import accuracy_score, classification_report
|
||||||
import rich # <!-- дебаг через rich - моя guilty pleasure, очень уж люблю на красивые выводы смотреть
|
import rich # <!-- дебаг через rich - моя guilty pleasure, очень уж люблю на красивые выводы смотреть
|
||||||
from rich import print as rpint
|
from rich import print as rpint
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich import box
|
from rich import box
|
||||||
|
|
@ -24,16 +24,16 @@ console = Console()
|
||||||
# Для преобразования TF-IDF в вектора
|
# Для преобразования TF-IDF в вектора
|
||||||
class TextExtractor(BaseEstimator, TransformerMixin):
|
class TextExtractor(BaseEstimator, TransformerMixin):
|
||||||
def fit(self, X, y=None): return self
|
def fit(self, X, y=None): return self
|
||||||
|
|
||||||
def transform(self, X):
|
def transform(self, X):
|
||||||
return X['full_text'].fillna('')
|
return X['full_text'].fillna('')
|
||||||
|
|
||||||
# Для StandardScaler
|
# Для StandardScaler
|
||||||
class NumberExtractor(BaseEstimator, TransformerMixin):
|
class NumberExtractor(BaseEstimator, TransformerMixin):
|
||||||
def fit(self, X, y=None): return self
|
def fit(self, X, y=None): return self
|
||||||
|
|
||||||
def transform(self, X):
|
def transform(self, X):
|
||||||
return X[['amount']].fillna(0)
|
return X[['amount']].fillna(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
|
|
@ -43,24 +43,26 @@ def train_model():
|
||||||
terminals = pd.read_csv('data/terminals.csv')
|
terminals = pd.read_csv('data/terminals.csv')
|
||||||
receipts = pd.read_csv('data/receipts.csv')
|
receipts = pd.read_csv('data/receipts.csv')
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
console.log(f"Файлы для обучения не найдены :( \n {e}", style="white on red")
|
console.log(
|
||||||
|
f"Файлы для обучения не найдены :( \n {e}", style="white on red")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
console.log("[yellow]Предобрабатываем данные...[/yellow]")
|
console.log("[yellow]Предобрабатываем данные...[/yellow]")
|
||||||
# Приклеиваеем вместе имена товаров
|
# Приклеиваеем вместе имена товаров
|
||||||
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
|
receipts_agg = receipts.groupby('transaction_id')['item_name'].apply(
|
||||||
lambda x: ' '.join(str(i) for i in x)
|
lambda x: ' '.join(str(i) for i in x)
|
||||||
).reset_index()
|
).reset_index()
|
||||||
|
|
||||||
# Делаем один большой датафрейм с которым будем работать
|
# Делаем один большой датафрейм с которым будем работать
|
||||||
df = tx.merge(terminals[['terminal_id', 'terminal_name', 'terminal_description']], on='terminal_id', how='left')
|
df = tx.merge(terminals[['terminal_id', 'terminal_name',
|
||||||
|
'terminal_description']], on='terminal_id', how='left')
|
||||||
df = df.merge(receipts_agg, on='transaction_id', how='left')
|
df = df.merge(receipts_agg, on='transaction_id', how='left')
|
||||||
|
|
||||||
# Делаем текстовое поле для TF-IDF
|
# Делаем текстовое поле для TF-IDF
|
||||||
df['full_text'] = (
|
df['full_text'] = (
|
||||||
df['terminal_name'].astype(str) + " " +
|
df['terminal_name'].astype(str) + " " +
|
||||||
df['terminal_description'].astype(str) + " " + # <!-- изначально я пробовал клеить id транзакции, однако модель слишком на ней зацикливалась
|
|
||||||
|
df['terminal_description'].astype(str) + " " + # <!-- изначально я пробовал клеить id транзакции, однако модель слишком на ней зацикливалась
|
||||||
df['item_name'].astype(str)
|
df['item_name'].astype(str)
|
||||||
).str.lower()
|
).str.lower()
|
||||||
|
|
||||||
|
|
@ -102,13 +104,12 @@ def train_model():
|
||||||
|
|
||||||
# Валидация
|
# Валидация
|
||||||
|
|
||||||
console.log("[yellow]Оцениваем качество на валидационной выборке...[/yellow]")
|
console.log(
|
||||||
|
"[yellow]Оцениваем качество на валидационной выборке...[/yellow]")
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X, y, test_size=0.2, random_state=42, stratify=y
|
X, y, test_size=0.2, random_state=42, stratify=y
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pipeline.fit(X_train, y_train)
|
pipeline.fit(X_train, y_train)
|
||||||
y_pred = pipeline.predict(X_test)
|
y_pred = pipeline.predict(X_test)
|
||||||
probs = pipeline.predict_proba(X_test)
|
probs = pipeline.predict_proba(X_test)
|
||||||
|
|
@ -118,7 +119,6 @@ def train_model():
|
||||||
|
|
||||||
table = Table(box=box.ROUNDED, title="Отчёт")
|
table = Table(box=box.ROUNDED, title="Отчёт")
|
||||||
|
|
||||||
|
|
||||||
table.add_column("Метрика", justify="center", style="yellow")
|
table.add_column("Метрика", justify="center", style="yellow")
|
||||||
table.add_column("Значение", justify="center", style="yellow")
|
table.add_column("Значение", justify="center", style="yellow")
|
||||||
|
|
||||||
|
|
@ -129,7 +129,7 @@ def train_model():
|
||||||
console.print("[yellow]Репорт по классам[/yellow]", justify="center")
|
console.print("[yellow]Репорт по классам[/yellow]", justify="center")
|
||||||
console.print(classification_report(y_test, y_pred), justify="center")
|
console.print(classification_report(y_test, y_pred), justify="center")
|
||||||
|
|
||||||
# Метрики норм, учимся на всех данных и сохранем данные
|
# Метрики норм, учимся на всех данных и сохранем модель
|
||||||
|
|
||||||
console.log("[yellow]Учимся на всем, что есть...[/yellow]")
|
console.log("[yellow]Учимся на всем, что есть...[/yellow]")
|
||||||
pipeline.fit(X, y)
|
pipeline.fit(X, y)
|
||||||
|
|
@ -138,6 +138,8 @@ def train_model():
|
||||||
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
|
joblib.dump(pipeline, 'solution/model/mcc_model.pkl')
|
||||||
console.log(Markdown("Сохранили модель в **solution/model/mcc_model.pkl**"))
|
console.log(Markdown("Сохранили модель в **solution/model/mcc_model.pkl**"))
|
||||||
|
|
||||||
with console.status("Учим модель..."):
|
|
||||||
train_model()
|
if __name__ == "__main__":
|
||||||
console.print(Markdown("*Модель готова :)*"))
|
with console.status("Учим модель..."):
|
||||||
|
train_model()
|
||||||
|
console.print(Markdown("*Модель готова :)*"))
|
||||||
Loading…
Reference in New Issue