From a869c0e899f33704f0748d10450a0456d53fd5e7 Mon Sep 17 00:00:00 2001 From: arman Date: Fri, 14 Feb 2025 23:54:47 +0100 Subject: [PATCH] subset update+plots --- cnn_bootstrap_agg.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cnn_bootstrap_agg.py b/cnn_bootstrap_agg.py index d53036a..36e6599 100644 --- a/cnn_bootstrap_agg.py +++ b/cnn_bootstrap_agg.py @@ -42,7 +42,7 @@ def train_model(model, train_dataset, val_dataset, criterion, optimizer, epochs, history['train_loss'].append(train_loss) history['train_r2'].append(train_r2) - # **Validierung nach jeder Epoche** + model.eval() val_loss = 0 all_val_preds, all_val_targets = [], [] @@ -64,12 +64,12 @@ def train_model(model, train_dataset, val_dataset, criterion, optimizer, epochs, print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train R²: {train_r2:.4f}, Val R²: {val_r2:.4f}") - return history # **Gibt die Verlaufsdaten zurück** + return history def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, batch_size=32, learning_rate=0.001): models = [] - all_histories = [] # **Speichert Trainingsverlauf aller Modelle** + all_histories = [] subset_size = len(train_dataset) // num_models @@ -82,7 +82,7 @@ def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, bat subset = Subset(train_dataset, subset_indices) - # **Validierungsdaten als restliche Daten** + val_indices = list(range(start_idx, end_idx)) val_subset = Subset(train_dataset, val_indices) @@ -91,7 +91,7 @@ def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, bat optimizer = optim.Adam(model.parameters(), lr=learning_rate) history = train_model(model, subset, val_subset, criterion, optimizer, epochs, batch_size) - all_histories.append(history) # **Speichere Verlaufsdaten** + all_histories.append(history) models.append(model) return models, all_histories @@ -116,7 +116,7 @@ def plot_training_histories(histories, num_models): fig, axes = plt.subplots(1, 2, figsize=(14, 5)) - # **Links: Trainings- und Validierungsverlust** + for i in range(num_models): axes[0].plot(epochs, histories[i]['train_loss'], label=f"Train Loss Model {i+1}") axes[0].plot(epochs, histories[i]['val_loss'], linestyle='dashed', label=f"Val Loss Model {i+1}") @@ -126,7 +126,6 @@ def plot_training_histories(histories, num_models): axes[0].set_ylabel("Loss") axes[0].legend() - # **Rechts: R²-Werte für Training und Validierung** for i in range(num_models): axes[1].plot(epochs, histories[i]['train_r2'], label=f"Train R² Model {i+1}") axes[1].plot(epochs, histories[i]['val_r2'], linestyle='dashed', label=f"Val R² Model {i+1}")