Ruby/model/dynamic_expand.py
Dani a8b3129806 Redid the dashboard home page into more of a status page.
Fixed the weird desync/threading issue that was stopping ruby from working.
2025-04-29 23:04:53 -04:00

38 lines
1.1 KiB
Python

import torch
import asyncio
import time
from model.tokenizer import Tokenizer
from model.brain_state import save_model, DEVICE, model, optimizer
tokenizer = Tokenizer()
expand_lock = asyncio.Lock()
_last_expansion_time = 0
async def expand_model_if_needed():
global _last_expansion_time
async with expand_lock:
needed_vocab_size = tokenizer.next_id
current_vocab_size = model.head.out_features
if needed_vocab_size <= current_vocab_size:
return
print(f"[Expand] Expanding vocabulary: {current_vocab_size} -> {needed_vocab_size}")
old_head_weight = model.head.weight.data
old_out_features = old_head_weight.size(0)
in_features = model.head.in_features
new_head = torch.nn.Linear(in_features, needed_vocab_size, bias=False).to(DEVICE)
with torch.no_grad():
new_head.weight[:old_out_features] = old_head_weight
model.head = new_head
torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)
_last_expansion_time = time.time()
save_model()