# main.py import os import torch from config import * from tokenizers.word_tokenizer import WordTokenizer from models.gpt import GPT from training.trainer import TextDataset, train def load_texts(): text = "" # --- Books --- book_dir = os.path.join("data", "books") os.makedirs(book_dir, exist_ok=True) print(f"[INFO] Scanning books from: {book_dir}") for file in os.listdir(book_dir): path = os.path.join(book_dir, file) if file.endswith(".txt"): print(f" 📚 Loading {file}") try: with open(path, encoding="utf-8") as f: text += f.read() + "\n" except Exception as e: print(f" ❌ Failed to read {file}: {e}") # --- OpenWebText --- owt_path = os.path.join("data/openwebtext", "owt_20000.jsonl") print(f"[INFO] Scanning OpenWebText: {owt_path}") if os.path.exists(owt_path): with open(owt_path, encoding="utf-8") as f: for i, line in enumerate(f): if i % 1000 == 0: print(f" ⏳ {i} lines read...") try: text += line.strip() + "\n" except Exception as e: print(f" ❌ Line {i} decode error: {e}") else: print(f"[WARN] OpenWebText file not found: {owt_path}") # --- Chat logs --- if os.path.exists("catlin_chatlog.txt"): print(f"[INFO] Appending chat log...") with open("catlin_chatlog.txt", encoding="utf-8") as f: text += "\n" + f.read() print(f"[INFO] Raw text loaded: {len(text)} characters") return text[:MAX_TOKENS * 10] def main(): print("[INFO] Starting main()") raw_text = load_texts() print(f"[INFO] Loaded text: {len(raw_text)} characters") tokenizer = WordTokenizer(VOCAB_SIZE) tokenizer.fit(raw_text) tokenizer.save("catlin_tokenizer.pkl") print("[INFO] Tokenizer built and saved") tokens = tokenizer.encode(raw_text) print(f"[INFO] Total tokens: {len(tokens)}") dataset = TextDataset(tokens, CONTEXT_SIZE) if len(dataset) == 0: print("❌ ERROR: Dataset is empty. Aborting.") return model = GPT(VOCAB_SIZE, CONTEXT_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS) print("[INFO] Model initialized") train(model, dataset, DEVICE if torch.cuda.is_available() else "cpu", LEARNING_RATE, BATCH_SIZE, epochs=10) print("[INFO] Training complete") torch.save(model.state_dict(), "catlin_model.pt") print("[INFO] Model saved to catlin_model.pt") if __name__ == "__main__": main()