From 8f1a9081f9c93c943276488269f7c656bbfa02f3 Mon Sep 17 00:00:00 2001 From: Nils <1826514@stud.hs-mannheim.de> Date: Sun, 16 Feb 2025 19:13:47 +0100 Subject: [PATCH] jetzt kleiner und nicht groesser --- Transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Transformer.py b/Transformer.py index 1568cd0..01e5e64 100644 --- a/Transformer.py +++ b/Transformer.py @@ -110,14 +110,14 @@ if __name__ == '__main__': # Hyperparameter und Konfigurationen params = { # Training - "epochs": [1], + "epochs": [20], "patience": [7], "learning_rate": [1e-4], # 1e-4 "weight_decay": [5e-4], # Model 'nhead': [2], # 5 "dropout": [0.2], - 'hiden_dim': [1024, 2048], + 'hiden_dim': [512, 1024], 'num_layers': [6] } @@ -127,8 +127,8 @@ if __name__ == '__main__': best_params = {} best_params_rmse = -1 # Example usage of grid_params - # for param_set in grid_params: - # print(param_set) + for param_set in grid_params: + print(param_set) print('Number of grid_params:', len(grid_params)) # Configs @@ -142,7 +142,7 @@ if __name__ == '__main__': BATCH_SIZE = 32 N_MODELS = 1 - USE_GIRD_SEARCH = True + USE_GIRD_SEARCH = False models = [] timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -243,7 +243,7 @@ if __name__ == '__main__': test_r2 = r2_score(test_labels, test_preds) print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}") - if test_rmse > best_params_rmse: + if test_rmse < best_params_rmse: best_params_rmse = test_rmse best_params = params