Merge branch 'main' of https://gitty.informatik.hs-mannheim.de/3016498/ANLP_WS24_CA2
commit
a3ba704cd8
22
BERT.py
22
BERT.py
|
|
@ -54,12 +54,12 @@ if __name__ == '__main__':
|
|||
# Hyperparameter und Konfigurationen
|
||||
params = {
|
||||
# Training
|
||||
"epochs": [1],
|
||||
"epochs": [20],
|
||||
"patience": [7],
|
||||
"learning_rate": [1e-5, 1e-6],
|
||||
"learning_rate": [1e-5],
|
||||
"weight_decay": [5e-4],
|
||||
# Model
|
||||
"dropout": [0.6]
|
||||
"dropout": [0.3]
|
||||
}
|
||||
|
||||
# Generate permutations of hyperparameters
|
||||
|
|
@ -67,9 +67,9 @@ if __name__ == '__main__':
|
|||
grid_params = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||
best_params = {}
|
||||
best_params_rmse = -1
|
||||
# Example usage of grid_params
|
||||
# for param_set in grid_params:
|
||||
# print(param_set)
|
||||
#Example usage of grid_params
|
||||
for param_set in grid_params:
|
||||
print(param_set)
|
||||
print('Number of grid_params:', len(grid_params))
|
||||
|
||||
|
||||
|
|
@ -81,18 +81,18 @@ if __name__ == '__main__':
|
|||
TEST_SIZE = 0.1
|
||||
VAL_SIZE = 0.1
|
||||
|
||||
MAX_LEN = 280
|
||||
MAX_LEN = 128 #280
|
||||
BATCH_SIZE = 32
|
||||
|
||||
N_MODELS = 1
|
||||
USE_GIRD_SEARCH = True
|
||||
USE_GIRD_SEARCH = False
|
||||
|
||||
models = []
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
# Daten laden und vorbereiten
|
||||
embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
|
||||
gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
|
||||
# embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
|
||||
# gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
|
||||
|
||||
X, y = dataset_helper.load_preprocess_data(path_data=DATA_PATH, verbose=True)
|
||||
|
||||
|
|
@ -183,7 +183,7 @@ if __name__ == '__main__':
|
|||
test_r2 = r2_score(test_labels, test_preds)
|
||||
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
|
||||
|
||||
if test_rmse > best_params_rmse:
|
||||
if test_rmse < best_params_rmse:
|
||||
best_params_rmse = test_rmse
|
||||
best_params = params
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue