619 lines
131 KiB
Plaintext
619 lines
131 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# CNN Regression"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"from torch.utils.data import DataLoader\n",
|
||
"from tqdm import tqdm # Fortschrittsbalken\n",
|
||
"import numpy as np\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/var/folders/l7/061cw0t95vz1myntpf9bj9540000gn/T/ipykernel_46830/3644220936.py:6: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
||
" train_dataset = torch.load(data_path + '/train.pt')\n",
|
||
"/var/folders/l7/061cw0t95vz1myntpf9bj9540000gn/T/ipykernel_46830/3644220936.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
||
" test_dataset = torch.load(data_path + '/test.pt')\n",
|
||
"/var/folders/l7/061cw0t95vz1myntpf9bj9540000gn/T/ipykernel_46830/3644220936.py:8: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
||
" val_dataset = torch.load(data_path + '/val.pt')\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Daten laden\n",
|
||
"\n",
|
||
"data_path = 'data/embedded_padded'\n",
|
||
"BATCH_SIZE = 32\n",
|
||
"\n",
|
||
"train_dataset = torch.load(data_path + '/train.pt')\n",
|
||
"test_dataset = torch.load(data_path + '/test.pt')\n",
|
||
"val_dataset = torch.load(data_path + '/val.pt')\n",
|
||
"\n",
|
||
"# DataLoader vorbereiten\n",
|
||
"def collate_fn(batch):\n",
|
||
" input_ids = torch.stack([item[\"input_ids\"] for item in batch]) \n",
|
||
" labels = torch.tensor([item[\"labels\"] for item in batch], dtype=torch.float32).unsqueeze(1) \n",
|
||
" return input_ids, labels\n",
|
||
"\n",
|
||
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)\n",
|
||
"val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)\n",
|
||
"test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/Users/michellegoppinger/Documents/Dokumente – Laptop von Michelle/Uni/Master/ANLP/ANLP_WS24_CA2/HumorDataset.py:56: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)\n",
|
||
" item = {'input_ids': torch.tensor(self.data[idx], dtype=torch.float)}\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x600 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Labels extrahieren und in eine Liste konvertieren\n",
|
||
"train_labels = [item[\"labels\"].item() for item in train_dataset] \n",
|
||
"\n",
|
||
"# Verteilung der Labels visualisieren\n",
|
||
"plt.figure(figsize=(8, 6))\n",
|
||
"sns.histplot(train_labels, bins=20)\n",
|
||
"plt.xlabel(\"Humor Scores\")\n",
|
||
"plt.ylabel(\"Frequency\")\n",
|
||
"plt.title(\"Verteilung der Trainingslabels\")\n",
|
||
"plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class WeightedMSELoss(nn.Module):\n",
|
||
" def __init__(self, weights):\n",
|
||
" super(WeightedMSELoss, self).__init__()\n",
|
||
" self.weights = weights\n",
|
||
"\n",
|
||
" def forward(self, inputs, targets):\n",
|
||
" weights = self.weights[targets.long()]\n",
|
||
" loss = weights * (inputs - targets) ** 2\n",
|
||
" return loss.mean()\n",
|
||
"\n",
|
||
"# Gewichtung basierend auf Seltenheit der Zwischenwerte\n",
|
||
"weights = torch.tensor([2.0 if 0.2 <= x <= 0.8 else 1.0 for x in range(2)], dtype=torch.float32)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class CNN_HumorRegressor(nn.Module):\n",
|
||
" def __init__(self, embed_dim, filter_sizes, num_filters, dropout=0.5):\n",
|
||
" super(CNN_HumorRegressor, self).__init__()\n",
|
||
"\n",
|
||
" # Convolutional Layers mit verschiedenen Filtergrößen\n",
|
||
" self.convs = nn.ModuleList([\n",
|
||
" nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embed_dim)) \n",
|
||
" for fs in filter_sizes\n",
|
||
" ])\n",
|
||
"\n",
|
||
" # Highway-Netzwerk für bessere Feature-Extraktion\n",
|
||
" self.highway = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters)\n",
|
||
"\n",
|
||
" # Dropout zur Vermeidung von Overfitting\n",
|
||
" self.dropout = nn.Dropout(dropout)\n",
|
||
"\n",
|
||
" # Fully Connected Layers\n",
|
||
" self.fc1 = nn.Linear(len(filter_sizes) * num_filters, 256)\n",
|
||
" self.fc2 = nn.Linear(256, 128)\n",
|
||
" self.fc3 = nn.Linear(128, 1)\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" x = x.unsqueeze(1) # [Batch Size, 1, Seq Length, Embed Dim]\n",
|
||
"\n",
|
||
" # Convolution + ReLU activation\n",
|
||
" conved = [F.relu(conv(x)).squeeze(3) for conv in self.convs]\n",
|
||
"\n",
|
||
" # Max-Pooling über jede Feature-Map\n",
|
||
" pooled = [F.max_pool1d(c, c.size(2)).squeeze(2) for c in conved]\n",
|
||
"\n",
|
||
" # Feature-Vektor kombinieren\n",
|
||
" cat = torch.cat(pooled, dim=1)\n",
|
||
"\n",
|
||
" # Highway-Netzwerk\n",
|
||
" highway = F.relu(self.highway(cat))\n",
|
||
" highway = self.dropout(highway + cat)\n",
|
||
"\n",
|
||
" # Fully Connected Layers\n",
|
||
" fc_out = F.relu(self.fc1(highway))\n",
|
||
" fc_out = F.relu(self.fc2(fc_out))\n",
|
||
" return torch.sigmoid(self.fc3(fc_out)) # Sigmoid für Wertebereich [0, 1]\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"EMBED_DIM = train_dataset[0][\"input_ids\"].shape[1]\n",
|
||
"FILTER_SIZES = [2, 3, 4, 5]\n",
|
||
"NUM_FILTERS = 300\n",
|
||
"DROPOUT = 0.5\n",
|
||
"LR = 0.001\n",
|
||
"EPOCHS = 10\n",
|
||
"\n",
|
||
"device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"\n",
|
||
"# Modell initialisieren\n",
|
||
"model = CNN_HumorRegressor(EMBED_DIM, FILTER_SIZES, NUM_FILTERS, DROPOUT).to(device)\n",
|
||
"\n",
|
||
"# Gewichtete Verlustfunktion und Optimierer\n",
|
||
"criterion = WeightedMSELoss(weights.to(device))\n",
|
||
"optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train_model(model, train_loader, val_loader, criterion, optimizer, epochs, device):\n",
|
||
" for epoch in range(epochs):\n",
|
||
" model.train()\n",
|
||
" total_loss = 0\n",
|
||
"\n",
|
||
" # Fortschrittsbalken für das Training\n",
|
||
" with tqdm(train_loader, unit=\"batch\", desc=f\"Epoch {epoch+1}/{epochs}\") as tepoch:\n",
|
||
" for inputs, labels in tepoch:\n",
|
||
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
||
"\n",
|
||
" optimizer.zero_grad()\n",
|
||
" outputs = model(inputs)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" total_loss += loss.item()\n",
|
||
" tepoch.set_postfix(loss=loss.item())\n",
|
||
"\n",
|
||
" val_loss = evaluate(model, val_loader, criterion, device)\n",
|
||
" print(f\"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss:.4f} - Val Loss: {val_loss:.4f}\")\n",
|
||
"\n",
|
||
"def evaluate(model, test_loader, criterion, device):\n",
|
||
" model.eval()\n",
|
||
" total_loss = 0\n",
|
||
" with tqdm(test_loader, unit=\"batch\", desc=\"Evaluating\") as tepoch:\n",
|
||
" with torch.no_grad():\n",
|
||
" for inputs, labels in tepoch:\n",
|
||
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
||
" outputs = model(inputs)\n",
|
||
" loss = criterion(outputs, labels)\n",
|
||
" total_loss += loss.item()\n",
|
||
" return total_loss / len(test_loader)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1/10: 0%| | 0/124 [00:00<?, ?batch/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1/10: 100%|██████████| 124/124 [00:41<00:00, 2.99batch/s, loss=0.249]\n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:02<00:00, 7.09batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1/10 - Train Loss: 30.5479 - Val Loss: 0.2415\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 2/10: 100%|██████████| 124/124 [00:40<00:00, 3.08batch/s, loss=0.297]\n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:02<00:00, 7.36batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 2/10 - Train Loss: 27.5358 - Val Loss: 0.2162\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 3/10: 100%|██████████| 124/124 [00:40<00:00, 3.03batch/s, loss=0.286]\n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:02<00:00, 7.26batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 3/10 - Train Loss: 23.0742 - Val Loss: 0.2215\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 4/10: 100%|██████████| 124/124 [00:43<00:00, 2.83batch/s, loss=0.123] \n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:02<00:00, 7.24batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 4/10 - Train Loss: 16.9821 - Val Loss: 0.2608\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 5/10: 100%|██████████| 124/124 [00:41<00:00, 3.01batch/s, loss=0.104] \n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:02<00:00, 7.43batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 5/10 - Train Loss: 10.0560 - Val Loss: 0.2646\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 6/10: 100%|██████████| 124/124 [00:28<00:00, 4.34batch/s, loss=0.138] \n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.23batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 6/10 - Train Loss: 9.3069 - Val Loss: 0.2535\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 7/10: 100%|██████████| 124/124 [00:28<00:00, 4.31batch/s, loss=0.00183]\n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.21batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 7/10 - Train Loss: 6.4416 - Val Loss: 0.2688\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 8/10: 100%|██████████| 124/124 [00:28<00:00, 4.37batch/s, loss=0.00722]\n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:30<00:00, 1.92s/batch]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 8/10 - Train Loss: 4.9270 - Val Loss: 0.2915\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 9/10: 100%|██████████| 124/124 [00:29<00:00, 4.13batch/s, loss=0.0664] \n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.17batch/s]\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 9/10 - Train Loss: 2.8456 - Val Loss: 0.3152\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 10/10: 100%|██████████| 124/124 [00:27<00:00, 4.48batch/s, loss=0.0111] \n",
|
||
"Evaluating: 100%|██████████| 16/16 [00:01<00:00, 10.12batch/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 10/10 - Train Loss: 2.2282 - Val Loss: 0.2945\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"train_model(model, train_loader, val_loader, criterion, optimizer, EPOCHS, device)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Evaluating: 100%|██████████| 16/16 [00:01<00:00, 9.59batch/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Test Loss (MSE): 0.3395\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"test_loss = evaluate(model, test_loader, criterion, device)\n",
|
||
"print(f\"Test Loss (MSE): {test_loss:.4f}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Evaluation Metrics on Test Data:\n",
|
||
"Mean Squared Error (MSE): 0.3358\n",
|
||
"Root Mean Squared Error (RMSE): 0.5795\n",
|
||
"Mean Absolute Error (MAE): 0.3900\n",
|
||
"R² Score: -0.3445\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"def evaluate_metrics(model, test_loader, device):\n",
|
||
" model.eval()\n",
|
||
" predictions = []\n",
|
||
" actuals = []\n",
|
||
" with torch.no_grad():\n",
|
||
" for inputs, labels in test_loader:\n",
|
||
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
||
" outputs = model(inputs)\n",
|
||
" predictions.extend(outputs.cpu().numpy().flatten())\n",
|
||
" actuals.extend(labels.cpu().numpy().flatten())\n",
|
||
"\n",
|
||
" mse = mean_squared_error(actuals, predictions)\n",
|
||
" rmse = np.sqrt(mse)\n",
|
||
" mae = mean_absolute_error(actuals, predictions)\n",
|
||
" r2 = r2_score(actuals, predictions)\n",
|
||
"\n",
|
||
" return mse, rmse, mae, r2, actuals, predictions\n",
|
||
"\n",
|
||
"mse, rmse, mae, r2, actuals, predictions = evaluate_metrics(model, test_loader, device)\n",
|
||
"\n",
|
||
"print(\"Evaluation Metrics on Test Data:\")\n",
|
||
"print(f\"Mean Squared Error (MSE): {mse:.4f}\")\n",
|
||
"print(f\"Root Mean Squared Error (RMSE): {rmse:.4f}\")\n",
|
||
"print(f\"Mean Absolute Error (MAE): {mae:.4f}\")\n",
|
||
"print(f\"R² Score: {r2:.4f}\")\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 800x600 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"# Definiere korrekte und falsche Vorhersagen basierend auf einem Schwellenwert\n",
|
||
"threshold = 0.5\n",
|
||
"predicted_labels = (np.array(predictions) > threshold).astype(int)\n",
|
||
"true_labels = (np.array(actuals) > threshold).astype(int)\n",
|
||
"\n",
|
||
"# Bool-Array für korrekte Vorhersagen\n",
|
||
"correct = predicted_labels == true_labels\n",
|
||
"\n",
|
||
"# Farben zuordnen: Grün für korrekt, Rot für falsch\n",
|
||
"colors = ['green' if is_correct else 'red' for is_correct in correct]\n",
|
||
"\n",
|
||
"# Scatter-Plot\n",
|
||
"plt.figure(figsize=(8, 6))\n",
|
||
"plt.scatter(actuals, predictions, c=colors, alpha=0.6, edgecolor='k')\n",
|
||
"\n",
|
||
"\n",
|
||
"# Legende anpassen\n",
|
||
"import matplotlib.patches as mpatches\n",
|
||
"green_patch = mpatches.Patch(color='green', label='Correct Predictions')\n",
|
||
"red_patch = mpatches.Patch(color='red', label='Incorrect Predictions')\n",
|
||
"plt.legend(handles=[green_patch, red_patch])\n",
|
||
"\n",
|
||
"# Achsen und Titel\n",
|
||
"plt.title('True vs. Predicted Humor Scores')\n",
|
||
"plt.xlabel('True Humor Score')\n",
|
||
"plt.ylabel('Predicted Humor Score')\n",
|
||
"plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"239\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"# Load the data from csv\n",
|
||
"df = pd.read_csv('data/hack.csv')\n",
|
||
"df_test = df.iloc[test_dataset.original_indices].copy()\n",
|
||
"df_test['prediction'] = predicted_labels\n",
|
||
"df_test['label'] = true_labels\n",
|
||
"df_test['pred_correct'] = (df_test['prediction'] == df_test['label'])\n",
|
||
"\n",
|
||
"df_test_sorted = df_test.sort_values(by='humor_rating').reset_index(drop=True)\n",
|
||
"\n",
|
||
"from matplotlib import patches as mpatches\n",
|
||
"\n",
|
||
"median_rating = df['humor_rating'].median()\n",
|
||
"# get first index where humor_rating is greater than median_rating\n",
|
||
"median_idx = df_test_sorted[df_test_sorted['humor_rating'] > median_rating].index[0]\n",
|
||
"print(median_idx)\n",
|
||
"# range idx for len df_test\n",
|
||
"range_idx = range(len(df_test))\n",
|
||
"colors = df_test_sorted['pred_correct'].map({True: 'g', False: 'r'})\n",
|
||
"# bar plot for each df_test humor_rating value \n",
|
||
"plt.bar(range_idx, df_test_sorted['humor_rating'], color=colors)\n",
|
||
"# vertical line for True/False cut off\n",
|
||
"plt.axvline(x=median_idx, color='black', linestyle='--')\n",
|
||
"# Create a legend handles\n",
|
||
"green_patch = mpatches.Patch(color='g', label='Correct Prediction')\n",
|
||
"red_patch = mpatches.Patch(color='r', label='Incorrect Prediction')\n",
|
||
"line_patch = mpatches.Patch(color='black', label='humor_rating cut off')\n",
|
||
"plt.title('Humor Rating vs Prediction for Test Set')\n",
|
||
"plt.xlabel('Index')\n",
|
||
"plt.ylabel('Humor Rating')\n",
|
||
"plt.legend(handles=[green_patch, red_patch, line_patch])\n",
|
||
"plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|