Ruby/dataset.py

21 lines
522 B
Python

import torch
from tokenizer import simple_tokenizer, load_vocab
def create_dataset():
vocab = load_vocab()
data = [
("a", "b"),
("ab", "c"),
("abc", "d"),
]
inputs = [torch.tensor(simple_tokenizer(src, vocab), dtype=torch.long) for src, tgt in data]
targets = [torch.tensor(simple_tokenizer(tgt, vocab), dtype=torch.long) for src, tgt in data]
return inputs, targets
if __name__ == "__main__":
inputs, targets = create_dataset()
print(inputs)
print(targets)