Ruby/model/brain_state.py
2025-04-25 23:16:18 -04:00

13 lines
375 B
Python

import torch
import torch.nn as nn
from model.brain_architecture import TinyTransformer
from model.tokenizer import Tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = Tokenizer()
VOCAB_SIZE = len(tokenizer.vocab) + 10 # with a small buffer
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()