This commit introduces a new script that implements a Laplace-smoothed bigram language model for computing validation perplexity. The implementation includes: - Data loading and splitting functionality (90/10 train/validation split) - Character vocabulary building from training data only - Bigram counting and Laplace smoothing with alpha=1.0 - Negative log-likelihood and perplexity computation - Proper handling of out-of-vocabulary characters during evaluation The script can process existing train.txt/val.txt files or automatically split a data.txt file if the required input files are missing, making it self-contained and easy to use for language model evaluation tasks.
131 lines
3.5 KiB
Python
131 lines
3.5 KiB
Python
# 09_perplexity.py
|
|
"""
|
|
Compute validation perplexity for a Laplace-smoothed bigram model (Numpy).
|
|
|
|
What it does:
|
|
- Loads train/val text (or splits data.txt 90/10 if missing).
|
|
- Builds char vocab on train ONLY.
|
|
- Trains a Laplace-smoothed bigram model (alpha=1.0).
|
|
- Evaluates NLL and Perplexity on val.
|
|
|
|
How to run:
|
|
python 09_perplexity.py
|
|
# doctests (optional)
|
|
python -m doctest -v 09_perplexity.py
|
|
|
|
Notes:
|
|
- If val has many OOV characters (not seen in train), those positions are skipped.
|
|
- Perplexity = exp(mean NLL), where NLL is computed over observed bigrams.
|
|
"""
|
|
from __future__ import annotations
|
|
from pathlib import Path
|
|
import numpy as np
|
|
|
|
FALLBACK = "aabbaabb\n"
|
|
|
|
# ---------- IO ----------
|
|
|
|
|
|
def load_or_split():
|
|
"""Return (train_text, val_text), normalizing newlines to '\\n'.
|
|
|
|
If train/val files exist, read them; else split data.txt 90/10.
|
|
|
|
>>> tr, va = load_or_split()
|
|
>>> isinstance(tr, str) and isinstance(va, str)
|
|
True
|
|
"""
|
|
tr_p, va_p = Path("train.txt"), Path("val.txt")
|
|
if tr_p.exists() and va_p.exists():
|
|
tr = tr_p.read_text(encoding="utf-8")
|
|
va = va_p.read_text(encoding="utf-8")
|
|
else:
|
|
base = Path("data.txt")
|
|
txt = base.read_text(encoding="utf-8") if base.exists() else FALLBACK
|
|
cut = int(0.9 * len(txt))
|
|
tr, va = txt[:cut], txt[cut:]
|
|
# normalize
|
|
to_lf = lambda s: s.replace("\r\n", "\n").replace("\r", "\n")
|
|
return to_lf(tr), to_lf(va)
|
|
|
|
|
|
# ---------- vocab ----------
|
|
|
|
def build_vocab(text: str):
|
|
"""Build vocab from train text: returns (chars, stoi, itos)."""
|
|
chars = sorted(set(text))
|
|
stoi = {c: i for i, c in enumerate(chars)}
|
|
itos = {i: c for c, i in stoi.items()}
|
|
return chars, stoi, itos
|
|
|
|
|
|
# ---------- model ----------
|
|
|
|
def bigram_counts(text: str, stoi: dict[str, int]) -> np.ndarray:
|
|
"""Return VxV bigram counts for text restricted to train vocab.
|
|
|
|
>>> import numpy as _np
|
|
>>> stoi = {'a':0, 'b':1}
|
|
>>> M = bigram_counts("abba", stoi)
|
|
>>> int(M[0,1]), int(M[1,1])
|
|
(1, 1)
|
|
"""
|
|
ids = [stoi[c] for c in text if c in stoi]
|
|
V = len(stoi)
|
|
mat = np.zeros((V, V), dtype=np.int64)
|
|
for a, b in zip(ids[:-1], ids[1:]):
|
|
mat[a, b] += 1
|
|
return mat
|
|
|
|
|
|
def laplace(mat: np.ndarray, alpha: float = 1.0) -> np.ndarray:
|
|
"""Row-wise Laplace smoothing; returns probabilites.
|
|
|
|
>>> import numpy as _np
|
|
>>> P = laplace(_np.array([[0,1],[0,0]]), alpha=1.0)
|
|
>>> _np.allclose(P.sum(axis=1), 1.0)
|
|
True
|
|
"""
|
|
P = mat.astype(np.float64) + alpha
|
|
P /= P.sum(axis=1, keepdims=True)
|
|
return P
|
|
|
|
|
|
# ---------- evaluation ----------
|
|
|
|
def nll_on_text(text: str, stoi: dict[str, int], P: np.ndarray) -> float:
|
|
"""Average negative log-likelihood over bigrams present in vocab.
|
|
|
|
>>> import numpy as _np
|
|
>>> stoi={'a':0, 'b':1}; P=_np.array([[0.5, 0.5], [0.5, 0.5]])
|
|
>>> round(nll_on_text("ab", stoi, P), 5)
|
|
0.69315
|
|
"""
|
|
ids = [stoi[c] for c in text if c in stoi]
|
|
logs = []
|
|
for a, b in zip(ids[:-1], ids[1:]):
|
|
p = float(P[a, b])
|
|
logs.append(-np.log(max(1e-12, p)))
|
|
return float(np.mean(logs)) if logs else float("inf")
|
|
|
|
|
|
def perplexity(nll: float) -> float:
|
|
"""Perplexity = exp(NLL)."""
|
|
return float(np.exp(nll))
|
|
|
|
|
|
def main() -> None:
|
|
train, val = load_or_split()
|
|
chars, stoi, itos = build_vocab(train)
|
|
M = bigram_counts(train, stoi)
|
|
P = laplace(M, alpha=1.0)
|
|
|
|
nll = nll_on_text(val, stoi, P)
|
|
ppl = perplexity(nll)
|
|
print(f"Validation NLL: {nll:.4f}")
|
|
print(f"Perplexity: {ppl:.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|