add laplace smoothing implementation for bigram model with doctests and text sampling

This commit is contained in:
2025-09-23 22:36:57 -04:00
parent bcde2ad478
commit 46e6eb557f

173
07_laplace_smoothing.py Normal file
View File

@@ -0,0 +1,173 @@
# 07_laplace_smoothing.py
"""
Laplace (add-one) smoothing for a bigram counts model; compare with unsmoothed.
What it does:
- Loads UTF-8 text from train.txt (or data.txt; tiny fallback if misisng).
- Builds a VxV bigram counts matrix.
- Derives:
* Unsmoothed row-wise probabilites (zeros allowed).
* Smoothed probabilites with Laplace alpha=1.0 (no zeros).
- Samples text from both models to visualize the difference.
How to run:
python 07_laplace_smoothing.py
# optional doctests:
python -m doctest -v 07_laplace_smoothing.py
Notes:
- Smoothing removes zero-probability transitions, helping avoid "dead rows".
- We'll add temperature/top-k in the next lesson.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Tuple
import random
import math
FALLBACK = "aaab\n"
# ---------- IO & Vocab ----------
def load_text() -> str:
"""Loads and normalize text from train.txt or data.txt, else fallback.
Returns:
Text with only LF ('\\n') newlines.
>>> isinstance(load_text(), str)
True
"""
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
t = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
return t.replace("\r\n", "\n").replace("\r", "\n")
def build_vocab(text: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]:
"""Return (sorted chars, stoi, itos).
>>> chars, stoi, itos = build_vocab("ba\\n")
>>> chars == ['\\n', 'a', 'b']
True
"""
chars = sorted(set(text))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
return chars, stoi, itos
# ---------- counts -> probabilities ----------
def bigram_counts(text: str, stoi: Dict[str, int]) -> List[List[int]]:
"""Build VxV counts matrix for bigrams.
>>> M = bigram_counts("aba", {'a':0, 'b':1})
>>> M[0][1], M[1][0] # a->b and b->a each once
(1, 1)
"""
V = len(stoi)
M: List[List[int]] = [[0 for _ in range(V)] for _ in range(V)]
ids = [stoi[c] for c in text if c in stoi]
for a, b in zip(ids[:-1], ids[1:]):
M[a][b] += 1
return M
def probs_unsmoothed(M: List[List[int]]) -> List[List[float]]:
"""Row-normalize counts without smoothing (zeros allowed).
>>> P = probs_unsmoothed([[0,3],[0,0]])
>>> P[0] == [0.0, 1.0] and P[1] == [0.0, 0.0]
True
"""
P: List[List[float]] = []
for row in M:
s = sum(row)
if s == 0:
P.append([0.0 for _ in row])
else:
P.append([c / s for c in row])
return P
def probs_laplace(M: List[List[int]], alpha: float = 1.0) -> List[List[float]]:
"""Row-wise Laplace smoothing (add-alpha) -> probabilities.
Each row i:
p_ij = (c_ij + alpha) / (sum_j c_ij + alpha * V)
>>> P = probs_laplace([[0, 3],[0, 0]], alpha=1.0)
>>> abs(sum(P[0]) - 1.0) < 1e-9 and abs(sum(P[1]) - 1.0) < 1e-9
True
>>> all(p > 0.0 for p in P[1]) # previously all-zero row now has mass
True
"""
V = len(M[0]) if M else 0
out: List[List[float]] = []
for row in M:
s = sum(row)
denom = s + alpha * V
out.append([(c + alpha) / denom for c in row])
return out
# ---------- sampling ----------
def _categorical_sample(probs: List[float], rng: random.Random) -> int:
"""Sample an index from a probability vector.
>>> rng = random.Random(0)
>>> _categorical_sample([0.0, 1.0, 0.0], rng)
1
"""
r = rng.random()
acc = 0.0
for i, p in enumerate(probs):
acc += p
if r <= acc:
return i
return max(range(len(probs)), key=lambda i: probs[i])
def sample_text(P: List[List[float]], itos: Dict[int, str], length: int = 300, seed: int = 123) -> str:
"""Generate text by walking the bigram model (row-stochastic P).
>>> chars, stoi, itos = build_vocab("aba")
>>> M = bigram_counts("aba", stoi)
>>> P = probs_laplace(M, alpha=1.0)
>>> s = sample_text(P, itos, length=5, seed=0)
>>> len(s) == 5
True
"""
rng = random.Random(seed)
V = len(itos)
cur = rng.randrange(V)
out = [itos[cur]]
for _ in range(length - 1):
row = P[cur]
cur = _categorical_sample(row, rng)
out.append(itos[cur])
return "".join(out)
# ---------- main ----------
def main() -> None:
text = load_text()
chars, stoi, itos = build_vocab(text)
M = bigram_counts(text, stoi)
P_unsm = probs_unsmoothed(M)
P_lap = probs_laplace(M, alpha=1.0)
print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}")
print("\n=== Unsmoothed sample ===")
print(sample_text(P_unsm, itos, length=400, seed=123))
print("\n=== Laplace-smoothed (alpha=1.0) sample ===")
print(sample_text(P_lap, itos, length=400, seed=123))
if __name__ == "__main__":
main()