From bcde2ad478bc5e857ceea0299ac7a4e14cfd5d88 Mon Sep 17 00:00:00 2001 From: Dani Date: Tue, 23 Sep 2025 22:02:04 -0400 Subject: [PATCH] add bigram counts model with text sampling functionality and doctests --- 06_bigram_counts.py | 202 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 06_bigram_counts.py diff --git a/06_bigram_counts.py b/06_bigram_counts.py new file mode 100644 index 0000000..81043bd --- /dev/null +++ b/06_bigram_counts.py @@ -0,0 +1,202 @@ +# 06_bigram_counts.py +""" +Build a bigram counts model (no smoothing) and sample text. + +What it does: +- Loads UTF-8 text from train.txt (preferred) or data.txt; normalizes to "\\n". +- Builds a sorted vocabulary. +- Fills a VxV integer bigram counts matrix. +- Converts to per-row probabilities (no smoothing). +- Samples text via a sdimple categorical draw per step. + +How to run: + python 06_bigram_counts.py + # optional doctests + python -m doctest -v 06_bigram_counts.py + +Notes: +- This version has **no smoothing**, so many entries will be zero. In the next + lesson we'll add Laplace smoothing to remove zeros. +""" + +from __future__ import annotations +from pathlib import Path +from typing import Dict, List, Tuple +import random + +FALLBACK = "ababa\n" + +# ---------- IO & Vocab ---------- + + +def load_text() -> str: + """ Load and normalize text from train.txt or data.txt, else fallback. + + Returns: + UTF-8 text with only LF ('\\n') newlines. + + >>> isinstance(load_text(), str) + True + """ + p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt") + txt = p.read_text(encoding="utf-8") if p.exists() else FALLBACK + return txt.replace("\r\n", "\n").replace("\r", "\n") + + +def build_vocab(text: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]: + """Build sorted vocabulary and mapping dicts. + + Args: + text: Corpus string. + + Returns: + (chars, stoi, itos) + + >>> chars, stoi, itos = build_vocab("ba\\n") + >>> chars == ['\\n', 'a', 'b'] + True + >>> stoi['a'], itos[2] + (1, 'b') + """ + chars = sorted(set(text)) + stoi = {ch: i for i, ch in enumerate(chars)} + itos = {i: ch for ch, i in stoi.items()} + return chars, stoi, itos + + +# ---------- counts -> probabilities ---------- + +def bigram_counts(text: str, stoi: Dict[str, int]) -> List[List[int]]: + """Return bigram counts matrix of shape [V, V]. + + Args: + text: Corpus. + stoi: Char-to-index mapping. + + Returns: + VxV list of ints, where M[i][j] = count of (i->j). + + >>> M = bigram_counts("aba", {'a':0, 'b':1}) + >>> M[0][1], M[1][0] # a->b, b->a + (1, 1) + """ + V = len(stoi) + M: List[List[int]] = [[0 for _ in range(V)] for _ in range(V)] + ids = [stoi[c] for c in text if c in stoi] + for a, b in zip(ids[:-1], ids[1:]): + M[a][b] += 1 + return M + + +def row_to_probs(row: List[int]) -> List[float]: + """Normalize a count row into probabilites (no smoothing). + + Args: + row: List of non-negative counts. + + Returns: + Probabilities that sum to 1.0 if the row has any mass; otherwise all zeros. + + >>> row_to_probs([0, 2, 2]) + [0.0, 0.5, 0.5] + >>> row_to_probs([0, 0, 0]) + [0.0, 0.0, 0.0] + """ + s = sum(row) + if s == 0: + return [0.0 for _ in row] + return [c / s for c in row] + + +def counts_to_probs(M: List[List[int]]) -> List[List[float]]: + """Convert a counts matrix to a matrix of row-wise probabilities. + + >>> P = counts_to_probs([[0, 1], [2, 0]]) + >>> sum(P[0]), sum(P[1]) + (1.0, 1.0) + """ + return [row_to_probs(r) for r in M] + + +# ---------- sampling ---------- + +def _categorical_sample(probs: List[float], rng: random.Random) -> int: + """Sample an index from a probability vector. + + Args: + probs: Probabilities (should sum to ~1.0). + rng: Random number generator. + + Returns: + Chosen index. + + >>> rng = random.Random(0) + >>> _categorical_sample([0.0, 1.0, 0.0], rng) + 1 + """ + r = rng.random() + acc = 0.0 + for i, p in enumerate(probs): + acc += p + if r <= acc: + return i + # Numerical edge case: if due to randoming we didn't return, pick last max-prob. + return max(range(len(probs)), key=lambda i: probs[i]) + + +def sample_text(P: List[List[float]], itos: Dict[int, str], length: int = 400, seed: int = 123, start_char: str | None = None) -> str: + """Generate text by walking the bigram model. + + Args: + P: Row-stochastic matrix [V, V] of P(next|current). + itos: Index-to-char map. + length: Number of characters to produce. + seed: RNG seed for reproducibility. + start_char: If provided, start from this character; else pick randomly. + + Returns: + Generated string of length 'length'. + + >>> chars, stoi, itos = build_vocab("aba") + >>> M = bigram_counts("aba", stoi) + >>> P = counts_to_probs(M) + >>> out = sample_text(P, itos, length=5, seed=0, start_char='a') + >>> len(out) == 5 + True + """ + rng = random.Random(seed) + V = len(itos) + # choose starting index + if start_char is not None: + start_idx = {v: k for k, v in itos.items()}[start_char] + else: + start_idx = rng.randrange(V) + + cur = start_idx + out = [itos[cur]] + for _ in range(length - 1): + row = P[cur] + if not row or sum(row) == 0.0: + # dead: jump to random char + cur = rng.randrange(V) + else: + cur = _categorical_sample(row, rng) + out.append(itos[cur]) + return "".join(out) + + +# ---------- main ---------- + +def main() -> None: + text = load_text() + chars, stoi, itos = build_vocab(text) + M = bigram_counts(text, stoi) + P = counts_to_probs(M) + + print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}") + print("\n=== Sample (no smoothing) ===") + print(sample_text(P, itos, length=500, seed=123)) + + +if __name__ == "__main__": + main()