Coverage for src / models / predict_model.py: 82%
56 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:30 +0000
1import logging
2import os
3from typing import Any, Dict
5import joblib
6import pandas as pd
7import torch
9from src.models.churn_mlp import ChurnMLP
10from src.schemas.data_schema import CustomerSchema
12# Configuração Padrão Corporativa de Logging
13logging.basicConfig(
14 level=logging.INFO,
15 format="%(asctime)s [%(levelname)s] %(message)s",
16 datefmt="%Y-%m-%d %H:%M:%S",
17)
18logger = logging.getLogger(__name__)
20# Limiar de Negócio definido na Etapa 2 (Custo Financeiro Ótimo)
21BUSINESS_THRESHOLD = 0.30
23# Variáveis Globais de Estado (Lazy Loading)
24_preprocessor = None
25_model = None
28def load_artifacts():
29 """
30 Carrega o preprocessor (Scikit-Learn) e os pesos (PyTorch) do disco.
31 Utiliza Lazy Loading para carregar apenas uma vez na inicialização da API.
32 """
33 global _preprocessor, _model
35 if _preprocessor is None:
36 logger.info("Carregando Preprocessor do disco...")
37 preprocessor_path = "models/preprocessor.joblib"
38 if not os.path.exists(preprocessor_path):
39 logger.error(f"Artefato {preprocessor_path} não encontrado.")
40 raise FileNotFoundError(
41 f"Artefato {preprocessor_path} não encontrado. Execute train_model.py primeiro."
42 )
43 _preprocessor = joblib.load(preprocessor_path)
45 if _model is None:
46 logger.info("Carregando Pesos da Rede Neural do disco...")
47 model_path = "models/churn_mlp.pth"
48 if not os.path.exists(model_path):
49 logger.error(f"Artefato {model_path} não encontrado.")
50 raise FileNotFoundError(
51 f"Artefato {model_path} não encontrado. Execute train_model.py primeiro."
52 )
54 # Recupera o número de features pós-OneHotEncoder (do preprocessor fitado)
55 state_dict = torch.load(
56 model_path, map_location=torch.device("cpu"), weights_only=True
57 )
58 input_dim = state_dict["fc1.weight"].shape[1]
60 _model = ChurnMLP(input_dim=input_dim)
61 _model.load_state_dict(state_dict)
62 _model.eval()
65def predict_churn(raw_data: Dict[str, Any]) -> Dict[str, Any]:
66 """
67 Função de inferência end-to-end.
68 Recebe um dicionário JSON com os dados do cliente, valida com Pandera, processa e aplica o threshold.
69 """
70 load_artifacts()
72 # 0. Converte para DataFrame
73 df = pd.DataFrame([raw_data])
75 # 1. Validação de Contrato (Pandera)
76 try:
77 CustomerSchema.validate(df)
78 logger.info("Validação do Schema Pandera passou com sucesso.")
79 except Exception as e:
80 logger.error(f"Falha na validação do Schema dos dados de entrada: {e}")
81 raise ValueError(f"Dados de entrada inválidos segundo o Schema. Detalhe: {e}")
83 # Tratamento específico que fizemos na fase de limpeza (TotalCharges missing)
84 if "TotalCharges" in df.columns:
85 df["TotalCharges"] = pd.to_numeric(df["TotalCharges"], errors="coerce")
86 df["TotalCharges"] = df["TotalCharges"].fillna(0.0)
88 # 2. Pipeline do Scikit-Learn
89 try:
90 X_tf = _preprocessor.transform(df)
91 except ValueError as e:
92 logger.error(f"Erro no ColumnTransformer: {e}")
93 raise ValueError(f"Erro no pré-processamento. Detalhe: {e}")
95 # 3. PyTorch Inference
96 X_tensor = torch.tensor(X_tf, dtype=torch.float32)
98 with torch.no_grad():
99 logits = _model(X_tensor)
100 probability = torch.sigmoid(logits).item()
102 # 4. Regra de Negócio (Threshold)
103 is_churn = bool(probability >= BUSINESS_THRESHOLD)
105 logger.info(
106 f"Inferência concluída. Probabilidade: {probability:.4f} | Churn: {is_churn}"
107 )
109 return {
110 "churn_prediction": is_churn,
111 "probability": round(probability, 4),
112 "threshold_used": BUSINESS_THRESHOLD,
113 }