Coverage for src / models / train_model.py: 92%
106 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:30 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:30 +0000
1import logging
2import os
4import joblib
5import mlflow
6import numpy as np
7import pandas as pd
8import torch
9import torch.nn as nn
10import torch.optim as optim
11from sklearn.model_selection import train_test_split
12from torch.utils.data import DataLoader, Dataset
14from src.features.build_features import clean_raw_data, get_preprocessor
15from src.models.churn_mlp import ChurnMLP
17# Configuração Padrão Corporativa de Logging
18logging.basicConfig(
19 level=logging.INFO,
20 format="%(asctime)s [%(levelname)s] %(message)s",
21 datefmt="%Y-%m-%d %H:%M:%S",
22)
23logger = logging.getLogger(__name__)
25# Constantes do Campeão (Grid Search)
26LEARNING_RATE = 0.001
27BATCH_SIZE = 64
28EPOCHS = 100
29PATIENCE = 10
32class ChurnDataset(Dataset):
33 def __init__(self, X, y):
34 self.X = torch.tensor(X, dtype=torch.float32)
35 self.y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
37 def __len__(self):
38 return len(self.X)
40 def __getitem__(self, idx):
41 return self.X[idx], self.y[idx]
44def main():
45 logger.info("Iniciando Treinamento Modular do Telco Churn PyTorch MLP...")
47 # 0. Setup do MLflow Tracking
48 mlflow.set_tracking_uri("sqlite:///mlflow.db")
49 experiment_name = "churn_mlp_pytorch_modular"
51 if not mlflow.get_experiment_by_name(experiment_name):
52 mlflow.create_experiment(
53 name=experiment_name, artifact_location="file:./mlruns"
54 )
55 mlflow.set_experiment(experiment_name)
57 # Englobando a execução em uma Run do MLflow
58 with mlflow.start_run(run_name="Modular_Training_v1"):
59 mlflow.log_param("learning_rate", LEARNING_RATE)
60 mlflow.log_param("batch_size", BATCH_SIZE)
61 mlflow.log_param("patience", PATIENCE)
63 # 1. Carregamento e Limpeza
64 data_path = "data/raw/dataset.csv"
65 if not os.path.exists(data_path):
66 logger.error(f"Dataset não encontrado em {data_path}.")
67 raise FileNotFoundError(
68 f"Dataset não encontrado em {data_path}. Rode o download primeiro."
69 )
71 df = pd.read_csv(data_path)
72 df = clean_raw_data(df)
74 # Separação
75 X = df.drop(columns=["Churn", "customerID"])
76 y = df["Churn"].map({"Yes": 1, "No": 0}).values
78 X_train, X_test, y_train, y_test = train_test_split(
79 X, y, test_size=0.2, stratify=y, random_state=42
80 )
82 # 2. Pré-processamento
83 logger.info("Processando features via ColumnTransformer...")
84 preprocessor = get_preprocessor(list(X.columns))
86 X_train_tf = preprocessor.fit_transform(X_train)
87 X_test_tf = preprocessor.transform(X_test)
89 # Salva o Preprocessor fitado
90 os.makedirs("models", exist_ok=True)
91 joblib.dump(preprocessor, "models/preprocessor.joblib")
92 logger.info("Preprocessor salvo em models/preprocessor.joblib")
94 # 3. Preparação para PyTorch
95 train_dataset = ChurnDataset(X_train_tf, y_train)
96 test_dataset = ChurnDataset(X_test_tf, y_test)
98 train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
99 test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
101 # 4. Inicialização do Modelo
102 input_dim = X_train_tf.shape[1]
103 model = ChurnMLP(input_dim=input_dim)
105 # Cálculo do pos_weight para balanceamento
106 num_positives = np.sum(y_train == 1)
107 num_negatives = np.sum(y_train == 0)
108 pos_weight_val = num_negatives / num_positives
109 criterion = nn.BCEWithLogitsLoss(
110 pos_weight=torch.tensor([pos_weight_val], dtype=torch.float32)
111 )
113 optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
115 # 5. Loop de Treinamento com Early Stopping
116 logger.info("Iniciando treinamento da Rede Neural...")
117 best_val_loss = float("inf")
118 epochs_no_improve = 0
119 final_epoch = 0
121 for epoch in range(EPOCHS):
122 model.train()
123 train_loss = 0.0
125 for batch_X, batch_y in train_loader:
126 optimizer.zero_grad()
127 outputs = model(batch_X)
128 loss = criterion(outputs, batch_y)
129 loss.backward()
130 optimizer.step()
131 train_loss += loss.item() * batch_X.size(0)
133 train_loss /= len(train_loader.dataset)
135 # Validação
136 model.eval()
137 val_loss = 0.0
138 with torch.no_grad():
139 for batch_X, batch_y in test_loader:
140 outputs = model(batch_X)
141 loss = criterion(outputs, batch_y)
142 val_loss += loss.item() * batch_X.size(0)
144 val_loss /= len(test_loader.dataset)
146 # Log Epoch Metrics to MLflow
147 mlflow.log_metric("train_loss", train_loss, step=epoch)
148 mlflow.log_metric("val_loss", val_loss, step=epoch)
150 # Early Stopping Check
151 if val_loss < best_val_loss:
152 best_val_loss = val_loss
153 epochs_no_improve = 0
154 torch.save(model.state_dict(), "models/churn_mlp.pth")
155 else:
156 epochs_no_improve += 1
158 if (epoch + 1) % 10 == 0:
159 logger.info(
160 f"Epoch {epoch + 1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}"
161 )
163 if epochs_no_improve >= PATIENCE:
164 logger.warning(f"Early Stopping acionado na Época {epoch + 1}!")
165 final_epoch = epoch + 1
166 break
168 # Registra a última época e salva o modelo no MLflow Registry
169 mlflow.log_metric("epochs_run", final_epoch)
170 mlflow.log_metric("best_val_loss", best_val_loss)
172 logger.info(
173 "Treinamento concluído. Pesos da Rede Neural salvos em models/churn_mlp.pth"
174 )
177if __name__ == "__main__":
178 main()