Merge branch 'main' of https://gitty.informatik.hs-mannheim.de/3016498/ANLP_WS24_CA2
commit
b243b64551
|
|
@ -110,14 +110,14 @@ if __name__ == '__main__':
|
||||||
# Hyperparameter und Konfigurationen
|
# Hyperparameter und Konfigurationen
|
||||||
params = {
|
params = {
|
||||||
# Training
|
# Training
|
||||||
"epochs": [1],
|
"epochs": [20],
|
||||||
"patience": [7],
|
"patience": [7],
|
||||||
"learning_rate": [1e-4], # 1e-4
|
"learning_rate": [1e-4], # 1e-4
|
||||||
"weight_decay": [5e-4],
|
"weight_decay": [5e-4],
|
||||||
# Model
|
# Model
|
||||||
'nhead': [2], # 5
|
'nhead': [2], # 5
|
||||||
"dropout": [0.2],
|
"dropout": [0.2],
|
||||||
'hiden_dim': [1024, 2048],
|
'hiden_dim': [512, 1024],
|
||||||
'num_layers': [6]
|
'num_layers': [6]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,8 +127,8 @@ if __name__ == '__main__':
|
||||||
best_params = {}
|
best_params = {}
|
||||||
best_params_rmse = -1
|
best_params_rmse = -1
|
||||||
# Example usage of grid_params
|
# Example usage of grid_params
|
||||||
# for param_set in grid_params:
|
for param_set in grid_params:
|
||||||
# print(param_set)
|
print(param_set)
|
||||||
print('Number of grid_params:', len(grid_params))
|
print('Number of grid_params:', len(grid_params))
|
||||||
|
|
||||||
# Configs
|
# Configs
|
||||||
|
|
@ -142,7 +142,7 @@ if __name__ == '__main__':
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
|
|
||||||
N_MODELS = 1
|
N_MODELS = 1
|
||||||
USE_GIRD_SEARCH = True
|
USE_GIRD_SEARCH = False
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
|
@ -243,7 +243,7 @@ if __name__ == '__main__':
|
||||||
test_r2 = r2_score(test_labels, test_preds)
|
test_r2 = r2_score(test_labels, test_preds)
|
||||||
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
|
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
|
||||||
|
|
||||||
if test_rmse > best_params_rmse:
|
if test_rmse < best_params_rmse:
|
||||||
best_params_rmse = test_rmse
|
best_params_rmse = test_rmse
|
||||||
best_params = params
|
best_params = params
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue