14 lines
427 B
Python
14 lines
427 B
Python
import torch
|
|
import torch.nn as nn
|
|
from model.brain_architecture import TinyTransformer
|
|
from model.tokenizer import Tokenizer
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
tokenizer = Tokenizer()
|
|
VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer
|
|
|
|
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
loss_fn = nn.CrossEntropyLoss()
|