Corrected a possible 'posioning' of the model
This commit is contained in:
parent
d456c833bf
commit
97b43f832b
@ -123,7 +123,7 @@ def journal():
|
|||||||
@app.route("/concepts")
|
@app.route("/concepts")
|
||||||
def concepts():
|
def concepts():
|
||||||
clusters = cluster_vocab(n_clusters=10)
|
clusters = cluster_vocab(n_clusters=10)
|
||||||
return render_template("concepts.html", clusters=clusters)
|
return render_template("concepts.html", clusters={i: cluster for i, cluster in enumerate(clusters)})
|
||||||
|
|
||||||
|
|
||||||
@app.route("/dreams")
|
@app.route("/dreams")
|
||||||
|
@ -11,34 +11,33 @@ recent_dreams = []
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_response():
|
def generate_response(max_tokens: int = 50):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
input_ids = torch.tensor([tokenizer.token_to_id("<start>")], device=DEVICE).unsqueeze(0)
|
||||||
|
|
||||||
# Start from an empty input: purely organic thought
|
generated = []
|
||||||
input_ids = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
|
|
||||||
|
|
||||||
output_tokens = []
|
for _ in range(max_tokens):
|
||||||
max_length = 50
|
|
||||||
|
|
||||||
for _ in range(max_length):
|
|
||||||
output = model(input_ids)
|
output = model(input_ids)
|
||||||
|
if torch.isnan(output).any():
|
||||||
|
print("[Brain] Detected NaN in output, restarting generation.")
|
||||||
|
return "..."
|
||||||
|
|
||||||
next_token_logits = output[:, -1, :]
|
next_token_logits = output[:, -1, :]
|
||||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||||
|
|
||||||
# Get token id values for special tokens
|
token_id = next_token.item()
|
||||||
pad_token_id = tokenizer.vocab.get("<pad>", None)
|
|
||||||
unk_token_id = tokenizer.vocab.get("<unk>", None)
|
|
||||||
|
|
||||||
# Stop if the model predicts <pad> or <unk>
|
# If she outputs <end> token, stop generation
|
||||||
if pad_token_id is not None and next_token.item() == pad_token_id:
|
if tokenizer.reverse_vocab.get(token_id, "") == "<end>":
|
||||||
break
|
|
||||||
if unk_token_id is not None and next_token.item() == unk_token_id:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
output_tokens.append(next_token.item())
|
generated.append(token_id)
|
||||||
|
|
||||||
|
next_token = next_token.unsqueeze(0)
|
||||||
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
||||||
|
|
||||||
return tokenizer.detokenize(output_tokens)
|
return tokenizer.detokenize(generated)
|
||||||
|
|
||||||
|
|
||||||
def score_sentence(sentence: str) -> float:
|
def score_sentence(sentence: str) -> float:
|
||||||
|
@ -19,9 +19,9 @@ def save_vocab(vocab):
|
|||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vocab = load_vocab()
|
self.vocab = {"<pad>": 0, "<unk>": 1, "<start>": 2, "<end>": 3}
|
||||||
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
|
self.reverse_vocab = {0: "<pad>", 1: "<unk>", 2: "<start>", 3: "<end>"}
|
||||||
self.next_id = max(self.vocab.values(), default=0) + 1
|
self.next_id = 4
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
words = re.findall(r"\b\w+\b", text.lower())
|
words = re.findall(r"\b\w+\b", text.lower())
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
import time
|
import time
|
||||||
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock
|
from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, get_optimizer, expand_lock
|
||||||
from model.brain_state import model, tokenizer, DEVICE, loss_fn
|
from model.brain_state import model, tokenizer, DEVICE, loss_fn
|
||||||
|
from model.brainmap import update_brainmap
|
||||||
from context.context import add_to_context, get_recent_context
|
from context.context import add_to_context, get_recent_context
|
||||||
|
|
||||||
LOSS_FILE = "data/logs/loss.log"
|
LOSS_FILE = "data/logs/loss.log"
|
||||||
@ -34,38 +35,32 @@ def train_on_message(text: str, source: str = "user"):
|
|||||||
try:
|
try:
|
||||||
model.train()
|
model.train()
|
||||||
context_texts = get_recent_context(10)
|
context_texts = get_recent_context(10)
|
||||||
augmented_text = " ".join(context_texts + [text])
|
|
||||||
|
# Here's the important change:
|
||||||
|
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(augmented_text)
|
tokens = tokenizer.tokenize(augmented_text)
|
||||||
|
|
||||||
if not tokens or len(tokens) < 2:
|
if len(tokens) < 2:
|
||||||
return
|
return
|
||||||
|
|
||||||
max_token_id = model.head.out_features - 1
|
max_token_id = model.head.out_features - 1
|
||||||
|
tokens = [t if t <= max_token_id else max_token_id for t in tokens]
|
||||||
|
tokens = tokens[:128]
|
||||||
|
|
||||||
# Clamp each token to be inside model's head size
|
if len(tokens) < 2:
|
||||||
clamped_tokens = []
|
|
||||||
for token in tokens:
|
|
||||||
if token > max_token_id:
|
|
||||||
clamped_tokens.append(max_token_id)
|
|
||||||
elif token < 0:
|
|
||||||
clamped_tokens.append(0)
|
|
||||||
else:
|
|
||||||
clamped_tokens.append(token)
|
|
||||||
|
|
||||||
# Clamp sequence length
|
|
||||||
clamped_tokens = clamped_tokens[:128]
|
|
||||||
|
|
||||||
if len(clamped_tokens) < 2:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||||
target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||||
|
|
||||||
opt = get_optimizer()
|
opt = get_optimizer()
|
||||||
|
|
||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
||||||
|
if torch.isnan(loss):
|
||||||
|
print("[Trainer] Detected NaN loss, skipping update.")
|
||||||
|
return
|
||||||
|
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -74,9 +69,6 @@ def train_on_message(text: str, source: str = "user"):
|
|||||||
log_loss(loss.item())
|
log_loss(loss.item())
|
||||||
log_vocab_growth()
|
log_vocab_growth()
|
||||||
add_to_context(text, source=source)
|
add_to_context(text, source=source)
|
||||||
|
update_brainmap(augmented_text.split())
|
||||||
except Exception as e:
|
|
||||||
print(f"[Train] Exception during training: {e}")
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
expand_lock.release()
|
expand_lock.release()
|
Loading…
x
Reference in New Issue
Block a user