Compare commits
2 Commits
684bf33675
...
cde0068725
Author | SHA1 | Date | |
---|---|---|---|
cde0068725 | |||
99fddcab4d |
@ -10,28 +10,43 @@ from context.context import get_recent_context
|
|||||||
recent_dreams = []
|
recent_dreams = []
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def generate_response():
|
def generate_response():
|
||||||
model.eval()
|
model.eval()
|
||||||
context_texts = get_recent_context(5)
|
context_texts = get_recent_context(10)
|
||||||
if context_texts:
|
seed_text = " ".join(context_texts[-1:])
|
||||||
start = random.choice(context_texts)
|
tokens = tokenizer.tokenize(seed_text)
|
||||||
seed_tokens = tokenizer.tokenize(start)
|
|
||||||
if seed_tokens:
|
|
||||||
seed = torch.tensor([seed_tokens[-1]], device=DEVICE).unsqueeze(0)
|
|
||||||
seed = seed[:, -128:]
|
|
||||||
else:
|
|
||||||
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
|
||||||
else:
|
|
||||||
seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0)
|
|
||||||
|
|
||||||
output = model(seed)
|
input_tensor = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||||
pred = torch.argmax(output, dim=-1).squeeze().item()
|
|
||||||
|
|
||||||
# Clamp prediction into known vocab range
|
output_tokens = []
|
||||||
if pred >= tokenizer.next_id:
|
max_tokens = 32
|
||||||
pred = random.randint(0, tokenizer.next_id - 1)
|
|
||||||
|
|
||||||
return tokenizer.detokenize([pred])
|
for _ in range(max_tokens):
|
||||||
|
output = model(input_tensor)
|
||||||
|
logits = output[:, -1, :].squeeze(0)
|
||||||
|
|
||||||
|
# Apply temperature (soft randomness)
|
||||||
|
temperature = 0.8
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# Top-k sampling
|
||||||
|
k = 10
|
||||||
|
topk_logits, topk_indices = torch.topk(logits, k)
|
||||||
|
probs = torch.nn.functional.softmax(topk_logits, dim=-1)
|
||||||
|
next_token = topk_indices[torch.multinomial(probs, 1)].item()
|
||||||
|
|
||||||
|
output_tokens.append(next_token)
|
||||||
|
|
||||||
|
input_tensor = torch.cat([input_tensor, torch.tensor([[next_token]], device=DEVICE)], dim=1)
|
||||||
|
|
||||||
|
# Optional: stop if next_token maps to period, question mark, or exclamation
|
||||||
|
next_char = tokenizer.detokenize([next_token])
|
||||||
|
if any(p in next_char for p in [".", "?", "!"]):
|
||||||
|
break
|
||||||
|
|
||||||
|
text = tokenizer.detokenize(output_tokens)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def score_sentence(sentence: str) -> float:
|
def score_sentence(sentence: str) -> float:
|
||||||
|
@ -28,7 +28,12 @@ def expand_model_if_needed():
|
|||||||
# print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
# print(f"[Expand] Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||||
|
|
||||||
old_state = model.state_dict()
|
old_state = model.state_dict()
|
||||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
new_model = TinyTransformer(
|
||||||
|
vocab_size=current_vocab_size,
|
||||||
|
embed_dim=model.token_embed.embedding_dim,
|
||||||
|
depth=len(model.blocks),
|
||||||
|
heads=model.blocks[0].attn.heads
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, param in new_model.named_parameters():
|
for name, param in new_model.named_parameters():
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import asyncio
|
||||||
from context.context import load_context
|
from context.context import load_context
|
||||||
from model.trainer import train_on_message
|
from model.trainer import train_on_message
|
||||||
from model.dynamic_expand import expand_model_if_needed
|
from model.dynamic_expand import expand_model_if_needed
|
||||||
|
@ -23,7 +23,7 @@ def train_on_message(text: str, source: str = "user"):
|
|||||||
expand_model_if_needed()
|
expand_model_if_needed()
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if now - _last_expansion_time < 5: # If expansion happened within the last 5 seconds
|
if now - _last_expansion_time < 5:
|
||||||
print("[Train] Skipping to stabilize after expansion.")
|
print("[Train] Skipping to stabilize after expansion.")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -35,21 +35,32 @@ def train_on_message(text: str, source: str = "user"):
|
|||||||
model.train()
|
model.train()
|
||||||
context_texts = get_recent_context(10)
|
context_texts = get_recent_context(10)
|
||||||
augmented_text = " ".join(context_texts + [text])
|
augmented_text = " ".join(context_texts + [text])
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(augmented_text)
|
tokens = tokenizer.tokenize(augmented_text)
|
||||||
|
|
||||||
if len(tokens) < 2:
|
if not tokens or len(tokens) < 2:
|
||||||
return
|
return
|
||||||
|
|
||||||
max_token_id = model.head.out_features - 1
|
max_token_id = model.head.out_features - 1
|
||||||
tokens = [min(t, max_token_id) for t in tokens]
|
|
||||||
|
|
||||||
if len(tokens) < 2:
|
# Clamp each token to be inside model's head size
|
||||||
|
clamped_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if token > max_token_id:
|
||||||
|
clamped_tokens.append(max_token_id)
|
||||||
|
elif token < 0:
|
||||||
|
clamped_tokens.append(0)
|
||||||
|
else:
|
||||||
|
clamped_tokens.append(token)
|
||||||
|
|
||||||
|
# Clamp sequence length
|
||||||
|
clamped_tokens = clamped_tokens[:128]
|
||||||
|
|
||||||
|
if len(clamped_tokens) < 2:
|
||||||
return
|
return
|
||||||
|
|
||||||
tokens = tokens[:128]
|
input_tensor = torch.tensor(clamped_tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||||
|
target_tensor = torch.tensor(clamped_tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
||||||
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
||||||
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
||||||
|
|
||||||
opt = get_optimizer()
|
opt = get_optimizer()
|
||||||
|
|
||||||
@ -63,5 +74,9 @@ def train_on_message(text: str, source: str = "user"):
|
|||||||
log_loss(loss.item())
|
log_loss(loss.item())
|
||||||
log_vocab_growth()
|
log_vocab_growth()
|
||||||
add_to_context(text, source=source)
|
add_to_context(text, source=source)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Train] Exception during training: {e}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
expand_lock.release()
|
expand_lock.release()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user