Created NORA. She has been designed from zero. At this point, I have determined the best hyperparamers for her to train. Next step is to help her communicate on discord and see how she handles it.

This commit is contained in:
Dani 2025-06-03 23:43:58 -04:00
commit 16289fc942
10 changed files with 955 additions and 0 deletions

198
.gitignore vendored Normal file
View File

@ -0,0 +1,198 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
checkpoints/nora_step_*.pt
data/books
*.json

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 [fullname]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

33
README.md Normal file
View File

@ -0,0 +1,33 @@
# Nora: Train a Transformer LM from Scratch
> A minimal, from-scratch language model. No pretrained weights—just public-domain books and your GPU (or CPU).
## Overview
Nora is a character-level Transformer language model written entirely in PyTorch. It learns from whatever plaintext `.txt` files you place in `data/books/`. Over time, you can extend Noras codebase (e.g., add reinforcement-learning loops, self-improvement modules, etc.) toward more advanced AI, if you wish.
## Why “Nora”?
- A simple, humanlike female name.
- Short, easy to pronounce.
- As the project scales, “Nora” could theoretically be extended with modules to approach more general intelligence.
## Requirements
- **Python 3.10.6** (Windows 11 or any OS)
- **CUDA-capable GPU** (if you want to train faster; otherwise CPU)
- **PyTorch** (install with `pip install torch torchvision`)
- **tqdm** (`pip install tqdm`)
- **Other Python packages**: `numpy`, `typing`
## Folder Structure
- nora/
- ├── config.py
- ├── tokenizer.py
- ├── data_loader.py
- ├── model.py
- ├── train.py
- ├── utils.py
- ├── main.py
- └── README.md

146
config.py Normal file
View File

@ -0,0 +1,146 @@
"""
config.py
Define hyperparameters, file paths, and other settings via argparse.
"""
import argparse
import torch
def get_config():
parser = argparse.ArgumentParser(description="Nora: Train a Transformer from scratch")
# Data & paths
parser.add_argument(
"--data_dir",
type=str,
default="data/books",
help="Path to folder containing .txt files (public-domain books).",
)
parser.add_argument(
"--vocab_path",
type=str,
default="data/vocab.json",
help="Where to save/load the tokenizer vocabulary.",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="checkpoints",
help="Directory to save model checkpoints.",
)
# Model hyperparameters
parser.add_argument(
"--d_model",
type=int,
default=512,
help="Transformer embedding size (d_model).",
)
parser.add_argument(
"--nhead",
type=int,
default=8,
help="Number of attention heads.",
)
parser.add_argument(
"--num_layers",
type=int,
default=6,
help="Number of Transformer encoder layers.",
)
parser.add_argument(
"--dim_feedforward",
type=int,
default=2048,
help="Inner feedforward dimension.",
)
parser.add_argument(
"--dropout",
type=float,
default=0.1,
help="Dropout rate in Transformer.",
)
# Training hyperparameters
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size per training step.",
)
parser.add_argument(
"--seq_length",
type=int,
default=128,
help="Sequence length (context window) in tokens.",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="Number of training epochs.",
)
parser.add_argument(
"--lr",
type=float,
default=1e-4,
help="Learning rate.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=1000,
help="Linear learning rate warmup steps.",
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Gradient clipping norm.",
)
# Logging & checkpointing
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="Print training loss every N steps.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help="Save a checkpoint every N steps.",
)
# Device selection
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to train on ('cuda' or 'cpu').",
)
# Scaling options (for Pi vs GPU)
parser.add_argument(
"--tiny",
action="store_true",
help="If set, override model sizes to be tiny (for Pi 3B or very low-compute).",
)
args = parser.parse_args()
# If --tiny is set, override some hyperparameters to very small values:
if args.tiny:
args.d_model = 64
args.nhead = 2
args.num_layers = 2
args.dim_feedforward = 256
args.batch_size = 8
args.seq_length = 32
args.lr = 1e-3
args.epochs = 5
return args

63
data_loader.py Normal file
View File

