bert with history

main
Felix Jan Michael Mucha 2025-02-16 19:08:19 +01:00
parent 0116985d2b
commit 9a6599f3f6
16 changed files with 144244 additions and 11 deletions

20
BERT.py
View File

@ -54,12 +54,12 @@ if __name__ == '__main__':
# Hyperparameter und Konfigurationen # Hyperparameter und Konfigurationen
params = { params = {
# Training # Training
"epochs": [1], "epochs": [20],
"patience": [7], "patience": [7],
"learning_rate": [1e-5, 1e-6], "learning_rate": [1e-5],
"weight_decay": [5e-4], "weight_decay": [5e-4],
# Model # Model
"dropout": [0.6] "dropout": [0.3]
} }
# Generate permutations of hyperparameters # Generate permutations of hyperparameters
@ -68,8 +68,8 @@ if __name__ == '__main__':
best_params = {} best_params = {}
best_params_rmse = -1 best_params_rmse = -1
#Example usage of grid_params #Example usage of grid_params
# for param_set in grid_params: for param_set in grid_params:
# print(param_set) print(param_set)
print('Number of grid_params:', len(grid_params)) print('Number of grid_params:', len(grid_params))
@ -81,18 +81,18 @@ if __name__ == '__main__':
TEST_SIZE = 0.1 TEST_SIZE = 0.1
VAL_SIZE = 0.1 VAL_SIZE = 0.1
MAX_LEN = 280 MAX_LEN = 128 #280
BATCH_SIZE = 32 BATCH_SIZE = 32
N_MODELS = 1 N_MODELS = 1
USE_GIRD_SEARCH = True USE_GIRD_SEARCH = False
models = [] models = []
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Daten laden und vorbereiten # Daten laden und vorbereiten
embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix( # embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM) # gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
X, y = dataset_helper.load_preprocess_data(path_data=DATA_PATH, verbose=True) X, y = dataset_helper.load_preprocess_data(path_data=DATA_PATH, verbose=True)
@ -183,7 +183,7 @@ if __name__ == '__main__':
test_r2 = r2_score(test_labels, test_preds) test_r2 = r2_score(test_labels, test_preds)
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}") print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
if test_rmse > best_params_rmse: if test_rmse < best_params_rmse:
best_params_rmse = test_rmse best_params_rmse = test_rmse
best_params = params best_params = params

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long