jetzt kleiner und nicht groesser

main
NIls Rekus 2025-02-16 19:13:47 +01:00
parent a3ba704cd8
commit 8f1a9081f9
1 changed files with 6 additions and 6 deletions

View File

@ -110,14 +110,14 @@ if __name__ == '__main__':
# Hyperparameter und Konfigurationen
params = {
# Training
"epochs": [1],
"epochs": [20],
"patience": [7],
"learning_rate": [1e-4], # 1e-4
"weight_decay": [5e-4],
# Model
'nhead': [2], # 5
"dropout": [0.2],
'hiden_dim': [1024, 2048],
'hiden_dim': [512, 1024],
'num_layers': [6]
}
@ -127,8 +127,8 @@ if __name__ == '__main__':
best_params = {}
best_params_rmse = -1
# Example usage of grid_params
# for param_set in grid_params:
# print(param_set)
for param_set in grid_params:
print(param_set)
print('Number of grid_params:', len(grid_params))
# Configs
@ -142,7 +142,7 @@ if __name__ == '__main__':
BATCH_SIZE = 32
N_MODELS = 1
USE_GIRD_SEARCH = True
USE_GIRD_SEARCH = False
models = []
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
@ -243,7 +243,7 @@ if __name__ == '__main__':
test_r2 = r2_score(test_labels, test_preds)
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
if test_rmse > best_params_rmse:
if test_rmse < best_params_rmse:
best_params_rmse = test_rmse
best_params = params