Trying to fix assert errors
This commit is contained in:
parent
e506032364
commit
6ab7b7586a
@ -6,8 +6,7 @@ from model.tokenizer import Tokenizer
|
|||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
tokenizer = Tokenizer()
|
tokenizer = Tokenizer()
|
||||||
VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer
|
VOCAB_SIZE = len(tokenizer.vocab) + 10 # with a small buffer
|
||||||
|
|
||||||
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
@ -1,41 +1,36 @@
|
|||||||
import torch
|
import torch
|
||||||
from model.brain_architecture import TinyTransformer
|
from model.brain_architecture import TinyTransformer
|
||||||
from model.brain_state import model, tokenizer, DEVICE, optimizer
|
from model.brain_state import model, tokenizer, DEVICE
|
||||||
import copy
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer():
|
||||||
|
global optimizer
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def expand_model_if_needed():
|
def expand_model_if_needed():
|
||||||
current_vocab_size = len(tokenizer.vocab) + 10 # Tiny buffer
|
global model, optimizer
|
||||||
|
|
||||||
|
current_vocab_size = len(tokenizer.vocab) + 10 # Buffer
|
||||||
old_vocab_size = model.head.out_features
|
old_vocab_size = model.head.out_features
|
||||||
|
|
||||||
if current_vocab_size <= old_vocab_size:
|
if current_vocab_size <= old_vocab_size:
|
||||||
return # No expansion needed
|
return
|
||||||
|
|
||||||
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||||
|
|
||||||
# Save old model
|
old_state = model.state_dict()
|
||||||
old_model = copy.deepcopy(model).to('cpu')
|
|
||||||
|
|
||||||
# Create new model
|
|
||||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||||
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4)
|
|
||||||
|
|
||||||
# Copy parameters
|
# Transfer matching parameters
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, param in old_model.named_parameters():
|
for name, param in new_model.named_parameters():
|
||||||
if name in dict(new_model.named_parameters()):
|
if name in old_state and old_state[name].shape == param.shape:
|
||||||
try:
|
param.copy_(old_state[name])
|
||||||
new_param = dict(new_model.named_parameters())[name]
|
|
||||||
if param.shape == new_param.shape:
|
|
||||||
new_param.copy_(param)
|
|
||||||
else:
|
|
||||||
print(f"Skipping mismatched param: {name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error copying param: {name} — {e}")
|
|
||||||
|
|
||||||
# Replace global references
|
model = new_model
|
||||||
globals()["model"] = new_model
|
opt = get_optimizer()
|
||||||
globals()["optimizer"] = new_optimizer
|
|
||||||
|
|
||||||
print("Expansion complete.")
|
|
||||||
|
|
||||||
|
print("Model expanded and optimizer rebuilt.")
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn
|
from model.brain_state import model, tokenizer, DEVICE,loss_fn
|
||||||
from context.context import add_to_context, get_recent_context
|
from context.context import add_to_context, get_recent_context
|
||||||
from model.dynamic_expand import expand_model_if_needed
|
from model.dynamic_expand import expand_model_if_needed, get_optimizer
|
||||||
from model.brainmap import update_brainmap
|
from model.brainmap import update_brainmap
|
||||||
|
|
||||||
LOSS_FILE = "data/logs/loss.log"
|
LOSS_FILE = "data/logs/loss.log"
|
||||||
@ -31,9 +31,10 @@ def train_on_message(text: str):
|
|||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
||||||
|
|
||||||
optimizer.zero_grad()
|
opt = get_optimizer()
|
||||||
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
opt.step()
|
||||||
|
|
||||||
log_loss(loss.item())
|
log_loss(loss.item())
|
||||||
add_to_context(text)
|
add_to_context(text)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user