NIls Rekus 2025-02-16 18:24:33 +01:00
commit e926dbbf96
18 changed files with 150245 additions and 6 deletions

14
CNN.py
View File

@ -58,14 +58,14 @@ if __name__ == '__main__':
# Hyperparameter und Konfigurationen # Hyperparameter und Konfigurationen
params = { params = {
# Training # Training
"epochs": [5], "epochs": [20],
"patience": [7], "patience": [7],
"learning_rate": [0.001], "learning_rate": [0.001],
"weight_decay": [5e-4] , "weight_decay": [5e-4],
# Model # Model
"filter_sizes": [[2, 3, 4, 5]], "filter_sizes": [[2, 3, 4, 5]],
"num_filters": [150], "num_filters": [150],
"dropout": [0.6] "dropout": [0.3]
} }
# Generate permutations of hyperparameters # Generate permutations of hyperparameters
@ -74,8 +74,8 @@ if __name__ == '__main__':
best_params = {} best_params = {}
best_params_rmse = -1 best_params_rmse = -1
# Example usage of grid_params # Example usage of grid_params
# for param_set in grid_params: for param_set in grid_params:
# print(param_set) print(param_set)
print('Number of grid_params:', len(grid_params)) print('Number of grid_params:', len(grid_params))
# Configs # Configs
@ -89,7 +89,7 @@ if __name__ == '__main__':
BATCH_SIZE = 32 BATCH_SIZE = 32
N_MODELS = 1 N_MODELS = 1
USE_GIRD_SEARCH = True USE_GIRD_SEARCH = False
models = [] models = []
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
@ -114,6 +114,8 @@ if __name__ == '__main__':
subset_size = len(train_dataset) // N_MODELS subset_size = len(train_dataset) // N_MODELS
device = ml_helper.get_device(verbose=True, include_mps=False) device = ml_helper.get_device(verbose=True, include_mps=False)
#device = torch.device("mps")
#print('Using device:', device)
# assert if N_MODLES > 1, than grid_params should be len 1 # assert if N_MODLES > 1, than grid_params should be len 1
if N_MODELS > 1 and len(grid_params) > 1 or N_MODELS > 1 and USE_GIRD_SEARCH: if N_MODELS > 1 and len(grid_params) > 1 or N_MODELS > 1 and USE_GIRD_SEARCH:

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long