bert with history
parent
0116985d2b
commit
9a6599f3f6
22
BERT.py
22
BERT.py
|
|
@ -54,12 +54,12 @@ if __name__ == '__main__':
|
|||
# Hyperparameter und Konfigurationen
|
||||
params = {
|
||||
# Training
|
||||
"epochs": [1],
|
||||
"epochs": [20],
|
||||
"patience": [7],
|
||||
"learning_rate": [1e-5, 1e-6],
|
||||
"learning_rate": [1e-5],
|
||||
"weight_decay": [5e-4],
|
||||
# Model
|
||||
"dropout": [0.6]
|
||||
"dropout": [0.3]
|
||||
}
|
||||
|
||||
# Generate permutations of hyperparameters
|
||||
|
|
@ -67,9 +67,9 @@ if __name__ == '__main__':
|
|||
grid_params = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||
best_params = {}
|
||||
best_params_rmse = -1
|
||||
# Example usage of grid_params
|
||||
# for param_set in grid_params:
|
||||
# print(param_set)
|
||||
#Example usage of grid_params
|
||||
for param_set in grid_params:
|
||||
print(param_set)
|
||||
print('Number of grid_params:', len(grid_params))
|
||||
|
||||
|
||||
|
|
@ -81,18 +81,18 @@ if __name__ == '__main__':
|
|||
TEST_SIZE = 0.1
|
||||
VAL_SIZE = 0.1
|
||||
|
||||
MAX_LEN = 280
|
||||
MAX_LEN = 128 #280
|
||||
BATCH_SIZE = 32
|
||||
|
||||
N_MODELS = 1
|
||||
USE_GIRD_SEARCH = True
|
||||
USE_GIRD_SEARCH = False
|
||||
|
||||
models = []
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
# Daten laden und vorbereiten
|
||||
embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
|
||||
gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
|
||||
# embedding_matrix, word_index, vocab_size, d_model = dataset_helper.get_embedding_matrix(
|
||||
# gloVe_path=GLOVE_PATH, emb_len=EMBEDDING_DIM)
|
||||
|
||||
X, y = dataset_helper.load_preprocess_data(path_data=DATA_PATH, verbose=True)
|
||||
|
||||
|
|
@ -183,7 +183,7 @@ if __name__ == '__main__':
|
|||
test_r2 = r2_score(test_labels, test_preds)
|
||||
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
|
||||
|
||||
if test_rmse > best_params_rmse:
|
||||
if test_rmse < best_params_rmse:
|
||||
best_params_rmse = test_rmse
|
||||
best_params = params
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue