bert with history

main
Felix Jan Michael Mucha 2025-02-16 19:08:19 +01:00
parent 0116985d2b
commit 9a6599f3f6
16 changed files with 144244 additions and 11 deletions

22
BERT.py
View File

@ -54,12 +54,12 @@ if __name__ == '__main__':
# Hyperparameter und Konfigurationen
params = {
# Training
"epochs": [1],
"epochs": [20],
"patience": [7],
"learning_rate": [1e-5, 1e-6],
"learning_rate": [1e-5],
"weight_decay": [5e-4],
# Model
"dropout": [0.6]
"dropout": [0.3]
}
# Generate permutations of hyperparameters
@ -67,9 +67,9 @@ if __name__ == '__main__':
grid_params = [dict(zip(keys, v)) for v in itertools.product(*values)]
best_params = {}
best_params_rmse = -1
# Example usage of grid_params
# for param_set in grid_params:
# print(param_set)
#Example usage of grid_params
for param_set in grid_params:
print(param_set)
print('Number of grid_params:', len(grid_params))
@ -81,18 +81,18 @@ if __name__ == '__main__':
TEST_SIZE = 0.1
VAL_SIZE = 0.1
MAX_LEN = 280
MAX_LEN = 128 #280
BATCH_SIZE = 32
N_MODELS = 1
USE_GIRD_SEARCH = True
USE_GIRD_SEARCH = False
models = []
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Daten laden und vorbereiten
embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
# embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
# gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
X, y = dataset_helper.load_preprocess_data(path_data=DATA_PATH, verbose=True)
@ -183,7 +183,7 @@ if __name__ == '__main__':
test_r2 = r2_score(test_labels, test_preds)
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
if test_rmse > best_params_rmse:
if test_rmse < best_params_rmse:
best_params_rmse = test_rmse
best_params = params

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long