24 lines
786 B
Python
24 lines
786 B
Python
import torch
|
|
from model import TinyGPT
|
|
from tokenizer import simple_tokenizer, detokenizer, load_vocab
|
|
|
|
def test_model():
|
|
vocab = load_vocab()
|
|
|
|
# Load model
|
|
model = TinyGPT(vocab_size=len(vocab), embed_size=32, num_heads=2, num_layers=2).cuda()
|
|
model.load_state_dict(torch.load("ruby_model.pth", weights_only=True)) # Set weights_only=True
|
|
model.eval()
|
|
|
|
# Test input
|
|
test_input = torch.tensor(simple_tokenizer("abc", vocab), dtype=torch.long).cuda()
|
|
with torch.no_grad():
|
|
output = model(test_input.unsqueeze(0), test_input.unsqueeze(0))
|
|
predicted_idx = output.argmax(-1).squeeze()[-1].item()
|
|
|
|
predicted_char = detokenizer([predicted_idx], vocab)
|
|
print(f"Ruby says: {predicted_char}")
|
|
|
|
if __name__ == "__main__":
|
|
test_model()
|