Corrected a possible 'posioning' of the model

This commit is contained in:
Dani 2025-04-27 14:36:38 -04:00
parent d456c833bf
commit 97b43f832b
4 changed files with 34 additions and 43 deletions

View File

@ -123,7 +123,7 @@ def journal():
@app.route("/concepts")
def concepts():
clusters = cluster_vocab(n_clusters=10)
return render_template("concepts.html", clusters=clusters)
return render_template("concepts.html", clusters={i: cluster for i, cluster in enumerate(clusters)})
@app.route("/dreams")

View File

@ -11,34 +11,33 @@ recent_dreams = []
@torch.inference_mode()
def generate_response():
def generate_response(max_tokens: int = 50):
model.eval()
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
# Start from an empty input: purely organic thought
input_ids = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
generated = []
output_tokens = []
max_length = 50
for _ in range(max_length):
for _ in range(max_tokens):
output = model(input_ids)
if torch.isnan(output).any():
print("[Brain] Detected NaN in output, restarting generation.")
return "..."
next_token_logits = output[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
# Get token id values for special tokens
pad_token_id = tokenizer.vocab.get("<pad>", None)
unk_token_id = tokenizer.vocab.get("<unk>", None)
token_id = next_token.item()
# Stop if the model predicts <pad> or <unk>
if pad_token_id is not None and next_token.item() == pad_token_id:
break
if unk_token_id is not None and next_token.item() == unk_token_id:
# If she outputs <end> token, stop generation
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
break
output_tokens.append(next_token.item())
generated.append(token_id)
next_token = next_token.unsqueeze(0)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
return tokenizer.detokenize(output_tokens)
return tokenizer.detokenize(generated)
def score_sentence(sentence: str) -> float:

View File

@ -19,9 +19,9 @@ def save_vocab(vocab):
class Tokenizer:
def __init__(self):
self.vocab = load_vocab()
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
self.next_id = max(self.vocab.values(), default=0) + 1
self.vocab = {"<pad>": 0, "<unk>": 1, "<start>": 2, "<end>": 3}
self.reverse_vocab = {0: "<pad>", 1: "<unk>", 2: "<start>", 3: "<end>"}
self.next_id = 4
def tokenize(self, text):
words = re.findall(r"\b\w+\b", text.lower())

View File

@ -2,6 +2,7 @@ import torch
import time
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock
from model.brain_state import model, tokenizer, DEVICE, loss_fn
from model.brainmap import update_brainmap
from context.context import add_to_context, get_recent_context
LOSS_FILE = "data/logs/loss.log"
@ -34,38 +35,32 @@ def train_on_message(text: str, source: str = "user"):
try:
model.train()
context_texts = get_recent_context(10)
augmented_text = " ".join(context_texts + [text])
# Here's the important change:
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
tokens = tokenizer.tokenize(augmented_text)
if not tokens or len(tokens) < 2:
if len(tokens) < 2:
return
max_token_id = model.head.out_features - 1
tokens = [t if t <= max_token_id else max_token_id for t in tokens]
tokens = tokens[:128]
# Clamp each token to be inside model's head size
clamped_tokens = []
for token in tokens:
if token > max_token_id:
clamped_tokens.append(max_token_id)
elif token < 0:
clamped_tokens.append(0)
else:
clamped_tokens.append(token)
# Clamp sequence length
clamped_tokens = clamped_tokens[:128]
if len(clamped_tokens) < 2:
if len(tokens) < 2:
return
input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
opt = get_optimizer()
output = model(input_tensor)
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
if torch.isnan(loss):
print("[Trainer] Detected NaN loss, skipping update.")
return
opt.zero_grad()
loss.backward()
@ -74,9 +69,6 @@ def train_on_message(text: str, source: str = "user"):
log_loss(loss.item())
log_vocab_growth()
add_to_context(text, source=source)
except Exception as e:
print(f"[Train] Exception during training: {e}")
update_brainmap(augmented_text.split())
finally:
expand_lock.release()
expand_lock.release()