""" This file contains the HumorDataset class. """ import torch import numpy as np from nltk.tokenize import word_tokenize class TextDataset(torch.utils.data.Dataset): def __init__(self, texts, labels, word_index, max_len=50): self.original_indices = labels.index.to_list() self.texts = texts.reset_index(drop=True) self.labels = labels.reset_index(drop=True) self.word_index = word_index self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, idx): texts = self.texts[idx] tokens = word_tokenize(texts.lower()) label = self.labels[idx] # Tokenize and convert to indices input_ids = [self.word_index.get(word, self.word_index['']) for word in tokens] # Pad or truncate to max_len if len(input_ids) < self.max_len: input_ids += [self.word_index['']] * (self.max_len - len(input_ids)) else: input_ids = input_ids[:self.max_len] # Convert to PyTorch tensors input_ids = torch.tensor(input_ids, dtype=torch.long) label = torch.tensor(label, dtype=torch.long) return input_ids, label class HumorDataset(torch.utils.data.Dataset): def __init__(self, data, labels, vocab_size=0, emb_dim=None): self.original_indices = labels.index.to_list() self.data = data self.labels = labels.reset_index(drop=True) self.vocab_size = vocab_size self.emb_dim = emb_dim # TODO: bug fix self.shape = self.get_shape() def __getitem__(self, idx): item = {'input_ids': torch.tensor(self.data[idx], dtype=torch.float)} item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float) return item def __len__(self): return len(self.labels) def get_single_shape(self, data): shape_data = None if type(data) == list: shape_data = len(data[0]) elif type(data) == torch.Tensor: shape_data = data[0].shape elif type(data) == np.ndarray: shape_data = data[0].shape return shape_data def get_shape(self): shape_data = self.get_single_shape(self.data) shape_labels = self.get_single_shape(self.labels) return shape_data, shape_labels