diff --git a/model/brain.py b/model/brain.py index a6c3b48..74d0c32 100644 --- a/model/brain.py +++ b/model/brain.py @@ -10,33 +10,27 @@ from context.context import get_recent_context recent_dreams = [] -@torch.no_grad() +@torch.inference_mode() def generate_response(): model.eval() - seed = torch.randint(0, model.head.out_features, (1, 1), device=DEVICE) - input_ids = seed + + # Start from an empty tensor: she speaks purely from herself + input_ids = torch.zeros((1, 1), dtype=torch.long, device=DEVICE) + output_tokens = [] + max_length = 50 - for _ in range(50): # Max 50 tokens (short sentences) + for _ in range(max_length): output = model(input_ids) - next_token_logits = output[:, -1, :] / 0.8 # temperature 0.8 + next_token_logits = output[:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1) - # Top-K Sampling - top_k = 40 - values, indices = torch.topk(next_token_logits, k=top_k) - probs = F.softmax(values, dim=-1) - sampled_idx = torch.multinomial(probs, num_samples=1) - - next_token = indices.gather(-1, sampled_idx) + # Stop if the model predicts padding or unknown token + if next_token.item() in [tokenizer.token_to_id(""), tokenizer.token_to_id("")]: + break output_tokens.append(next_token.item()) - - input_ids = torch.cat([input_ids, next_token.view(1, 1)], dim=1) - - # Break if punctuation (end of sentence) - word = tokenizer.detokenize(next_token.item()) - if word in [".", "!", "?"]: - break + input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) return tokenizer.detokenize(output_tokens) diff --git a/reader/reader.py b/reader/reader.py index 181c850..9783ea0 100644 --- a/reader/reader.py +++ b/reader/reader.py @@ -7,7 +7,8 @@ import json BOOK_DIR = "data/books" PROGRESS_FILE = "data/memory/book_progress.json" -READ_DELAY = 10 # seconds between lines +READ_DELAY = 0.2 # seconds between lines +PARAGRAPH_MIN_LENGTH = 20 def get_books(): @@ -29,8 +30,6 @@ def save_progress(prog): async def read_books_forever(): books = get_books() progress = load_progress() - buffered_lines = [] - while True: for book in books: path = os.path.join(BOOK_DIR, book) @@ -41,26 +40,26 @@ async def read_books_forever(): lines = f.readlines() idx = progress.get(book, 0) + paragraph = "" + while idx < len(lines): line = lines[idx].strip() idx += 1 + + if not line: + if len(paragraph) > PARAGRAPH_MIN_LENGTH: + train_on_message(paragraph.strip(), source="book") + paragraph = "" + await asyncio.sleep(READ_DELAY) + set_next_action(READ_DELAY, "Reading") + else: + paragraph += " " + line + progress[book] = idx save_progress(progress) - if is_valid_line(line): - buffered_lines.append(line) - - # If we have enough lines buffered, combine and train - if len(buffered_lines) >= 3: - combined_text = " ".join(buffered_lines) - train_on_message(combined_text, source="book") - buffered_lines.clear() - - set_next_action(READ_DELAY, "Reading") + # train last paragraph if any + if paragraph and len(paragraph) > PARAGRAPH_MIN_LENGTH: + train_on_message(paragraph.strip(), source="book") await asyncio.sleep(READ_DELAY) - - # End of a book: train whatever lines are left buffered - if buffered_lines: - combined_text = " ".join(buffered_lines) - train_on_message(combined_text, source="book") - buffered_lines.clear() + set_next_action(READ_DELAY, "Reading")