305 lines
9.2 KiB
Plaintext
305 lines
9.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## CNN 1b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load Packages"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
|
"from sklearn.metrics import accuracy_score, f1_score, confusion_matrix\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import seaborn as sns\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Datensatz laden und DatenLoader"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"data_path = 'data/embedded_padded'\n",
|
|
"\n",
|
|
"BATCH_SIZE = 32\n",
|
|
"\n",
|
|
"train_dataset = torch.load(data_path + '/train.pt')\n",
|
|
"test_dataset = torch.load(data_path + '/test.pt')\n",
|
|
"val_dataset = torch.load(data_path + '/val.pt')\n",
|
|
"\n",
|
|
"# DataLoader vorbereiten\n",
|
|
"\n",
|
|
"\n",
|
|
"def collate_fn(batch):\n",
|
|
" input_ids = torch.stack([item[\"input_ids\"] for item in batch]) \n",
|
|
" labels = torch.tensor([item[\"labels\"] for item in batch], dtype=torch.float32).unsqueeze(1) \n",
|
|
" return input_ids, labels\n",
|
|
"\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)\n",
|
|
"val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)\n",
|
|
"test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"### CNN-Modell definieren\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"class HumorCNN(nn.Module):\n",
|
|
" def __init__(self, embedding_dim=100):\n",
|
|
" super(HumorCNN, self).__init__()\n",
|
|
"\n",
|
|
" self.conv1 = nn.Conv2d(1, 50, (3, embedding_dim))\n",
|
|
" self.conv2 = nn.Conv2d(1, 50, (4, embedding_dim))\n",
|
|
" self.conv3 = nn.Conv2d(1, 50, (5, embedding_dim))\n",
|
|
"\n",
|
|
" self.bn1 = nn.BatchNorm1d(50)\n",
|
|
" self.bn2 = nn.BatchNorm1d(50)\n",
|
|
" self.bn3 = nn.BatchNorm1d(50)\n",
|
|
"\n",
|
|
" self.fc = nn.Linear(150, 1)\n",
|
|
"\n",
|
|
" self.dropout = nn.Dropout(0.5)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" x = x.unsqueeze(1) \n",
|
|
"\n",
|
|
" x1 = F.relu(self.bn1(self.conv1(x).squeeze(3)))\n",
|
|
" x2 = F.relu(self.bn2(self.conv2(x).squeeze(3)))\n",
|
|
" x3 = F.relu(self.bn3(self.conv3(x).squeeze(3)))\n",
|
|
" \n",
|
|
" x1 = F.max_pool1d(x1, x1.size(2)).squeeze(2)\n",
|
|
" x2 = F.max_pool1d(x2, x2.size(2)).squeeze(2)\n",
|
|
" x3 = F.max_pool1d(x3, x3.size(2)).squeeze(2)\n",
|
|
"\n",
|
|
" x = torch.cat((x1, x2, x3), 1)\n",
|
|
" \n",
|
|
" x = self.dropout(x)\n",
|
|
" x = self.fc(x)\n",
|
|
" return torch.sigmoid(x)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"### Training des Modells\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"# Automatische Geräteauswahl (Apple MPS, CUDA, CPU)\n",
|
|
"if torch.backends.mps.is_available():\n",
|
|
" device = torch.device(\"mps\") \n",
|
|
"elif torch.cuda.is_available():\n",
|
|
" device = torch.device(\"cuda\") \n",
|
|
"else:\n",
|
|
" device = torch.device(\"cpu\") \n",
|
|
"\n",
|
|
"model = HumorCNN().to(device)\n",
|
|
"\n",
|
|
"criterion = nn.BCELoss()\n",
|
|
"optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) \n",
|
|
"\n",
|
|
"scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True) \n",
|
|
"\n",
|
|
"epochs = 10 # Nur 10 Epochen\n",
|
|
"best_val_loss = float('inf')\n",
|
|
"best_test_accuracy = 0\n",
|
|
"patience = 3\n",
|
|
"counter = 0\n",
|
|
"\n",
|
|
"for epoch in range(epochs):\n",
|
|
" model.train()\n",
|
|
" total_loss = 0\n",
|
|
"\n",
|
|
" for texts, labels in train_loader:\n",
|
|
" texts, labels = texts.to(device), labels.to(device)\n",
|
|
" optimizer.zero_grad()\n",
|
|
" outputs = model(texts)\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" total_loss += loss.item()\n",
|
|
" \n",
|
|
" avg_train_loss = total_loss / len(train_loader)\n",
|
|
"\n",
|
|
" # ========================\n",
|
|
" # Validierung\n",
|
|
" # ========================\n",
|
|
" model.eval()\n",
|
|
" val_loss = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for texts, labels in val_loader:\n",
|
|
" texts, labels = texts.to(device), labels.to(device)\n",
|
|
" outputs = model(texts)\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" val_loss += loss.item()\n",
|
|
" \n",
|
|
" avg_val_loss = val_loss / len(val_loader)\n",
|
|
"\n",
|
|
" # ========================\n",
|
|
" # Evaluierung mit Testdaten\n",
|
|
" # ========================\n",
|
|
" test_preds = []\n",
|
|
" test_labels = []\n",
|
|
" with torch.no_grad():\n",
|
|
" for texts, labels in test_loader:\n",
|
|
" texts, labels = texts.to(device), labels.to(device)\n",
|
|
" outputs = model(texts)\n",
|
|
" predictions = (outputs > 0.5).float()\n",
|
|
" test_preds.extend(predictions.cpu().numpy())\n",
|
|
" test_labels.extend(labels.cpu().numpy())\n",
|
|
"\n",
|
|
" test_accuracy = accuracy_score(test_labels, test_preds)\n",
|
|
" test_f1 = f1_score(test_labels, test_preds)\n",
|
|
"\n",
|
|
" print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Acc: {test_accuracy:.4f}, Test F1: {test_f1:.4f}')\n",
|
|
" \n",
|
|
" # ========================\n",
|
|
" # Lernraten-Anpassung\n",
|
|
" # ========================\n",
|
|
" scheduler.step(avg_val_loss)\n",
|
|
"\n",
|
|
" # ========================\n",
|
|
" # Bestes Modell speichern\n",
|
|
" # ========================\n",
|
|
" if test_accuracy > best_test_accuracy:\n",
|
|
" best_test_accuracy = test_accuracy\n",
|
|
" torch.save(model.state_dict(), \"best_model.pth\")\n",
|
|
" print(\"🚀 Bestes Modell gespeichert mit Test-Accuracy:\", test_accuracy)\n",
|
|
"\n",
|
|
" # ========================\n",
|
|
" # Early Stopping\n",
|
|
" # ========================\n",
|
|
" if avg_val_loss < best_val_loss:\n",
|
|
" best_val_loss = avg_val_loss\n",
|
|
" counter = 0\n",
|
|
" else:\n",
|
|
" counter += 1\n",
|
|
" if counter >= patience:\n",
|
|
" print(\"Early Stopping ausgelöst!\")\n",
|
|
" break\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"### Finale Evaluierung & Confusion Matrix\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"model.load_state_dict(torch.load(\"best_model.pth\")) \n",
|
|
"model.eval()\n",
|
|
"\n",
|
|
"all_preds = []\n",
|
|
"all_labels = []\n",
|
|
"\n",
|
|
"with torch.no_grad(): \n",
|
|
" for texts, labels in test_loader:\n",
|
|
" texts, labels = texts.to(device), labels.to(device)\n",
|
|
" outputs = model(texts)\n",
|
|
" predictions = (outputs > 0.5).float()\n",
|
|
" all_preds.extend(predictions.cpu().numpy())\n",
|
|
" all_labels.extend(labels.cpu().numpy())\n",
|
|
"\n",
|
|
"all_preds = [int(p[0]) for p in all_preds]\n",
|
|
"all_labels = [int(l[0]) for l in all_labels]\n",
|
|
"\n",
|
|
"accuracy = accuracy_score(all_labels, all_preds)\n",
|
|
"f1 = f1_score(all_labels, all_preds)\n",
|
|
"\n",
|
|
"print(f'🚀 Finale Test Accuracy: {accuracy:.4f}')\n",
|
|
"print(f'🚀 Finale Test F1 Score: {f1:.4f}')\n",
|
|
"\n",
|
|
"# Confusion Matrix\n",
|
|
"conf_matrix = confusion_matrix(all_labels, all_preds)\n",
|
|
"\n",
|
|
"plt.figure(figsize=(6,5))\n",
|
|
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap=\"Blues\", xticklabels=['No Humor', 'Humor'], yticklabels=['No Humor', 'Humor'])\n",
|
|
"plt.xlabel(\"Predicted Label\")\n",
|
|
"plt.ylabel(\"True Label\")\n",
|
|
"plt.title(\"Confusion Matrix\")\n",
|
|
"plt.show()\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|