ANLP_WS24_CA2/cnn_class.ipynb

847 lines
112 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CNN 1b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load Packages"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"from tqdm import tqdm\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"#lokal imports\n",
"from dataset_generator import create_embedding_matrix, split_data, load_preprocess_data\n",
"from HumorDataset import TextDataset\n",
"from BalancedCELoss import BalancedCELoss\n",
"import ml_helper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Datensatz laden und DatenLoader"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(0)\n",
"np.random.seed(0)\n",
"\n",
"\n",
"best_model_filename = 'best_cnn_class_model.pt'\n",
"#device = ml_helper.get_device(verbose=True)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"### CNN-Modell definieren\n"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameter und Konfigurationen\n",
"params = {\n",
" \"embedding_dim\": 100,\n",
" \"filter_sizes\": [2, 3, 4, 5],\n",
" \"num_filters\": 150,\n",
" \"batch_size\": 32,\n",
" \"learning_rate\": 0.001,\n",
" \"epochs\": 25,\n",
" \"glove_path\": 'data/glove.6B.100d.txt',\n",
" \"max_len\": 280,\n",
" \"test_size\": 0.1,\n",
" \"val_size\": 0.1,\n",
" \"patience\": 5,\n",
" \"data_path\": 'data/hack.csv',\n",
" \"dropout\": 0.6,\n",
" \"weight_decay\": 5e-4,\n",
" \"alpha\": 0.1, # Alpha für die Balance in der Loss-Funktion\n",
" # patience for early stopping\n",
" 'early_stopping_patience': 5 # 5 (3 to 10)\n",
"\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"class CNNBinaryClassifier(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, filter_sizes, num_filters, embedding_matrix, dropout):\n",
" super(CNNBinaryClassifier, self).__init__()\n",
" self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)\n",
" self.convs = nn.ModuleList([\n",
" nn.Sequential(\n",
" nn.Conv2d(1, num_filters, (fs, embedding_dim)),\n",
" nn.BatchNorm2d(num_filters),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d((params[\"max_len\"] - fs + 1, 1)),\n",
" nn.Dropout(dropout)\n",
" )\n",
" for fs in filter_sizes\n",
" ])\n",
" self.fc1 = nn.Linear(len(filter_sizes) * num_filters, 128)\n",
" self.fc2 = nn.Linear(128, 2) # 2 Klassen, daher 2 Outputs für CrossEntropyLoss\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" x = self.embedding(x).unsqueeze(1)\n",
" conv_outputs = [conv(x).squeeze(3).squeeze(2) for conv in self.convs]\n",
" x = torch.cat(conv_outputs, 1)\n",
" x = torch.relu(self.fc1(x))\n",
" x = self.dropout(x)\n",
" return self.fc2(x) # 2 Outputs, CrossEntropyLoss übernimmt die Softmax"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"400002\n",
"vocab_size: 400002, d_model: 100\n"
]
}
],
"source": [
"# Daten laden\n",
"embedding_matrix, word_index, vocab_size, d_model = create_embedding_matrix(\n",
" gloVe_path=params[\"glove_path\"], emb_len=params[\"embedding_dim\"]\n",
")\n",
"X, y = load_preprocess_data(path_data=params[\"data_path\"])\n"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train 3945 3945\n",
"test 494 494\n",
"val 493 493\n"
]
}
],
"source": [
"\n",
"# Daten splitten\n",
"data_split = split_data(X, y, test_size=params[\"test_size\"], val_size=params[\"val_size\"])\n",
"train_dataset = TextDataset(data_split['train']['X'], data_split['train']['y'], word_index, max_len=params[\"max_len\"])\n",
"val_dataset = TextDataset(data_split['val']['X'], data_split['val']['y'], word_index, max_len=params[\"max_len\"])\n",
"test_dataset = TextDataset(data_split['test']['X'], data_split['test']['y'], word_index, max_len=params[\"max_len\"])\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=params[\"batch_size\"], shuffle=True)\n",
"val_loader = DataLoader(val_dataset, batch_size=params[\"batch_size\"], shuffle=False)\n",
"test_loader = DataLoader(test_dataset, batch_size=params[\"batch_size\"], shuffle=False)\n"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"import EarlyStopping as EarlyStopping\n",
"# Modell initialisieren\n",
"model = CNNBinaryClassifier(\n",
" vocab_size=vocab_size,\n",
" embedding_dim=params[\"embedding_dim\"],\n",
" filter_sizes=params[\"filter_sizes\"],\n",
" num_filters=params[\"num_filters\"],\n",
" embedding_matrix=embedding_matrix,\n",
" dropout=params[\"dropout\"]\n",
")\n",
"model = model.to(device)\n",
"\n",
"# BalancedCELoss verwenden\n",
"criterion = BalancedCELoss(alpha=params[\"alpha\"])\n",
"optimizer = optim.Adam(model.parameters(), lr=params[\"learning_rate\"], weight_decay=params[\"weight_decay\"])\n",
"early_stopping = EarlyStopping.EarlyStopping(patience=params['early_stopping_patience'], verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1/25: 100%|██████████| 124/124 [00:38<00:00, 3.26it/s, Train Loss=0.734]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 1, Train Loss: 105.9015, Val Loss: 12.5712\n",
"Train Accuracy: 0.4958, Val Accuracy: 0.5314\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2/25: 100%|██████████| 124/124 [00:36<00:00, 3.39it/s, Train Loss=0.79] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 2, Train Loss: 91.0446, Val Loss: 12.5252\n",
"Train Accuracy: 0.5141, Val Accuracy: 0.5274\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 3/25: 100%|██████████| 124/124 [00:36<00:00, 3.39it/s, Train Loss=0.826]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 3, Train Loss: 93.3248, Val Loss: 12.5840\n",
"Train Accuracy: 0.5039, Val Accuracy: 0.5254\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 4/25: 100%|██████████| 124/124 [00:36<00:00, 3.40it/s, Train Loss=0.7] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 4, Train Loss: 92.2199, Val Loss: 12.5006\n",
"Train Accuracy: 0.4984, Val Accuracy: 0.5517\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5/25: 100%|██████████| 124/124 [00:37<00:00, 3.29it/s, Train Loss=0.768]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 5, Train Loss: 91.2856, Val Loss: 11.9061\n",
"Train Accuracy: 0.5290, Val Accuracy: 0.5862\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 6/25: 100%|██████████| 124/124 [00:40<00:00, 3.09it/s, Train Loss=0.694]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 6, Train Loss: 90.5596, Val Loss: 11.3011\n",
"Train Accuracy: 0.5430, Val Accuracy: 0.6126\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 7/25: 100%|██████████| 124/124 [00:39<00:00, 3.18it/s, Train Loss=0.771]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 7, Train Loss: 89.5808, Val Loss: 11.5313\n",
"Train Accuracy: 0.5582, Val Accuracy: 0.6207\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 8/25: 100%|██████████| 124/124 [00:36<00:00, 3.38it/s, Train Loss=0.697]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 8, Train Loss: 88.8963, Val Loss: 11.0529\n",
"Train Accuracy: 0.5648, Val Accuracy: 0.6308\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 9/25: 100%|██████████| 124/124 [00:37<00:00, 3.34it/s, Train Loss=0.846]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 9, Train Loss: 88.4877, Val Loss: 11.0292\n",
"Train Accuracy: 0.5706, Val Accuracy: 0.6207\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 10/25: 100%|██████████| 124/124 [00:36<00:00, 3.41it/s, Train Loss=0.756]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 10, Train Loss: 88.5556, Val Loss: 11.0032\n",
"Train Accuracy: 0.5833, Val Accuracy: 0.6308\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 11/25: 100%|██████████| 124/124 [00:36<00:00, 3.41it/s, Train Loss=0.664]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 11, Train Loss: 88.3764, Val Loss: 10.7751\n",
"Train Accuracy: 0.5706, Val Accuracy: 0.6389\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 12/25: 100%|██████████| 124/124 [00:38<00:00, 3.26it/s, Train Loss=0.866]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 12, Train Loss: 88.9168, Val Loss: 11.1027\n",
"Train Accuracy: 0.5721, Val Accuracy: 0.6085\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 13/25: 100%|██████████| 124/124 [00:39<00:00, 3.13it/s, Train Loss=0.711]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 13, Train Loss: 88.4298, Val Loss: 11.0765\n",
"Train Accuracy: 0.5888, Val Accuracy: 0.6288\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 14/25: 100%|██████████| 124/124 [00:39<00:00, 3.11it/s, Train Loss=0.728]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 14, Train Loss: 88.7229, Val Loss: 11.1684\n",
"Train Accuracy: 0.5823, Val Accuracy: 0.6349\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 15/25: 100%|██████████| 124/124 [00:37<00:00, 3.28it/s, Train Loss=0.774]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 15, Train Loss: 89.3287, Val Loss: 11.4475\n",
"Train Accuracy: 0.5830, Val Accuracy: 0.6146\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 16/25: 100%|██████████| 124/124 [00:35<00:00, 3.48it/s, Train Loss=0.797]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 16, Train Loss: 85.6701, Val Loss: 10.7575\n",
"Train Accuracy: 0.6175, Val Accuracy: 0.6329\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 17/25: 100%|██████████| 124/124 [00:38<00:00, 3.23it/s, Train Loss=0.649]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 17, Train Loss: 83.7000, Val Loss: 10.7996\n",
"Train Accuracy: 0.6294, Val Accuracy: 0.6166\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 18/25: 100%|██████████| 124/124 [00:37<00:00, 3.31it/s, Train Loss=0.703]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 18, Train Loss: 80.2727, Val Loss: 10.7781\n",
"Train Accuracy: 0.6679, Val Accuracy: 0.6450\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 19/25: 100%|██████████| 124/124 [00:38<00:00, 3.24it/s, Train Loss=0.519]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 19, Train Loss: 73.5981, Val Loss: 11.1218\n",
"Train Accuracy: 0.7113, Val Accuracy: 0.6247\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 20/25: 100%|██████████| 124/124 [00:36<00:00, 3.41it/s, Train Loss=1.05] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 20, Train Loss: 66.4704, Val Loss: 11.3424\n",
"Train Accuracy: 0.7592, Val Accuracy: 0.6227\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 21/25: 100%|██████████| 124/124 [00:25<00:00, 4.90it/s, Train Loss=0.794]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 21, Train Loss: 59.3716, Val Loss: 12.2167\n",
"Train Accuracy: 0.8043, Val Accuracy: 0.6024\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 22/25: 100%|██████████| 124/124 [00:25<00:00, 4.79it/s, Train Loss=0.261]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 22, Train Loss: 48.0339, Val Loss: 13.4658\n",
"Train Accuracy: 0.8525, Val Accuracy: 0.6085\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 23/25: 100%|██████████| 124/124 [00:23<00:00, 5.23it/s, Train Loss=0.218]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 23, Train Loss: 36.6165, Val Loss: 15.3780\n",
"Train Accuracy: 0.8966, Val Accuracy: 0.5963\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 24/25: 100%|██████████| 124/124 [00:23<00:00, 5.29it/s, Train Loss=0.166] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 24, Train Loss: 29.4375, Val Loss: 21.4867\n",
"Train Accuracy: 0.9202, Val Accuracy: 0.5822\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 25/25: 100%|██████████| 124/124 [00:22<00:00, 5.40it/s, Train Loss=0.209] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 25, Train Loss: 21.5571, Val Loss: 31.7498\n",
"Train Accuracy: 0.9437, Val Accuracy: 0.5578\n"
]
}
],
"source": [
"# Training\n",
"history = {\n",
" \"train_loss\": [],\n",
" \"val_loss\": [],\n",
" \"train_acc\": [],\n",
" \"val_acc\": [],\n",
"}\n",
"\n",
"for epoch in range(params[\"epochs\"]):\n",
" model.train()\n",
" train_loss, correct, total = 0.0, 0, 0\n",
"\n",
" with tqdm(train_loader, desc=f\"Epoch {epoch + 1}/{params['epochs']}\") as pbar:\n",
" for X_batch, y_batch in pbar:\n",
" X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
" optimizer.zero_grad()\n",
" outputs = model(X_batch)\n",
" loss = criterion(outputs, y_batch)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" train_loss += loss.item()\n",
" predicted = torch.argmax(outputs, dim=1)\n",
" correct += (predicted == y_batch).sum().item()\n",
" total += y_batch.size(0)\n",
"\n",
" pbar.set_postfix({\"Train Loss\": loss.item()})\n",
"\n",
" train_acc = correct / total\n",
" history[\"train_loss\"].append(train_loss / len(train_loader))\n",
" history[\"train_acc\"].append(train_acc)\n",
"\n",
" # Validation\n",
" model.eval()\n",
" val_loss, correct, total = 0.0, 0, 0\n",
" with torch.no_grad():\n",
" for X_batch, y_batch in val_loader:\n",
" X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
" outputs = model(X_batch)\n",
" loss = criterion(outputs, y_batch)\n",
" val_loss += loss.item()\n",
" predicted = torch.argmax(outputs, dim=1)\n",
" correct += (predicted == y_batch).sum().item()\n",
" total += y_batch.size(0)\n",
"\n",
" val_acc = correct / total\n",
" history[\"val_loss\"].append(val_loss / len(val_loader))\n",
" history[\"val_acc\"].append(val_acc)\n",
"\n",
" print(f\"\\nEpoch {epoch + 1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}\")\n",
" print(f\"Train Accuracy: {train_acc:.4f}, Val Accuracy: {val_acc:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy: 0.6235\n"
]
}
],
"source": [
"# Testen und Visualisieren\n",
"model.eval()\n",
"test_correct, test_total = 0, 0\n",
"all_labels, all_preds = [], []\n",
"\n",
"with torch.no_grad():\n",
" for X_batch, y_batch in test_loader:\n",
" X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
" outputs = model(X_batch)\n",
" predicted = torch.argmax(outputs, dim=1)\n",
" all_labels.extend(y_batch.cpu().numpy())\n",
" all_preds.extend(predicted.cpu().numpy())\n",
" test_correct += (predicted == y_batch).sum().item()\n",
" test_total += y_batch.size(0)\n",
"\n",
"test_accuracy = test_correct / test_total\n",
"print(f\"Test Accuracy: {test_accuracy:.4f}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 Finale Test Accuracy: 0.6235\n",
"🚀 Finale Test F1 Score: 0.6189\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 600x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import ml_evaluation as ml_eval\n",
"\n",
"print(f'🚀 Finale Test Accuracy: {ml_eval.get_accuracy(all_preds, all_labels):.4f}')\n",
"print(f'🚀 Finale Test F1 Score: {ml_eval.get_f1_score(all_preds, all_labels):.4f}')\n",
"\n",
"# Confusion matrix\n",
"con_plt = ml_eval.plot_confusion_matrix(all_preds, all_labels, ['0', '1'])\n",
"con_plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class 0: 0.44\n",
"Class 1: 0.56\n"
]
}
],
"source": [
"ml_eval.get_label_distribution(all_labels, all_preds)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import ml_evaluation as ml_eval\n",
"ml_eval.plot_rating_preds(all_preds, all_labels, test_dataset).show()"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"def visualize_distribution(true_values, predicted_values):\n",
" plt.figure(figsize=(10, 6))\n",
"\n",
" # Häufigkeiten der Klassen berechnen\n",
" true_counts = np.bincount(true_values, minlength=2)\n",
" predicted_counts = np.bincount(predicted_values, minlength=2)\n",
"\n",
" # Barplot erstellen\n",
" labels = ['No Humor', 'Humor']\n",
" x = np.arange(len(labels))\n",
"\n",
" plt.bar(x - 0.2, true_counts, width=0.4, color='skyblue', label='Wahre Werte', edgecolor='black')\n",
" plt.bar(x + 0.2, predicted_counts, width=0.4, color='salmon', label='Vorhergesagte Werte', edgecolor='black')\n",
"\n",
" plt.title('Verteilung der wahren Werte und Vorhersagen')\n",
" plt.xticks(x, labels)\n",
" plt.ylabel('Häufigkeit')\n",
" plt.xlabel('Klassen')\n",
" plt.legend()\n",
" plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Visualisierung der Verteilung (Barplot)\n",
"visualize_distribution(all_labels, all_preds)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}