21 lines
702 B
Python
21 lines
702 B
Python
import torch
|
|
import torch.nn as nn
|
|
import os
|
|
from model.brain_architecture import TinyTransformer
|
|
from model.tokenizer import Tokenizer
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
MODEL_SAVE_PATH = "data/memory/model.pt"
|
|
|
|
tokenizer = Tokenizer()
|
|
VOCAB_SIZE = len(tokenizer.vocab) + 10 # with a small buffer
|
|
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.95)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
|
|
|
def save_model():
|
|
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
|
torch.save(model.state_dict(), MODEL_SAVE_PATH)
|