ANLP_WS24_CA2/HumorDataset.py

112 lines
3.4 KiB
Python

"""
This file contains the HumorDataset class.
"""
import torch
import numpy as np
from nltk.tokenize import word_tokenize
class TextRegDataset(torch.utils.data.Dataset):
def __init__(self, texts, labels, word_index, max_len=50):
self.original_indices = labels.index.to_list()
self.texts = texts.reset_index(drop=True)
self.labels = labels.reset_index(drop=True)
self.word_index = word_index
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
texts = self.texts[idx]
tokens = word_tokenize(texts.lower())
label = self.labels[idx]
# Tokenize and convert to indices
input_ids = [self.word_index.get(word, self.word_index['<UNK>']) for word in tokens]
# Pad or truncate to max_len
if len(input_ids) < self.max_len:
input_ids += [self.word_index['<PAD>']] * (self.max_len - len(input_ids))
else:
input_ids = input_ids[:self.max_len]
# Convert to PyTorch tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
label = torch.tensor(label, dtype=torch.float)
return input_ids, label
class TextDataset(torch.utils.data.Dataset):
def __init__(self, texts, labels, word_index, max_len=50):
self.original_indices = labels.index.to_list()
self.texts = texts.reset_index(drop=True)
self.labels = labels.reset_index(drop=True)
self.word_index = word_index
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
texts = self.texts[idx]
tokens = word_tokenize(texts.lower())
label = self.labels[idx]
# Tokenize and convert to indices
input_ids = [self.word_index.get(word, self.word_index['<UNK>']) for word in tokens]
# Pad or truncate to max_len
if len(input_ids) < self.max_len:
input_ids += [self.word_index['<PAD>']] * (self.max_len - len(input_ids))
else:
input_ids = input_ids[:self.max_len]
# Convert to PyTorch tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
label = torch.tensor(label, dtype=torch.long)
return input_ids, label
class HumorDataset(torch.utils.data.Dataset):
def __init__(self, data, labels, vocab_size=0, emb_dim=None):
self.original_indices = labels.index.to_list()
self.data = data
self.labels = labels.reset_index(drop=True)
self.vocab_size = vocab_size
self.emb_dim = emb_dim
# TODO: bug fix
self.shape = self.get_shape()
def __getitem__(self, idx):
item = {'input_ids': torch.tensor(self.data[idx], dtype=torch.float)}
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)
return item
def __len__(self):
return len(self.labels)
def get_single_shape(self, data):
shape_data = None
if type(data) == list:
shape_data = len(data[0])
elif type(data) == torch.Tensor:
shape_data = data[0].shape
elif type(data) == np.ndarray:
shape_data = data[0].shape
return shape_data
def get_shape(self):
shape_data = self.get_single_shape(self.data)
shape_labels = self.get_single_shape(self.labels)
return shape_data, shape_labels