199 lines
6.8 KiB
Python
199 lines
6.8 KiB
Python
"""
|
|
This file contains the transformer model.
|
|
"""
|
|
|
|
|
|
# TODO refactor the code
|
|
# TODO create ml helper script
|
|
# TODO create ml evaluation script
|
|
|
|
# TODO track overfitting better
|
|
# TODO validate model in training (accuracy, loss, etc)
|
|
|
|
# TODO set length to a constant value which is the max length of the sentences or nearly
|
|
|
|
|
|
# TODO user gloVe embeddings
|
|
|
|
#TODO: add attention mask
|
|
# TODO: add positional encoding
|
|
#TODO: add dropout (if needed)
|
|
|
|
import time
|
|
import json
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from transformers import AdamW
|
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
import ml_helper
|
|
import ml_history
|
|
|
|
class TransformerBinaryClassifier(nn.Module):
|
|
def __init__(self, vocab_size, embed_dim, num_heads, num_layers, hidden_dim, dropout=0.1):
|
|
super(TransformerBinaryClassifier, self).__init__()
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
|
self.transformer = nn.Transformer(embed_dim, num_heads, num_layers, num_layers, hidden_dim, dropout)
|
|
self.fc = nn.Linear(embed_dim, 1)
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def forward(self, input_ids):
|
|
input_ids = input_ids.long()
|
|
embedded = self.embedding(input_ids)
|
|
transformer_output = self.transformer(embedded, embedded)
|
|
pooled_output = transformer_output.mean(dim=1)
|
|
logits = self.fc(pooled_output)
|
|
return self.sigmoid(logits)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Load the data
|
|
data_path = 'data/idx_based_padded'
|
|
|
|
train_dataset = torch.load(data_path + '/train.pt')
|
|
test_dataset = torch.load(data_path + '/test.pt')
|
|
val_dataset = torch.load(data_path + '/val.pt')
|
|
|
|
# +2 for padding and unk tokens
|
|
vocab_size = train_dataset.vocab_size + 2
|
|
embed_dim = 100 #train_dataset.emb_dim
|
|
|
|
# NOTE: Info comes from data explore notebook: 280 is max length,
|
|
# 139 contains 80% and 192 contains 95% of the data
|
|
max_len = 280
|
|
|
|
device = ml_helper.get_device(verbose=True)
|
|
|
|
# Model hyperparameters
|
|
num_heads = 2
|
|
num_layers = 2
|
|
hidden_dim = 256
|
|
|
|
model = TransformerBinaryClassifier(vocab_size, embed_dim, num_heads, num_layers, hidden_dim)
|
|
|
|
# Training parameters
|
|
epochs = 3 #3
|
|
batch_size = 8
|
|
learning_rate = 2e-5
|
|
|
|
# Optimizer and loss function
|
|
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
|
criterion = nn.BCEWithLogitsLoss()
|
|
|
|
|
|
# Data loaders
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
|
|
|
|
################################################################################################
|
|
# Training
|
|
################################################################################################
|
|
|
|
# Initialize the history
|
|
history = ml_history.History()
|
|
|
|
# Model to device
|
|
model.to(device)
|
|
|
|
print("Starting training...")
|
|
start_training_time = time.time()
|
|
|
|
# Training loop
|
|
model.train()
|
|
for epoch in range(epochs):
|
|
# init batch tracking
|
|
epoch_start_time = time.time()
|
|
history.batch_reset()
|
|
|
|
for batch in train_loader:
|
|
optimizer.zero_grad()
|
|
# prepare batch
|
|
input_ids = batch['input_ids'].to(device)
|
|
labels = batch['labels'].unsqueeze(1).to(device)
|
|
# forward pass
|
|
outputs = model(input_ids)
|
|
loss = criterion(outputs, labels)
|
|
# backward pass
|
|
loss.backward()
|
|
optimizer.step()
|
|
# calculate accuracy train
|
|
preds = outputs.round()
|
|
train_acc = accuracy_score(labels.cpu().detach().numpy(),
|
|
preds.cpu().detach().numpy())
|
|
# update batch history
|
|
history.batch_update_train(loss.item(), train_acc)
|
|
|
|
# calculate accuracy val
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for val_batch in val_loader:
|
|
val_input_ids = val_batch['input_ids'].to(device)
|
|
val_labels_batch = val_batch['labels'].unsqueeze(1).to(device)
|
|
val_outputs = model(val_input_ids)
|
|
val_acc = accuracy_score(val_outputs.round().cpu().numpy(),
|
|
val_labels_batch.cpu().numpy())
|
|
history.batch_update_val(val_acc)
|
|
model.train()
|
|
|
|
# update epoch history
|
|
history.update()
|
|
|
|
epoch_end_time = time.time()
|
|
|
|
print(f"Epoch {epoch + 1}/{epochs}, Time: {epoch_end_time - epoch_start_time:.2f} sec, Loss: {history.history['loss'][-1]:.4f}, Train Acc: {history.history['train_acc'][-1]:.4f}, Val Acc: {history.history['val_acc'][-1]:.4f}")
|
|
|
|
end_training_time = time.time()
|
|
print(f"Training finished in {end_training_time - start_training_time:.2f} seconds")
|
|
|
|
|
|
################################################################################################
|
|
# Evaluation
|
|
################################################################################################
|
|
print("Starting evaluation...")
|
|
|
|
model.eval()
|
|
predictions, true_labels = [], []
|
|
with torch.no_grad():
|
|
for batch in test_loader:
|
|
input_ids = batch['input_ids'].to(device)
|
|
labels = batch['labels'].unsqueeze(1).to(device)
|
|
|
|
outputs = model(input_ids)
|
|
preds = outputs.round()
|
|
predictions.extend(preds.cpu().numpy())
|
|
true_labels.extend(labels.cpu().numpy())
|
|
|
|
accuracy = accuracy_score(true_labels, predictions)
|
|
print(f"Accuracy: {accuracy}")
|
|
|
|
|
|
################################################################################################
|
|
# Save model and hyperparameters
|
|
################################################################################################
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
|
|
ml_helper.save_model_and_hyperparameters(model, 'transformer', accuracy, timestamp,
|
|
max_len=max_len,
|
|
vocab_size=vocab_size,
|
|
embed_dim=embed_dim,
|
|
num_heads=num_heads,
|
|
num_layers=num_layers,
|
|
hidden_dim=hidden_dim,
|
|
epochs=epochs,
|
|
batch_size=batch_size,
|
|
learning_rate=learning_rate)
|
|
|
|
#save history
|
|
|
|
history_path = f'models/transformer_history_{timestamp}.json'
|
|
with open(history_path, 'w') as f:
|
|
json.dump(history.get_history(), f) |