43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
import torch
|
|
import threading
|
|
import time
|
|
from model.brain_architecture import TinyTransformer
|
|
from model.brain_state import model, tokenizer, DEVICE
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
|
|
|
|
_last_vocab_size = 0
|
|
expand_lock = threading.Lock()
|
|
_last_expansion_time = 0
|
|
|
|
|
|
def get_optimizer():
|
|
return optimizer
|
|
|
|
|
|
def expand_model_if_needed():
|
|
global model, optimizer, _last_expansion_time
|
|
|
|
with expand_lock:
|
|
current_vocab_size = len(tokenizer.vocab) + 10
|
|
old_vocab_size = model.head.out_features
|
|
|
|
if current_vocab_size <= old_vocab_size:
|
|
return
|
|
|
|
# print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
|
|
|
old_state = model.state_dict()
|
|
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
for name, param in new_model.named_parameters():
|
|
if name in old_state and old_state[name].shape == param.shape:
|
|
param.copy_(old_state[name])
|
|
|
|
model = new_model
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
_last_expansion_time = time.time()
|
|
|
|
# print("[Expand] Expansion complete.")
|