feat: add Laplace-smoothed bigram model perplexity computation script

This commit introduces a new script that implements a Laplace-smoothed bigram language model for computing validation perplexity. The implementation includes:
- Data loading and splitting functionality (90/10 train/validation split)
- Character vocabulary building from training data only
- Bigram counting and Laplace smoothing with alpha=1.0
- Negative log-likelihood and perplexity computation
- Proper handling of out-of-vocabulary characters during evaluation

The script can process existing train.txt/val.txt files or automatically split a data.txt file if the required input files are missing, making it self-contained and easy to use for language model evaluation tasks.
This commit is contained in:
2025-09-24 00:33:23 -04:00
parent 119bf8e40c
commit 674c53651c

130
09_perplexity.py Normal file
View File

@@ -0,0 +1,130 @@
# 09_perplexity.py
"""
Compute validation perplexity for a Laplace-smoothed bigram model (Numpy).
What it does:
- Loads train/val text (or splits data.txt 90/10 if missing).
- Builds char vocab on train ONLY.
- Trains a Laplace-smoothed bigram model (alpha=1.0).
- Evaluates NLL and Perplexity on val.
How to run:
python 09_perplexity.py
# doctests (optional)
python -m doctest -v 09_perplexity.py
Notes:
- If val has many OOV characters (not seen in train), those positions are skipped.
- Perplexity = exp(mean NLL), where NLL is computed over observed bigrams.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
FALLBACK = "aabbaabb\n"
# ---------- IO ----------
def load_or_split():
"""Return (train_text, val_text), normalizing newlines to '\\n'.
If train/val files exist, read them; else split data.txt 90/10.
>>> tr, va = load_or_split()
>>> isinstance(tr, str) and isinstance(va, str)
True
"""
tr_p, va_p = Path("train.txt"), Path("val.txt")
if tr_p.exists() and va_p.exists():
tr = tr_p.read_text(encoding="utf-8")
va = va_p.read_text(encoding="utf-8")
else:
base = Path("data.txt")
txt = base.read_text(encoding="utf-8") if base.exists() else FALLBACK
cut = int(0.9 * len(txt))
tr, va = txt[:cut], txt[cut:]
# normalize
to_lf = lambda s: s.replace("\r\n", "\n").replace("\r", "\n")
return to_lf(tr), to_lf(va)
# ---------- vocab ----------
def build_vocab(text: str):
"""Build vocab from train text: returns (chars, stoi, itos)."""
chars = sorted(set(text))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
return chars, stoi, itos
# ---------- model ----------
def bigram_counts(text: str, stoi: dict[str, int]) -> np.ndarray:
"""Return VxV bigram counts for text restricted to train vocab.
>>> import numpy as _np
>>> stoi = {'a':0, 'b':1}
>>> M = bigram_counts("abba", stoi)
>>> int(M[0,1]), int(M[1,1])
(1, 1)
"""
ids = [stoi[c] for c in text if c in stoi]
V = len(stoi)
mat = np.zeros((V, V), dtype=np.int64)
for a, b in zip(ids[:-1], ids[1:]):
mat[a, b] += 1
return mat
def laplace(mat: np.ndarray, alpha: float = 1.0) -> np.ndarray:
"""Row-wise Laplace smoothing; returns probabilites.
>>> import numpy as _np
>>> P = laplace(_np.array([[0,1],[0,0]]), alpha=1.0)
>>> _np.allclose(P.sum(axis=1), 1.0)
True
"""
P = mat.astype(np.float64) + alpha
P /= P.sum(axis=1, keepdims=True)
return P
# ---------- evaluation ----------
def nll_on_text(text: str, stoi: dict[str, int], P: np.ndarray) -> float:
"""Average negative log-likelihood over bigrams present in vocab.
>>> import numpy as _np
>>> stoi={'a':0, 'b':1}; P=_np.array([[0.5, 0.5], [0.5, 0.5]])
>>> round(nll_on_text("ab", stoi, P), 5)
0.69315
"""
ids = [stoi[c] for c in text if c in stoi]
logs = []
for a, b in zip(ids[:-1], ids[1:]):
p = float(P[a, b])
logs.append(-np.log(max(1e-12, p)))
return float(np.mean(logs)) if logs else float("inf")
def perplexity(nll: float) -> float:
"""Perplexity = exp(NLL)."""
return float(np.exp(nll))
def main() -> None:
train, val = load_or_split()
chars, stoi, itos = build_vocab(train)
M = bigram_counts(train, stoi)
P = laplace(M, alpha=1.0)
nll = nll_on_text(val, stoi, P)
ppl = perplexity(nll)
print(f"Validation NLL: {nll:.4f}")
print(f"Perplexity: {ppl:.4f}")
if __name__ == "__main__":
main()