NIls Rekus 2025-02-16 19:09:58 +01:00
commit a3ba704cd8
16 changed files with 144244 additions and 11 deletions

22
BERT.py
View File

@ -54,12 +54,12 @@ if __name__ == '__main__':
# Hyperparameter und Konfigurationen
params = {
# Training
"epochs": [1],
"epochs": [20],
"patience": [7],
"learning_rate": [1e-5, 1e-6],
"learning_rate": [1e-5],
"weight_decay": [5e-4],
# Model
"dropout": [0.6]
"dropout": [0.3]
}
# Generate permutations of hyperparameters
@ -67,9 +67,9 @@ if __name__ == '__main__':
grid_params = [dict(zip(keys, v)) for v in itertools.product(*values)]
best_params = {}
best_params_rmse = -1
# Example usage of grid_params
# for param_set in grid_params:
# print(param_set)
#Example usage of grid_params
for param_set in grid_params:
print(param_set)
print('Number of grid_params:', len(grid_params))
@ -81,18 +81,18 @@ if __name__ == '__main__':
TEST_SIZE = 0.1
VAL_SIZE = 0.1
MAX_LEN = 280
MAX_LEN = 128 #280
BATCH_SIZE = 32
N_MODELS = 1
USE_GIRD_SEARCH = True
USE_GIRD_SEARCH = False
models = []
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Daten laden und vorbereiten
embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
# embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
# gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
X, y = dataset_helper.load_preprocess_data(path_data=DATA_PATH, verbose=True)
@ -183,7 +183,7 @@ if __name__ == '__main__':
test_r2 = r2_score(test_labels, test_preds)
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
if test_rmse > best_params_rmse:
if test_rmse < best_params_rmse:
best_params_rmse = test_rmse
best_params = params

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long