Coverage for src / models / predict_model.py: 82%

56 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:30 +0000

1import logging 

2import os 

3from typing import Any, Dict 

4 

5import joblib 

6import pandas as pd 

7import torch 

8 

9from src.models.churn_mlp import ChurnMLP 

10from src.schemas.data_schema import CustomerSchema 

11 

12# Configuração Padrão Corporativa de Logging 

13logging.basicConfig( 

14 level=logging.INFO, 

15 format="%(asctime)s [%(levelname)s] %(message)s", 

16 datefmt="%Y-%m-%d %H:%M:%S", 

17) 

18logger = logging.getLogger(__name__) 

19 

20# Limiar de Negócio definido na Etapa 2 (Custo Financeiro Ótimo) 

21BUSINESS_THRESHOLD = 0.30 

22 

23# Variáveis Globais de Estado (Lazy Loading) 

24_preprocessor = None 

25_model = None 

26 

27 

28def load_artifacts(): 

29 """ 

30 Carrega o preprocessor (Scikit-Learn) e os pesos (PyTorch) do disco. 

31 Utiliza Lazy Loading para carregar apenas uma vez na inicialização da API. 

32 """ 

33 global _preprocessor, _model 

34 

35 if _preprocessor is None: 

36 logger.info("Carregando Preprocessor do disco...") 

37 preprocessor_path = "models/preprocessor.joblib" 

38 if not os.path.exists(preprocessor_path): 

39 logger.error(f"Artefato {preprocessor_path} não encontrado.") 

40 raise FileNotFoundError( 

41 f"Artefato {preprocessor_path} não encontrado. Execute train_model.py primeiro." 

42 ) 

43 _preprocessor = joblib.load(preprocessor_path) 

44 

45 if _model is None: 

46 logger.info("Carregando Pesos da Rede Neural do disco...") 

47 model_path = "models/churn_mlp.pth" 

48 if not os.path.exists(model_path): 

49 logger.error(f"Artefato {model_path} não encontrado.") 

50 raise FileNotFoundError( 

51 f"Artefato {model_path} não encontrado. Execute train_model.py primeiro." 

52 ) 

53 

54 # Recupera o número de features pós-OneHotEncoder (do preprocessor fitado) 

55 state_dict = torch.load( 

56 model_path, map_location=torch.device("cpu"), weights_only=True 

57 ) 

58 input_dim = state_dict["fc1.weight"].shape[1] 

59 

60 _model = ChurnMLP(input_dim=input_dim) 

61 _model.load_state_dict(state_dict) 

62 _model.eval() 

63 

64 

65def predict_churn(raw_data: Dict[str, Any]) -> Dict[str, Any]: 

66 """ 

67 Função de inferência end-to-end. 

68 Recebe um dicionário JSON com os dados do cliente, valida com Pandera, processa e aplica o threshold. 

69 """ 

70 load_artifacts() 

71 

72 # 0. Converte para DataFrame 

73 df = pd.DataFrame([raw_data]) 

74 

75 # 1. Validação de Contrato (Pandera) 

76 try: 

77 CustomerSchema.validate(df) 

78 logger.info("Validação do Schema Pandera passou com sucesso.") 

79 except Exception as e: 

80 logger.error(f"Falha na validação do Schema dos dados de entrada: {e}") 

81 raise ValueError(f"Dados de entrada inválidos segundo o Schema. Detalhe: {e}") 

82 

83 # Tratamento específico que fizemos na fase de limpeza (TotalCharges missing) 

84 if "TotalCharges" in df.columns: 

85 df["TotalCharges"] = pd.to_numeric(df["TotalCharges"], errors="coerce") 

86 df["TotalCharges"] = df["TotalCharges"].fillna(0.0) 

87 

88 # 2. Pipeline do Scikit-Learn 

89 try: 

90 X_tf = _preprocessor.transform(df) 

91 except ValueError as e: 

92 logger.error(f"Erro no ColumnTransformer: {e}") 

93 raise ValueError(f"Erro no pré-processamento. Detalhe: {e}") 

94 

95 # 3. PyTorch Inference 

96 X_tensor = torch.tensor(X_tf, dtype=torch.float32) 

97 

98 with torch.no_grad(): 

99 logits = _model(X_tensor) 

100 probability = torch.sigmoid(logits).item() 

101 

102 # 4. Regra de Negócio (Threshold) 

103 is_churn = bool(probability >= BUSINESS_THRESHOLD) 

104 

105 logger.info( 

106 f"Inferência concluída. Probabilidade: {probability:.4f} | Churn: {is_churn}" 

107 ) 

108 

109 return { 

110 "churn_prediction": is_churn, 

111 "probability": round(probability, 4), 

112 "threshold_used": BUSINESS_THRESHOLD, 

113 }