commit 16289fc942888e648b4b0a1404d02279d8f27e8d Author: Dani Date: Tue Jun 3 23:43:58 2025 -0400 Created NORA. She has been designed from zero. At this point, I have determined the best hyperparamers for her to train. Next step is to help her communicate on discord and see how she handles it. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7299dec --- /dev/null +++ b/.gitignore @@ -0,0 +1,198 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore +checkpoints/nora_step_*.pt +data/books +*.json diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..fb5f6af --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 [fullname] + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f1030d5 --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# Nora: Train a Transformer LM from Scratch + +> A minimal, from-scratch language model. No pretrained weights—just public-domain books and your GPU (or CPU). + +## Overview + +Nora is a character-level Transformer language model written entirely in PyTorch. It learns from whatever plain‐text `.txt` files you place in `data/books/`. Over time, you can extend Nora’s codebase (e.g., add reinforcement-learning loops, self-improvement modules, etc.) toward more advanced AI, if you wish. + +## Why “Nora”? + +- A simple, human‐like female name. +- Short, easy to pronounce. +- As the project scales, “Nora” could theoretically be extended with modules to approach more general intelligence. + +## Requirements + +- **Python 3.10.6** (Windows 11 or any OS) +- **CUDA-capable GPU** (if you want to train faster; otherwise CPU) +- **PyTorch** (install with `pip install torch torchvision`) +- **tqdm** (`pip install tqdm`) +- **Other Python packages**: `numpy`, `typing` + +## Folder Structure + +- nora/ +- ├── config.py +- ├── tokenizer.py +- ├── data_loader.py +- ├── model.py +- ├── train.py +- ├── utils.py +- ├── main.py +- └── README.md \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..0ad435c --- /dev/null +++ b/config.py @@ -0,0 +1,146 @@ +""" +config.py + +Define hyperparameters, file paths, and other settings via argparse. +""" + +import argparse +import torch + + +def get_config(): + parser = argparse.ArgumentParser(description="Nora: Train a Transformer from scratch") + + # Data & paths + parser.add_argument( + "--data_dir", + type=str, + default="data/books", + help="Path to folder containing .txt files (public-domain books).", + ) + parser.add_argument( + "--vocab_path", + type=str, + default="data/vocab.json", + help="Where to save/load the tokenizer vocabulary.", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default="checkpoints", + help="Directory to save model checkpoints.", + ) + + # Model hyperparameters + parser.add_argument( + "--d_model", + type=int, + default=512, + help="Transformer embedding size (d_model).", + ) + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads.", + ) + parser.add_argument( + "--num_layers", + type=int, + default=6, + help="Number of Transformer encoder layers.", + ) + parser.add_argument( + "--dim_feedforward", + type=int, + default=2048, + help="Inner feedforward dimension.", + ) + parser.add_argument( + "--dropout", + type=float, + default=0.1, + help="Dropout rate in Transformer.", + ) + + # Training hyperparameters + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size per training step.", + ) + parser.add_argument( + "--seq_length", + type=int, + default=128, + help="Sequence length (context window) in tokens.", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + help="Number of training epochs.", + ) + parser.add_argument( + "--lr", + type=float, + default=1e-4, + help="Learning rate.", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=1000, + help="Linear learning rate warmup steps.", + ) + parser.add_argument( + "--max_grad_norm", + type=float, + default=1.0, + help="Gradient clipping norm.", + ) + + # Logging & checkpointing + parser.add_argument( + "--log_interval", + type=int, + default=100, + help="Print training loss every N steps.", + ) + parser.add_argument( + "--save_interval", + type=int, + default=1000, + help="Save a checkpoint every N steps.", + ) + + # Device selection + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to train on ('cuda' or 'cpu').", + ) + + # Scaling options (for Pi vs GPU) + parser.add_argument( + "--tiny", + action="store_true", + help="If set, override model sizes to be tiny (for Pi 3B or very low-compute).", + ) + + args = parser.parse_args() + + # If --tiny is set, override some hyperparameters to very small values: + if args.tiny: + args.d_model = 64 + args.nhead = 2 + args.num_layers = 2 + args.dim_feedforward = 256 + args.batch_size = 8 + args.seq_length = 32 + args.lr = 1e-3 + args.epochs = 5 + + return args diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..4a9b5ae --- /dev/null +++ b/data_loader.py @@ -0,0 +1,63 @@ +""" +data_loader.py + +Loads all .txt files from data_dir, concatenates them, tokenizes them, +and creates a Dataset of (input_seq, target_seq) for language modeling. +""" + +import os +import torch +from torch.utils.data import Dataset, DataLoader + + +class TextDataset(Dataset): + def __init__(self, data_dir: str, tokenizer, seq_length: int): + """ + - data_dir: folder of .txt public-domain books. + - tokenizer: instance of CharTokenizer (from tokenizer.py). + - seq_length: context length in tokens. + """ + super().__init__() + self.seq_length = seq_length + self.tokenizer = tokenizer + + # Read and concatenate all text files into one long string + texts = [] + for root, _, files in os.walk(data_dir): + for fname in files: + if not fname.lower().endswith(".txt"): + continue + path = os.path.join(root, fname) + with open(path, "r", encoding="utf-8", errors="ignore") as f: + texts.append(f.read()) + full_text = "\n".join(texts) + token_ids = self.tokenizer.encode(full_text) + + # Prepare input-target pairs + self.examples = [] + stride = 32 + for i in range(0, len(token_ids) - seq_length, stride): + inp = token_ids[i : i + seq_length] + targ = token_ids[i + 1 : i + seq_length + 1] + self.examples.append((inp, targ)) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx): + inp, targ = self.examples[idx] + return torch.tensor(inp, dtype=torch.long), torch.tensor(targ, dtype=torch.long) + + +def get_dataloader( + data_dir: str, tokenizer, seq_length: int, batch_size: int, shuffle: bool = True +) -> DataLoader: + dataset = TextDataset(data_dir, tokenizer, seq_length) + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=True, + num_workers=8, + pin_memory=True, + ) diff --git a/main.py b/main.py new file mode 100644 index 0000000..5f48e9d --- /dev/null +++ b/main.py @@ -0,0 +1,89 @@ +""" +main.py + +Orchestrates the entire Nora project: +- Parses arguments +- Builds or loads tokenizer +- Constructs dataset & dataloader +- Instantiates the model +- Sets up optimizer, scheduler +- Calls train() +""" + +import os +import torch +import logging +from config import get_config +from tokenizer import CharTokenizer +from data_loader import get_dataloader +from model import NoraTransformerLM +from train import train +from utils import setup_logging, load_checkpoint, save_checkpoint + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.enabled = True + + +def main(): + args = get_config() + + # 1) Logging setup + log_file = os.path.join(args.checkpoint_dir, "train.log") + setup_logging(log_file) + + logging.info(f"[main] Using device: {args.device}") + logging.info(f"[main] Config: {args}") + + # 2) Tokenizer: if vocab exists, load; else build from data_dir + tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir) + + # 3) DataLoader + dataloader = get_dataloader( + data_dir=args.data_dir, + tokenizer=tokenizer, + seq_length=args.seq_length, + batch_size=args.batch_size, + shuffle=True, + ) + + # 4) Model instantiation + model = NoraTransformerLM( + vocab_size=tokenizer.vocab_size(), + d_model=args.d_model, + nhead=args.nhead, + num_layers=args.num_layers, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + max_seq_len=args.seq_length, + ) + + # 5) Optimizer & scheduler (linear warmup + decay) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9) + + def lr_lambda(current_step): + # Linear warmup for first warmup_steps, then decay with 1/sqrt(step) + if current_step < args.warmup_steps: + return float(current_step) / float(max(1, args.warmup_steps)) + return (args.warmup_steps ** 0.5) * float(current_step ** -0.5) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + # 6) Check for existing checkpoint to resume + start_step = 0 + ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else [] + ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")] + if ckpts: + latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1]) + logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.") + start_step = load_checkpoint(latest_ckpt, model, optimizer) + + # 7) Begin training + try: + train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step) + except Exception as e: + logging.exception("[main] Exception during training") + raise e + + +if __name__ == "__main__": + main() diff --git a/model.py b/model.py new file mode 100644 index 0000000..8eacd70 --- /dev/null +++ b/model.py @@ -0,0 +1,100 @@ +""" +model.py + +Defines a Transformer‐based language model from scratch, using PyTorch’s nn.Transformer. +No pretrained weights—everything is initialized randomly. +""" + +import torch +import torch.nn as nn +import math + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, max_len: int = 10_000): + super().__init__() + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # shape: (1, max_len, d_model) + self.register_buffer("pe", pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (batch_size, seq_length, d_model) + returns x + positional encodings for the first seq_length positions. + """ + x = x + self.pe[:, : x.size(1), :] + return x + + +class NoraTransformerLM(nn.Module): + def __init__( + self, + vocab_size: int, + d_model: int = 512, + nhead: int = 8, + num_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.1, + max_seq_len: int = 512, + ): + super().__init__() + self.model_type = "TransformerLM" + self.d_model = d_model + self.vocab_size = vocab_size + + # Token embedding + positional encoding + self.token_embed = nn.Embedding(vocab_size, d_model) + self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len) + + # Transformer encoder layers + encoder_layers = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation="gelu", + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=num_layers + ) + + # Final linear layer to project to vocabulary + self.fc_out = nn.Linear(d_model, vocab_size) + + # Initialization + self._init_weights() + + def _init_weights(self): + nn.init.normal_(self.token_embed.weight, mean=0, std=self.d_model ** -0.5) + nn.init.zeros_(self.fc_out.bias) + nn.init.normal_(self.fc_out.weight, mean=0, std=self.d_model ** -0.5) + + def forward(self, src: torch.Tensor) -> torch.Tensor: + """ + src: (batch_size, seq_length), token IDs + returns: logits (batch_size, seq_length, vocab_size) + """ + + # Embed tokens and add positional encoding + x = self.token_embed(src) * math.sqrt(self.d_model) # (B, S, D) + x = self.pos_encoder(x) # (B, S, D) + # PyTorch Transformer expects (S, B, D) + x = x.permute(1, 0, 2) # (seq_length, batch_size, d_model) + + # Create a causal mask so each position can only attend to previous positions + seq_len = x.size(0) + mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool() + + # Pass through Transformer encoder + x = self.transformer_encoder(x, mask=mask) # (seq_length, batch_size, d_model) + + # Back to (B, S, D) + x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model) + logits = self.fc_out(x) # (batch_size, seq_length, vocab_size) + return logits diff --git a/tokenizer.py b/tokenizer.py new file mode 100644 index 0000000..de69267 --- /dev/null +++ b/tokenizer.py @@ -0,0 +1,86 @@ +""" +tokenizer.py + +A simple character‐level tokenizer that builds its own vocabulary from all text files. +Saves/loads vocab to/from JSON. You can extend this to a word‐level tokenizer if you wish. +""" + +import json +import os +from collections import Counter +from typing import List, Dict, Union + + +class CharTokenizer: + def __init__(self, vocab_path: str, data_dir: str): + """ + If vocab_path exists, load it; otherwise, build from raw text in data_dir. + """ + self.vocab_path = vocab_path + self.data_dir = data_dir + self.stoi: Dict[str, int] = {} + self.itos: Dict[int, str] = {} + + if os.path.isfile(self.vocab_path): + self._load_vocab() + else: + self._build_vocab() + + def _build_vocab(self): + """ + Read all .txt files under data_dir, count character frequencies, + build a sorted vocabulary, and save to vocab_path. + """ + counter = Counter() + print(f"[tokenizer] Building vocabulary from data in '{self.data_dir}'...") + for root, _, files in os.walk(self.data_dir): + for fname in files: + if not fname.lower().endswith(".txt"): + continue + path = os.path.join(root, fname) + with open(path, "r", encoding="utf-8", errors="ignore") as f: + text = f.read() + counter.update(text) + + # Ensure a consistent ordering: sort by frequency descending, then Unicode codepoint + sorted_chars = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + unique_chars = [ch for ch, _ in sorted_chars] + + # Add special tokens + tokens = ["", ""] + unique_chars + + self.stoi = {ch: i for i, ch in enumerate(tokens)} + self.itos = {i: ch for i, ch in enumerate(tokens)} + + # Save to JSON + os.makedirs(os.path.dirname(self.vocab_path), exist_ok=True) + with open(self.vocab_path, "w", encoding="utf-8") as f: + json.dump(self.stoi, f, ensure_ascii=False, indent=2) + print(f"[tokenizer] Built vocab size = {len(self.stoi)}; saved to '{self.vocab_path}'.") + + def _load_vocab(self): + """ + Load existing vocabulary from vocab_path. + """ + print(f"[tokenizer] Loading vocabulary from '{self.vocab_path}'...") + with open(self.vocab_path, "r", encoding="utf-8") as f: + self.stoi = json.load(f) + self.itos = {i: ch for ch, i in self.stoi.items()} + print(f"[tokenizer] Loaded vocab size = {len(self.stoi)}.") + + def encode(self, text: str) -> List[int]: + """ + Convert a string to a list of integer token IDs (character‐level). + Unrecognized chars map to . + """ + unk_id = self.stoi.get("") + return [self.stoi.get(ch, unk_id) for ch in text] + + def decode(self, token_ids: List[int]) -> str: + """ + Convert a list of token IDs back into a string. + """ + return "".join(self.itos.get(i, "") for i in token_ids) + + def vocab_size(self) -> int: + return len(self.stoi) diff --git a/train.py b/train.py new file mode 100644 index 0000000..4c2c0cf --- /dev/null +++ b/train.py @@ -0,0 +1,135 @@ +""" +train.py + +Training loop for Nora, with automatic mixed precision (AMP) to speed up on CUDA GPUs. +Uses tqdm for progress bars, logging for metrics, and gradient clipping + LR scheduler. +""" + +import time +import logging +import math +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.nn.utils import clip_grad_norm_ +from torch.amp import GradScaler, autocast # <-- updated import + + +def train( + model: torch.nn.Module, + dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + scheduler, + tokenizer, + config, + start_step: int = 0, +): + """ + model: NoraTransformerLM + dataloader: DataLoader for TextDataset + optimizer: AdamW (or Adam) + scheduler: LR scheduler with warmup + tokenizer: CharTokenizer + config: namespace from config.py + start_step: if resuming from checkpoint + """ + + device = config.device + model.to(device) + criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi[""]) + scaler = GradScaler() + + global_step = start_step + steps_per_epoch = len(dataloader) + total_steps = config.epochs * steps_per_epoch + + logging.info( + f"[train] Starting training for {config.epochs} epochs, " + f"{steps_per_epoch} steps/epoch, total approx {total_steps} steps." + ) + + for epoch in range(config.epochs): + model.train() + epoch_loss = 0.0 + epoch_start = time.time() + + # If you want to profile the first 100 steps, uncomment below: + # if global_step == start_step: + # t0 = time.time() + + pbar = tqdm( + enumerate(dataloader), + total=steps_per_epoch, + desc=f"Epoch {epoch + 1}", + ncols=100, + unit="step", + ) + for step, (inputs, targets) in pbar: + inputs = inputs.to(device) + targets = targets.to(device) + + optimizer.zero_grad() + + # Mixed precision forward/backward (specify device_type="cuda") + with autocast(device_type="cuda", dtype=torch.float16): + logits = model(inputs) # (batch, seq_len, vocab_size) + loss = criterion( + logits.view(-1, tokenizer.vocab_size()), + targets.view(-1), + ) + + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + clip_grad_norm_(model.parameters(), config.max_grad_norm) + scaler.step(optimizer) + scaler.update() + scheduler.step() + + epoch_loss += loss.item() + global_step += 1 + + # Log every log_interval steps + if global_step % config.log_interval == 0: + avg_loss = epoch_loss / (step + 1) + ppl = math.exp(avg_loss) + logging.info( + f"[step {global_step}/{total_steps}] " + f"avg_loss = {avg_loss:.4f}, ppl = {ppl:.2f}, " + f"lr = {scheduler.get_last_lr()[0]:.2e}" + ) + + # Save checkpoint every save_interval steps + if global_step % config.save_interval == 0: + from utils import save_checkpoint + + save_checkpoint( + model, + optimizer, + global_step, + config.checkpoint_dir, + tokenizer=tokenizer, + ) + + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + # (Optional) Profile first 100 steps + # if global_step == start_step + 100: + # elapsed = time.time() - t0 + # print( + # f"[profile] avg time/step over first 100 batches: " + # f"{elapsed/100:.4f} s" + # ) + + epoch_time = time.time() - epoch_start + avg_epoch_loss = epoch_loss / steps_per_epoch + logging.info( + f"[epoch {epoch + 1}/{config.epochs}] " + f"avg_loss = {avg_epoch_loss:.4f}, time = {epoch_time:.1f}s" + ) + + # Final checkpoint at end of all epochs + from utils import save_checkpoint + + save_checkpoint(model, optimizer, global_step, config.checkpoint_dir, tokenizer=tokenizer) + logging.info("[train] Training complete.") diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e941ea1 --- /dev/null +++ b/utils.py @@ -0,0 +1,84 @@ +""" +utils.py + +Common utilities: logging setup, checkpoint saving & loading, device checks, etc. +""" + +import os +import logging +import torch + + +def setup_logging(log_file: str = None): + """ + Set up logging to stdout (and optionally to a file). + """ + root = logging.getLogger() + root.setLevel(logging.INFO) + formatter = logging.Formatter( + "[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + + # Console handler + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + root.addHandler(ch) + + # File handler + if log_file: + os.makedirs(os.path.dirname(log_file), exist_ok=True) + fh = logging.FileHandler(log_file) + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + root.addHandler(fh) + + +def save_checkpoint( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + step: int, + checkpoint_dir: str, + tokenizer=None, +): + """ + Save model state, optimizer state, and tokenizer (optional) to a checkpoint file. + """ + os.makedirs(checkpoint_dir, exist_ok=True) + ckpt_path = os.path.join(checkpoint_dir, f"nora_step_{step}.pt") + state = { + "step": step, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + } + if tokenizer: + # tokenizer.stoi is JSON‐serializable + state["tokenizer_stoi"] = tokenizer.stoi + + torch.save(state, ckpt_path) + logging.info(f"[checkpoint] Saved checkpoint to {ckpt_path}") + + +def load_checkpoint( + ckpt_path: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer = None +): + """ + Load model & optimizer state from a checkpoint. Returns step. + If optimizer is None, only loads model weights. + """ + if not os.path.isfile(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + state = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(state["model_state_dict"]) + step = state.get("step", 0) + if optimizer and "optimizer_state_dict" in state: + optimizer.load_state_dict(state["optimizer_state_dict"]) + logging.info(f"[checkpoint] Loaded checkpoint from {ckpt_path} (step {step})") + return step + + +def get_default_device(): + """ + Return 'cuda' if available; otherwise 'cpu'. + """ + return "cuda" if torch.cuda.is_available() else "cpu"