First good level of progress
This commit is contained in:
84
main.py
Normal file
84
main.py
Normal file
@ -0,0 +1,84 @@
|
||||
# main.py
|
||||
|
||||
import os
|
||||
import torch
|
||||
from config import *
|
||||
from tokenizers.word_tokenizer import WordTokenizer
|
||||
from models.gpt import GPT
|
||||
from training.trainer import TextDataset, train
|
||||
|
||||
|
||||
def load_texts():
|
||||
text = ""
|
||||
|
||||
# --- Books ---
|
||||
book_dir = os.path.join("data", "books")
|
||||
os.makedirs(book_dir, exist_ok=True)
|
||||
|
||||
print(f"[INFO] Scanning books from: {book_dir}")
|
||||
for file in os.listdir(book_dir):
|
||||
path = os.path.join(book_dir, file)
|
||||
if file.endswith(".txt"):
|
||||
print(f" 📚 Loading {file}")
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
text += f.read() + "\n"
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to read {file}: {e}")
|
||||
|
||||
# --- OpenWebText ---
|
||||
owt_path = os.path.join("data/openwebtext", "owt_20000.jsonl")
|
||||
print(f"[INFO] Scanning OpenWebText: {owt_path}")
|
||||
|
||||
if os.path.exists(owt_path):
|
||||
with open(owt_path, encoding="utf-8") as f:
|
||||
for i, line in enumerate(f):
|
||||
if i % 1000 == 0:
|
||||
print(f" ⏳ {i} lines read...")
|
||||
try:
|
||||
text += line.strip() + "\n"
|
||||
except Exception as e:
|
||||
print(f" ❌ Line {i} decode error: {e}")
|
||||
else:
|
||||
print(f"[WARN] OpenWebText file not found: {owt_path}")
|
||||
|
||||
# --- Chat logs ---
|
||||
if os.path.exists("catlin_chatlog.txt"):
|
||||
print(f"[INFO] Appending chat log...")
|
||||
with open("catlin_chatlog.txt", encoding="utf-8") as f:
|
||||
text += "\n" + f.read()
|
||||
|
||||
print(f"[INFO] Raw text loaded: {len(text)} characters")
|
||||
return text[:MAX_TOKENS * 10]
|
||||
|
||||
|
||||
def main():
|
||||
print("[INFO] Starting main()")
|
||||
raw_text = load_texts()
|
||||
print(f"[INFO] Loaded text: {len(raw_text)} characters")
|
||||
|
||||
tokenizer = WordTokenizer(VOCAB_SIZE)
|
||||
tokenizer.fit(raw_text)
|
||||
tokenizer.save("catlin_tokenizer.pkl")
|
||||
print("[INFO] Tokenizer built and saved")
|
||||
|
||||
tokens = tokenizer.encode(raw_text)
|
||||
print(f"[INFO] Total tokens: {len(tokens)}")
|
||||
|
||||
dataset = TextDataset(tokens, CONTEXT_SIZE)
|
||||
if len(dataset) == 0:
|
||||
print("❌ ERROR: Dataset is empty. Aborting.")
|
||||
return
|
||||
|
||||
model = GPT(VOCAB_SIZE, CONTEXT_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS)
|
||||
print("[INFO] Model initialized")
|
||||
|
||||
train(model, dataset, DEVICE if torch.cuda.is_available() else "cpu", LEARNING_RATE, BATCH_SIZE, epochs=10)
|
||||
print("[INFO] Training complete")
|
||||
|
||||
torch.save(model.state_dict(), "catlin_model.pt")
|
||||
print("[INFO] Model saved to catlin_model.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user