lstm trained

main
arman 2025-01-29 12:38:37 +01:00
parent 0ec8e4dcc8
commit e2066cb63d
1 changed files with 169 additions and 0 deletions

169
lstm_1b.py 100644
View File

@ -0,0 +1,169 @@
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import ml_helper
import ml_history
class ImprovedLSTMBinaryClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout=0.1, bidirectional=False):
super(ImprovedLSTMBinaryClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional)
self.layer_norm = nn.LayerNorm(hidden_dim * 2 if bidirectional else hidden_dim)
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids):
input_ids = input_ids.long()
embedded = self.embedding(input_ids)
lstm_output, _ = self.lstm(embedded)
pooled_output = lstm_output[:, -1, :]
pooled_output = self.layer_norm(pooled_output)
logits = self.fc(pooled_output)
return self.sigmoid(logits)
if __name__ == "__main__":
# Load the data
data_path = 'data/idx_based_padded'
train_dataset = torch.load(data_path + '/train.pt')
test_dataset = torch.load(data_path + '/test.pt')
val_dataset = torch.load(data_path + '/val.pt')
# +2 for padding and unk tokens
vocab_size = train_dataset.vocab_size + 2
embed_dim = 100 # train_dataset.emb_dim
# NOTE: Info comes from data explore notebook: 280 is max length,
# 139 contains 80% and 192 contains 95% of the data
max_len = 280
device = ml_helper.get_device(verbose=True)
# Model hyperparameters
hidden_dim = 256
num_layers = 2
dropout = 0.3
bidirectional = True # Enable bidirectional LSTM
model = ImprovedLSTMBinaryClassifier(vocab_size, embed_dim, hidden_dim, num_layers, dropout, bidirectional)
# Training parameters
epochs = 3
batch_size = 8
learning_rate = 2e-5
# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
################################################################################################
# Training
################################################################################################
# Initialize the history
history = ml_history.History()
# Model to device
model.to(device)
print("Starting training...")
start_training_time = time.time()
# Training loop
model.train()
for epoch in range(epochs):
epoch_start_time = time.time()
history.batch_reset()
for batch in train_loader:
optimizer.zero_grad()
# prepare batch
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].unsqueeze(1).to(device)
# forward pass
outputs = model(input_ids)
loss = criterion(outputs, labels)
# backward pass
loss.backward()
optimizer.step()
# calculate accuracy train
preds = outputs.round()
train_acc = accuracy_score(labels.cpu().detach().numpy(),
preds.cpu().detach().numpy())
# update batch history
history.batch_update_train(loss.item(), train_acc)
# calculate accuracy val
model.eval()
with torch.no_grad():
for val_batch in val_loader:
val_input_ids = val_batch['input_ids'].to(device)
val_labels_batch = val_batch['labels'].unsqueeze(1).to(device)
val_outputs = model(val_input_ids)
val_acc = accuracy_score(val_outputs.round().cpu().numpy(),
val_labels_batch.cpu().numpy())
history.batch_update_val(val_acc)
model.train()
# update epoch history
history.update()
epoch_end_time = time.time()
print(f"Epoch {epoch + 1}/{epochs}, Time: {epoch_end_time - epoch_start_time:.2f} sec, Loss: {history.history['loss'][-1]:.4f}, Train Acc: {history.history['train_acc'][-1]:.4f}, Val Acc: {history.history['val_acc'][-1]:.4f}")
end_training_time = time.time()
print(f"Training finished in {end_training_time - start_training_time:.2f} seconds")
################################################################################################
# Evaluation
################################################################################################
print("Starting evaluation...")
model.eval()
predictions, true_labels = [], []
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].unsqueeze(1).to(device)
outputs = model(input_ids)
preds = outputs.round()
predictions.extend(preds.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(true_labels, predictions)
print(f"Accuracy: {accuracy}")
################################################################################################
# Save model and hyperparameters
################################################################################################
timestamp = time.strftime("%Y%m%d-%H%M%S")
ml_helper.save_model_and_hyperparameters(model, 'improved_lstm', accuracy, timestamp,
max_len=max_len,
vocab_size=vocab_size,
embed_dim=embed_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
dropout=dropout,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate)
# Save history
history_path = f'models/improved_lstm_history_{timestamp}.json'
with open(history_path, 'w') as f:
json.dump(history.get_history(), f)