ANLP_WS24_CA2/BertFine.ipynb

418 lines
82 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Notebook Fine-Tuning Bert\n",
"In diesem Notebook wird Bert bzw. 'BertForSequenceClassification' feingetuned. <br>\n",
"Funktionen werden aus diesem [Skript](bert_no_ernie.py) geladen."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from bert_no_ernie import *\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Rohdaten einlesen\n",
"An dieser Stelle, wird der Hackathon Datensatz eingelesen welcher Annotierte Daten enthält.\n",
"Die wichtigsten Attribute dieses Datensatzes in diesem sind *Text* (welcher den \"Witz\" als String enthält) und *is_humor* (ein durch 0 und 1 dargestellter Wahrheitswert) welcher angibt ob der entsprechende Text in der Zeile ein Witz ist oder nicht."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>text</th>\n",
" <th>is_humor</th>\n",
" <th>humor_rating</th>\n",
" <th>humor_controversy</th>\n",
" <th>offense_rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>TENNESSEE: We're the best state. Nobody even c...</td>\n",
" <td>1</td>\n",
" <td>2.42</td>\n",
" <td>1.0</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>A man inserted an advertisement in the classif...</td>\n",
" <td>1</td>\n",
" <td>2.50</td>\n",
" <td>1.0</td>\n",
" <td>1.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>How many men does it take to open a can of bee...</td>\n",
" <td>1</td>\n",
" <td>1.95</td>\n",
" <td>0.0</td>\n",
" <td>2.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>Told my mom I hit 1200 Twitter followers. She ...</td>\n",
" <td>1</td>\n",
" <td>2.11</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Roses are dead. Love is fake. Weddings are bas...</td>\n",
" <td>1</td>\n",
" <td>2.78</td>\n",
" <td>0.0</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id text is_humor \\\n",
"0 1 TENNESSEE: We're the best state. Nobody even c... 1 \n",
"1 2 A man inserted an advertisement in the classif... 1 \n",
"2 3 How many men does it take to open a can of bee... 1 \n",
"3 4 Told my mom I hit 1200 Twitter followers. She ... 1 \n",
"4 5 Roses are dead. Love is fake. Weddings are bas... 1 \n",
"\n",
" humor_rating humor_controversy offense_rating \n",
"0 2.42 1.0 0.2 \n",
"1 2.50 1.0 1.1 \n",
"2 1.95 0.0 2.4 \n",
"3 2.11 1.0 0.0 \n",
"4 2.78 0.0 0.1 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"data/hack.csv\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#Hyperparameter festlegen. Und Zufall seeden\n",
"# Set Max Epoch Amount\n",
"EPOCH = 10\n",
"# DROPOUT-PROBABILITY\n",
"DROPOUT = 0.1\n",
"# BATCHSIZE\n",
"BATCH_SIZE = 16\n",
"#LEARNING RATE\n",
"LEARNING_RATE = 1e-5\n",
"# RANDOM SEED\n",
"RNDM_SEED = 501\n",
"# FREEZE Bert Layers\n",
"FREEZE = True\n",
"\n",
"torch.manual_seed(RNDM_SEED)\n",
"np.random.seed(RNDM_SEED)\n",
"torch.cuda.manual_seed_all(RNDM_SEED)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Tokenizer für Bert Model laden.\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-uncased\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Daten aufteilen(70/15/15) und an Custom Dataset Klasse übergeben\n",
"train_data,test_data,val_data = create_datasets(tokenizer,df,.7,True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# DataLoaders basierend auf Datasets kreieren.\n",
"train_loader, test_loader, validation_loader = create_dataloaders([train_data,test_data,val_data],batchsize=BATCH_SIZE,shufflelist=[True,True,False])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"# Model instanziieren, sowie Loss-Funktion und Optimizer\n",
"mybert = CustomBert(DROPOUT)\n",
"mybert.to(DEVICE)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(mybert.parameters(), lr = LEARNING_RATE)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For 1 the Scores are: \n",
"Training Loss is 0.6827\n",
"Validation Loss: 0.6828 ### Validation Accuracy 60.8333%\n",
"For 2 the Scores are: \n",
"Training Loss is 0.6836\n",
"Validation Loss: 0.6825 ### Validation Accuracy 60.8333%\n",
"For 3 the Scores are: \n",
"Training Loss is 0.6824\n",
"Validation Loss: 0.6821 ### Validation Accuracy 60.8333%\n",
"For 4 the Scores are: \n",
"Training Loss is 0.6815\n",
"Validation Loss: 0.6817 ### Validation Accuracy 60.8333%\n",
"For 5 the Scores are: \n",
"Training Loss is 0.6808\n",
"Validation Loss: 0.6814 ### Validation Accuracy 60.8333%\n",
"For 6 the Scores are: \n",
"Training Loss is 0.6809\n",
"Validation Loss: 0.6810 ### Validation Accuracy 60.8333%\n",
"For 7 the Scores are: \n",
"Training Loss is 0.6801\n",
"Validation Loss: 0.6807 ### Validation Accuracy 60.7500%\n",
"For 8 the Scores are: \n",
"Training Loss is 0.6795\n",
"Validation Loss: 0.6804 ### Validation Accuracy 60.7500%\n",
"For 9 the Scores are: \n",
"Training Loss is 0.6797\n",
"Validation Loss: 0.6801 ### Validation Accuracy 60.7500%\n",
"For 10 the Scores are: \n",
"Training Loss is 0.6793\n",
"Validation Loss: 0.6799 ### Validation Accuracy 60.7500%\n"
]
}
],
"source": [
"# Trainings - und Validierungs Durchgänge\n",
"loss_vals, eval_vals = np.zeros(EPOCH), np.zeros(EPOCH)\n",
"\n",
"for epoch in range(EPOCH):\n",
" print(f\"For {epoch+1} the Scores are: \")\n",
" loss_vals[epoch] = training_loop(mybert,optimizer=optimizer,criterion=criterion,train_loader=train_loader,freeze_bert=FREEZE)\n",
" eval_vals[epoch] = eval_loop(mybert,criterion=criterion,validation_loader=validation_loader) "
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.68267711, 0.68355761, 0.68237029, 0.68148399, 0.68079539,\n",
" 0.68086683, 0.68012043, 0.67948493, 0.67972843, 0.67932365])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([0.68283186, 0.68245001, 0.68208028, 0.68170239, 0.68136094,\n",
" 0.68103237, 0.68071597, 0.68041458, 0.68011246, 0.67985092])"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(loss_vals)\n",
"display(eval_vals)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"def test_loop(model:CustomBert, test_loader:DataLoader):\n",
" test_accuracy = np.zeros(len(test_loader))\n",
" for index,batch in enumerate(test_loader):\n",
" input_ids, att_mask, labels = batch.values()\n",
" input_ids, att_mask, labels = input_ids.to(DEVICE), att_mask.to(DEVICE), labels.to(DEVICE)\n",
" with torch.no_grad():\n",
" # model = torch.load(\"best_bert_model.pth\")\n",
" # model.to(DEVICE)\n",
" output = model(input_ids,att_mask)\n",
" output = output.cpu()\n",
" labels = labels.cpu()\n",
" pred_flat = np.argmax(a=output,axis=1).flatten()\n",
" test_accuracy[index] = accuracy_score(labels,pred_flat)\n",
"\n",
" return test_accuracy"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.6875, 0.5625, 0.75 , 0.625 , 0.625 , 0.75 , 0.6875, 0.5 ,\n",
" 0.375 , 0.1875, 0.4375, 0.75 , 0.75 , 0.8125, 0.5 , 0.5 ,\n",
" 0.8125, 0.5 , 0.8125, 0.625 , 0.5625, 0.4375, 0.5625, 0.8125,\n",
" 0.6875, 0.8125, 0.625 , 0.6875, 0.5625, 0.75 , 0.8125, 0.8125,\n",
" 0.75 , 0.5 , 0.625 , 0.6875, 0.6875, 0.5 , 0.625 , 0.5625,\n",
" 0.625 , 0.4375, 0.6875, 0.75 , 0.6875, 0.1875, 0.625 , 0.5 ,\n",
" 0.875 , 0.625 , 0.625 , 0.4375, 0.5625, 0.6875, 0.6875, 0.625 ,\n",
" 0.375 , 0.4375, 0.6875, 0.6875, 0.5625, 0.4375, 0.5 , 0.5625,\n",
" 0.6875, 0.5625, 0.4375, 0.8125, 0.75 , 0.75 , 0.625 , 0.6875,\n",
" 0.5625, 0.9375, 0.5625])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_acc_score = test_loop(mybert,test_loader)\n",
"test_acc_score"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"def plot_test_metrics(accuracy):\n",
" \"\"\"\n",
" Plot Test Metrics of Model (Confiuson Matrix, Accuracy)\n",
" \"\"\"\n",
" plt.plot(accuracy)\n",
" plt.hlines(np.mean(accuracy),0,len(accuracy),'red','dotted','Mean Accuracy {:.4f}'.format(np.mean(accuracy)))\n",
" plt.title(\"Accuracy of Test\")\n",
" plt.xlabel(\"Num Batches\")\n",
" plt.ylabel(\"Accurcy 0.0 - 1.0\")\n",
" plt.grid(True)\n",
" plt.legend()\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_test_metrics(test_acc_score)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}