Files
Nora/data_loader.py

74 lines
2.4 KiB
Python

"""
data_loader.py
Loads all .txt files from data_dir, concatenates them, tokenizes them,
and creates a Dataset of (input_seq, target_seq) for language modeling.
"""
import os
import torch
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, data_dir: str, tokenizer, seq_length: int):
"""
- data_dir: folder of .txt public-domain books.
- tokenizer: instance of CharTokenizer (from tokenizer.py).
- seq_length: context length in tokens.
"""
super().__init__()
self.seq_length = seq_length
self.tokenizer = tokenizer
# Read and concatenate all .txt files under two folders:
# - data/books/
# - data/conversational/
texts = []
# If data_dir is a single path, we still look for a sibling "conversational" folder
conversational_dir = os.path.join(os.path.dirname(data_dir), "conversational")
# Gather all folders to walk
dirs_to_walk = [data_dir]
if os.path.isdir(conversational_dir):
dirs_to_walk.append(conversational_dir)
for d in dirs_to_walk:
for root, _, files in os.walk(d):
for fname in files:
if not fname.lower().endswith(".txt"):
continue
path = os.path.join(root, fname)
with open(path, "r", encoding="utf-8", errors="ignore") as f:
texts.append(f.read())
full_text = "\n".join(texts)
token_ids = self.tokenizer.encode(full_text)
# Prepare input-target pairs
self.examples = []
stride = 32
for i in range(0, len(token_ids) - seq_length, stride):
inp = token_ids[i : i + seq_length]
targ = token_ids[i + 1 : i + seq_length + 1]
self.examples.append((inp, targ))
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
inp, targ = self.examples[idx]
return torch.tensor(inp, dtype=torch.long), torch.tensor(targ, dtype=torch.long)
def get_dataloader(
data_dir: str, tokenizer, seq_length: int, batch_size: int, shuffle: bool = True
) -> DataLoader:
dataset = TextDataset(data_dir, tokenizer, seq_length)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=True,
num_workers=8,
pin_memory=True,
)