""" This file contains the Datasets class. """ import torch from nltk.tokenize import word_tokenize from torch.utils.data import Dataset from transformers import AutoTokenizer class GloveDataset(Dataset): def __init__(self, texts, labels, word_index, max_len=50): self.original_indices = labels.index.to_list() self.texts = texts.reset_index(drop=True) self.labels = labels.reset_index(drop=True) self.word_index = word_index self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, idx): texts = self.texts[idx] tokens = word_tokenize(texts.lower()) label = self.labels[idx] # Tokenize and convert to indices input_ids = [self.word_index.get(word, self.word_index['']) for word in tokens] # Pad or truncate to max_len if len(input_ids) < self.max_len: input_ids += [self.word_index['']] * (self.max_len - len(input_ids)) else: input_ids = input_ids[:self.max_len] # Convert to PyTorch tensors input_ids = torch.tensor(input_ids, dtype=torch.long) label = torch.tensor(label, dtype=torch.float) return input_ids, label class BertDataset(Dataset): def __init__(self,tokenizer:AutoTokenizer, texts, labels, max_len:int=128): super(BertDataset,self).__init__() self.tokenizer = tokenizer self.max_length = max_len self.text = texts.to_numpy() self.labels = labels.to_numpy() def __getitem__(self,idx:int): text = self.text[idx] labels = self.labels[idx] encoding = self.tokenizer( text, padding="max_length", return_attention_mask = True, max_length=self.max_length, truncation = True, return_tensors = 'pt' ) input_ids = encoding['input_ids'].flatten() attention_mask = encoding['attention_mask'].flatten() return { 'input_ids': torch.as_tensor(input_ids,dtype=torch.long), 'attention_mask':torch.as_tensor(attention_mask,dtype=torch.long), 'labels':torch.tensor(labels,dtype=torch.float) } def __len__(self): return len(self.labels)