Changing how she generates replies.
This commit is contained in:
parent
bae6bff5ce
commit
12df801a44
@ -10,33 +10,27 @@ from context.context import get_recent_context
|
|||||||
recent_dreams = []
|
recent_dreams = []
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def generate_response():
|
def generate_response():
|
||||||
model.eval()
|
model.eval()
|
||||||
seed = torch.randint(0, model.head.out_features, (1, 1), device=DEVICE)
|
|
||||||
input_ids = seed
|
# Start from an empty tensor: she speaks purely from herself
|
||||||
|
input_ids = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
|
||||||
|
|
||||||
output_tokens = []
|
output_tokens = []
|
||||||
|
max_length = 50
|
||||||
|
|
||||||
for _ in range(50): # Max 50 tokens (short sentences)
|
for _ in range(max_length):
|
||||||
output = model(input_ids)
|
output = model(input_ids)
|
||||||
next_token_logits = output[:, -1, :] / 0.8 # temperature 0.8
|
next_token_logits = output[:, -1, :]
|
||||||
|
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||||
|
|
||||||
# Top-K Sampling
|
# Stop if the model predicts padding or unknown token
|
||||||
top_k = 40
|
if next_token.item() in [tokenizer.token_to_id("<pad>"), tokenizer.token_to_id("<unk>")]:
|
||||||
values, indices = torch.topk(next_token_logits, k=top_k)
|
break
|
||||||
probs = F.softmax(values, dim=-1)
|
|
||||||
sampled_idx = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
next_token = indices.gather(-1, sampled_idx)
|
|
||||||
|
|
||||||
output_tokens.append(next_token.item())
|
output_tokens.append(next_token.item())
|
||||||
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
||||||
input_ids = torch.cat([input_ids, next_token.view(1, 1)], dim=1)
|
|
||||||
|
|
||||||
# Break if punctuation (end of sentence)
|
|
||||||
word = tokenizer.detokenize(next_token.item())
|
|
||||||
if word in [".", "!", "?"]:
|
|
||||||
break
|
|
||||||
|
|
||||||
return tokenizer.detokenize(output_tokens)
|
return tokenizer.detokenize(output_tokens)
|
||||||
|
|
||||||
|
@ -7,7 +7,8 @@ import json
|
|||||||
|
|
||||||
BOOK_DIR = "data/books"
|
BOOK_DIR = "data/books"
|
||||||
PROGRESS_FILE = "data/memory/book_progress.json"
|
PROGRESS_FILE = "data/memory/book_progress.json"
|
||||||
READ_DELAY = 10 # seconds between lines
|
READ_DELAY = 0.2 # seconds between lines
|
||||||
|
PARAGRAPH_MIN_LENGTH = 20
|
||||||
|
|
||||||
|
|
||||||
def get_books():
|
def get_books():
|
||||||
@ -29,8 +30,6 @@ def save_progress(prog):
|
|||||||
async def read_books_forever():
|
async def read_books_forever():
|
||||||
books = get_books()
|
books = get_books()
|
||||||
progress = load_progress()
|
progress = load_progress()
|
||||||
buffered_lines = []
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
for book in books:
|
for book in books:
|
||||||
path = os.path.join(BOOK_DIR, book)
|
path = os.path.join(BOOK_DIR, book)
|
||||||
@ -41,26 +40,26 @@ async def read_books_forever():
|
|||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
idx = progress.get(book, 0)
|
idx = progress.get(book, 0)
|
||||||
|
paragraph = ""
|
||||||
|
|
||||||
while idx < len(lines):
|
while idx < len(lines):
|
||||||
line = lines[idx].strip()
|
line = lines[idx].strip()
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
if not line:
|
||||||
|
if len(paragraph) > PARAGRAPH_MIN_LENGTH:
|
||||||
|
train_on_message(paragraph.strip(), source="book")
|
||||||
|
paragraph = ""
|
||||||
|
await asyncio.sleep(READ_DELAY)
|
||||||
|
set_next_action(READ_DELAY, "Reading")
|
||||||
|
else:
|
||||||
|
paragraph += " " + line
|
||||||
|
|
||||||
progress[book] = idx
|
progress[book] = idx
|
||||||
save_progress(progress)
|
save_progress(progress)
|
||||||
|
|
||||||
if is_valid_line(line):
|
# train last paragraph if any
|
||||||
buffered_lines.append(line)
|
if paragraph and len(paragraph) > PARAGRAPH_MIN_LENGTH:
|
||||||
|
train_on_message(paragraph.strip(), source="book")
|
||||||
# If we have enough lines buffered, combine and train
|
|
||||||
if len(buffered_lines) >= 3:
|
|
||||||
combined_text = " ".join(buffered_lines)
|
|
||||||
train_on_message(combined_text, source="book")
|
|
||||||
buffered_lines.clear()
|
|
||||||
|
|
||||||
set_next_action(READ_DELAY, "Reading")
|
|
||||||
await asyncio.sleep(READ_DELAY)
|
await asyncio.sleep(READ_DELAY)
|
||||||
|
set_next_action(READ_DELAY, "Reading")
|
||||||
# End of a book: train whatever lines are left buffered
|
|
||||||
if buffered_lines:
|
|
||||||
combined_text = " ".join(buffered_lines)
|
|
||||||
train_on_message(combined_text, source="book")
|
|
||||||
buffered_lines.clear()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user