ANLP_WS24_CA2/ml_train.py

113 lines
4.4 KiB
Python

from tqdm import tqdm
import torch
import numpy as np
def train_epoch(model, train_loader, criterion, optimizer, device, history, epoch, total_epochs, bert_freeze=False, is_bert=False):
model.train()
if bert_freeze and hasattr(model, 'freeze_bert_params'):
model.freeze_bert_params()
with tqdm(train_loader, desc=f"├ Epoch {epoch + 1}/{total_epochs}") as pbar:
for batch in pbar:
optimizer.zero_grad()
if is_bert:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device).float()
predictions = model(input_ids, attention_mask=attention_mask).float()
else:
X_batch, y_batch = batch
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
predictions = model(X_batch).float()
labels = y_batch
loss = criterion(predictions, labels)
loss.backward()
optimizer.step()
preds = predictions.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
history.batch_update_train(loss.item(), preds, labels)
# Update progress bar
pbar.set_postfix({"Train Loss": loss.item()})
history.update()
history.batch_reset()
def validate_epoch(model, val_loader, epoch, criterion, device, history, is_bert=False):
model.eval()
val_loss = 0.0
val_preds, val_labels = [], []
with torch.no_grad():
for batch in val_loader:
if is_bert:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device).float()
predictions = model(input_ids, attention_mask=attention_mask).float()
else:
X_batch, y_batch = batch
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
labels = y_batch
predictions = model(X_batch).float()
loss = criterion(predictions, labels)
val_loss += loss.item()
val_preds.extend(predictions.cpu().detach().numpy())
val_labels.extend(labels.cpu().detach().numpy())
val_rmse = history.calculate_rmse(np.array(val_preds), np.array(val_labels))
history.batch_update_val(val_loss / len(val_loader), np.array(val_preds), np.array(val_labels), epoch)
history.update()
history.batch_reset()
return val_rmse
def test_loop(model, test_loader, device, is_bert=False):
model.eval()
test_preds, test_labels = [], []
with torch.no_grad():
for batch in test_loader:
if is_bert:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device).float()
predictions = model(input_ids, attention_mask=attention_mask).float()
else:
X_batch, y_batch = batch
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
labels = y_batch
predictions = model(X_batch).float()
test_preds.extend(predictions.cpu().detach().numpy())
test_labels.extend(labels.cpu().detach().numpy())
return test_labels, test_preds
def ensemble_predict(models, test_loader, device, is_bert=False):
for model in models:
model.eval()
test_preds = []
with torch.no_grad():
for batch in test_loader:
if is_bert:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
predictions = [model(input_ids, attention_mask=attention_mask).float().cpu().detach().numpy() for model in models]
else:
X_batch, y_batch = batch
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
predictions = [model(X_batch).float().cpu().detach().numpy() for model in models]
predictions = predictions
test_preds.append(predictions)
#check if predictions are empty lists
if not test_preds[0]:
raise ValueError("No predictions were made in ensemble prediction.")
test_preds = np.concatenate(test_preds, axis=1)
return test_preds