Ruby/brain/brain_state.py

21 lines
695 B
Python

import torch
import torch.nn as nn
import os
from brain.brain_architecture import TinyTransformer
from ego.tokenizer import Tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_SAVE_PATH = "memory/model.pt"
tokenizer = Tokenizer()
VOCAB_SIZE = len(tokenizer.vocab) + 10 # with a small buffer
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.95)
loss_fn = nn.CrossEntropyLoss()
def save_model():
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
torch.save(model.state_dict(), MODEL_SAVE_PATH)