import time import json import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from sklearn.metrics import accuracy_score import ml_helper import ml_history class ImprovedLSTMBinaryClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout=0.1, bidirectional=False): super(ImprovedLSTMBinaryClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional) self.layer_norm = nn.LayerNorm(hidden_dim * 2 if bidirectional else hidden_dim) self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, 1) self.sigmoid = nn.Sigmoid() def forward(self, input_ids): input_ids = input_ids.long() embedded = self.embedding(input_ids) lstm_output, _ = self.lstm(embedded) pooled_output = lstm_output[:, -1, :] pooled_output = self.layer_norm(pooled_output) logits = self.fc(pooled_output) return self.sigmoid(logits) if __name__ == "__main__": # Load the data data_path = 'data/idx_based_padded' train_dataset = torch.load(data_path + '/train.pt') test_dataset = torch.load(data_path + '/test.pt') val_dataset = torch.load(data_path + '/val.pt') # +2 for padding and unk tokens vocab_size = train_dataset.vocab_size + 2 embed_dim = 100 # train_dataset.emb_dim # NOTE: Info comes from data explore notebook: 280 is max length, # 139 contains 80% and 192 contains 95% of the data max_len = 280 device = ml_helper.get_device(verbose=True) # Model hyperparameters hidden_dim = 256 num_layers = 2 dropout = 0.3 bidirectional = True # Enable bidirectional LSTM model = ImprovedLSTMBinaryClassifier(vocab_size, embed_dim, hidden_dim, num_layers, dropout, bidirectional) # Training parameters epochs = 3 batch_size = 8 learning_rate = 2e-5 # Optimizer and loss function optimizer = optim.Adam(model.parameters(), lr=learning_rate) criterion = nn.BCEWithLogitsLoss() # Data loaders train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) ################################################################################################ # Training ################################################################################################ # Initialize the history history = ml_history.History() # Model to device model.to(device) print("Starting training...") start_training_time = time.time() # Training loop model.train() for epoch in range(epochs): epoch_start_time = time.time() history.batch_reset() for batch in train_loader: optimizer.zero_grad() # prepare batch input_ids = batch['input_ids'].to(device) labels = batch['labels'].unsqueeze(1).to(device) # forward pass outputs = model(input_ids) loss = criterion(outputs, labels) # backward pass loss.backward() optimizer.step() # calculate accuracy train preds = outputs.round() train_acc = accuracy_score(labels.cpu().detach().numpy(), preds.cpu().detach().numpy()) # update batch history history.batch_update_train(loss.item(), train_acc) # calculate accuracy val model.eval() with torch.no_grad(): for val_batch in val_loader: val_input_ids = val_batch['input_ids'].to(device) val_labels_batch = val_batch['labels'].unsqueeze(1).to(device) val_outputs = model(val_input_ids) val_acc = accuracy_score(val_outputs.round().cpu().numpy(), val_labels_batch.cpu().numpy()) history.batch_update_val(val_acc) model.train() # update epoch history history.update() epoch_end_time = time.time() print(f"Epoch {epoch + 1}/{epochs}, Time: {epoch_end_time - epoch_start_time:.2f} sec, Loss: {history.history['loss'][-1]:.4f}, Train Acc: {history.history['train_acc'][-1]:.4f}, Val Acc: {history.history['val_acc'][-1]:.4f}") end_training_time = time.time() print(f"Training finished in {end_training_time - start_training_time:.2f} seconds") ################################################################################################ # Evaluation ################################################################################################ print("Starting evaluation...") model.eval() predictions, true_labels = [], [] with torch.no_grad(): for batch in test_loader: input_ids = batch['input_ids'].to(device) labels = batch['labels'].unsqueeze(1).to(device) outputs = model(input_ids) preds = outputs.round() predictions.extend(preds.cpu().numpy()) true_labels.extend(labels.cpu().numpy()) accuracy = accuracy_score(true_labels, predictions) print(f"Accuracy: {accuracy}") ################################################################################################ # Save model and hyperparameters ################################################################################################ timestamp = time.strftime("%Y%m%d-%H%M%S") ml_helper.save_model_and_hyperparameters(model, 'improved_lstm', accuracy, timestamp, max_len=max_len, vocab_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim, num_layers=num_layers, dropout=dropout, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate) # Save history history_path = f'models/improved_lstm_history_{timestamp}.json' with open(history_path, 'w') as f: json.dump(history.get_history(), f)