@ -0,0 +1,63 @@
"""
data_loader.py
Loads all .txt files from data_dir, concatenates them, tokenizes them,
and creates a Dataset of (input_seq, target_seq) for language modeling.
"""
import os
import torch
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, data_dir: str, tokenizer, seq_length: int):
"""
- data_dir: folder of .txt public-domain books.
- tokenizer: instance of CharTokenizer (from tokenizer.py).
- seq_length: context length in tokens.
"""
super().__init__()
self.seq_length = seq_length
self.tokenizer = tokenizer
# Read and concatenate all text files into one long string
texts = []
for root, _, files in os.walk(data_dir):
for fname in files:
if not fname.lower().endswith(".txt"):
continue
path = os.path.join(root, fname)
with open(path, "r", encoding="utf-8", errors="ignore") as f:
texts.append(f.read())
full_text = "\n".join(texts)
token_ids = self.tokenizer.encode(full_text)
# Prepare input-target pairs
self.examples = []
stride = 32
for i in range(0, len(token_ids) - seq_length, stride):
inp = token_ids[i : i + seq_length]
targ = token_ids[i + 1 : i + seq_length + 1]
self.examples.append((inp, targ))
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
inp, targ = self.examples[idx]
return torch.tensor(inp, dtype=torch.long), torch.tensor(targ, dtype=torch.long)
def get_dataloader(
data_dir: str, tokenizer, seq_length: int, batch_size: int, shuffle: bool = True
) -> DataLoader:
dataset = TextDataset(data_dir, tokenizer, seq_length)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=True,
num_workers=8,
pin_memory=True,
)

89
main.py Normal file
View File

@ -0,0 +1,89 @@
"""
main.py
Orchestrates the entire Nora project:
- Parses arguments
- Builds or loads tokenizer
- Constructs dataset & dataloader
- Instantiates the model
- Sets up optimizer, scheduler
- Calls train()
"""
import os
import torch
import logging
from config import get_config
from tokenizer import CharTokenizer
from data_loader import get_dataloader
from model import NoraTransformerLM
from train import train
from utils import setup_logging, load_checkpoint, save_checkpoint
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
def main():
args = get_config()
# 1) Logging setup
log_file = os.path.join(args.checkpoint_dir, "train.log")
setup_logging(log_file)
logging.info(f"[main] Using device: {args.device}")
logging.info(f"[main] Config: {args}")
# 2) Tokenizer: if vocab exists, load; else build from data_dir
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
# 3) DataLoader
dataloader = get_dataloader(
data_dir=args.data_dir,
tokenizer=tokenizer,
seq_length=args.seq_length,
batch_size=args.batch_size,
shuffle=True,
)
# 4) Model instantiation
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=args.d_model,
nhead=args.nhead,
num_layers=args.num_layers,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
max_seq_len=args.seq_length,
)
# 5) Optimizer & scheduler (linear warmup + decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
def lr_lambda(current_step):
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
if current_step < args.warmup_steps:
return float(current_step) / float(max(1, args.warmup_steps))
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 6) Check for existing checkpoint to resume
start_step = 0
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
if ckpts:
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
start_step = load_checkpoint(latest_ckpt, model, optimizer)
# 7) Begin training
try:
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
except Exception as e:
logging.exception("[main] Exception during training")
raise e
if __name__ == "__main__":
main()

100
model.py Normal file
View File

