138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
import random
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
|
from transformers import BertForSequenceClassification, AutoTokenizer
|
|
import numpy as np
|
|
|
|
import Datasets
|
|
import dataset_helper
|
|
import EarlyStopping
|
|
import ml_helper
|
|
import ml_history
|
|
import ml_train
|
|
|
|
SEED = 501
|
|
random.seed(SEED)
|
|
np.random.seed(SEED)
|
|
torch.manual_seed(SEED)
|
|
torch.cuda.manual_seed_all(SEED)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
class CustomBert(nn.Module):
|
|
def __init__(self,dropout):
|
|
super().__init__()
|
|
#Bert + Custom Layers (Not a tuple any longer -- idk why)
|
|
self.bfsc = BertForSequenceClassification.from_pretrained("bert-base-uncased")
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.classifier = nn.Linear(2,1)
|
|
# self.sm = nn.Softmax(dim=1)
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
x = self.bfsc(input_ids, attention_mask = attention_mask)
|
|
x = self.dropout(x[0])
|
|
x = self.classifier(x)
|
|
x = x.squeeze()
|
|
return x
|
|
|
|
|
|
def freeze_bert_params(self):
|
|
for param in self.bfsc.named_parameters():
|
|
param[1].requires_grad_(False)
|
|
|
|
def unfreeze_bert_params(self):
|
|
for param in self.bfsc.named_parameters():
|
|
param[1].requires_grad_(True)
|
|
|
|
if __name__ == '__main__':
|
|
# Hyperparameter und Konfigurationen
|
|
params = {
|
|
# Config
|
|
"max_len": 128,
|
|
# Training
|
|
"epochs": 10,
|
|
"patience": 7,
|
|
"batch_size": 32,
|
|
"learning_rate": 0.001,
|
|
"weight_decay": 5e-4 ,
|
|
# Model
|
|
"filter_sizes": [2, 3, 4, 5],
|
|
"num_filters": 150,
|
|
"dropout": 0.6
|
|
}
|
|
|
|
# Configs
|
|
MODEL_NAME = 'BERT.pt'
|
|
HIST_NAME = 'BERT_history'
|
|
GLOVE_PATH = 'data/glove.6B.100d.txt'
|
|
DATA_PATH = 'data/hack.csv'
|
|
FREEZE_BERT = False
|
|
EMBEDDING_DIM = 100
|
|
TEST_SIZE = 0.1
|
|
VAL_SIZE = 0.1
|
|
|
|
# Daten laden und vorbereiten
|
|
embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
|
|
gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
|
|
|
|
X, y = dataset_helper.load_preprocess_data(path_data=DATA_PATH, verbose=True)
|
|
|
|
# Aufteilen der Daten
|
|
data_split = dataset_helper.split_data(X, y, test_size=TEST_SIZE, val_size=VAL_SIZE)
|
|
|
|
# Initialize BertTokenizer from Pretrained
|
|
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased",do_lower_case=True)
|
|
print("Tokenizer Initialized")
|
|
|
|
# Dataset und DataLoader
|
|
train_dataset = Datasets.BertDataset(tokenizer, data_split['train']['X'], data_split['train']['y'], max_len=params["max_len"])
|
|
val_dataset = Datasets.BertDataset(tokenizer, data_split['val']['X'], data_split['val']['y'], max_len=params["max_len"])
|
|
test_dataset = Datasets.BertDataset(tokenizer, data_split['test']['X'], data_split['test']['y'], max_len=params["max_len"])
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=params["batch_size"], shuffle=False)
|
|
test_loader = DataLoader(test_dataset, batch_size=params["batch_size"], shuffle=False)
|
|
|
|
# Modell initialisieren
|
|
model = CustomBert(dropout=params["dropout"])
|
|
|
|
device = ml_helper.get_device(verbose=True, include_mps=False)
|
|
model = model.to(device)
|
|
|
|
criterion = nn.MSELoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
|
|
early_stopping = EarlyStopping.EarlyStoppingCallback(patience=params["patience"], verbose=True, model_name=MODEL_NAME)
|
|
|
|
hist = ml_history.History()
|
|
|
|
# Training und Validierung
|
|
for epoch in range(params["epochs"]):
|
|
ml_train.train_epoch(model, train_loader, criterion, optimizer, device, hist, epoch, params["epochs"], bert_freeze=FREEZE_BERT, is_bert=True)
|
|
|
|
val_rmse = ml_train.validate_epoch(model, val_loader, epoch, criterion, device, hist, is_bert=True)
|
|
|
|
early_stopping(val_rmse, model)
|
|
if early_stopping.early_stop:
|
|
print("Early stopping triggered.")
|
|
break
|
|
|
|
# Load best model
|
|
model.load_state_dict(torch.load('models/checkpoints/' + MODEL_NAME))
|
|
|
|
# Test Evaluation
|
|
test_labels, test_preds = ml_train.test_loop(model, test_loader, device, is_bert=True)
|
|
|
|
hist.add_test_results(test_labels, test_preds)
|
|
|
|
# save training history
|
|
hist.save_history(HIST_NAME)
|
|
|
|
# RMSE, MAE und R²-Score für das Test-Set
|
|
test_mae = mean_absolute_error(test_labels, test_preds)
|
|
test_rmse = np.sqrt(mean_squared_error(test_labels, test_preds))
|
|
test_r2 = r2_score(test_labels, test_preds)
|
|
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
|