38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
import torch
|
|
import asyncio
|
|
import time
|
|
from model.tokenizer import Tokenizer
|
|
from model.brain_state import save_model, DEVICE, model, optimizer
|
|
|
|
tokenizer = Tokenizer()
|
|
expand_lock = asyncio.Lock()
|
|
_last_expansion_time = 0
|
|
|
|
|
|
async def expand_model_if_needed():
|
|
global _last_expansion_time
|
|
|
|
async with expand_lock:
|
|
needed_vocab_size = tokenizer.next_id
|
|
current_vocab_size = model.head.out_features
|
|
|
|
if needed_vocab_size <= current_vocab_size:
|
|
return
|
|
|
|
print(f"[Expand] Expanding vocabulary: {current_vocab_size} -> {needed_vocab_size}")
|
|
|
|
old_head_weight = model.head.weight.data
|
|
old_out_features = old_head_weight.size(0)
|
|
in_features = model.head.in_features
|
|
|
|
new_head = torch.nn.Linear(in_features, needed_vocab_size, bias=False).to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
new_head.weight[:old_out_features] = old_head_weight
|
|
|
|
model.head = new_head
|
|
torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)
|
|
|
|
_last_expansion_time = time.time()
|
|
save_model()
|