Changing how she generates replies.

This commit is contained in:
Dani 2025-04-27 14:09:40 -04:00
parent bae6bff5ce
commit 12df801a44
2 changed files with 31 additions and 38 deletions

View File

@ -10,33 +10,27 @@ from context.context import get_recent_context
recent_dreams = [] recent_dreams = []
@torch.no_grad() @torch.inference_mode()
def generate_response(): def generate_response():
model.eval() model.eval()
seed = torch.randint(0, model.head.out_features, (1, 1), device=DEVICE)
input_ids = seed # Start from an empty tensor: she speaks purely from herself
input_ids = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
output_tokens = [] output_tokens = []
max_length = 50
for _ in range(50): # Max 50 tokens (short sentences) for _ in range(max_length):
output = model(input_ids) output = model(input_ids)
next_token_logits = output[:, -1, :] / 0.8 # temperature 0.8 next_token_logits = output[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
# Top-K Sampling # Stop if the model predicts padding or unknown token
top_k = 40 if next_token.item() in [tokenizer.token_to_id("<pad>"), tokenizer.token_to_id("<unk>")]:
values, indices = torch.topk(next_token_logits, k=top_k) break
probs = F.softmax(values, dim=-1)
sampled_idx = torch.multinomial(probs, num_samples=1)
next_token = indices.gather(-1, sampled_idx)
output_tokens.append(next_token.item()) output_tokens.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
input_ids = torch.cat([input_ids, next_token.view(1, 1)], dim=1)
# Break if punctuation (end of sentence)
word = tokenizer.detokenize(next_token.item())
if word in [".", "!", "?"]:
break
return tokenizer.detokenize(output_tokens) return tokenizer.detokenize(output_tokens)

View File

@ -7,7 +7,8 @@ import json
BOOK_DIR = "data/books" BOOK_DIR = "data/books"
PROGRESS_FILE = "data/memory/book_progress.json" PROGRESS_FILE = "data/memory/book_progress.json"
READ_DELAY = 10 # seconds between lines READ_DELAY = 0.2 # seconds between lines
PARAGRAPH_MIN_LENGTH = 20
def get_books(): def get_books():
@ -29,8 +30,6 @@ def save_progress(prog):
async def read_books_forever(): async def read_books_forever():
books = get_books() books = get_books()
progress = load_progress() progress = load_progress()
buffered_lines = []
while True: while True:
for book in books: for book in books:
path = os.path.join(BOOK_DIR, book) path = os.path.join(BOOK_DIR, book)
@ -41,26 +40,26 @@ async def read_books_forever():
lines = f.readlines() lines = f.readlines()
idx = progress.get(book, 0) idx = progress.get(book, 0)
paragraph = ""
while idx < len(lines): while idx < len(lines):
line = lines[idx].strip() line = lines[idx].strip()
idx += 1 idx += 1
if not line:
if len(paragraph) > PARAGRAPH_MIN_LENGTH:
train_on_message(paragraph.strip(), source="book")
paragraph = ""
await asyncio.sleep(READ_DELAY)
set_next_action(READ_DELAY, "Reading")
else:
paragraph += " " + line
progress[book] = idx progress[book] = idx
save_progress(progress) save_progress(progress)
if is_valid_line(line): # train last paragraph if any
buffered_lines.append(line) if paragraph and len(paragraph) > PARAGRAPH_MIN_LENGTH:
train_on_message(paragraph.strip(), source="book")
# If we have enough lines buffered, combine and train
if len(buffered_lines) >= 3:
combined_text = " ".join(buffered_lines)
train_on_message(combined_text, source="book")
buffered_lines.clear()
set_next_action(READ_DELAY, "Reading")
await asyncio.sleep(READ_DELAY) await asyncio.sleep(READ_DELAY)
set_next_action(READ_DELAY, "Reading")
# End of a book: train whatever lines are left buffered
if buffered_lines:
combined_text = " ".join(buffered_lines)
train_on_message(combined_text, source="book")
buffered_lines.clear()