From cde0068725c96e271a0fba6ffb2b70cae9962910 Mon Sep 17 00:00:00 2001 From: Dani Date: Sun, 27 Apr 2025 13:51:49 -0400 Subject: [PATCH] Fixing the CUDA errors Fixing replies. --- model/brain.py | 49 +++++++++++++++++++++++++++-------------- model/dynamic_expand.py | 7 +++++- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/model/brain.py b/model/brain.py index 9646247..2f265ec 100644 --- a/model/brain.py +++ b/model/brain.py @@ -10,28 +10,43 @@ from context.context import get_recent_context recent_dreams = [] +@torch.no_grad() def generate_response(): model.eval() - context_texts = get_recent_context(5) - if context_texts: - start = random.choice(context_texts) - seed_tokens = tokenizer.tokenize(start) - if seed_tokens: - seed = torch.tensor([seed_tokens[-1]], device=DEVICE).unsqueeze(0) - seed = seed[:, -128:] - else: - seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0) - else: - seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0) + context_texts = get_recent_context(10) + seed_text = " ".join(context_texts[-1:]) + tokens = tokenizer.tokenize(seed_text) - output = model(seed) - pred = torch.argmax(output, dim=-1).squeeze().item() + input_tensor = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0) - # Clamp prediction into known vocab range - if pred >= tokenizer.next_id: - pred = random.randint(0, tokenizer.next_id - 1) + output_tokens = [] + max_tokens = 32 - return tokenizer.detokenize([pred]) + for _ in range(max_tokens): + output = model(input_tensor) + logits = output[:, -1, :].squeeze(0) + + # Apply temperature (soft randomness) + temperature = 0.8 + logits = logits / temperature + + # Top-k sampling + k = 10 + topk_logits, topk_indices = torch.topk(logits, k) + probs = torch.nn.functional.softmax(topk_logits, dim=-1) + next_token = topk_indices[torch.multinomial(probs, 1)].item() + + output_tokens.append(next_token) + + input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=DEVICE)], dim=1) + + # Optional: stop if next_token maps to period, question mark, or exclamation + next_char = tokenizer.detokenize([next_token]) + if any(p in next_char for p in [".", "?", "!"]): + break + + text = tokenizer.detokenize(output_tokens) + return text def score_sentence(sentence: str) -> float: diff --git a/model/dynamic_expand.py b/model/dynamic_expand.py index ada6331..c90c3b9 100644 --- a/model/dynamic_expand.py +++ b/model/dynamic_expand.py @@ -28,7 +28,12 @@ def expand_model_if_needed(): # print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}") old_state = model.state_dict() - new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE) + new_model = TinyTransformer( + vocab_size=current_vocab_size, + embed_dim=model.token_embed.embedding_dim, + depth=len(model.blocks), + heads=model.blocks[0].attn.heads + ).to(DEVICE) with torch.no_grad(): for name, param in new_model.named_parameters():