ANLP_WS24_CA2/lstm_1b.py

170 lines
6.5 KiB
Python

import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import ml_helper
import ml_history
class ImprovedLSTMBinaryClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout=0.1, bidirectional=False):
super(ImprovedLSTMBinaryClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional)
self.layer_norm = nn.LayerNorm(hidden_dim * 2 if bidirectional else hidden_dim)
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids):
input_ids = input_ids.long()
embedded = self.embedding(input_ids)
lstm_output, _ = self.lstm(embedded)
pooled_output = lstm_output[:, -1, :]
pooled_output = self.layer_norm(pooled_output)
logits = self.fc(pooled_output)
return self.sigmoid(logits)
if __name__ == "__main__":
# Load the data
data_path = 'data/idx_based_padded'
train_dataset = torch.load(data_path + '/train.pt')
test_dataset = torch.load(data_path + '/test.pt')
val_dataset = torch.load(data_path + '/val.pt')
# +2 for padding and unk tokens
vocab_size = train_dataset.vocab_size + 2
embed_dim = 100 # train_dataset.emb_dim
# NOTE: Info comes from data explore notebook: 280 is max length,
# 139 contains 80% and 192 contains 95% of the data
max_len = 280
device = ml_helper.get_device(verbose=True)
# Model hyperparameters
hidden_dim = 256
num_layers = 2
dropout = 0.3
bidirectional = True # Enable bidirectional LSTM
model = ImprovedLSTMBinaryClassifier(vocab_size, embed_dim, hidden_dim, num_layers, dropout, bidirectional)
# Training parameters
epochs = 3
batch_size = 8
learning_rate = 2e-5
# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
################################################################################################
# Training
################################################################################################
# Initialize the history
history = ml_history.History()
# Model to device
model.to(device)
print("Starting training...")
start_training_time = time.time()
# Training loop
model.train()
for epoch in range(epochs):
epoch_start_time = time.time()
history.batch_reset()
for batch in train_loader:
optimizer.zero_grad()
# prepare batch
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].unsqueeze(1).to(device)
# forward pass
outputs = model(input_ids)
loss = criterion(outputs, labels)
# backward pass
loss.backward()
optimizer.step()
# calculate accuracy train
preds = outputs.round()
train_acc = accuracy_score(labels.cpu().detach().numpy(),
preds.cpu().detach().numpy())
# update batch history
history.batch_update_train(loss.item(), train_acc)
# calculate accuracy val
model.eval()
with torch.no_grad():
for val_batch in val_loader:
val_input_ids = val_batch['input_ids'].to(device)
val_labels_batch = val_batch['labels'].unsqueeze(1).to(device)
val_outputs = model(val_input_ids)
val_acc = accuracy_score(val_outputs.round().cpu().numpy(),
val_labels_batch.cpu().numpy())
history.batch_update_val(val_acc)
model.train()
# update epoch history
history.update()
epoch_end_time = time.time()
print(f"Epoch {epoch + 1}/{epochs}, Time: {epoch_end_time - epoch_start_time:.2f} sec, Loss: {history.history['loss'][-1]:.4f}, Train Acc: {history.history['train_acc'][-1]:.4f}, Val Acc: {history.history['val_acc'][-1]:.4f}")
end_training_time = time.time()
print(f"Training finished in {end_training_time - start_training_time:.2f} seconds")
################################################################################################
# Evaluation
################################################################################################
print("Starting evaluation...")
model.eval()
predictions, true_labels = [], []
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].unsqueeze(1).to(device)
outputs = model(input_ids)
preds = outputs.round()
predictions.extend(preds.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(true_labels, predictions)
print(f"Accuracy: {accuracy}")
################################################################################################
# Save model and hyperparameters
################################################################################################
timestamp = time.strftime("%Y%m%d-%H%M%S")
ml_helper.save_model_and_hyperparameters(model, 'improved_lstm', accuracy, timestamp,
max_len=max_len,
vocab_size=vocab_size,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
dropout=dropout,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate)
# Save history
history_path = f'models/improved_lstm_history_{timestamp}.json'
with open(history_path, 'w') as f:
json.dump(history.get_history(), f)