diff --git a/lstm_1b.py b/lstm_1b.py new file mode 100644 index 0000000..06404ca --- /dev/null +++ b/lstm_1b.py @@ -0,0 +1,169 @@ +import time +import json + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from sklearn.metrics import accuracy_score + +import ml_helper +import ml_history + +class ImprovedLSTMBinaryClassifier(nn.Module): + def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout=0.1, bidirectional=False): + super(ImprovedLSTMBinaryClassifier, self).__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional) + self.layer_norm = nn.LayerNorm(hidden_dim * 2 if bidirectional else hidden_dim) + self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, input_ids): + input_ids = input_ids.long() + embedded = self.embedding(input_ids) + lstm_output, _ = self.lstm(embedded) + pooled_output = lstm_output[:, -1, :] + pooled_output = self.layer_norm(pooled_output) + logits = self.fc(pooled_output) + return self.sigmoid(logits) + + +if __name__ == "__main__": + # Load the data + data_path = 'data/idx_based_padded' + train_dataset = torch.load(data_path + '/train.pt') + test_dataset = torch.load(data_path + '/test.pt') + val_dataset = torch.load(data_path + '/val.pt') + + # +2 for padding and unk tokens + vocab_size = train_dataset.vocab_size + 2 + embed_dim = 100 # train_dataset.emb_dim + + # NOTE: Info comes from data explore notebook: 280 is max length, + # 139 contains 80% and 192 contains 95% of the data + max_len = 280 + + device = ml_helper.get_device(verbose=True) + + # Model hyperparameters + hidden_dim = 256 + num_layers = 2 + dropout = 0.3 + bidirectional = True # Enable bidirectional LSTM + + model = ImprovedLSTMBinaryClassifier(vocab_size, embed_dim, hidden_dim, num_layers, dropout, bidirectional) + + # Training parameters + epochs = 3 + batch_size = 8 + learning_rate = 2e-5 + + # Optimizer and loss function + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + criterion = nn.BCEWithLogitsLoss() + + # Data loaders + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + ################################################################################################ + # Training + ################################################################################################ + # Initialize the history + history = ml_history.History() + + # Model to device + model.to(device) + + print("Starting training...") + start_training_time = time.time() + + # Training loop + model.train() + for epoch in range(epochs): + epoch_start_time = time.time() + history.batch_reset() + + for batch in train_loader: + optimizer.zero_grad() + # prepare batch + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].unsqueeze(1).to(device) + # forward pass + outputs = model(input_ids) + loss = criterion(outputs, labels) + # backward pass + loss.backward() + optimizer.step() + # calculate accuracy train + preds = outputs.round() + train_acc = accuracy_score(labels.cpu().detach().numpy(), + preds.cpu().detach().numpy()) + # update batch history + history.batch_update_train(loss.item(), train_acc) + + # calculate accuracy val + model.eval() + with torch.no_grad(): + for val_batch in val_loader: + val_input_ids = val_batch['input_ids'].to(device) + val_labels_batch = val_batch['labels'].unsqueeze(1).to(device) + val_outputs = model(val_input_ids) + val_acc = accuracy_score(val_outputs.round().cpu().numpy(), + val_labels_batch.cpu().numpy()) + history.batch_update_val(val_acc) + model.train() + + # update epoch history + history.update() + + epoch_end_time = time.time() + + print(f"Epoch {epoch + 1}/{epochs}, Time: {epoch_end_time - epoch_start_time:.2f} sec, Loss: {history.history['loss'][-1]:.4f}, Train Acc: {history.history['train_acc'][-1]:.4f}, Val Acc: {history.history['val_acc'][-1]:.4f}") + + end_training_time = time.time() + print(f"Training finished in {end_training_time - start_training_time:.2f} seconds") + + ################################################################################################ + # Evaluation + ################################################################################################ + print("Starting evaluation...") + + model.eval() + predictions, true_labels = [], [] + with torch.no_grad(): + for batch in test_loader: + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].unsqueeze(1).to(device) + + outputs = model(input_ids) + preds = outputs.round() + predictions.extend(preds.cpu().numpy()) + true_labels.extend(labels.cpu().numpy()) + + accuracy = accuracy_score(true_labels, predictions) + print(f"Accuracy: {accuracy}") + + ################################################################################################ + # Save model and hyperparameters + ################################################################################################ + timestamp = time.strftime("%Y%m%d-%H%M%S") + + ml_helper.save_model_and_hyperparameters(model, 'improved_lstm', accuracy, timestamp, + max_len=max_len, + vocab_size=vocab_size, + embed_dim=embed_dim, + hidden_dim=hidden_dim, + num_layers=num_layers, + dropout=dropout, + epochs=epochs, + batch_size=batch_size, + learning_rate=learning_rate) + + # Save history + history_path = f'models/improved_lstm_history_{timestamp}.json' + with open(history_path, 'w') as f: + json.dump(history.get_history(), f)