Fixed vocab issues
This commit is contained in:
parent
b876dede9e
commit
e72678b242
3
.gitignore
vendored
3
.gitignore
vendored
@ -172,4 +172,5 @@ cython_debug/
|
||||
/data/books/alice_in_wonderland.txt
|
||||
/data/books/wizard_of_oz.txt
|
||||
/data/memory/context.json
|
||||
/data/memory/dreams.json
|
||||
/data/memory/dreams.json
|
||||
/data/memory/vocab.json
|
@ -3,13 +3,13 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from model.memory import save_dream
|
||||
from model.brain_state import model, tokenizer, DEVICE
|
||||
from context.context import get_recent_context
|
||||
|
||||
recent_dreams = []
|
||||
|
||||
|
||||
def generate_response():
|
||||
model.eval()
|
||||
# Pick a real known word to seed from context memory
|
||||
context_texts = get_recent_context(5)
|
||||
if context_texts:
|
||||
start = random.choice(context_texts)
|
||||
@ -22,12 +22,13 @@ def generate_response():
|
||||
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
||||
|
||||
output = model(seed)
|
||||
pred = torch.argmax(output, dim=-1).squeeze().tolist()
|
||||
pred = torch.argmax(output, dim=-1).squeeze().item()
|
||||
|
||||
if not isinstance(pred, list):
|
||||
pred = [pred]
|
||||
# Clamp prediction into known vocab range
|
||||
if pred >= tokenizer.next_id:
|
||||
pred = random.randint(0, tokenizer.next_id - 1)
|
||||
|
||||
return tokenizer.detokenize(pred)
|
||||
return tokenizer.detokenize([pred])
|
||||
|
||||
|
||||
def score_sentence(sentence: str) -> float:
|
||||
|
@ -6,7 +6,7 @@ from model.tokenizer import Tokenizer
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
tokenizer = Tokenizer()
|
||||
VOCAB_SIZE = 10000 # Expandable if needed
|
||||
VOCAB_SIZE = len(tokenizer.vocab) + 10 # Slight buffer
|
||||
|
||||
model = TinyTransformer(vocab_size=VOCAB_SIZE).to(DEVICE)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
41
model/dynamic_expand.py
Normal file
41
model/dynamic_expand.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from model.brain_architecture import TinyTransformer
|
||||
from model.brain_state import model, tokenizer, DEVICE, optimizer
|
||||
import copy
|
||||
|
||||
|
||||
def expand_model_if_needed():
|
||||
current_vocab_size = len(tokenizer.vocab) + 10 # Tiny buffer
|
||||
old_vocab_size = model.head.out_features
|
||||
|
||||
if current_vocab_size <= old_vocab_size:
|
||||
return # No expansion needed
|
||||
|
||||
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
|
||||
# Save old model
|
||||
old_model = copy.deepcopy(model).to('cpu')
|
||||
|
||||
# Create new model
|
||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-4)
|
||||
|
||||
# Copy parameters
|
||||
with torch.no_grad():
|
||||
for name, param in old_model.named_parameters():
|
||||
if name in dict(new_model.named_parameters()):
|
||||
try:
|
||||
new_param = dict(new_model.named_parameters())[name]
|
||||
if param.shape == new_param.shape:
|
||||
new_param.copy_(param)
|
||||
else:
|
||||
print(f"Skipping mismatched param: {name}")
|
||||
except Exception as e:
|
||||
print(f"Error copying param: {name} — {e}")
|
||||
|
||||
# Replace global references
|
||||
globals()["model"] = new_model
|
||||
globals()["optimizer"] = new_optimizer
|
||||
|
||||
print("Expansion complete.")
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import time
|
||||
from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn
|
||||
from context.context import add_to_context, get_recent_context
|
||||
from model.dynamic_expand import expand_model_if_needed
|
||||
|
||||
LOSS_FILE = "data/logs/loss.log"
|
||||
|
||||
@ -12,6 +13,7 @@ def log_loss(value: float):
|
||||
|
||||
|
||||
def train_on_message(text: str):
|
||||
expand_model_if_needed()
|
||||
model.train()
|
||||
context_texts = get_recent_context(3)
|
||||
augmented_text = " ".join(context_texts + [text])
|
||||
|
Loading…
x
Reference in New Issue
Block a user