Updated the model capacity

This commit is contained in:
2025-06-30 18:08:11 -04:00
parent 159be1eb82
commit 6366f72716
6 changed files with 95 additions and 10058 deletions

6
.gitignore vendored
View File

@ -205,4 +205,8 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/
/data/books
/data/books
/data/openwebtext
/catlin_tokenizer.pkl
/catlin_model.pt
/catlin_chatlog.txt

View File

@ -8,4 +8,4 @@ NUM_LAYERS = 6
BATCH_SIZE = 16
LEARNING_RATE = 3e-4
DEVICE = "cuda" # fallback handled in trainer
MAX_TOKENS = 100_000 # Used to cap input corpus size
MAX_TOKENS = 500_000 # Used to cap input corpus size

File diff suppressed because one or more lines are too long

66
main.py
View File

@ -2,6 +2,7 @@
import os
import torch
import json
from config import *
from tokenizers.word_tokenizer import WordTokenizer
from models.gpt import GPT
@ -11,45 +12,42 @@ from training.trainer import TextDataset, train
def load_texts():
text = ""
# --- Books ---
book_dir = os.path.join("data", "books")
os.makedirs(book_dir, exist_ok=True)
print("[INFO] Scanning books from: data\\books")
books_dir = "data/books"
if os.path.exists(books_dir):
for file in os.listdir(books_dir):
if file.endswith(".txt"):
print(f" 📚 Loading {file}")
with open(os.path.join(books_dir, file), encoding="utf-8") as f:
text += f.read()
print(f"[INFO] Scanning books from: {book_dir}")
for file in os.listdir(book_dir):
path = os.path.join(book_dir, file)
if file.endswith(".txt"):
print(f" 📚 Loading {file}")
try:
with open(path, encoding="utf-8") as f:
text += f.read() + "\n"
except Exception as e:
print(f" ❌ Failed to read {file}: {e}")
# --- OpenWebText ---
owt_path = os.path.join("data/openwebtext", "owt_20000.jsonl")
print(f"[INFO] Scanning OpenWebText: {owt_path}")
if os.path.exists(owt_path):
with open(owt_path, encoding="utf-8") as f:
for i, line in enumerate(f):
if i % 1000 == 0:
print(f"{i} lines read...")
owt_dir = "data/openwebtext"
if os.path.exists(owt_dir):
print(f"[INFO] Scanning OpenWebText: {owt_dir}")
for fname in os.listdir(owt_dir):
if fname.endswith(".jsonl"):
path = os.path.join(owt_dir, fname)
print(f" 📄 Loading {fname}")
try:
text += line.strip() + "\n"
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f):
try:
obj = json.loads(line)
text += obj.get("text", "")
except json.JSONDecodeError:
continue
if len(text) >= MAX_TOKENS * 10:
break
except Exception as e:
print(f" ❌ Line {i} decode error: {e}")
else:
print(f"[WARN] OpenWebText file not found: {owt_path}")
print(f" ⚠️ Error reading {fname}: {e}")
# --- Chat logs ---
if os.path.exists("catlin_chatlog.txt"):
print(f"[INFO] Appending chat log...")
with open("catlin_chatlog.txt", encoding="utf-8") as f:
text += "\n" + f.read()
print("[INFO] Raw text loaded:", len(text), "characters")
print(f"[INFO] Raw text loaded: {len(text)} characters")
return text[:MAX_TOKENS * 10]
# Truncate to MAX_TOKENS * 10 (rough estimate)
clipped = text[:MAX_TOKENS * 10]
print("[INFO] Loaded text:", len(clipped), "characters")
return clipped
def main():

View File

@ -4,22 +4,42 @@ import torch
import torch.nn as nn
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
class CausalSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim, num_heads,
dropout=dropout,
batch_first=True,
bias=True
)
def forward(self, x):
B, T, _ = x.size()
# Create causal mask: (T, T) with float('-inf') for future positions
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
# Pass it in as attn_mask
return self.attn(x, x, x, attn_mask=mask)[0]
class GPTBlock(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = CausalSelfAttention(embed_dim, num_heads, dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(),
nn.Linear(4 * embed_dim, embed_dim)
nn.Linear(4 * embed_dim, embed_dim),
nn.Dropout(dropout)
)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x):
attn_out, _ = self.attn(x, x, x, need_weights=False)
x = self.ln1(x + attn_out)
x = self.ln2(x + self.ff(x))
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
@ -28,16 +48,18 @@ class GPT(nn.Module):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Parameter(torch.zeros(1, context_size, embed_dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([GPTBlock(embed_dim, num_heads) for _ in range(num_layers)])
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
GPTBlock(embed_dim, num_heads, dropout)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
B, T = x.size()
tok_emb = self.token_emb(x)
x = tok_emb + self.pos_emb[:, :T, :]
x = self.dropout(x)
x = self.token_emb(x) + self.pos_emb[:, :T, :]
x = self.drop(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)

View File

@ -1,5 +1,3 @@
# tools/openwebtext_fetcher.py
from datasets import load_dataset
import os
from tqdm import tqdm
@ -8,14 +6,29 @@ TARGET_DIR = "data/openwebtext"
os.makedirs(TARGET_DIR, exist_ok=True)
def fetch_subset(n=10000, split="train"):
def fetch_subset(total=20000, chunk_size=5000, split="train"):
ds = load_dataset("stas/openwebtext-10k", split=split)
with open(os.path.join(TARGET_DIR, f"owt_{n}.jsonl"), "w", encoding="utf-8") as f:
for i, item in tqdm(enumerate(ds), total=n, desc="Writing JSONL"):
f.write(f"{item['text'].replace(chr(10),' ')}\n")
if i + 1 >= n:
break
print(f"[INFO] Total to fetch: {total} | Chunk size: {chunk_size}")
count = 0
file_index = 0
f = None
for item in tqdm(ds, desc="Downloading"):
if count % chunk_size == 0:
if f: f.close()
file_path = os.path.join(TARGET_DIR, f"owt_{file_index:05d}.jsonl")
f = open(file_path, "w", encoding="utf-8")
print(f"[INFO] Created {file_path}")
file_index += 1
f.write(f"{item['text'].replace(chr(10), ' ')}\n")
count += 1
if count >= total:
break
if f:
f.close()
print(f"[INFO] ✅ Done. {count} samples across {file_index} files.")
if __name__ == "__main__":
fetch_subset(20000) # fetch 20K examples (~100MB)
fetch_subset(total=100000, chunk_size=5000)