ANLP_WS24_CA2/lstm_1b.py

174 lines
5.9 KiB
Python

import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
# Automatische Geräteauswahl (Apple MPS, CUDA, CPU)
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print('Using device:', device)
class ImprovedLSTMBinaryClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, dropout=0.1):
super(ImprovedLSTMBinaryClassifier, self).__init__()
self.lstm = nn.LSTM(input_dim,
hidden_dim,
num_layers,
batch_first=True,
dropout=dropout,
bidirectional=False)
self.layer_norm = nn.LayerNorm(hidden_dim)
# Zusätzliche Fully Connected Layers ohne ReLU
self.fc1 = nn.Linear(hidden_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 32)
self.fc4 = nn.Linear(32, 1)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids):
lstm_out, _ = self.lstm(input_ids)
lstm_out = self.dropout(lstm_out)
pooled = lstm_out[:, -1, :] # Letztes verstecktes Zustand
normalized = self.layer_norm(pooled)
# Mehrere Fully Connected Schichten
x = self.fc1(normalized)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
return self.sigmoid(x)
# Training und Evaluation
if __name__ == "__main__":
# Daten laden (Annahme: Eingebettete Daten sind bereits vorbereitet)
data_path = '/content/drive/MyDrive/Colab Notebooks/ANLP_WS24_CA2/data/embedded_padded'
train_dataset = torch.load(data_path + '/train.pt')
test_dataset = torch.load(data_path + '/test.pt')
val_dataset = torch.load(data_path + '/val.pt')
# Hyperparameter
input_dim = 100
hidden_dim = 256
num_layers = 2
dropout = 0.3
batch_size = 64
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Modell initialisieren
model = ImprovedLSTMBinaryClassifier(
input_dim=input_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
dropout=dropout
).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
best_val_loss = float('inf')
best_test_accuracy = 0
patience = 3
counter = 0
history = {'train_loss': [], 'val_loss': [], 'test_acc': [], 'test_f1': []}
epochs = 5
for epoch in range(epochs):
# Training
model.train()
total_loss = 0
start_time = time.time()
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].unsqueeze(1).to(device)
outputs = model(input_ids)
loss = criterion(outputs, labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / len(train_loader)
# Validierung
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].unsqueeze(1).to(device)
outputs = model(input_ids)
val_loss += criterion(outputs, labels).item()
avg_val_loss = val_loss / len(val_loader)
# Test Evaluation
test_preds = []
test_labels = []
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].unsqueeze(1).to(device)
outputs = model(input_ids)
preds = (outputs > 0.5).float()
test_preds.extend(preds.cpu().numpy())
test_labels.extend(labels.cpu().numpy())
test_accuracy = accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds)
# History aktualisieren
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['test_acc'].append(test_accuracy)
history['test_f1'].append(test_f1)
# Lernrate anpassen
scheduler.step(avg_val_loss)
# Ausgabe
epoch_time = time.time() - start_time
print(f'Epoch {epoch+1}/{epochs} | Time: {epoch_time:.2f}s')
print(f'Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}')
print(f'Test Acc: {test_accuracy:.4f} | Test F1: {test_f1:.4f}\n')
# Bestes Modell speichern
if test_accuracy > best_test_accuracy:
best_test_accuracy = test_accuracy
torch.save(model.state_dict(), "best_lstm_model.pth")
print(f"🚀 Neues bestes Modell gespeichert (Acc: {test_accuracy:.4f})")
# Early Stopping
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
counter = 0
else:
counter += 1
if counter >= patience:
print("⛔ Early Stopping ausgelöst!")
break