Nora/utils.py

85 lines
2.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
utils.py
Common utilities: logging setup, checkpoint saving & loading, device checks, etc.
"""
import os
import logging
import torch
def setup_logging(log_file: str = None):
"""
Set up logging to stdout (and optionally to a file).
"""
root = logging.getLogger()
root.setLevel(logging.INFO)
formatter = logging.Formatter(
"[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Console handler
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)
# File handler
if log_file:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
root.addHandler(fh)
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
step: int,
checkpoint_dir: str,
tokenizer=None,
):
"""
Save model state, optimizer state, and tokenizer (optional) to a checkpoint file.
"""
os.makedirs(checkpoint_dir, exist_ok=True)
ckpt_path = os.path.join(checkpoint_dir, f"nora_step_{step}.pt")
state = {
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
if tokenizer:
# tokenizer.stoi is JSONserializable
state["tokenizer_stoi"] = tokenizer.stoi
torch.save(state, ckpt_path)
logging.info(f"[checkpoint] Saved checkpoint to {ckpt_path}")
def load_checkpoint(
ckpt_path: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer = None
):
"""
Load model & optimizer state from a checkpoint. Returns step.
If optimizer is None, only loads model weights.
"""
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
state = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state["model_state_dict"])
step = state.get("step", 0)
if optimizer and "optimizer_state_dict" in state:
optimizer.load_state_dict(state["optimizer_state_dict"])
logging.info(f"[checkpoint] Loaded checkpoint from {ckpt_path} (step {step})")
return step
def get_default_device():
"""
Return 'cuda' if available; otherwise 'cpu'.
"""
return "cuda" if torch.cuda.is_available() else "cpu"