174 lines
4.7 KiB
Python
174 lines
4.7 KiB
Python
# 07_laplace_smoothing.py
|
|
"""
|
|
Laplace (add-one) smoothing for a bigram counts model; compare with unsmoothed.
|
|
|
|
What it does:
|
|
- Loads UTF-8 text from train.txt (or data.txt; tiny fallback if misisng).
|
|
- Builds a VxV bigram counts matrix.
|
|
- Derives:
|
|
* Unsmoothed row-wise probabilites (zeros allowed).
|
|
* Smoothed probabilites with Laplace alpha=1.0 (no zeros).
|
|
- Samples text from both models to visualize the difference.
|
|
|
|
How to run:
|
|
python 07_laplace_smoothing.py
|
|
# optional doctests:
|
|
python -m doctest -v 07_laplace_smoothing.py
|
|
|
|
Notes:
|
|
- Smoothing removes zero-probability transitions, helping avoid "dead rows".
|
|
- We'll add temperature/top-k in the next lesson.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
import random
|
|
import math
|
|
|
|
FALLBACK = "aaab\n"
|
|
|
|
# ---------- IO & Vocab ----------
|
|
|
|
|
|
def load_text() -> str:
|
|
"""Loads and normalize text from train.txt or data.txt, else fallback.
|
|
|
|
Returns:
|
|
Text with only LF ('\\n') newlines.
|
|
|
|
>>> isinstance(load_text(), str)
|
|
True
|
|
"""
|
|
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
|
|
t = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
|
|
return t.replace("\r\n", "\n").replace("\r", "\n")
|
|
|
|
|
|
def build_vocab(text: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]:
|
|
"""Return (sorted chars, stoi, itos).
|
|
|
|
>>> chars, stoi, itos = build_vocab("ba\\n")
|
|
>>> chars == ['\\n', 'a', 'b']
|
|
True
|
|
"""
|
|
chars = sorted(set(text))
|
|
stoi = {c: i for i, c in enumerate(chars)}
|
|
itos = {i: c for c, i in stoi.items()}
|
|
return chars, stoi, itos
|
|
|
|
# ---------- counts -> probabilities ----------
|
|
|
|
|
|
def bigram_counts(text: str, stoi: Dict[str, int]) -> List[List[int]]:
|
|
"""Build VxV counts matrix for bigrams.
|
|
|
|
>>> M = bigram_counts("aba", {'a':0, 'b':1})
|
|
>>> M[0][1], M[1][0] # a->b and b->a each once
|
|
(1, 1)
|
|
"""
|
|
V = len(stoi)
|
|
M: List[List[int]] = [[0 for _ in range(V)] for _ in range(V)]
|
|
ids = [stoi[c] for c in text if c in stoi]
|
|
for a, b in zip(ids[:-1], ids[1:]):
|
|
M[a][b] += 1
|
|
return M
|
|
|
|
|
|
def probs_unsmoothed(M: List[List[int]]) -> List[List[float]]:
|
|
"""Row-normalize counts without smoothing (zeros allowed).
|
|
|
|
>>> P = probs_unsmoothed([[0,3],[0,0]])
|
|
>>> P[0] == [0.0, 1.0] and P[1] == [0.0, 0.0]
|
|
True
|
|
"""
|
|
P: List[List[float]] = []
|
|
for row in M:
|
|
s = sum(row)
|
|
if s == 0:
|
|
P.append([0.0 for _ in row])
|
|
else:
|
|
P.append([c / s for c in row])
|
|
return P
|
|
|
|
|
|
def probs_laplace(M: List[List[int]], alpha: float = 1.0) -> List[List[float]]:
|
|
"""Row-wise Laplace smoothing (add-alpha) -> probabilities.
|
|
|
|
Each row i:
|
|
p_ij = (c_ij + alpha) / (sum_j c_ij + alpha * V)
|
|
|
|
>>> P = probs_laplace([[0, 3],[0, 0]], alpha=1.0)
|
|
>>> abs(sum(P[0]) - 1.0) < 1e-9 and abs(sum(P[1]) - 1.0) < 1e-9
|
|
True
|
|
>>> all(p > 0.0 for p in P[1]) # previously all-zero row now has mass
|
|
True
|
|
"""
|
|
V = len(M[0]) if M else 0
|
|
out: List[List[float]] = []
|
|
for row in M:
|
|
s = sum(row)
|
|
denom = s + alpha * V
|
|
out.append([(c + alpha) / denom for c in row])
|
|
return out
|
|
|
|
# ---------- sampling ----------
|
|
|
|
|
|
def _categorical_sample(probs: List[float], rng: random.Random) -> int:
|
|
"""Sample an index from a probability vector.
|
|
|
|
>>> rng = random.Random(0)
|
|
>>> _categorical_sample([0.0, 1.0, 0.0], rng)
|
|
1
|
|
"""
|
|
r = rng.random()
|
|
acc = 0.0
|
|
for i, p in enumerate(probs):
|
|
acc += p
|
|
if r <= acc:
|
|
return i
|
|
return max(range(len(probs)), key=lambda i: probs[i])
|
|
|
|
|
|
def sample_text(P: List[List[float]], itos: Dict[int, str], length: int = 300, seed: int = 123) -> str:
|
|
"""Generate text by walking the bigram model (row-stochastic P).
|
|
|
|
>>> chars, stoi, itos = build_vocab("aba")
|
|
>>> M = bigram_counts("aba", stoi)
|
|
>>> P = probs_laplace(M, alpha=1.0)
|
|
>>> s = sample_text(P, itos, length=5, seed=0)
|
|
>>> len(s) == 5
|
|
True
|
|
"""
|
|
rng = random.Random(seed)
|
|
V = len(itos)
|
|
cur = rng.randrange(V)
|
|
out = [itos[cur]]
|
|
for _ in range(length - 1):
|
|
row = P[cur]
|
|
cur = _categorical_sample(row, rng)
|
|
out.append(itos[cur])
|
|
return "".join(out)
|
|
|
|
# ---------- main ----------
|
|
|
|
|
|
def main() -> None:
|
|
text = load_text()
|
|
chars, stoi, itos = build_vocab(text)
|
|
M = bigram_counts(text, stoi)
|
|
|
|
P_unsm = probs_unsmoothed(M)
|
|
P_lap = probs_laplace(M, alpha=1.0)
|
|
|
|
print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}")
|
|
print("\n=== Unsmoothed sample ===")
|
|
print(sample_text(P_unsm, itos, length=400, seed=123))
|
|
print("\n=== Laplace-smoothed (alpha=1.0) sample ===")
|
|
print(sample_text(P_lap, itos, length=400, seed=123))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|