Felix Jan Michael Mucha 2025-02-16 19:15:15 +01:00
commit b243b64551
1 changed files with 6 additions and 6 deletions

View File

@ -110,14 +110,14 @@ if __name__ == '__main__':
# Hyperparameter und Konfigurationen # Hyperparameter und Konfigurationen
params = { params = {
# Training # Training
"epochs": [1], "epochs": [20],
"patience": [7], "patience": [7],
"learning_rate": [1e-4], # 1e-4 "learning_rate": [1e-4], # 1e-4
"weight_decay": [5e-4], "weight_decay": [5e-4],
# Model # Model
'nhead': [2], # 5 'nhead': [2], # 5
"dropout": [0.2], "dropout": [0.2],
'hiden_dim': [1024, 2048], 'hiden_dim': [512, 1024],
'num_layers': [6] 'num_layers': [6]
} }
@ -127,8 +127,8 @@ if __name__ == '__main__':
best_params = {} best_params = {}
best_params_rmse = -1 best_params_rmse = -1
# Example usage of grid_params # Example usage of grid_params
# for param_set in grid_params: for param_set in grid_params:
# print(param_set) print(param_set)
print('Number of grid_params:', len(grid_params)) print('Number of grid_params:', len(grid_params))
# Configs # Configs
@ -142,7 +142,7 @@ if __name__ == '__main__':
BATCH_SIZE = 32 BATCH_SIZE = 32
N_MODELS = 1 N_MODELS = 1
USE_GIRD_SEARCH = True USE_GIRD_SEARCH = False
models = [] models = []
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
@ -243,7 +243,7 @@ if __name__ == '__main__':
test_r2 = r2_score(test_labels, test_preds) test_r2 = r2_score(test_labels, test_preds)
print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}") print(f"Test RMSE: {test_rmse:.4f}, Test MAE: {test_mae:.4f}, Test R²: {test_r2:.4f}")
if test_rmse > best_params_rmse: if test_rmse < best_params_rmse:
best_params_rmse = test_rmse best_params_rmse = test_rmse
best_params = params best_params = params