subset update+plots
parent
299e01a820
commit
a869c0e899
|
|
@ -42,7 +42,7 @@ def train_model(model, train_dataset, val_dataset, criterion, optimizer, epochs,
|
|||
history['train_loss'].append(train_loss)
|
||||
history['train_r2'].append(train_r2)
|
||||
|
||||
# **Validierung nach jeder Epoche**
|
||||
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
all_val_preds, all_val_targets = [], []
|
||||
|
|
@ -64,12 +64,12 @@ def train_model(model, train_dataset, val_dataset, criterion, optimizer, epochs,
|
|||
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train R²: {train_r2:.4f}, Val R²: {val_r2:.4f}")
|
||||
|
||||
return history # **Gibt die Verlaufsdaten zurück**
|
||||
return history
|
||||
|
||||
|
||||
def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, batch_size=32, learning_rate=0.001):
|
||||
models = []
|
||||
all_histories = [] # **Speichert Trainingsverlauf aller Modelle**
|
||||
all_histories = []
|
||||
|
||||
subset_size = len(train_dataset) // num_models
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, bat
|
|||
|
||||
subset = Subset(train_dataset, subset_indices)
|
||||
|
||||
# **Validierungsdaten als restliche Daten**
|
||||
|
||||
val_indices = list(range(start_idx, end_idx))
|
||||
val_subset = Subset(train_dataset, val_indices)
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ def bootstrap_aggregation(ModelClass, train_dataset, num_models=3, epochs=5, bat
|
|||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
history = train_model(model, subset, val_subset, criterion, optimizer, epochs, batch_size)
|
||||
all_histories.append(history) # **Speichere Verlaufsdaten**
|
||||
all_histories.append(history)
|
||||
models.append(model)
|
||||
|
||||
return models, all_histories
|
||||
|
|
@ -116,7 +116,7 @@ def plot_training_histories(histories, num_models):
|
|||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
# **Links: Trainings- und Validierungsverlust**
|
||||
|
||||
for i in range(num_models):
|
||||
axes[0].plot(epochs, histories[i]['train_loss'], label=f"Train Loss Model {i+1}")
|
||||
axes[0].plot(epochs, histories[i]['val_loss'], linestyle='dashed', label=f"Val Loss Model {i+1}")
|
||||
|
|
@ -126,7 +126,6 @@ def plot_training_histories(histories, num_models):
|
|||
axes[0].set_ylabel("Loss")
|
||||
axes[0].legend()
|
||||
|
||||
# **Rechts: R²-Werte für Training und Validierung**
|
||||
for i in range(num_models):
|
||||
axes[1].plot(epochs, histories[i]['train_r2'], label=f"Train R² Model {i+1}")
|
||||
axes[1].plot(epochs, histories[i]['val_r2'], linestyle='dashed', label=f"Val R² Model {i+1}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue