Merge branch 'main' of https://gitty.informatik.hs-mannheim.de/3016498/ANLP_WS24_CA2
commit
e926dbbf96
14
CNN.py
14
CNN.py
|
|
@ -58,14 +58,14 @@ if __name__ == '__main__':
|
|||
# Hyperparameter und Konfigurationen
|
||||
params = {
|
||||
# Training
|
||||
"epochs": [5],
|
||||
"epochs": [20],
|
||||
"patience": [7],
|
||||
"learning_rate": [0.001],
|
||||
"weight_decay": [5e-4] ,
|
||||
"weight_decay": [5e-4],
|
||||
# Model
|
||||
"filter_sizes": [[2, 3, 4, 5]],
|
||||
"num_filters": [150],
|
||||
"dropout": [0.6]
|
||||
"dropout": [0.3]
|
||||
}
|
||||
|
||||
# Generate permutations of hyperparameters
|
||||
|
|
@ -74,8 +74,8 @@ if __name__ == '__main__':
|
|||
best_params = {}
|
||||
best_params_rmse = -1
|
||||
# Example usage of grid_params
|
||||
# for param_set in grid_params:
|
||||
# print(param_set)
|
||||
for param_set in grid_params:
|
||||
print(param_set)
|
||||
print('Number of grid_params:', len(grid_params))
|
||||
|
||||
# Configs
|
||||
|
|
@ -89,7 +89,7 @@ if __name__ == '__main__':
|
|||
BATCH_SIZE = 32
|
||||
|
||||
N_MODELS = 1
|
||||
USE_GIRD_SEARCH = True
|
||||
USE_GIRD_SEARCH = False
|
||||
|
||||
models = []
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
|
@ -114,6 +114,8 @@ if __name__ == '__main__':
|
|||
|
||||
subset_size = len(train_dataset) // N_MODELS
|
||||
device = ml_helper.get_device(verbose=True, include_mps=False)
|
||||
#device = torch.device("mps")
|
||||
#print('Using device:', device)
|
||||
|
||||
# assert if N_MODLES > 1, than grid_params should be len 1
|
||||
if N_MODELS > 1 and len(grid_params) > 1 or N_MODELS > 1 and USE_GIRD_SEARCH:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue