""" This file contains the transformer model. """ # TODO refactor the code # TODO create ml helper script # TODO create ml evaluation script # TODO track overfitting better # TODO validate model in training (accuracy, loss, etc) # TODO set length to a constant value which is the max length of the sentences or nearly # TODO user gloVe embeddings #TODO: add attention mask # TODO: add positional encoding #TODO: add dropout (if needed) import time import json import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from transformers import AdamW from sklearn.metrics import accuracy_score import ml_helper import ml_history class TransformerBinaryClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_heads, num_layers, hidden_dim, dropout=0.1): super(TransformerBinaryClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.transformer = nn.Transformer(embed_dim, num_heads, num_layers, num_layers, hidden_dim, dropout) self.fc = nn.Linear(embed_dim, 1) self.sigmoid = nn.Sigmoid() def forward(self, input_ids): input_ids = input_ids.long() embedded = self.embedding(input_ids) transformer_output = self.transformer(embedded, embedded) pooled_output = transformer_output.mean(dim=1) logits = self.fc(pooled_output) return self.sigmoid(logits) if __name__ == "__main__": # Load the data data_path = 'data/idx_based_padded' train_dataset = torch.load(data_path + '/train.pt') test_dataset = torch.load(data_path + '/test.pt') val_dataset = torch.load(data_path + '/val.pt') # +2 for padding and unk tokens vocab_size = train_dataset.vocab_size + 2 embed_dim = 100 #train_dataset.emb_dim # NOTE: Info comes from data explore notebook: 280 is max length, # 139 contains 80% and 192 contains 95% of the data max_len = 280 device = ml_helper.get_device(verbose=True) # Model hyperparameters num_heads = 2 num_layers = 2 hidden_dim = 256 model = TransformerBinaryClassifier(vocab_size, embed_dim, num_heads, num_layers, hidden_dim) # Training parameters epochs = 3 #3 batch_size = 8 learning_rate = 2e-5 # Optimizer and loss function optimizer = AdamW(model.parameters(), lr=learning_rate) criterion = nn.BCEWithLogitsLoss() # Data loaders train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) ################################################################################################ # Training ################################################################################################ # Initialize the history history = ml_history.History() # Model to device model.to(device) print("Starting training...") start_training_time = time.time() # Training loop model.train() for epoch in range(epochs): # init batch tracking epoch_start_time = time.time() history.batch_reset() for batch in train_loader: optimizer.zero_grad() # prepare batch input_ids = batch['input_ids'].to(device) labels = batch['labels'].unsqueeze(1).to(device) # forward pass outputs = model(input_ids) loss = criterion(outputs, labels) # backward pass loss.backward() optimizer.step() # calculate accuracy train preds = outputs.round() train_acc = accuracy_score(labels.cpu().detach().numpy(), preds.cpu().detach().numpy()) # update batch history history.batch_update_train(loss.item(), train_acc) # calculate accuracy val model.eval() with torch.no_grad(): for val_batch in val_loader: val_input_ids = val_batch['input_ids'].to(device) val_labels_batch = val_batch['labels'].unsqueeze(1).to(device) val_outputs = model(val_input_ids) val_acc = accuracy_score(val_outputs.round().cpu().numpy(), val_labels_batch.cpu().numpy()) history.batch_update_val(val_acc) model.train() # update epoch history history.update() epoch_end_time = time.time() print(f"Epoch {epoch + 1}/{epochs}, Time: {epoch_end_time - epoch_start_time:.2f} sec, Loss: {history.history['loss'][-1]:.4f}, Train Acc: {history.history['train_acc'][-1]:.4f}, Val Acc: {history.history['val_acc'][-1]:.4f}") end_training_time = time.time() print(f"Training finished in {end_training_time - start_training_time:.2f} seconds") ################################################################################################ # Evaluation ################################################################################################ print("Starting evaluation...") model.eval() predictions, true_labels = [], [] with torch.no_grad(): for batch in test_loader: input_ids = batch['input_ids'].to(device) labels = batch['labels'].unsqueeze(1).to(device) outputs = model(input_ids) preds = outputs.round() predictions.extend(preds.cpu().numpy()) true_labels.extend(labels.cpu().numpy()) accuracy = accuracy_score(true_labels, predictions) print(f"Accuracy: {accuracy}") ################################################################################################ # Save model and hyperparameters ################################################################################################ timestamp = time.strftime("%Y%m%d-%H%M%S") ml_helper.save_model_and_hyperparameters(model, 'transformer', accuracy, timestamp, max_len=max_len, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate) #save history history_path = f'models/transformer_history_{timestamp}.json' with open(history_path, 'w') as f: json.dump(history.get_history(), f)