subset update+plots

main
arman 2025-02-14 23:54:47 +01:00
parent 299e01a820
commit a869c0e899
1 changed files with 6 additions and 7 deletions

View File

@ -42,7 +42,7 @@ def train_model(model, train_dataset, val_dataset, criterion, optimizer, epochs,
history['train_loss'].append(train_loss)
history['train_r2'].append(train_r2)
# **Validierung nach jeder Epoche**
model.eval()
val_loss = 0
all_val_preds, all_val_targets = [], []
@ -64,12 +64,12 @@ def train_model(model, train_dataset, val_dataset, criterion, optimizer, epochs,
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train R²: {train_r2:.4f}, Val R²: {val_r2:.4f}")
return history # **Gibt die Verlaufsdaten zurück**
return history
def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, batch_size=32, learning_rate=0.001):
models = []
all_histories = [] # **Speichert Trainingsverlauf aller Modelle**
all_histories = []
subset_size = len(train_dataset) // num_models
@ -82,7 +82,7 @@ def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, bat
subset = Subset(train_dataset, subset_indices)
# **Validierungsdaten als restliche Daten**
val_indices = list(range(start_idx, end_idx))
val_subset = Subset(train_dataset, val_indices)
@ -91,7 +91,7 @@ def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, bat
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
history = train_model(model, subset, val_subset, criterion, optimizer, epochs, batch_size)
all_histories.append(history) # **Speichere Verlaufsdaten**
all_histories.append(history)
models.append(model)
return models, all_histories
@ -116,7 +116,7 @@ def plot_training_histories(histories, num_models):
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# **Links: Trainings- und Validierungsverlust**
for i in range(num_models):
axes[0].plot(epochs, histories[i]['train_loss'], label=f"Train Loss Model {i+1}")
axes[0].plot(epochs, histories[i]['val_loss'], linestyle='dashed', label=f"Val Loss Model {i+1}")
@ -126,7 +126,6 @@ def plot_training_histories(histories, num_models):
axes[0].set_ylabel("Loss")
axes[0].legend()
# **Rechts: R²-Werte für Training und Validierung**
for i in range(num_models):
axes[1].plot(epochs, histories[i]['train_r2'], label=f"Train R² Model {i+1}")
axes[1].plot(epochs, histories[i]['val_r2'], linestyle='dashed', label=f"Val R² Model {i+1}")