83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
# main.py
|
|
|
|
import os
|
|
import torch
|
|
import json
|
|
from config import *
|
|
from tokenizers.word_tokenizer import WordTokenizer
|
|
from models.gpt import GPT
|
|
from training.trainer import TextDataset, train
|
|
|
|
|
|
def load_texts():
|
|
text = ""
|
|
|
|
print("[INFO] Scanning books from: data\\books")
|
|
books_dir = "data/books"
|
|
if os.path.exists(books_dir):
|
|
for file in os.listdir(books_dir):
|
|
if file.endswith(".txt"):
|
|
print(f" 📚 Loading {file}")
|
|
with open(os.path.join(books_dir, file), encoding="utf-8") as f:
|
|
text += f.read()
|
|
|
|
owt_dir = "data/openwebtext"
|
|
if os.path.exists(owt_dir):
|
|
print(f"[INFO] Scanning OpenWebText: {owt_dir}")
|
|
for fname in os.listdir(owt_dir):
|
|
if fname.endswith(".jsonl"):
|
|
path = os.path.join(owt_dir, fname)
|
|
print(f" 📄 Loading {fname}")
|
|
try:
|
|
with open(path, encoding="utf-8") as f:
|
|
for i, line in enumerate(f):
|
|
try:
|
|
obj = json.loads(line)
|
|
text += obj.get("text", "")
|
|
except json.JSONDecodeError:
|
|
continue
|
|
if len(text) >= MAX_TOKENS * 10:
|
|
break
|
|
except Exception as e:
|
|
print(f" ⚠️ Error reading {fname}: {e}")
|
|
|
|
print("[INFO] Raw text loaded:", len(text), "characters")
|
|
|
|
# Truncate to MAX_TOKENS * 10 (rough estimate)
|
|
clipped = text[:MAX_TOKENS * 10]
|
|
print("[INFO] Loaded text:", len(clipped), "characters")
|
|
|
|
return clipped
|
|
|
|
|
|
def main():
|
|
print("[INFO] Starting main()")
|
|
raw_text = load_texts()
|
|
print(f"[INFO] Loaded text: {len(raw_text)} characters")
|
|
|
|
tokenizer = WordTokenizer(VOCAB_SIZE)
|
|
tokenizer.fit(raw_text)
|
|
tokenizer.save("catlin_tokenizer.pkl")
|
|
print("[INFO] Tokenizer built and saved")
|
|
|
|
tokens = tokenizer.encode(raw_text)
|
|
print(f"[INFO] Total tokens: {len(tokens)}")
|
|
|
|
dataset = TextDataset(tokens, CONTEXT_SIZE)
|
|
if len(dataset) == 0:
|
|
print("❌ ERROR: Dataset is empty. Aborting.")
|
|
return
|
|
|
|
model = GPT(VOCAB_SIZE, CONTEXT_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS)
|
|
print("[INFO] Model initialized")
|
|
|
|
train(model, dataset, DEVICE if torch.cuda.is_available() else "cpu", LEARNING_RATE, BATCH_SIZE, epochs=10)
|
|
print("[INFO] Training complete")
|
|
|
|
torch.save(model.state_dict(), "catlin_model.pt")
|
|
print("[INFO] Model saved to catlin_model.pt")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|