ANLP_WS24_CA2/cnn_reg.ipynb

437 lines
40 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNN Regression"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import required libraries\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm # Progress bar\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"import matplotlib.patches as mpatches\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/l7/061cw0t95vz1myntpf9bj9540000gn/T/ipykernel_16242/2331049751.py:6: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" train_dataset = torch.load(data_path + '/train.pt')\n",
"/var/folders/l7/061cw0t95vz1myntpf9bj9540000gn/T/ipykernel_16242/2331049751.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" test_dataset = torch.load(data_path + '/test.pt')\n",
"/var/folders/l7/061cw0t95vz1myntpf9bj9540000gn/T/ipykernel_16242/2331049751.py:8: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" val_dataset = torch.load(data_path + '/val.pt')\n"
]
}
],
"source": [
"# Define the data path and batch size\n",
"data_path = 'data/embedded_padded'\n",
"BATCH_SIZE = 32\n",
"\n",
"# Load datasets\n",
"train_dataset = torch.load(data_path + '/train.pt')\n",
"test_dataset = torch.load(data_path + '/test.pt')\n",
"val_dataset = torch.load(data_path + '/val.pt')\n",
"\n",
"# Define the collate function for DataLoader\n",
"def collate_fn(batch):\n",
" input_ids = torch.stack([item[\"input_ids\"] for item in batch]) \n",
" labels = torch.tensor([item[\"labels\"] for item in batch], dtype=torch.float32).unsqueeze(1) \n",
" return input_ids, labels\n",
"\n",
"# Create DataLoaders\n",
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)\n",
"val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)\n",
"test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/michellegoppinger/Documents/Dokumente  Laptop von Michelle/Uni/Master/ANLP/ANLP_WS24_CA2/HumorDataset.py:56: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)\n",
" item = {'input_ids': torch.tensor(self.data[idx], dtype=torch.float)}\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAIjCAYAAAAN/63DAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQi1JREFUeJzt3Ql4VNX9//FvICQsEraYBAoEBNk3AY2URRBKBIooWBcEQkEWBf0LipSKbFrAqCxVhFJZpIIg/SEqILKLSgBFkUWNgEBKSQAFDIskJMz/+Z7fc+c3kwVCSDKTnPfreW4nd5m5Z+YG+8mZ7zk3wOVyuQQAAACwRDFfNwAAAAAoSARgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAfq1///5So0aNXD13woQJEhAQIIXdwoULzfv46quv8uw1ff3Z6Lm1Dflty5Yt5lz66Gjfvr00atRICsKRI0fM+fUaAvAfBGAAuaL/p56TxTN42Bbcb7rpJrGB/oHiXO9ixYpJ+fLlpXHjxjJ48GDZsWNHnp1nyZIlMmPGDPFH/tw2AJkFZrENAK7pX//6l9f6okWLZP369Zm2169f/4bO889//lOuXLmSq+eOHTtW/vKXv9zQ+ZEzzZo1k2eeecb8fO7cOfn+++9l+fLl5vqNGDFCpk2b5nX8b7/9JoGBgdcdMvft2ydPP/10jp/Trl07c66goCDJT9m1LTIy0py/RIkS+Xp+ANeHAAwgV/r06eO1vn37dhOAM27P6OLFi1K6dOkcn+dGgoMGrOsNWcid3/3ud5mu/csvvyy9e/eW6dOny6233iqPP/64e1/JkiXztT2XLl0yoVd7pPP7XFejveK+PD+ArFECASDfOLWWu3btMj1xGnz/+te/mn0ffPCBdOvWTapUqSLBwcFSq1YtefHFFyU9Pf2qNcBOTeWrr74qc+fONc/T599+++3y5ZdfXrPOVdeHDx8uK1euNG3T5zZs2FDWrl2bqf1avtGyZUsTYPQ8//jHP/K0dvbo0aPyxBNPSN26daVUqVJSqVIl+dOf/mTeY3Z/PAwZMsQcFxISIv369ZMzZ85kOu7jjz+Wtm3bSpkyZaRs2bLmc96/f/8126N/wLRp08aUMGj5hrbLuV65oe9JvxGoWLGi/O1vfxOXy5VtDbD2GmvvqV5rvSZhYWHyhz/8Qb7++mv379Lq1avNZ+aUWzi/F06d79KlS02vv4Zx/V1LTk7OsgbYob+Xv//97007a9asKXPmzMmy9jrj9cj4mldrW3Y1wJs2bXJfI/28e/ToYXrNPTm/awcPHjT/DvS4cuXKyZ///GfzuwAg9+gaAZCvfvnlF+nSpYs8/PDDpocwPDzcbNdAoCFr5MiR5lEDwbhx40xoeeWVV3L0lbOGJg2EGhJiY2OlZ8+e8tNPP12z1/jzzz+XFStWmPCpAfHvf/+79OrVSxISEky4VN98843cc889UrlyZZk4caIJ5pMmTZKbb745jz4ZMYF927Zt5rOpWrWqCUuzZ882geq7777L1FOuwV1DkAaj+Ph4c6yGLieQKQ2cMTExEh0dbXpgNSjpcRps9T1lN6BQA/If//hHadKkiXmfGkI1eH3xxRc39B712t5///0yb9488570j42sDB06VP7973+b99igQQPze6PXSUNh8+bN5fnnn5dff/1Vjh07ZnqUndf2pH9Aaa/vs88+KykpKVcte9A/HLp27SoPPvigPPLII/Lee++ZHmp9zoABA67rPeakbZ42bNhg/k3ccsst5lpqicTrr78urVu3NoE/4zXSNmpAnzJlitn/1ltvmT8Q9PoCyCUXAOSBYcOGafee17a77rrLbJszZ06m4y9evJhp25AhQ1ylS5d2Xbp0yb0tJibGFRkZ6V4/fPiwec1KlSq5Tp8+7d7+wQcfmO0fffSRe9v48eMztUnXg4KCXAcPHnRv+/bbb832119/3b2te/fupi3//e9/3dsOHDjgCgwMzPSaWdF2lylT5qrHZPUZxMXFmddftGiRe9uCBQvMthYtWrhSU1Pd22NjY812fe/q3LlzrvLly7sGDRrk9ZpJSUmucuXKeW3P+NlMnz7drJ86dcp1vfT6dOvWLdv9zms77VS6rm1waPv0d+hq9ByevwuOzZs3m9e75ZZbMn2mzj59zPh7+dprr7m3paSkuJo1a+YKCwtzf8bO566/c9d6zeza5vy+6ms5nPP88ssvXr+DxYoVc/Xr1y/TNRowYIDXa95///3m9x9A7lECASBfaU+ifmWbkX7t7NCe3J9//tl8Jaw9lj/88MM1X/ehhx6SChUquNf1uUp7gK+lU6dOpqTBob2eWlLgPFd7e7WX7r777jMlGo7atWubnru84vkZXL582fR66jm0l9f56t+Tzqrg2butPZZa47xmzRp3CcPZs2dNj6Z+ns5SvHhxiYqKks2bN2fbFj2nU5qS20GH2XF6Q/U6X+38OmPE8ePHc30e7fn2/EyvRj83/fbAoT2/un7y5ElTGpFfEhMTZffu3aakQUtDPH8HteTDuZYZe8c96e+6/q7otyUAcocADCBfaT1mVl9F61fu+tW41jRq+NTSAmcQlX6dfC3Vq1f3WnfCcFY1sdd6rvN857kagvRraQ2jGWW1Lbf0HFr2Ua1aNfOHQmhoqPkcNMRm9RnoQLKMwVJLNJwa1QMHDpjHu+++27yO57Ju3Trzvq72B4V+Bf/YY4+ZMhUty9CygLwIw+fPnzePWm6SHS1h0VkU9LO44447TGlATv6Y8aRlAjmlf9ho/a2nOnXqmMfsarDzgpasKK2vzkhnTNE/WC5cuJBnv+sAskYNMIB8lVWPnAa8u+66ywRfrTfV3lgdaKa9nqNHj85R6NJezax4DrTKj+fmpSeffFIWLFhgBn+1atXK/DGgtbwaPnMTPJ3naB1wREREpv1XmxFDr9PWrVtNL7EO6NJBgcuWLTNhWsNzdp9ZTmiwvdYfD1rnqj2b77//vjmf1oFrjavWaue01z2nvb85ld1gx4wDNfObv/y+AkUJARhAgdNBW/oVroYbnR3CcfjwYfEHOsBIA7kOAssoq225pYO+9Gv71157zWv6Lv0DISvaw9uhQwevnlX9Sl0HcymnrEPbr2Ue10unDOvYsaNZdN7eyZMnmwFeGopz83pOGzXUas/uteaE1t5sHZioi/ZW6+A3nT3CCcB5eec6LbXQnlbPXuAff/zRPDqD0Jye1ozXw+nF9ZTTtum8wEoHMWakpT/6LUDGnmkAeY8SCAAFzunR8uzBSk1NlTfffFP8pX0a+HSqNM+aVA2/OsVYXp4nYy+ezgaQXQ+jTvumtcIOnd0hLS3NHRB15gftVdfg6nmc49SpU9m25fTp01ne3ELpjAq5LfHo27eveW0N0lfrUc1Y8qEhXssUPM+twTAn5TE5oZ+bTmvn+fun61ou0qJFC68/KLRn3LOteh0yymnbNOTr5/r22297BWvtJdeeb+ePGQD5ix5gAAVO517V3jXt/XzqqadMMNKv7f3pK12tQdVAonWxOthMg88bb7xh5g7WQUw5oSH0pZdeyrRdBz9pL6dOO6bvW0sfdOqvuLg4M/jOmYotIw1p2jur5QLag6h/MOj0Zvfee6/Zr+FXQ7GGTu091VIKDXQ6vZuWNeh70feQFS1F0aCncwZrL6X2wOrr6/Rseo5r+e9//yvvvPOOu9dXpzzTO8ElJSWZO8R5DjjLSAfH6XkeeOABadq0qalt1s9Bp4nz7B3XYKplGTp1ns77rMd1795dckPDtZZYaL2v1v7q6+p11XDrDDTUKdvuvPNOGTNmjAnxet10rmENzxldT9u0vEP/aNGyl4EDB7qnQdPfA8+5kQHkHwIwgAKnAW/VqlUmGOmNCzQM6wA4DXfai+kPNNBob6/OKfvCCy+Yr/A1JOq8tDmZpcIJrPrcjLRnUQPwzJkzTS/w4sWLTemDBlQNftl9Bhpe9VgdOKfhWmd70DmMPXtW9c5rGu6mTp1qgpb2oOpARK2vzWo2DoeGaA2D8+fPNwOx9Kt4rdPWOZA1mF2LhkcN3toWHeymn5cGQB1Up4ParkbnO9bPQ//g0LIYrWXWemEN4J53j9Nj9DxaN63z7WpQz20A1t857YXVOmy9XbMO/NPPd9CgQV7H6eet4V0/T52pQgOrlqHojA2erqdt+u2C1liPHz/eXEsN3PpZayC/noF8AHIvQOdCu4HnA4BVdGo0ncHCmXEBAFD4UAMMANnQr6Y9aejVeVr1Tm0AgMKLHmAAuMqAJb1hgd6yVkf+a32tlhToLYUzzskLACg8qAEGgGzcc8898u6775qBXHqjCh20pDMsEH4BoHCjBxgAAABWoQYYAAAAViEAAwAAwCrUAOeAzkmpd4PSuS3z8lacAAAAyBta1as31tG50PXW7ldDAM4BDb86qTsAAAD823/+8x9zd8mrIQDngPb8Oh+o3moUAAAA/iU5Odl0WDq57WoIwDnglD1o+CUAAwAA+K+clKsyCA4AAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVfBqAp0yZIrfffruULVtWwsLC5L777pP4+HivYy5duiTDhg2TSpUqyU033SS9evWSEydOeB2TkJAg3bp1k9KlS5vXGTVqlKSlpXkds2XLFmnevLkEBwdL7dq1ZeHChQXyHgEAAOBffBqAP/30UxNut2/fLuvXr5fLly9L586d5cKFC+5jRowYIR999JEsX77cHH/8+HHp2bOne396eroJv6mpqbJt2zZ5++23TbgdN26c+5jDhw+bYzp06CC7d++Wp59+Wh577DH55JNPCvw9AwAAwLcCXC6XS/zEqVOnTA+uBt127drJr7/+KjfffLMsWbJEHnjgAXPMDz/8IPXr15e4uDi588475eOPP5Y//vGPJhiHh4ebY+bMmSOjR482rxcUFGR+Xr16tezbt899rocffljOnj0ra9euvWa7kpOTpVy5cqY9ISEh+fgJAAAAIDeuJ68Fih/RBquKFSuax127dple4U6dOrmPqVevnlSvXt0dgPWxcePG7vCroqOj5fHHH5f9+/fLbbfdZo7xfA3nGO0JzkpKSopZPD9QX9DSjp9//rnAzhcaGmo+WwAAgKLMbwLwlStXTCBt3bq1NGrUyGxLSkoyPbjly5f3OlbDru5zjvEMv85+Z9/VjtFg+9tvv0mpUqUy1SZPnDhRfEnDb7169eW33y4W2DlLlSotP/zwPSEYAAAUaX4TgLUWWEsUPv/8c183RcaMGSMjR450r2tQrlatWoG2QXt+NfxGDRgvIZVr5Pv5khOPyI75E815CcAAAKAo84sAPHz4cFm1apVs3bpVqlat6t4eERFhBrdpra5nL7DOAqH7nGN27tzp9XrOLBGex2ScOULXtT4kY++v0pkidPEHGn4rVq/r62YAAAAUGT6dBULH32n4ff/992XTpk1Ss2ZNr/0tWrSQEiVKyMaNG93bdJo0LQ9o1aqVWdfHvXv3ysmTJ93H6IwSGm4bNGjgPsbzNZxjnNcAAACAPQJ9XfagMzx88MEHZi5gp2ZXR/Bpz6w+Dhw40JQj6MA4DbVPPvmkCa46AE7ptGkadPv27SuxsbHmNcaOHWte2+nFHTp0qLzxxhvy3HPPyYABA0zYfu+998zMEAAAALCLT3uAZ8+ebWZ+aN++vVSuXNm9LFu2zH3M9OnTzTRnegMMnRpNyxlWrFjh3l+8eHFTPqGPGoz79Okj/fr1k0mTJrmP0Z5lDbva69u0aVN57bXX5K233jIzQQAAAMAuPu0BzskUxCVLlpRZs2aZJTuRkZGyZs2aq76OhuxvvvkmV+0EAABA0eHTHmAAAACgoBGAAQAAYBUCMAAAAKxCAAYAAIBV/OJGGAAAAMh/ei+Fn3/+ucDOFxoa6pd3mCUAAwAAWBJ+69WrL7/9drHAzlmqVGn54Yfv/S4EE4ABAAAsoD2/Gn6jBoyXkMo18v18yYlHZMf8iea8BGAAAAD4TEjlGlKxel2xGYPgAAAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALCKTwPw1q1bpXv37lKlShUJCAiQlStXeu3XbVktr7zyivuYGjVqZNo/depUr9fZs2ePtG3bVkqWLCnVqlWT2NjYAnuPAAAA8C8+DcAXLlyQpk2byqxZs7Lcn5iY6LXMnz/fBNxevXp5HTdp0iSv45588kn3vuTkZOncubNERkbKrl27THieMGGCzJ07N9/fHwAAAPxPoC9P3qVLF7NkJyIiwmv9gw8+kA4dOsgtt9zitb1s2bKZjnUsXrxYUlNTTXgOCgqShg0byu7du2XatGkyePDgPHonAAAAKCwKTQ3wiRMnZPXq1TJw4MBM+7TkoVKlSnLbbbeZHt60tDT3vri4OGnXrp0Jv47o6GiJj4+XM2fOZHmulJQU03PsuQAAAKBo8GkP8PV4++23TU9vz549vbY/9dRT0rx5c6lYsaJs27ZNxowZY8ogtIdXJSUlSc2aNb2eEx4e7t5XoUKFTOeaMmWKTJw4MV/fDwAAAHyj0ARgLWF49NFHzUA2TyNHjnT/3KRJE9PTO2TIEBNig4ODc3UuDdGer6s9wDp4DgAAAIVfoQjAn332mSlZWLZs2TWPjYqKMiUQR44ckbp165raYC2f8OSsZ1c3rME5t+EZAAAA/q1Q1ADPmzdPWrRoYWaMuBYd4FasWDEJCwsz661atTLTrV2+fNl9zPr16004zqr8AQAAAEWbTwPw+fPnTWDVRR0+fNj8nJCQ4FV+sHz5cnnssccyPV8HuM2YMUO+/fZb+emnn8yMDyNGjJA+ffq4w23v3r1NWYQOntu/f7/pRZ45c6ZXiQMAAADs4dMSiK+++spMa+ZwQmlMTIwsXLjQ/Lx06VJxuVzyyCOPZHq+linofp3XV2du0MFuGoA9w225cuVk3bp1MmzYMNOLHBoaKuPGjWMKNAAAAEv5NAC3b9/ehNur0aCaXVjV2R+2b99+zfPo4DitIwYAAAAKRQ0wAAAAkFcIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCo+DcBbt26V7t27S5UqVSQgIEBWrlzptb9///5mu+dyzz33eB1z+vRpefTRRyUkJETKly8vAwcOlPPnz3sds2fPHmnbtq2ULFlSqlWrJrGxsQXy/gAAAOB/fBqAL1y4IE2bNpVZs2Zle4wG3sTERPfy7rvveu3X8Lt//35Zv369rFq1yoTqwYMHu/cnJydL586dJTIyUnbt2iWvvPKKTJgwQebOnZuv7w0AAAD+KdCXJ+/SpYtZriY4OFgiIiKy3Pf999/L2rVr5csvv5SWLVuaba+//rp07dpVXn31VdOzvHjxYklNTZX58+dLUFCQNGzYUHbv3i3Tpk3zCsoAAACwg9/XAG/ZskXCwsKkbt268vjjj8svv/zi3hcXF2fKHpzwqzp16iTFihWTHTt2uI9p166dCb+O6OhoiY+PlzNnzmR5zpSUFNNz7LkAAACgaPDrAKzlD4sWLZKNGzfKyy+/LJ9++qnpMU5PTzf7k5KSTDj2FBgYKBUrVjT7nGPCw8O9jnHWnWMymjJlipQrV869aN0wAAAAigaflkBcy8MPP+z+uXHjxtKkSROpVauW6RXu2LFjvp13zJgxMnLkSPe69gATggEAAIoGv+4BzuiWW26R0NBQOXjwoFnX2uCTJ096HZOWlmZmhnDqhvXxxIkTXsc469nVFmvdsc4q4bkAAACgaChUAfjYsWOmBrhy5cpmvVWrVnL27Fkzu4Nj06ZNcuXKFYmKinIfozNDXL582X2MzhihNcUVKlTwwbsAAACAtQFY5+vVGRl0UYcPHzY/JyQkmH2jRo2S7du3y5EjR0wdcI8ePaR27dpmEJuqX7++qRMeNGiQ7Ny5U7744gsZPny4KZ3QGSBU7969zQA4nR9Yp0tbtmyZzJw506vEAQAAAPbwaQD+6quv5LbbbjOL0lCqP48bN06KFy9ubmBx7733Sp06dUyAbdGihXz22WemRMGh05zVq1fP1ATr9Gdt2rTxmuNXB7GtW7fOhGt9/jPPPGNenynQAAAA7OTTQXDt27cXl8uV7f5PPvnkmq+hMz4sWbLkqsfo4DkNzgAAAEChqgEGAAAAbhQBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWMWnAXjr1q3SvXt3qVKligQEBMjKlSvd+y5fviyjR4+Wxo0bS5kyZcwx/fr1k+PHj3u9Ro0aNcxzPZepU6d6HbNnzx5p27atlCxZUqpVqyaxsbEF9h4BAADgX3wagC9cuCBNmzaVWbNmZdp38eJF+frrr+WFF14wjytWrJD4+Hi59957Mx07adIkSUxMdC9PPvmke19ycrJ07txZIiMjZdeuXfLKK6/IhAkTZO7cufn+/gAAAOB/An158i5dupglK+XKlZP169d7bXvjjTfkjjvukISEBKlevbp7e9myZSUiIiLL11m8eLGkpqbK/PnzJSgoSBo2bCi7d++WadOmyeDBg/P4HQEAAMDfFaoa4F9//dWUOJQvX95ru5Y8VKpUSW677TbTw5uWlubeFxcXJ+3atTPh1xEdHW16k8+cOZPleVJSUkzPsecCAACAosGnPcDX49KlS6Ym+JFHHpGQkBD39qeeekqaN28uFStWlG3btsmYMWNMGYT28KqkpCSpWbOm12uFh4e791WoUCHTuaZMmSITJ07M9/cEAACAglcoArAOiHvwwQfF5XLJ7NmzvfaNHDnS/XOTJk1MT++QIUNMiA0ODs7V+TREe76u9gDr4DkAAAAUfoGFJfwePXpUNm3a5NX7m5WoqChTAnHkyBGpW7euqQ0+ceKE1zHOenZ1wxqccxueAQAA4N+KFYbwe+DAAdmwYYOp870WHeBWrFgxCQsLM+utWrUy063pazl0cJ2G46zKHwAAAFC0+bQH+Pz583Lw4EH3+uHDh02A1XreypUrywMPPGCmQFu1apWkp6ebml2l+7XUQQe47dixQzp06GBmgtD1ESNGSJ8+fdzhtnfv3qaed+DAgaaGeN++fTJz5kyZPn26z943AAAALA3AX331lQmvDqfuNiYmxszV++GHH5r1Zs2aeT1v8+bN0r59e1OmsHTpUnOsztygg900AHvW7+p0auvWrZNhw4ZJixYtJDQ0VMaNG8cUaAAAAJbyaQDWEKsD27JztX1KZ3/Yvn37Nc+jg+M+++yzXLURAAAARYtf1wADAAAAeY0ADAAAAKsQgAEAAGAVAjAAAACskqsA/NNPP+V9SwAAAAB/DcC1a9c205e98847cunSpbxvFQAAAOBPAVhvTqFTi+l8u3o74SFDhsjOnTvzvnUAAACAPwRgvTGF3k3t+PHjMn/+fElMTJQ2bdpIo0aNZNq0aXLq1Km8bicAAADg+0FwgYGB0rNnT1m+fLm8/PLL5rbGzz77rFSrVk369etngjEAAABQZAKw3sr4iSeekMqVK5ueXw2/hw4dkvXr15ve4R49euRdSwEAAABf3QpZw+6CBQskPj5eunbtKosWLTKPxYr9b56uWbOmLFy4UGrUqJEXbQQAAAB8G4Bnz54tAwYMkP79+5ve36yEhYXJvHnzbrR9AAAAgO8D8IEDB655TFBQkMTExOTm5QEAAAD/qgHW8gcd+JaRbnv77bfzol0AAACA/wTgKVOmSGhoaJZlD5MnT86LdgEAAAD+E4ATEhLMQLeMIiMjzT4AAACgSAVg7ends2dPpu3ffvutVKpUKS/aBQAAAPhPAH7kkUfkqaeeks2bN0t6erpZNm3aJP/v//0/efjhh/O+lQAAAIAvZ4F48cUX5ciRI9KxY0dzNzh15coVc/c3aoABAABQ5AKwTnG2bNkyE4S17KFUqVLSuHFjUwMMAAAAFLkA7KhTp45ZAAAAgCIdgLXmV291vHHjRjl58qQpf/Ck9cAAAABAkQnAOthNA3C3bt2kUaNGEhAQkPctAwAAAPwlAC9dulTee+896dq1a963CAAAAPC3adB0EFzt2rXzvjUAAACAPwbgZ555RmbOnCkulyvvWwQAAAD4WwnE559/bm6C8fHHH0vDhg2lRIkSXvtXrFiRV+0DAAAAfB+Ay5cvL/fff3/etgQAAADw1wC8YMGCvG8JAAAA4K81wCotLU02bNgg//jHP+TcuXNm2/Hjx+X8+fN52T4AAADA9z3AR48elXvuuUcSEhIkJSVF/vCHP0jZsmXl5ZdfNutz5szJ21YCAAAAvuwB1hthtGzZUs6cOSOlSpVyb9e6YL07HAAAAFCkeoA/++wz2bZtm5kP2FONGjXkv//9b161DQAAAPCPHuArV65Ienp6pu3Hjh0zpRAAAABAkQrAnTt3lhkzZrjXAwICzOC38ePHc3tkAAAAFL0SiNdee02io6OlQYMGcunSJendu7ccOHBAQkND5d133837VgIAAAC+DMBVq1aVb7/9VpYuXSp79uwxvb8DBw6URx991GtQHAAAAFAkArB5YmCg9OnTJ29bAwAAAPhjAF60aNFV9/fr1y+37QEAAAD8LwDrPMCeLl++LBcvXjTTopUuXZoADAAAgKI1C4TeAMNz0Rrg+Ph4adOmDYPgAAAAUPQCcFZuvfVWmTp1aqbeYQAAAKBIBmBnYNzx48fz8iUBAAAA39cAf/jhh17rLpdLEhMT5Y033pDWrVvnVdsAAAAA/+gBvu+++7yWnj17yoQJE6RJkyYyf/78HL/O1q1bpXv37lKlShVzN7mVK1dmCtbjxo2TypUrm/mFO3XqZG644en06dNm/uGQkBApX768mY9Ya5I96VzFbdu2lZIlS0q1atUkNjY2N28bAAAAtgbgK1eueC3p6emSlJQkS5YsMWE1py5cuCBNmzaVWbNmZblfg+rf//53mTNnjuzYsUPKlClj7kCnd59zaPjdv3+/rF+/XlatWmVC9eDBg937k5OTza2bIyMjZdeuXfLKK6+YsD537tzcvHUAAADYeiOMvNClSxezZEV7f2fMmCFjx46VHj16uOcfDg8PNz3FDz/8sHz//feydu1a+fLLL6Vly5bmmNdff126du0qr776qulZXrx4saSmppqeaZ2mrWHDhrJ7926ZNm2aV1AGAACAHXIVgEeOHJnjYzVo5sbhw4dNr7KWPTjKlSsnUVFREhcXZwKwPmrZgxN+lR5frFgx02N8//33m2PatWtnwq9De5FffvllM4VbhQoVMp07JSXFLJ69yAAAALA4AH/zzTdm0Rtg1K1b12z78ccfpXjx4tK8eXP3cVrXm1safpX2+HrSdWefPoaFhWWaiaJixYpex9SsWTPTazj7sgrAU6ZMkYkTJ+a67QAAAChiAVgHrpUtW1befvttd4DU3tQ///nPZrDZM888I4XZmDFjvHq5tQdYB88BAADA0kFwr732mukl9ew91Z9feuklsy8vREREmMcTJ054bdd1Z58+njx50mt/WlqamRnC85isXsPzHBkFBwebWSU8FwAAAFgcgLVH9NSpU5m267Zz587lRbtM2YIG1I0bN3qdV2t7W7VqZdb18ezZs2Z2B8emTZvMzBRaK+wcozNDaLmGQ2eM0NKNrMofAAAAULTlKgDr4DItd1ixYoUcO3bMLP/zP/9j5uDVOYFzSufr1RkZdHEGvunPCQkJpn746aefNr3KeuONvXv3Sr9+/czMDjr3sKpfv77cc889MmjQINm5c6d88cUXMnz4cDNATo9TvXv3NgPgtG06XdqyZctk5syZ1zWQDwAAAJbXAOu8vM8++6wJl07Pqg4+05Cp8+zm1FdffSUdOnRwrzuhNCYmRhYuXCjPPfecmStYpyvTnt42bdqYac/0hhYOneZMQ2/Hjh3N7A+9evUycwd7zhyxbt06GTZsmLRo0UJCQ0PNzTWYAg0AAMBOuQrApUuXljfffNOE3UOHDplttWrVMjequB7t27c38/1mR3uBJ02aZJbs6IwPegOOq9E71H322WfX1TYAAAAUTbkqgXAkJiaa5dZbbzXh92phFgAAACi0AfiXX34xJQd16tQxd13TEKy0BKKwT4EGAACAoi1XAXjEiBFSokQJM1hNyyEcDz30kKnRBQAAAIpUDbAOKvvkk0+katWqXtu1FOLo0aN51TYAAADAP3qAdWYGz55fh96AQm8iAQAAABSpAKy3O160aJHXbA1684nY2Fivac0AAACAIlECoUFXB8HpPL6pqalmvl69yYT2AOvNKAAAAIAi1QPcqFEj+fHHH82NKXr06GFKIvQOcN98842ZDxgAAAAoMj3Aeuc3vf2w3g3u+eefz59WAQAAAP7SA6zTn+3Zsyd/WgMAAAD4YwlEnz59ZN68eXnfGgAAAMAfB8GlpaXJ/PnzZcOGDdKiRQtzG2RP06ZNy6v2AQAAAL4LwD/99JPUqFFD9u3bJ82bNzfbdDCcJ50SDQAAACgSAVjv9JaYmCibN2923/r473//u4SHh+dX+wAAAADf1QC7XC6v9Y8//thMgQYAAAAU6UFw2QViAAAAoEgFYK3vzVjjS80vAAAAimwNsPb49u/fX4KDg836pUuXZOjQoZlmgVixYkXethIAAADwRQCOiYnJNB8wAAAAUGQD8IIFC/KvJQAAAIC/D4IDAAAAChsCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAq/h9AK5Ro4YEBARkWoYNG2b2t2/fPtO+oUOHer1GQkKCdOvWTUqXLi1hYWEyatQoSUtL89E7AgAAgC8Fip/78ssvJT093b2+b98++cMf/iB/+tOf3NsGDRokkyZNcq9r0HXoczX8RkREyLZt2yQxMVH69esnJUqUkMmTJxfgOwEAAIA/8PsAfPPNN3utT506VWrVqiV33XWXV+DVgJuVdevWyXfffScbNmyQ8PBwadasmbz44osyevRomTBhggQFBeX7ewAAAID/8PsSCE+pqanyzjvvyIABA0ypg2Px4sUSGhoqjRo1kjFjxsjFixfd++Li4qRx48Ym/Dqio6MlOTlZ9u/fn+V5UlJSzH7PBQAAAEWD3/cAe1q5cqWcPXtW+vfv797Wu3dviYyMlCpVqsiePXtMz258fLysWLHC7E9KSvIKv8pZ131ZmTJlikycODFf3wsAAAB8o1AF4Hnz5kmXLl1M2HUMHjzY/bP29FauXFk6duwohw4dMqUSuaG9yCNHjnSvaw9wtWrVbrD1AAAA8AeFJgAfPXrU1PE6PbvZiYqKMo8HDx40AVhrg3fu3Ol1zIkTJ8xjdnXDwcHBZgEAAEDRU2hqgBcsWGCmMNMZHa5m9+7d5lF7glWrVq1k7969cvLkSfcx69evl5CQEGnQoEE+txoAAAD+plD0AF+5csUE4JiYGAkM/L8ma5nDkiVLpGvXrlKpUiVTAzxixAhp166dNGnSxBzTuXNnE3T79u0rsbGxpu537NixZh5henkBAADsUygCsJY+6M0sdPYHTzqFme6bMWOGXLhwwdTp9urVywRcR/HixWXVqlXy+OOPm97gMmXKmCDtOW8wAAAA7FEoArD24rpcrkzbNfB++umn13y+zhKxZs2afGodAAAACpNCUwMMAAAA5AUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsIpfB+AJEyZIQECA11KvXj33/kuXLsmwYcOkUqVKctNNN0mvXr3kxIkTXq+RkJAg3bp1k9KlS0tYWJiMGjVK0tLSfPBuAAAA4A8Cxc81bNhQNmzY4F4PDPy/Jo8YMUJWr14ty5cvl3Llysnw4cOlZ8+e8sUXX5j96enpJvxGRETItm3bJDExUfr16yclSpSQyZMn++T9AAAAwLf8PgBr4NUAm9Gvv/4q8+bNkyVLlsjdd99tti1YsEDq168v27dvlzvvvFPWrVsn3333nQnQ4eHh0qxZM3nxxRdl9OjRpnc5KCjIB+8IAAAAvuTXJRDqwIEDUqVKFbnlllvk0UcfNSUNateuXXL58mXp1KmT+1gtj6hevbrExcWZdX1s3LixCb+O6OhoSU5Olv3792d7zpSUFHOM5wIAAICiwa8DcFRUlCxcuFDWrl0rs2fPlsOHD0vbtm3l3LlzkpSUZHpwy5cv7/UcDbu6T+mjZ/h19jv7sjNlyhRTUuEs1apVy5f3BwAAgILn1yUQXbp0cf/cpEkTE4gjIyPlvffek1KlSuXbeceMGSMjR450r2sPMCEYAACgaPDrHuCMtLe3Tp06cvDgQVMXnJqaKmfPnvU6RmeBcGqG9THjrBDOelZ1xY7g4GAJCQnxWgAAAFA0FKoAfP78eTl06JBUrlxZWrRoYWZz2Lhxo3t/fHy8qRFu1aqVWdfHvXv3ysmTJ93HrF+/3gTaBg0a+OQ9AAAAwLf8ugTi2Wefle7du5uyh+PHj8v48eOlePHi8sgjj5ja3IEDB5pShYoVK5pQ++STT5rQqzNAqM6dO5ug27dvX4mNjTV1v2PHjjVzB2svLwAAAOzj1wH42LFjJuz+8ssvcvPNN0ubNm3MFGf6s5o+fboUK1bM3ABDZ27QGR7efPNN9/M1LK9atUoef/xxE4zLlCkjMTExMmnSJB++KwAAAPiSXwfgpUuXXnV/yZIlZdasWWbJjvYer1mzJh9aBwAAgMKoUNUAAwAAADeKAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKzi1wF4ypQpcvvtt0vZsmUlLCxM7rvvPomPj/c6pn379hIQEOC1DB061OuYhIQE6datm5QuXdq8zqhRoyQtLa2A3w0AAAD8QaD4sU8//VSGDRtmQrAG1r/+9a/SuXNn+e6776RMmTLu4wYNGiSTJk1yr2vQdaSnp5vwGxERIdu2bZPExETp16+flChRQiZPnlzg7wkAAAC+5dcBeO3atV7rCxcuND24u3btknbt2nkFXg24WVm3bp0JzBs2bJDw8HBp1qyZvPjiizJ69GiZMGGCBAUF5fv7AAAAgP/w6xKIjH799VfzWLFiRa/tixcvltDQUGnUqJGMGTNGLl686N4XFxcnjRs3NuHXER0dLcnJybJ///4sz5OSkmL2ey4AAAAoGvy6B9jTlStX5Omnn5bWrVuboOvo3bu3REZGSpUqVWTPnj2mZ1frhFesWGH2JyUleYVf5azrvuxqjydOnJiv7wcAAAC+UWgCsNYC79u3Tz7//HOv7YMHD3b/rD29lStXlo4dO8qhQ4ekVq1auTqX9iKPHDnSva49wNWqVbuB1gMAAMBfFIoSiOHDh8uqVatk8+bNUrVq1aseGxUVZR4PHjxoHrU2+MSJE17HOOvZ1Q0HBwdLSEiI1wIAAICiwa8DsMvlMuH3/fffl02bNknNmjWv+Zzdu3ebR+0JVq1atZK9e/fKyZMn3cesX7/ehNoGDRrkY+sBAADgjwL9vexhyZIl8sEHH5i5gJ2a3XLlykmpUqVMmYPu79q1q1SqVMnUAI8YMcLMENGkSRNzrE6bpkG3b9++Ehsba15j7Nix5rW1pxcAAAB28ese4NmzZ5uZH/RmF9qj6yzLli0z+3UKM53eTENuvXr15JlnnpFevXrJRx995H6N4sWLm/IJfdTe4D59+ph5gD3nDQYAAIA9Av29BOJqdGCa3izjWnSWiDVr1uRhywAAAFBY+XUPMAAAAJDXCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAFiFAAwAAACrEIABAABgFQIwAAAArEIABgAAgFUIwAAAALAKARgAAABWIQADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqVgXgWbNmSY0aNaRkyZISFRUlO3fu9HWTAAAAUMCsCcDLli2TkSNHyvjx4+Xrr7+Wpk2bSnR0tJw8edLXTQMAAEABsiYAT5s2TQYNGiR//vOfpUGDBjJnzhwpXbq0zJ8/39dNAwAAQAEKFAukpqbKrl27ZMyYMe5txYoVk06dOklcXFym41NSUszi+PXXX81jcnJyAbVY5Pz58+bx9NF4SUv5Ld/Pl5yUYB71c3LOnd/0Gly5cqVAzsX5isY5OV/hPp8vzsn5OJ+/n7MgzxcfH++TbHH+/PkCyVDOOVwu1zWPDXDl5KhC7vjx4/K73/1Otm3bJq1atXJvf+655+TTTz+VHTt2eB0/YcIEmThxog9aCgAAgBvxn//8R6pWrXrVY6zoAb5e2lOs9cIO/cvs9OnTUqlSJQkICCiQNuhfMdWqVTMXMSQkpEDOibzD9Sv8uIaFH9ewcOP6FX7JBXwNtU/33LlzUqVKlWsea0UADg0NleLFi8uJEye8tut6REREpuODg4PN4ql8+fLiC/oLwz/8wovrV/hxDQs/rmHhxvUr/EIK8BqWK1cuR8dZMQguKChIWrRoIRs3bvTq1dV1z5IIAAAAFH1W9AArLWmIiYmRli1byh133CEzZsyQCxcumFkhAAAAYA9rAvBDDz0kp06dknHjxklSUpI0a9ZM1q5dK+Hh4eKPtARD5yzOWIqBwoHrV/hxDQs/rmHhxvUr/IL9+BpaMQsEAAAAYFUNMAAAAOAgAAMAAMAqBGAAAABYhQAMAAAAqxCAfWjWrFlSo0YNKVmypERFRcnOnTuvevzy5culXr165vjGjRvLmjVrCqytuLHr989//lPatm0rFSpUMEunTp2ueb3hf/8GHUuXLjV3hbzvvvvyvY3I22t49uxZGTZsmFSuXNmMTK9Tpw7/LS1E10+nMK1bt66UKlXK3GFsxIgRcunSpQJrL7xt3bpVunfvbu68pv9NXLlypVzLli1bpHnz5ubfX+3atWXhwoXiEzoLBAre0qVLXUFBQa758+e79u/f7xo0aJCrfPnyrhMnTmR5/BdffOEqXry4KzY21vXdd9+5xo4d6ypRooRr7969Bd52XP/16927t2vWrFmub775xvX999+7+vfv7ypXrpzr2LFjBd525O4aOg4fPuz63e9+52rbtq2rR48eBdZe3Pg1TElJcbVs2dLVtWtX1+eff26u5ZYtW1y7d+8u8Lbj+q/f4sWLXcHBweZRr90nn3ziqly5smvEiBEF3nb8rzVr1rief/5514oVK3RGMdf777/vupqffvrJVbp0adfIkSNNlnn99ddNtlm7dq2roBGAfeSOO+5wDRs2zL2enp7uqlKlimvKlClZHv/ggw+6unXr5rUtKirKNWTIkHxvK278+mWUlpbmKlu2rOvtt9/Ox1Yir6+hXrff//73rrfeessVExNDAC5k13D27NmuW265xZWamlqArUReXT899u677/bapkGqdevW+d5WXFtOAvBzzz3natiwode2hx56yBUdHe0qaJRA+EBqaqrs2rXLfA3uKFasmFmPi4vL8jm63fN4FR0dne3x8K/rl9HFixfl8uXLUrFixXxsKfL6Gk6aNEnCwsJk4MCBBdRS5OU1/PDDD6VVq1amBEJvgtSoUSOZPHmypKenF2DLkdvr9/vf/948xymT+Omnn0z5SteuXQus3bgx/pRlrLkTnD/5+eefzX9wM96FTtd/+OGHLJ+jd6/L6njdDv+/fhmNHj3a1Exl/A8B/Pcafv755zJv3jzZvXt3AbUSeX0NNTBt2rRJHn30UROcDh48KE888YT5Y1TvVgX/vn69e/c2z2vTpo1+ey1paWkydOhQ+etf/1pArcaNyi7LJCcny2+//WZquwsKPcBAAZs6daoZRPX++++bgR/wf+fOnZO+ffuawYyhoaG+bg5y6cqVK6YHf+7cudKiRQt56KGH5Pnnn5c5c+b4umnIAR08pT32b775pnz99deyYsUKWb16tbz44ou+bhoKIXqAfUD/D7R48eJy4sQJr+26HhERkeVzdPv1HA//un6OV1991QTgDRs2SJMmTfK5pcira3jo0CE5cuSIGe3sGaZUYGCgxMfHS61atQqg5biRf4c680OJEiXM8xz169c3vVL6lXxQUFC+txu5v34vvPCC+UP0scceM+s6G9KFCxdk8ODB5g8ZLaGAf4vIJsuEhIQUaO+v4rfFB/Q/str7sHHjRq//M9V1rU/Lim73PF6tX78+2+PhX9dPxcbGmp6KtWvXSsuWLQuotciLa6jTD+7du9eUPzjLvffeKx06dDA/63RM8P9/h61btzZlD84fL+rHH380wZjw6//XT8dOZAy5zh8z/zsGC/6ulT9lmQIfdgf39C86ncvChQvNVCCDBw82078kJSWZ/X379nX95S9/8ZoGLTAw0PXqq6+aabTGjx/PNGiF6PpNnTrVTPfz73//25WYmOhezp0758N3YbfrvYYZMQtE4buGCQkJZvaV4cOHu+Lj412rVq1yhYWFuV566SUfvgt7Xe/10//f0+v37rvvmum01q1b56pVq5aZJQm+ce7cOTO9py4aKadNm2Z+Pnr0qNmv10+vY8Zp0EaNGmWyjE4PyjRoFtL576pXr26CkU4Hs337dve+u+66y/wfrKf33nvPVadOHXO8TiOyevVqH7Qaubl+kZGR5j8OGRf9DzoKz79BTwTgwnkNt23bZqaQ1OClU6L97W9/M9Pbwf+v3+XLl10TJkwwobdkyZKuatWquZ544gnXmTNnfNR6bN68Ocv/b3Oumz7qdcz4nGbNmplrrv8GFyxY4JO2B+j/FHy/MwAAAOAb1AADAADAKgRgAAAAWIUADAAAAKsQgAEAAGAVAjAAAACsQgAGAACAVQjAAAAAsAoBGAAAAFYhAAMAAMAqBGAAuAH9+/eX++67L9P2LVu2SEBAgJw9e1YKi/T0dJk6darUq1dPSpUqJRUrVpSoqCh56623fN00AMhTgXn7cgAAf5eamipBQUGZtk+cOFH+8Y9/yBtvvCEtW7aU5ORk+eqrr+TMmTMF3hYAyE/0AANAAZgwYYI0a9bMa9uMGTOkRo0amXqTJ0+eLOHh4VK+fHmZNGmSpKWlyahRo0yPbNWqVWXBggVer7N37165++67Ta9tpUqVZPDgwXL+/PlMr/u3v/1NqlSpInXr1s2yjR9++KE88cQT8qc//Ulq1qwpTZs2lYEDB8qzzz7rPubKlSsSGxsrtWvXluDgYKlevbp53Rtty3/+8x958MEHzXvW99mjRw85cuSIV4/6HXfcIWXKlDHHtG7dWo4ePZrLqwHAdgRgAPAjmzZtkuPHj8vWrVtl2rRpMn78ePnjH/8oFSpUkB07dsjQoUNlyJAhcuzYMXP8hQsXJDo62uz/8ssvZfny5bJhwwYZPny41+tu3LhR4uPjZf369bJq1aoszx0REWHOf+rUqWzbN2bMGFMm8cILL8h3330nS5YsMWH9Rtpy+fJl87yyZcvKZ599Jl988YXcdNNNcs8995geYv0DQEPzXXfdJXv27JG4uDgTrLXEBAByxQUAyLWYmBhX8eLFXWXKlPFaSpYs6dL/xJ45c8YcN378eFfTpk29njt9+nRXZGSk12vpenp6untb3bp1XW3btnWvp6Wlmdd/9913zfrcuXNdFSpUcJ0/f959zOrVq13FihVzJSUluV83PDzclZKSctX3sn//flf9+vXNcxs3buwaMmSIa82aNe79ycnJruDgYNc///nPLJ+f27b861//Mu/zypUr7m26v1SpUq5PPvnE9csvv5jPcsuWLVdtPwDkFD3AAHCDOnToILt37/ZacjtwrGHDhlKs2P/9p1l7Vxs3buxeL168uCktOHnypFn//vvvTamClgY4tDxASxW0l9Whr3GtWtsGDRrIvn37ZPv27TJgwABzju7du8tjjz3mPldKSop07Ngxy+fnti3ffvutHDx40PQAa8+vLloGcenSJTl06JD5WUsntJdY2zNz5kxJTEzM8WcKABkxCA4AbpAGPq2J9eSUKDg01Lpc2pH5f/Sr/4xKlCjhta5f82e1TUPl9bYxJ7Sdt99+u1mefvppeeedd6Rv377y/PPPm7revJCxLVoj3KJFC1m8eHGmY2+++WbzqHXPTz31lKxdu1aWLVsmY8eONSUUd955Z560CYBd6AEGgAKgQS4pKckrBGtP8Y2qX7++6UHV+luH1tBqkM1usNv10F5hpa9/6623mhCsNbx52ZbmzZvLgQMHJCwszPwh4bmUK1fOfdxtt91mapC3bdsmjRo1MvXHAJAbBGAAKADt27c3g8t0BgX9Wn/WrFny8ccf3/DrPvroo1KyZEmJiYkx5QubN2+WJ5980vTaOoPTcuqBBx6Q6dOnm8F2OsOCzrwwbNgwqVOnjpkbWM8zevRoee6552TRokXmfWi5xLx5826oLfq80NBQM/ODDoI7fPiwObf2+GpPuq5r8NXBb9qudevWmcCsgRsAcoMADAAFQMPam2++aYKv1snu3LnTa3qx3CpdurR88skncvr0aVO2oCFWa3R1Lt/rpTW2H330kamz1dCrQVaDrwbOwMD/rZjT2R+eeeYZGTdunHlPDz30kLseObdt0efprBc6pVrPnj3N6+r0a1oDHBISYvb/8MMP0qtXL9MunQFCg7nOhgEAuRGgI+Fy9UwAAACgEKIHGAAAAFYhAAMAAMAqBGAAAABYhQAMAAAAqxCAAQAAYBUCMAAAAKxCAAYAAIBVCMAAAACwCgEYAAAAViEAAwAAwCoEYAAAAIhN/j+wLVS5ctgASAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 800x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Visualize label distribution in training data\n",
"train_labels = [item[\"labels\"].item() for item in train_dataset]\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"sns.histplot(train_labels, bins=20)\n",
"plt.xlabel(\"Humor Scores\")\n",
"plt.ylabel(\"Frequency\")\n",
"plt.title(\"Training Labels Distribution\")\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Define the CNN model for regression\n",
"class CNN_HumorRegressor(nn.Module):\n",
" def __init__(self, embed_dim, filter_sizes, num_filters, dropout=0.5):\n",
" super(CNN_HumorRegressor, self).__init__()\n",
" self.convs = nn.ModuleList([\n",
" nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embed_dim)) \n",
" for fs in filter_sizes\n",
" ])\n",
" self.highway = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.fc1 = nn.Linear(len(filter_sizes) * num_filters, 256)\n",
" self.fc2 = nn.Linear(256, 128)\n",
" self.fc3 = nn.Linear(128, 1)\n",
"\n",
" def forward(self, x):\n",
" x = x.unsqueeze(1) # [Batch Size, 1, Seq Length, Embed Dim]\n",
" conved = [F.relu(conv(x)).squeeze(3) for conv in self.convs]\n",
" pooled = [F.max_pool1d(c, c.size(2)).squeeze(2) for c in conved]\n",
" cat = torch.cat(pooled, dim=1)\n",
" highway = F.relu(self.highway(cat))\n",
" highway = self.dropout(highway + cat)\n",
" fc_out = F.relu(self.fc1(highway))\n",
" fc_out = F.relu(self.fc2(fc_out))\n",
" return torch.sigmoid(self.fc3(fc_out)) # Sigmoid for range [0, 1]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Define the weighted MSE loss\n",
"class WeightedMSELoss(nn.Module):\n",
" def __init__(self, weights):\n",
" super(WeightedMSELoss, self).__init__()\n",
" self.weights = weights\n",
"\n",
" def forward(self, inputs, targets):\n",
" weights = self.weights[targets.long()]\n",
" loss = weights * (inputs - targets) ** 2\n",
" return loss.mean()\n",
"\n",
"# Define weights for loss function\n",
"weights = torch.tensor([2.0 if 0.2 <= x <= 0.8 else 1.0 for x in range(2)], dtype=torch.float32)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Define the training function with ReduceLROnPlateau\n",
"def train_model_with_plateau_scheduler(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, device, patience=3):\n",
" train_losses = []\n",
" val_losses = []\n",
" best_val_loss = float('inf')\n",
" patience_counter = 0\n",
"\n",
" for epoch in range(epochs):\n",
" model.train()\n",
" total_loss = 0\n",
"\n",
" # Training phase\n",
" with tqdm(train_loader, unit=\"batch\", desc=f\"Epoch {epoch+1}/{epochs}\") as tepoch:\n",
" for inputs, labels in tepoch:\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
" tepoch.set_postfix(loss=loss.item())\n",
"\n",
" avg_train_loss = total_loss / len(train_loader)\n",
" train_losses.append(avg_train_loss)\n",
"\n",
" # Validation phase\n",
" val_loss, val_r2, val_mae = evaluate_with_metrics(model, val_loader, criterion, device)\n",
" val_losses.append(val_loss)\n",
"\n",
" print(f\"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f} - Val Loss: {val_loss:.4f}\")\n",
" print(f\"Validation R²: {val_r2:.4f} | Validation MAE: {val_mae:.4f}\")\n",
"\n",
" # Scheduler step\n",
" scheduler.step(val_loss)\n",
"\n",
" # Early stopping logic\n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" patience_counter = 0\n",
" torch.save(model.state_dict(), \"best_model.pt\") # Save best model\n",
" else:\n",
" patience_counter += 1\n",
" print(f\"No improvement for {patience_counter} epoch(s).\")\n",
"\n",
" if patience_counter >= patience:\n",
" print(\"Early stopping triggered.\")\n",
" break\n",
"\n",
" # Load best model after training\n",
" model.load_state_dict(torch.load(\"best_model.pt\"))\n",
"\n",
"# Evaluation function with metrics\n",
"def evaluate_with_metrics(model, data_loader, criterion, device):\n",
" model.eval()\n",
" total_loss = 0\n",
" predictions, actuals = [], []\n",
"\n",
" with torch.no_grad():\n",
" for inputs, labels in data_loader:\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" total_loss += loss.item()\n",
" predictions.extend(outputs.cpu().numpy().flatten())\n",
" actuals.extend(labels.cpu().numpy().flatten())\n",
"\n",
" avg_loss = total_loss / len(data_loader)\n",
" r2 = r2_score(actuals, predictions)\n",
" mae = mean_absolute_error(actuals, predictions)\n",
"\n",
" return avg_loss, r2, mae\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/michellegoppinger/.pyenv/versions/3.12.3/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n",
" warnings.warn(\n",
"Epoch 1/10: 100%|██████████| 124/124 [00:31<00:00, 3.98batch/s, loss=0.22] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10 - Train Loss: 0.2443 - Val Loss: 0.2275\n",
"Validation R²: 0.0946 | Validation MAE: 0.4442\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2/10: 100%|██████████| 124/124 [00:30<00:00, 4.12batch/s, loss=0.267]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/10 - Train Loss: 0.2150 - Val Loss: 0.2126\n",
"Validation R²: 0.1520 | Validation MAE: 0.4143\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 3/10: 100%|██████████| 124/124 [00:30<00:00, 4.13batch/s, loss=0.12] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/10 - Train Loss: 0.1805 - Val Loss: 0.2393\n",
"Validation R²: 0.0442 | Validation MAE: 0.3811\n",
"No improvement for 1 epoch(s).\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 4/10: 100%|██████████| 124/124 [00:30<00:00, 4.11batch/s, loss=0.119] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/10 - Train Loss: 0.1306 - Val Loss: 0.2551\n",
"Validation R²: -0.0116 | Validation MAE: 0.3799\n",
"No improvement for 2 epoch(s).\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5/10: 100%|██████████| 124/124 [00:30<00:00, 4.08batch/s, loss=0.0157]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/10 - Train Loss: 0.0840 - Val Loss: 0.2769\n",
"Validation R²: -0.0851 | Validation MAE: 0.3798\n",
"No improvement for 3 epoch(s).\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 6/10: 100%|██████████| 124/124 [00:30<00:00, 4.12batch/s, loss=0.00121]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/10 - Train Loss: 0.0412 - Val Loss: 0.2997\n",
"Validation R²: -0.1832 | Validation MAE: 0.3758\n",
"No improvement for 4 epoch(s).\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 7/10: 100%|██████████| 124/124 [00:30<00:00, 4.12batch/s, loss=0.11] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7/10 - Train Loss: 0.0245 - Val Loss: 0.2891\n",
"Validation R²: -0.1477 | Validation MAE: 0.3619\n",
"No improvement for 5 epoch(s).\n",
"Early stopping triggered.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/l7/061cw0t95vz1myntpf9bj9540000gn/T/ipykernel_16242/4163769425.py:53: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" model.load_state_dict(torch.load(\"best_model.pt\"))\n"
]
}
],
"source": [
"# Hyperparameters\n",
"EMBED_DIM = train_dataset[0][\"input_ids\"].shape[1]\n",
"FILTER_SIZES = [2, 3, 4, 5]\n",
"NUM_FILTERS = 300\n",
"DROPOUT = 0.5\n",
"LR = 0.001\n",
"EPOCHS = 10\n",
"\n",
"device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# Initialize model, loss, optimizer, and scheduler\n",
"model = CNN_HumorRegressor(EMBED_DIM, FILTER_SIZES, NUM_FILTERS, DROPOUT).to(device)\n",
"criterion = WeightedMSELoss(weights.to(device))\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
"scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)\n",
"\n",
"# Train the model\n",
"train_model_with_plateau_scheduler(model, train_loader, val_loader, criterion, optimizer, scheduler, EPOCHS, device, patience=5)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Set Metrics:\n",
"Test Loss (MSE): 0.2196\n",
"Test R²: 0.1218\n",
"Test MAE: 0.4207\n"
]
}
],
"source": [
"# Evaluate the model on test set\n",
"test_loss, test_r2, test_mae = evaluate_with_metrics(model, test_loader, criterion, device)\n",
"print(\"Test Set Metrics:\")\n",
"print(f\"Test Loss (MSE): {test_loss:.4f}\")\n",
"print(f\"Test R²: {test_r2:.4f}\")\n",
"print(f\"Test MAE: {test_mae:.4f}\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}