Ruby/test.py

24 lines
786 B
Python

import torch
from model import TinyGPT
from tokenizer import simple_tokenizer, detokenizer, load_vocab
def test_model():
vocab = load_vocab()
# Load model
model = TinyGPT(vocab_size=len(vocab), embed_size=32, num_heads=2, num_layers=2).cuda()
model.load_state_dict(torch.load("ruby_model.pth", weights_only=True)) # Set weights_only=True
model.eval()
# Test input
test_input = torch.tensor(simple_tokenizer("abc", vocab), dtype=torch.long).cuda()
with torch.no_grad():
output = model(test_input.unsqueeze(0), test_input.unsqueeze(0))
predicted_idx = output.argmax(-1).squeeze()[-1].item()
predicted_char = detokenizer([predicted_idx], vocab)
print(f"Ruby says: {predicted_char}")
if __name__ == "__main__":
test_model()