Files
Catlin/talk_to_catlin.py
2025-06-29 12:36:25 -04:00

96 lines
3.1 KiB
Python

# talk_to_catlin.py
import os
import torch
from config import *
from tokenizers.word_tokenizer import WordTokenizer
from models.gpt import GPT
from datetime import datetime
CHAT_LOG_PATH = "catlin_chatlog.txt"
def load_model(device):
model = GPT(VOCAB_SIZE, CONTEXT_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS)
model.load_state_dict(torch.load("catlin_model.pt", map_location=device))
model.to(device)
model.eval()
return model
def load_tokenizer(path="catlin_tokenizer.pkl"):
if not os.path.exists(path):
print(f"[ERROR] Tokenizer file '{path}' not found. Please train first.")
exit(1)
return WordTokenizer.load(path)
def top_p_sampling(logits, p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_keep = cumulative_probs <= p
sorted_indices_to_keep[1:] = sorted_indices_to_keep[:-1].clone()
sorted_indices_to_keep[0] = True
filtered_logits = sorted_logits[sorted_indices_to_keep]
filtered_indices = sorted_indices[sorted_indices_to_keep]
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
next_token = filtered_indices[torch.multinomial(probs, 1)]
return next_token.item()
def generate_response(model, tokenizer, prompt, device, max_tokens=50, k=10, temperature=1.0):
tokens = tokenizer.encode(prompt)[-CONTEXT_SIZE:]
input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
for _ in range(max_tokens):
with torch.no_grad():
logits = model(input_ids)
next_token_logits = logits[0, -1, :]
next_token_logits = next_token_logits / temperature
next_token = top_p_sampling(next_token_logits, p=0.9)
input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)
if input_ids.shape[1] > CONTEXT_SIZE:
input_ids = input_ids[:, -CONTEXT_SIZE:]
word = tokenizer.id_to_word.get(next_token, "")
if word in {".", "!", "?"}:
break
return tokenizer.decode(input_ids[0].tolist()[len(tokens):])
def log_chat(user_input, catlin_response):
with open(CHAT_LOG_PATH, "a", encoding="utf-8") as f:
f.write(f"[{datetime.now().isoformat()}] You: {user_input}\n")
f.write(f"[{datetime.now().isoformat()}] Catlin: {catlin_response}\n\n")
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(device)
tokenizer = load_tokenizer()
print("🤖 Catlin is online. Type 'exit' to end the conversation.\n")
history = []
while True:
user_input = input("You: ")
if user_input.strip().lower() == "exit":
break
history.append(f"You: {user_input}")
# Build memory window
memory_text = " ".join(history[-10:]) # Keep last 10 lines
catlin_response = generate_response(model, tokenizer, memory_text, device)
print(f"Catlin: {catlin_response}\n")
history.append(f"Catlin: {catlin_response}")
log_chat(user_input, catlin_response)
if __name__ == "__main__":
main()