add laplace smoothing implementation for bigram model with doctests and text sampling
This commit is contained in:
173
07_laplace_smoothing.py
Normal file
173
07_laplace_smoothing.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# 07_laplace_smoothing.py
|
||||
"""
|
||||
Laplace (add-one) smoothing for a bigram counts model; compare with unsmoothed.
|
||||
|
||||
What it does:
|
||||
- Loads UTF-8 text from train.txt (or data.txt; tiny fallback if misisng).
|
||||
- Builds a VxV bigram counts matrix.
|
||||
- Derives:
|
||||
* Unsmoothed row-wise probabilites (zeros allowed).
|
||||
* Smoothed probabilites with Laplace alpha=1.0 (no zeros).
|
||||
- Samples text from both models to visualize the difference.
|
||||
|
||||
How to run:
|
||||
python 07_laplace_smoothing.py
|
||||
# optional doctests:
|
||||
python -m doctest -v 07_laplace_smoothing.py
|
||||
|
||||
Notes:
|
||||
- Smoothing removes zero-probability transitions, helping avoid "dead rows".
|
||||
- We'll add temperature/top-k in the next lesson.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
import random
|
||||
import math
|
||||
|
||||
FALLBACK = "aaab\n"
|
||||
|
||||
# ---------- IO & Vocab ----------
|
||||
|
||||
|
||||
def load_text() -> str:
|
||||
"""Loads and normalize text from train.txt or data.txt, else fallback.
|
||||
|
||||
Returns:
|
||||
Text with only LF ('\\n') newlines.
|
||||
|
||||
>>> isinstance(load_text(), str)
|
||||
True
|
||||
"""
|
||||
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
|
||||
t = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
|
||||
return t.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
|
||||
def build_vocab(text: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]:
|
||||
"""Return (sorted chars, stoi, itos).
|
||||
|
||||
>>> chars, stoi, itos = build_vocab("ba\\n")
|
||||
>>> chars == ['\\n', 'a', 'b']
|
||||
True
|
||||
"""
|
||||
chars = sorted(set(text))
|
||||
stoi = {c: i for i, c in enumerate(chars)}
|
||||
itos = {i: c for c, i in stoi.items()}
|
||||
return chars, stoi, itos
|
||||
|
||||
# ---------- counts -> probabilities ----------
|
||||
|
||||
|
||||
def bigram_counts(text: str, stoi: Dict[str, int]) -> List[List[int]]:
|
||||
"""Build VxV counts matrix for bigrams.
|
||||
|
||||
>>> M = bigram_counts("aba", {'a':0, 'b':1})
|
||||
>>> M[0][1], M[1][0] # a->b and b->a each once
|
||||
(1, 1)
|
||||
"""
|
||||
V = len(stoi)
|
||||
M: List[List[int]] = [[0 for _ in range(V)] for _ in range(V)]
|
||||
ids = [stoi[c] for c in text if c in stoi]
|
||||
for a, b in zip(ids[:-1], ids[1:]):
|
||||
M[a][b] += 1
|
||||
return M
|
||||
|
||||
|
||||
def probs_unsmoothed(M: List[List[int]]) -> List[List[float]]:
|
||||
"""Row-normalize counts without smoothing (zeros allowed).
|
||||
|
||||
>>> P = probs_unsmoothed([[0,3],[0,0]])
|
||||
>>> P[0] == [0.0, 1.0] and P[1] == [0.0, 0.0]
|
||||
True
|
||||
"""
|
||||
P: List[List[float]] = []
|
||||
for row in M:
|
||||
s = sum(row)
|
||||
if s == 0:
|
||||
P.append([0.0 for _ in row])
|
||||
else:
|
||||
P.append([c / s for c in row])
|
||||
return P
|
||||
|
||||
|
||||
def probs_laplace(M: List[List[int]], alpha: float = 1.0) -> List[List[float]]:
|
||||
"""Row-wise Laplace smoothing (add-alpha) -> probabilities.
|
||||
|
||||
Each row i:
|
||||
p_ij = (c_ij + alpha) / (sum_j c_ij + alpha * V)
|
||||
|
||||
>>> P = probs_laplace([[0, 3],[0, 0]], alpha=1.0)
|
||||
>>> abs(sum(P[0]) - 1.0) < 1e-9 and abs(sum(P[1]) - 1.0) < 1e-9
|
||||
True
|
||||
>>> all(p > 0.0 for p in P[1]) # previously all-zero row now has mass
|
||||
True
|
||||
"""
|
||||
V = len(M[0]) if M else 0
|
||||
out: List[List[float]] = []
|
||||
for row in M:
|
||||
s = sum(row)
|
||||
denom = s + alpha * V
|
||||
out.append([(c + alpha) / denom for c in row])
|
||||
return out
|
||||
|
||||
# ---------- sampling ----------
|
||||
|
||||
|
||||
def _categorical_sample(probs: List[float], rng: random.Random) -> int:
|
||||
"""Sample an index from a probability vector.
|
||||
|
||||
>>> rng = random.Random(0)
|
||||
>>> _categorical_sample([0.0, 1.0, 0.0], rng)
|
||||
1
|
||||
"""
|
||||
r = rng.random()
|
||||
acc = 0.0
|
||||
for i, p in enumerate(probs):
|
||||
acc += p
|
||||
if r <= acc:
|
||||
return i
|
||||
return max(range(len(probs)), key=lambda i: probs[i])
|
||||
|
||||
|
||||
def sample_text(P: List[List[float]], itos: Dict[int, str], length: int = 300, seed: int = 123) -> str:
|
||||
"""Generate text by walking the bigram model (row-stochastic P).
|
||||
|
||||
>>> chars, stoi, itos = build_vocab("aba")
|
||||
>>> M = bigram_counts("aba", stoi)
|
||||
>>> P = probs_laplace(M, alpha=1.0)
|
||||
>>> s = sample_text(P, itos, length=5, seed=0)
|
||||
>>> len(s) == 5
|
||||
True
|
||||
"""
|
||||
rng = random.Random(seed)
|
||||
V = len(itos)
|
||||
cur = rng.randrange(V)
|
||||
out = [itos[cur]]
|
||||
for _ in range(length - 1):
|
||||
row = P[cur]
|
||||
cur = _categorical_sample(row, rng)
|
||||
out.append(itos[cur])
|
||||
return "".join(out)
|
||||
|
||||
# ---------- main ----------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
text = load_text()
|
||||
chars, stoi, itos = build_vocab(text)
|
||||
M = bigram_counts(text, stoi)
|
||||
|
||||
P_unsm = probs_unsmoothed(M)
|
||||
P_lap = probs_laplace(M, alpha=1.0)
|
||||
|
||||
print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}")
|
||||
print("\n=== Unsmoothed sample ===")
|
||||
print(sample_text(P_unsm, itos, length=400, seed=123))
|
||||
print("\n=== Laplace-smoothed (alpha=1.0) sample ===")
|
||||
print(sample_text(P_lap, itos, length=400, seed=123))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user