Updated the model capacity

This commit is contained in:
2025-06-30 18:08:11 -04:00
parent 159be1eb82
commit 6366f72716
6 changed files with 95 additions and 10058 deletions

66
main.py
View File

@ -2,6 +2,7 @@
import os
import torch
import json
from config import *
from tokenizers.word_tokenizer import WordTokenizer
from models.gpt import GPT
@ -11,45 +12,42 @@ from training.trainer import TextDataset, train
def load_texts():
text = ""
# --- Books ---
book_dir = os.path.join("data", "books")
os.makedirs(book_dir, exist_ok=True)
print("[INFO] Scanning books from: data\\books")
books_dir = "data/books"
if os.path.exists(books_dir):
for file in os.listdir(books_dir):
if file.endswith(".txt"):
print(f" 📚 Loading {file}")
with open(os.path.join(books_dir, file), encoding="utf-8") as f:
text += f.read()
print(f"[INFO] Scanning books from: {book_dir}")
for file in os.listdir(book_dir):
path = os.path.join(book_dir, file)
if file.endswith(".txt"):
print(f" 📚 Loading {file}")
try:
with open(path, encoding="utf-8") as f:
text += f.read() + "\n"
except Exception as e:
print(f" ❌ Failed to read {file}: {e}")
# --- OpenWebText ---
owt_path = os.path.join("data/openwebtext", "owt_20000.jsonl")
print(f"[INFO] Scanning OpenWebText: {owt_path}")
if os.path.exists(owt_path):
with open(owt_path, encoding="utf-8") as f:
for i, line in enumerate(f):
if i % 1000 == 0:
print(f"{i} lines read...")
owt_dir = "data/openwebtext"
if os.path.exists(owt_dir):
print(f"[INFO] Scanning OpenWebText: {owt_dir}")
for fname in os.listdir(owt_dir):
if fname.endswith(".jsonl"):
path = os.path.join(owt_dir, fname)
print(f" 📄 Loading {fname}")
try:
text += line.strip() + "\n"
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f):
try:
obj = json.loads(line)
text += obj.get("text", "")
except json.JSONDecodeError:
continue
if len(text) >= MAX_TOKENS * 10:
break
except Exception as e:
print(f" ❌ Line {i} decode error: {e}")
else:
print(f"[WARN] OpenWebText file not found: {owt_path}")
print(f" ⚠️ Error reading {fname}: {e}")
# --- Chat logs ---
if os.path.exists("catlin_chatlog.txt"):
print(f"[INFO] Appending chat log...")
with open("catlin_chatlog.txt", encoding="utf-8") as f:
text += "\n" + f.read()
print("[INFO] Raw text loaded:", len(text), "characters")
print(f"[INFO] Raw text loaded: {len(text)} characters")
return text[:MAX_TOKENS * 10]
# Truncate to MAX_TOKENS * 10 (rough estimate)
clipped = text[:MAX_TOKENS * 10]
print("[INFO] Loaded text:", len(clipped), "characters")
return clipped
def main():