96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
# talk_to_catlin.py
|
|
|
|
import os
|
|
import torch
|
|
from config import *
|
|
from tokenizers.word_tokenizer import WordTokenizer
|
|
from models.gpt import GPT
|
|
from datetime import datetime
|
|
|
|
CHAT_LOG_PATH = "catlin_chatlog.txt"
|
|
|
|
|
|
def load_model(device):
|
|
model = GPT(VOCAB_SIZE, CONTEXT_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS)
|
|
model.load_state_dict(torch.load("catlin_model.pt", map_location=device))
|
|
model.to(device)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def load_tokenizer(path="catlin_tokenizer.pkl"):
|
|
if not os.path.exists(path):
|
|
print(f"[ERROR] Tokenizer file '{path}' not found. Please train first.")
|
|
exit(1)
|
|
return WordTokenizer.load(path)
|
|
|
|
|
|
def top_p_sampling(logits, p=0.9):
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
sorted_indices_to_keep = cumulative_probs <= p
|
|
sorted_indices_to_keep[1:] = sorted_indices_to_keep[:-1].clone()
|
|
sorted_indices_to_keep[0] = True
|
|
|
|
filtered_logits = sorted_logits[sorted_indices_to_keep]
|
|
filtered_indices = sorted_indices[sorted_indices_to_keep]
|
|
|
|
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
|
|
next_token = filtered_indices[torch.multinomial(probs, 1)]
|
|
return next_token.item()
|
|
|
|
|
|
def generate_response(model, tokenizer, prompt, device, max_tokens=50, k=10, temperature=1.0):
|
|
tokens = tokenizer.encode(prompt)[-CONTEXT_SIZE:]
|
|
input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
|
|
|
|
for _ in range(max_tokens):
|
|
with torch.no_grad():
|
|
logits = model(input_ids)
|
|
next_token_logits = logits[0, -1, :]
|
|
next_token_logits = next_token_logits / temperature
|
|
next_token = top_p_sampling(next_token_logits, p=0.9)
|
|
|
|
input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device=device)], dim=1)
|
|
if input_ids.shape[1] > CONTEXT_SIZE:
|
|
input_ids = input_ids[:, -CONTEXT_SIZE:]
|
|
|
|
word = tokenizer.id_to_word.get(next_token, "")
|
|
if word in {".", "!", "?"}:
|
|
break
|
|
|
|
return tokenizer.decode(input_ids[0].tolist()[len(tokens):])
|
|
|
|
|
|
def log_chat(user_input, catlin_response):
|
|
with open(CHAT_LOG_PATH, "a", encoding="utf-8") as f:
|
|
f.write(f"[{datetime.now().isoformat()}] You: {user_input}\n")
|
|
f.write(f"[{datetime.now().isoformat()}] Catlin: {catlin_response}\n\n")
|
|
|
|
|
|
def main():
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = load_model(device)
|
|
tokenizer = load_tokenizer()
|
|
|
|
print("🤖 Catlin is online. Type 'exit' to end the conversation.\n")
|
|
history = []
|
|
|
|
while True:
|
|
user_input = input("You: ")
|
|
if user_input.strip().lower() == "exit":
|
|
break
|
|
|
|
history.append(f"You: {user_input}")
|
|
# Build memory window
|
|
memory_text = " ".join(history[-10:]) # Keep last 10 lines
|
|
catlin_response = generate_response(model, tokenizer, memory_text, device)
|
|
|
|
print(f"Catlin: {catlin_response}\n")
|
|
history.append(f"Catlin: {catlin_response}")
|
|
log_chat(user_input, catlin_response)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|