# talk_to_catlin.py import os import torch from config import * from tokenizers.word_tokenizer import WordTokenizer from models.gpt import GPT from datetime import datetime CHAT_LOG_PATH = "catlin_chatlog.txt" def load_model(device): model = GPT(VOCAB_SIZE, CONTEXT_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS) model.load_state_dict(torch.load("catlin_model.pt", map_location=device)) model.to(device) model.eval() return model def load_tokenizer(path="catlin_tokenizer.pkl"): if not os.path.exists(path): print(f"[ERROR] Tokenizer file '{path}' not found. Please train first.") exit(1) return WordTokenizer.load(path) def top_p_sampling(logits, p=0.9): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_keep = cumulative_probs <= p sorted_indices_to_keep[1:] = sorted_indices_to_keep[:-1].clone() sorted_indices_to_keep[0] = True filtered_logits = sorted_logits[sorted_indices_to_keep] filtered_indices = sorted_indices[sorted_indices_to_keep] probs = torch.nn.functional.softmax(filtered_logits, dim=-1) next_token = filtered_indices[torch.multinomial(probs, 1)] return next_token.item() def generate_response(model, tokenizer, prompt, device, max_tokens=50, k=10, temperature=1.0): tokens = tokenizer.encode(prompt)[-CONTEXT_SIZE:] input_ids = torch.tensor([tokens], dtype=torch.long).to(device) for _ in range(max_tokens): with torch.no_grad(): logits = model(input_ids) next_token_logits = logits[0, -1, :] next_token_logits = next_token_logits / temperature next_token = top_p_sampling(next_token_logits, p=0.9) input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1) if input_ids.shape[1] > CONTEXT_SIZE: input_ids = input_ids[:, -CONTEXT_SIZE:] word = tokenizer.id_to_word.get(next_token, "") if word in {".", "!", "?"}: break return tokenizer.decode(input_ids[0].tolist()[len(tokens):]) def log_chat(user_input, catlin_response): with open(CHAT_LOG_PATH, "a", encoding="utf-8") as f: f.write(f"[{datetime.now().isoformat()}] You: {user_input}\n") f.write(f"[{datetime.now().isoformat()}] Catlin: {catlin_response}\n\n") def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model(device) tokenizer = load_tokenizer() print("🤖 Catlin is online. Type 'exit' to end the conversation.\n") history = [] while True: user_input = input("You: ") if user_input.strip().lower() == "exit": break history.append(f"You: {user_input}") # Build memory window memory_text = " ".join(history[-10:]) # Keep last 10 lines catlin_response = generate_response(model, tokenizer, memory_text, device) print(f"Catlin: {catlin_response}\n") history.append(f"Catlin: {catlin_response}") log_chat(user_input, catlin_response) if __name__ == "__main__": main()