Ruby/core/dataset.py

38 lines
1.2 KiB
Python

import os
import torch
from torch.utils.data import Dataset
class CharDataset(Dataset):
"""
Builds a char-level dataset from all .txt files under books_dir.
Returns sequences of length block_size for next-char prediction.
"""
def __init__(self, books_dir: str, block_size: int):
texts = []
for fn in os.listdir(books_dir):
if fn.lower().endswith('.txt'):
path = os.path.join(books_dir, fn)
with open(path, 'r', encoding='utf8') as f:
texts.append(f.read())
data = '\n'.join(texts)
# build vocab
chars = sorted(set(data))
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for ch, i in self.stoi.items()}
self.vocab_size = len(self.stoi)
# encode all data as a single tensor
self.data = torch.tensor(
[self.stoi[ch] for ch in data],
dtype=torch.long
)
self.block_size = block_size
def __len__(self):
return len(self.data) - self.block_size
def __getitem__(self, idx):
x = self.data[idx: idx + self.block_size]
y = self.data[idx + 1: idx + 1 + self.block_size]
return x, y