@ -0,0 +1,100 @@
"""
model.py
Defines a Transformerbased language model from scratch, using PyTorchs nn.Transformer.
No pretrained weightseverything is initialized randomly.
"""
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10_000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # shape: (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch_size, seq_length, d_model)
returns x + positional encodings for the first seq_length positions.
"""
x = x + self.pe[:, : x.size(1), :]
return x
class NoraTransformerLM(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
max_seq_len: int = 512,
):
super().__init__()
self.model_type = "TransformerLM"
self.d_model = d_model
self.vocab_size = vocab_size
# Token embedding + positional encoding
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
# Transformer encoder layers
encoder_layers = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="gelu",
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layers, num_layers=num_layers
)
# Final linear layer to project to vocabulary
self.fc_out = nn.Linear(d_model, vocab_size)
# Initialization
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.token_embed.weight, mean=0, std=self.d_model ** -0.5)
nn.init.zeros_(self.fc_out.bias)
nn.init.normal_(self.fc_out.weight, mean=0, std=self.d_model ** -0.5)
def forward(self, src: torch.Tensor) -> torch.Tensor:
"""
src: (batch_size, seq_length), token IDs
returns: logits (batch_size, seq_length, vocab_size)
"""
# Embed tokens and add positional encoding
x = self.token_embed(src) * math.sqrt(self.d_model) # (B, S, D)
x = self.pos_encoder(x) # (B, S, D)
# PyTorch Transformer expects (S, B, D)
x = x.permute(1, 0, 2) # (seq_length, batch_size, d_model)
# Create a causal mask so each position can only attend to previous positions
seq_len = x.size(0)
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
# Pass through Transformer encoder
x = self.transformer_encoder(x, mask=mask) # (seq_length, batch_size, d_model)
# Back to (B, S, D)
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
return logits

86
tokenizer.py Normal file
View File

@ -0,0 +1,86 @@
"""
tokenizer.py
A simple characterlevel tokenizer that builds its own vocabulary from all text files.
Saves/loads vocab to/from JSON. You can extend this to a wordlevel tokenizer if you wish.
"""
import json
import os
from collections import Counter
from typing import List, Dict, Union
class CharTokenizer:
def __init__(self, vocab_path: str, data_dir: str):
"""
If vocab_path exists, load it; otherwise, build from raw text in data_dir.
"""
self.vocab_path = vocab_path
self.data_dir = data_dir
self.stoi: Dict[str, int] = {}
self.itos: Dict[int, str] = {}
if os.path.isfile(self.vocab_path):
self._load_vocab()
else:
self._build_vocab()
def _build_vocab(self):
"""
Read all .txt files under data_dir, count character frequencies,
build a sorted vocabulary, and save to vocab_path.
"""
counter = Counter()
print(f"[tokenizer] Building vocabulary from data in '{self.data_dir}'...")
for root, _, files in os.walk(self.data_dir):
for fname in files:
if not fname.lower().endswith(".txt"):
continue
path = os.path.join(root, fname)
with open(path, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
counter.update(text)
# Ensure a consistent ordering: sort by frequency descending, then Unicode codepoint
sorted_chars = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
unique_chars = [ch for ch, _ in sorted_chars]
# Add special tokens
tokens = ["<pad>", "<unk>"] + unique_chars
self.stoi = {ch: i for i, ch in enumerate(tokens)}
self.itos = {i: ch for i, ch in enumerate(tokens)}
# Save to JSON
os.makedirs(os.path.dirname(self.vocab_path), exist_ok=True)
with open(self.vocab_path, "w", encoding="utf-8") as f:
json.dump(self.stoi, f, ensure_ascii=False, indent=2)
print(f"[tokenizer] Built vocab size = {len(self.stoi)}; saved to '{self.vocab_path}'.")
def _load_vocab(self):
"""
Load existing vocabulary from vocab_path.
"""
print(f"[tokenizer] Loading vocabulary from '{self.vocab_path}'...")
with open(self.vocab_path, "r", encoding="utf-8") as f:
self.stoi = json.load(f)
self.itos = {i: ch for ch, i in self.stoi.items()}
print(f"[tokenizer] Loaded vocab size = {len(self.stoi)}.")
def encode(self, text: str) -> List[int]:
"""
Convert a string to a list of integer token IDs (characterlevel).
Unrecognized chars map to <unk>.
"""
unk_id = self.stoi.get("<unk>")
return [self.stoi.get(ch, unk_id) for ch in text]
def decode(self, token_ids: List[int]) -> str:
"""
Convert a list of token IDs back into a string.
"""
return "".join(self.itos.get(i, "<unk>") for i in token_ids)
def vocab_size(self) -> int:
return len(self.stoi)

135
train.py Normal file
View File

@ -0,0 +1,135 @@
"""
train.py
Training loop for Nora, with automatic mixed precision (AMP) to speed up on CUDA GPUs.
Uses tqdm for progress bars, logging for metrics, and gradient clipping + LR scheduler.
"""
import time
import logging
import math
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.amp import GradScaler, autocast # <-- updated import
def train(
model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
scheduler,
tokenizer,
config,
start_step: int = 0,
):
"""
model: NoraTransformerLM
dataloader: DataLoader for TextDataset
optimizer: AdamW (or Adam)
scheduler: LR scheduler with warmup
tokenizer: CharTokenizer
config: namespace from config.py
start_step: if resuming from checkpoint
"""
device = config.device
model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
scaler = GradScaler()
global_step = start_step
steps_per_epoch = len(dataloader)
total_steps = config.epochs * steps_per_epoch
logging.info(
f"[train] Starting training for {config.epochs} epochs, "
f"{steps_per_epoch} steps/epoch, total approx {total_steps} steps."
)
for epoch in range(config.epochs):
model.train()
epoch_loss = 0.0
epoch_start = time.time()
# If you want to profile the first 100 steps, uncomment below:
# if global_step == start_step:
# t0 = time.time()
pbar = tqdm(
enumerate(dataloader),
total=steps_per_epoch,
desc=f"Epoch {epoch + 1}",
ncols=100,
unit="step",
)
for step, (inputs, targets) in pbar:
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
# Mixed precision forward/backward (specify device_type="cuda")
with autocast(device_type="cuda", dtype=torch.float16):
logits = model(inputs) # (batch, seq_len, vocab_size)
loss = criterion(
logits.view(-1, tokenizer.vocab_size()),
targets.view(-1),
)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), config.max_grad_norm)
scaler.step(optimizer)
scaler.update()
scheduler.step()
epoch_loss += loss.item()
global_step += 1
# Log every log_interval steps
if global_step % config.log_interval == 0:
avg_loss = epoch_loss / (step + 1)
ppl = math.exp(avg_loss)
logging.info(
f"[step {global_step}/{total_steps}] "
f"avg_loss = {avg_loss:.4f}, ppl = {ppl:.2f}, "
f"lr = {scheduler.get_last_lr()[0]:.2e}"
)
# Save checkpoint every save_interval steps
if global_step % config.save_interval == 0:
from utils import save_checkpoint
save_checkpoint(
model,
optimizer,
global_step,
config.checkpoint_dir,
tokenizer=tokenizer,
)
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
# (Optional) Profile first 100 steps
# if global_step == start_step + 100:
# elapsed = time.time() - t0
# print(
# f"[profile] avg time/step over first 100 batches: "
# f"{elapsed/100:.4f} s"
# )
epoch_time = time.time() - epoch_start
avg_epoch_loss = epoch_loss / steps_per_epoch
logging.info(
f"[epoch {epoch + 1}/{config.epochs}] "
f"avg_loss = {avg_epoch_loss:.4f}, time = {epoch_time:.1f}s"
)
# Final checkpoint at end of all epochs
from utils import save_checkpoint
save_checkpoint(model, optimizer, global_step, config.checkpoint_dir, tokenizer=tokenizer)
logging.info("[train] Training complete.")

84
utils.py Normal file
View File

@ -0,0 +1,84 @@
"""
utils.py
Common utilities: logging setup, checkpoint saving & loading, device checks, etc.
"""
import os
import logging
import torch
def setup_logging(log_file: str = None):
"""
Set up logging to stdout (and optionally to a file).
"""
root = logging.getLogger()
root.setLevel(logging.INFO)
formatter = logging.Formatter(
"[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Console handler
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)
# File handler
if log_file:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
root.addHandler(fh)
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
step: int,
checkpoint_dir: str,
tokenizer=None,
):
"""
Save model state, optimizer state, and tokenizer (optional) to a checkpoint file.
"""
os.makedirs(checkpoint_dir, exist_ok=True)
ckpt_path = os.path.join(checkpoint_dir, f"nora_step_{step}.pt")
state = {
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
if tokenizer:
# tokenizer.stoi is JSONserializable
state["tokenizer_stoi"] = tokenizer.stoi
torch.save(state, ckpt_path)
logging.info(f"[checkpoint] Saved checkpoint to {ckpt_path}")
def load_checkpoint(
ckpt_path: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer = None
):
"""
Load model & optimizer state from a checkpoint. Returns step.
If optimizer is None, only loads model weights.
"""
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
state = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state["model_state_dict"])
step = state.get("step", 0)
if optimizer and "optimizer_state_dict" in state:
optimizer.load_state_dict(state["optimizer_state_dict"])
logging.info(f"[checkpoint] Loaded checkpoint from {ckpt_path} (step {step})")
return step
def get_default_device():
"""
Return 'cuda' if available; otherwise 'cpu'.
"""
return "cuda" if torch.cuda.is_available() else "cpu"