# 07_laplace_smoothing.py """ Laplace (add-one) smoothing for a bigram counts model; compare with unsmoothed. What it does: - Loads UTF-8 text from train.txt (or data.txt; tiny fallback if misisng). - Builds a VxV bigram counts matrix. - Derives: * Unsmoothed row-wise probabilites (zeros allowed). * Smoothed probabilites with Laplace alpha=1.0 (no zeros). - Samples text from both models to visualize the difference. How to run: python 07_laplace_smoothing.py # optional doctests: python -m doctest -v 07_laplace_smoothing.py Notes: - Smoothing removes zero-probability transitions, helping avoid "dead rows". - We'll add temperature/top-k in the next lesson. """ from __future__ import annotations from pathlib import Path from typing import Dict, List, Tuple import random import math FALLBACK = "aaab\n" # ---------- IO & Vocab ---------- def load_text() -> str: """Loads and normalize text from train.txt or data.txt, else fallback. Returns: Text with only LF ('\\n') newlines. >>> isinstance(load_text(), str) True """ p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt") t = p.read_text(encoding="utf-8") if p.exists() else FALLBACK return t.replace("\r\n", "\n").replace("\r", "\n") def build_vocab(text: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]: """Return (sorted chars, stoi, itos). >>> chars, stoi, itos = build_vocab("ba\\n") >>> chars == ['\\n', 'a', 'b'] True """ chars = sorted(set(text)) stoi = {c: i for i, c in enumerate(chars)} itos = {i: c for c, i in stoi.items()} return chars, stoi, itos # ---------- counts -> probabilities ---------- def bigram_counts(text: str, stoi: Dict[str, int]) -> List[List[int]]: """Build VxV counts matrix for bigrams. >>> M = bigram_counts("aba", {'a':0, 'b':1}) >>> M[0][1], M[1][0] # a->b and b->a each once (1, 1) """ V = len(stoi) M: List[List[int]] = [[0 for _ in range(V)] for _ in range(V)] ids = [stoi[c] for c in text if c in stoi] for a, b in zip(ids[:-1], ids[1:]): M[a][b] += 1 return M def probs_unsmoothed(M: List[List[int]]) -> List[List[float]]: """Row-normalize counts without smoothing (zeros allowed). >>> P = probs_unsmoothed([[0,3],[0,0]]) >>> P[0] == [0.0, 1.0] and P[1] == [0.0, 0.0] True """ P: List[List[float]] = [] for row in M: s = sum(row) if s == 0: P.append([0.0 for _ in row]) else: P.append([c / s for c in row]) return P def probs_laplace(M: List[List[int]], alpha: float = 1.0) -> List[List[float]]: """Row-wise Laplace smoothing (add-alpha) -> probabilities. Each row i: p_ij = (c_ij + alpha) / (sum_j c_ij + alpha * V) >>> P = probs_laplace([[0, 3],[0, 0]], alpha=1.0) >>> abs(sum(P[0]) - 1.0) < 1e-9 and abs(sum(P[1]) - 1.0) < 1e-9 True >>> all(p > 0.0 for p in P[1]) # previously all-zero row now has mass True """ V = len(M[0]) if M else 0 out: List[List[float]] = [] for row in M: s = sum(row) denom = s + alpha * V out.append([(c + alpha) / denom for c in row]) return out # ---------- sampling ---------- def _categorical_sample(probs: List[float], rng: random.Random) -> int: """Sample an index from a probability vector. >>> rng = random.Random(0) >>> _categorical_sample([0.0, 1.0, 0.0], rng) 1 """ r = rng.random() acc = 0.0 for i, p in enumerate(probs): acc += p if r <= acc: return i return max(range(len(probs)), key=lambda i: probs[i]) def sample_text(P: List[List[float]], itos: Dict[int, str], length: int = 300, seed: int = 123) -> str: """Generate text by walking the bigram model (row-stochastic P). >>> chars, stoi, itos = build_vocab("aba") >>> M = bigram_counts("aba", stoi) >>> P = probs_laplace(M, alpha=1.0) >>> s = sample_text(P, itos, length=5, seed=0) >>> len(s) == 5 True """ rng = random.Random(seed) V = len(itos) cur = rng.randrange(V) out = [itos[cur]] for _ in range(length - 1): row = P[cur] cur = _categorical_sample(row, rng) out.append(itos[cur]) return "".join(out) # ---------- main ---------- def main() -> None: text = load_text() chars, stoi, itos = build_vocab(text) M = bigram_counts(text, stoi) P_unsm = probs_unsmoothed(M) P_lap = probs_laplace(M, alpha=1.0) print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}") print("\n=== Unsmoothed sample ===") print(sample_text(P_unsm, itos, length=400, seed=123)) print("\n=== Laplace-smoothed (alpha=1.0) sample ===") print(sample_text(P_lap, itos, length=400, seed=123)) if __name__ == "__main__": main()