-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathdata.py
More file actions
38 lines (30 loc) · 1.25 KB
/
data.py
File metadata and controls
38 lines (30 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
class NetflixDataset(Dataset):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
self.input_ids = []
self.attn_masks = []
self.labels = []
self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")[
"description"
]
self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions])
for txt in self.txt_list:
encodings_dict = self.tokenizer(
"</s>" + txt + "</s>", truncation=True, max_length=self.max_length, padding="max_length"
)
self.input_ids.append(torch.tensor(encodings_dict["input_ids"]))
self.attn_masks.append(torch.tensor(encodings_dict["attention_mask"]))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.attn_masks[idx]
def netflix_collator(data):
return {
"input_ids": torch.stack([x[0] for x in data]),
"attention_mask": torch.stack([x[1] for x in data]),
"labels": torch.stack([x[0] for x in data]),
}