Ruby/model/brain_state.py
2025-04-25 23:00:04 -04:00

14 lines
427 B
Python

import torch
import torch.nn as nn
from model.brain_architecture import TinyTransformer
from model.tokenizer import Tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = Tokenizer()
VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()