From 99fddcab4df4aecbf2b6b16580c6668ca231197b Mon Sep 17 00:00:00 2001 From: Dani Date: Sun, 27 Apr 2025 13:42:29 -0400 Subject: [PATCH] Fixing another string of CUDA errors --- model/reweaver.py | 1 + model/trainer.py | 31 +++++++++++++++++++++++-------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/model/reweaver.py b/model/reweaver.py index 4687bfb..ce7c192 100644 --- a/model/reweaver.py +++ b/model/reweaver.py @@ -1,4 +1,5 @@ import random +import asyncio from context.context import load_context from model.trainer import train_on_message from model.dynamic_expand import expand_model_if_needed diff --git a/model/trainer.py b/model/trainer.py index 032c484..0d60ba2 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -23,7 +23,7 @@ def train_on_message(text: str, source: str = "user"): expand_model_if_needed() now = time.time() - if now - _last_expansion_time < 5: # If expansion happened within the last 5 seconds + if now - _last_expansion_time < 5: print("[Train] Skipping to stabilize after expansion.") return @@ -35,21 +35,32 @@ def train_on_message(text: str, source: str = "user"): model.train() context_texts = get_recent_context(10) augmented_text = " ".join(context_texts + [text]) + tokens = tokenizer.tokenize(augmented_text) - if len(tokens) < 2: + if not tokens or len(tokens) < 2: return max_token_id = model.head.out_features - 1 - tokens = [min(t, max_token_id) for t in tokens] - if len(tokens) < 2: + # Clamp each token to be inside model's head size + clamped_tokens = [] + for token in tokens: + if token > max_token_id: + clamped_tokens.append(max_token_id) + elif token < 0: + clamped_tokens.append(0) + else: + clamped_tokens.append(token) + + # Clamp sequence length + clamped_tokens = clamped_tokens[:128] + + if len(clamped_tokens) < 2: return - tokens = tokens[:128] - - input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) - target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) + input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) + target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) opt = get_optimizer() @@ -63,5 +74,9 @@ def train_on_message(text: str, source: str = "user"): log_loss(loss.item()) log_vocab_growth() add_to_context(text, source=source) + + except Exception as e: + print(f"[Train] Exception during training: {e}") + finally: expand_lock.release()