113 lines
4.4 KiB
Python
113 lines
4.4 KiB
Python
from tqdm import tqdm
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
def train_epoch(model, train_loader, criterion, optimizer, device, history, epoch, total_epochs, bert_freeze=False, is_bert=False):
|
|
model.train()
|
|
if bert_freeze and hasattr(model, 'freeze_bert_params'):
|
|
model.freeze_bert_params()
|
|
|
|
with tqdm(train_loader, desc=f"├ Epoch {epoch + 1}/{total_epochs}") as pbar:
|
|
for batch in pbar:
|
|
optimizer.zero_grad()
|
|
if is_bert:
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
labels = batch['labels'].to(device).float()
|
|
predictions = model(input_ids, attention_mask=attention_mask).float()
|
|
else:
|
|
X_batch, y_batch = batch
|
|
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
|
|
predictions = model(X_batch).float()
|
|
labels = y_batch
|
|
|
|
loss = criterion(predictions, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
preds = predictions.detach().cpu().numpy()
|
|
labels = labels.detach().cpu().numpy()
|
|
history.batch_update_train(loss.item(), preds, labels)
|
|
|
|
# Update progress bar
|
|
pbar.set_postfix({"Train Loss": loss.item()})
|
|
|
|
history.update()
|
|
history.batch_reset()
|
|
|
|
|
|
def validate_epoch(model, val_loader, epoch, criterion, device, history, is_bert=False):
|
|
model.eval()
|
|
val_loss = 0.0
|
|
val_preds, val_labels = [], []
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
if is_bert:
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
labels = batch['labels'].to(device).float()
|
|
predictions = model(input_ids, attention_mask=attention_mask).float()
|
|
else:
|
|
X_batch, y_batch = batch
|
|
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
|
|
labels = y_batch
|
|
predictions = model(X_batch).float()
|
|
loss = criterion(predictions, labels)
|
|
val_loss += loss.item()
|
|
val_preds.extend(predictions.cpu().detach().numpy())
|
|
val_labels.extend(labels.cpu().detach().numpy())
|
|
|
|
val_rmse = history.calculate_rmse(np.array(val_preds), np.array(val_labels))
|
|
history.batch_update_val(val_loss / len(val_loader), np.array(val_preds), np.array(val_labels), epoch)
|
|
history.update()
|
|
history.batch_reset()
|
|
|
|
return val_rmse
|
|
|
|
|
|
def test_loop(model, test_loader, device, is_bert=False):
|
|
model.eval()
|
|
test_preds, test_labels = [], []
|
|
with torch.no_grad():
|
|
for batch in test_loader:
|
|
if is_bert:
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
labels = batch['labels'].to(device).float()
|
|
predictions = model(input_ids, attention_mask=attention_mask).float()
|
|
else:
|
|
X_batch, y_batch = batch
|
|
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
|
|
labels = y_batch
|
|
predictions = model(X_batch).float()
|
|
test_preds.extend(predictions.cpu().detach().numpy())
|
|
test_labels.extend(labels.cpu().detach().numpy())
|
|
|
|
return test_labels, test_preds
|
|
|
|
def ensemble_predict(models, test_loader, device, is_bert=False):
|
|
for model in models:
|
|
model.eval()
|
|
|
|
test_preds = []
|
|
with torch.no_grad():
|
|
for batch in test_loader:
|
|
if is_bert:
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
predictions = [model(input_ids, attention_mask=attention_mask).float().cpu().detach().numpy() for model in models]
|
|
else:
|
|
X_batch, y_batch = batch
|
|
X_batch, y_batch = X_batch.to(device), y_batch.to(device).float()
|
|
predictions = [model(X_batch).float().cpu().detach().numpy() for model in models]
|
|
|
|
predictions = predictions
|
|
test_preds.append(predictions)
|
|
|
|
#check if predictions are empty lists
|
|
if not test_preds[0]:
|
|
raise ValueError("No predictions were made in ensemble prediction.")
|
|
|
|
test_preds = np.concatenate(test_preds, axis=1)
|
|
return test_preds |