import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from sklearn.metrics import mean_squared_error, r2_score from torch.optim.lr_scheduler import ReduceLROnPlateau import time from tqdm import tqdm class LSTMNetwork(nn.Module): def __init__(self, input_dim, hidden_dim, num_layers, output_dim, dropout=0.3): super(LSTMNetwork, self).__init__() self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): lstm_out, _ = self.lstm(x) return self.fc(self.dropout(lstm_out[:, -1, :])) def compute_metrics(predictions, labels): mse = mean_squared_error(labels, predictions) r2 = r2_score(labels, predictions) return mse, r2 def train_model(model, train_loader, val_loader, test_loader, epochs=10, device='cuda'): criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) best_val_loss = float('inf') best_test_r2 = -float('inf') patience = 3 counter = 0 history = {'train_loss': [], 'val_loss': [], 'test_r2': [], 'test_mse': []} for epoch in range(epochs): model.train() total_loss = 0 start_time = time.time() train_preds, train_labels = [], [] for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100): optimizer.zero_grad() inputs = batch['input_ids'].to(device) labels = batch['labels'].to(device) outputs = model(inputs) loss = criterion(outputs.squeeze(), labels) loss.backward() optimizer.step() total_loss += loss.item() train_preds.extend(outputs.squeeze().detach().cpu().numpy()) train_labels.extend(labels.cpu().numpy()) avg_train_loss = total_loss / len(train_loader) model.eval() val_loss = 0 val_preds, val_labels = [], [] with torch.no_grad(): for batch in val_loader: inputs = batch['input_ids'].to(device) labels = batch['labels'].to(device) outputs = model(inputs) val_loss += criterion(outputs.squeeze(), labels).item() val_preds.extend(outputs.squeeze().cpu().numpy()) val_labels.extend(labels.cpu().numpy()) avg_val_loss = val_loss / len(val_loader) test_preds, test_labels = [], [] with torch.no_grad(): for batch in test_loader: inputs = batch['input_ids'].to(device) labels = batch['labels'].to(device) outputs = model(inputs) test_preds.extend(outputs.squeeze().cpu().numpy()) test_labels.extend(labels.cpu().numpy()) test_mse, test_r2 = compute_metrics(test_preds, test_labels) history['train_loss'].append(avg_train_loss) history['val_loss'].append(avg_val_loss) history['test_r2'].append(test_r2) history['test_mse'].append(test_mse) scheduler.step(avg_val_loss) epoch_time = time.time() - start_time print(f'Epoch {epoch+1}/{epochs} | Time: {epoch_time:.2f}s') print(f'Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}') print(f'Test MSE: {test_mse:.4f} | Test R2: {test_r2:.4f}\n') if test_r2 > best_test_r2: best_test_r2 = test_r2 torch.save(model.state_dict(), "best_lstm_model.pth") print(f"🚀 New best model saved (R2: {test_r2:.4f})") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss counter = 0 else: counter += 1 if counter >= patience: print("⛔ Early stopping triggered!") break return history if __name__ == "__main__": data_path = '/content/drive/MyDrive/Colab Notebooks/ANLP_WS24_CA2/data/embedded_padded' train_dataset = torch.load(f'{data_path}/train.pt') test_dataset = torch.load(f'{data_path}/test.pt') val_dataset = torch.load(f'{data_path}/val.pt') input_dim = 100 hidden_dim = 1024 num_layers = 2 output_dim = 1 dropout = 0.2 batch_size = 256 epochs = 5 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = LSTMNetwork(input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers, output_dim=output_dim, dropout=dropout).to(device) history = train_model(model, train_loader, val_loader, test_loader, epochs=epochs, device=device)