ANLP_WS24_CA2/Datasets.py

69 lines
2.3 KiB
Python

"""
This file contains the Datasets class.
"""
import torch
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset
from transformers import AutoTokenizer
class GloveDataset(Dataset):
def __init__(self, texts, labels, word_index, max_len=50):
self.original_indices = labels.index.to_list()
self.texts = texts.reset_index(drop=True)
self.labels = labels.reset_index(drop=True)
self.word_index = word_index
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
texts = self.texts[idx]
tokens = word_tokenize(texts.lower())
label = self.labels[idx]
# Tokenize and convert to indices
input_ids = [self.word_index.get(word, self.word_index['<UNK>']) for word in tokens]
# Pad or truncate to max_len
if len(input_ids) < self.max_len:
input_ids += [self.word_index['<PAD>']] * (self.max_len - len(input_ids))
else:
input_ids = input_ids[:self.max_len]
# Convert to PyTorch tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
label = torch.tensor(label, dtype=torch.float)
return input_ids, label
class BertDataset(Dataset):
def __init__(self,tokenizer:AutoTokenizer, texts, labels, max_len:int=128):
super(BertDataset,self).__init__()
self.tokenizer = tokenizer
self.max_length = max_len
self.text = texts.to_numpy()
self.labels = labels.to_numpy()
def __getitem__(self,idx:int):
text = self.text[idx]
labels = self.labels[idx]
encoding = self.tokenizer(
text,
padding="max_length",
return_attention_mask = True,
max_length=self.max_length,
truncation = True,
return_tensors = 'pt'
)
input_ids = encoding['input_ids'].flatten()
attention_mask = encoding['attention_mask'].flatten()
return {
'input_ids': torch.as_tensor(input_ids,dtype=torch.long),
'attention_mask':torch.as_tensor(attention_mask,dtype=torch.long),
'labels':torch.tensor(labels,dtype=torch.float)
}
def __len__(self):
return len(self.labels)