Files
Aria/06_bigram_counts.py

203 lines
5.3 KiB
Python

# 06_bigram_counts.py
"""
Build a bigram counts model (no smoothing) and sample text.
What it does:
- Loads UTF-8 text from train.txt (preferred) or data.txt; normalizes to "\\n".
- Builds a sorted vocabulary.
- Fills a VxV integer bigram counts matrix.
- Converts to per-row probabilities (no smoothing).
- Samples text via a sdimple categorical draw per step.
How to run:
python 06_bigram_counts.py
# optional doctests
python -m doctest -v 06_bigram_counts.py
Notes:
- This version has **no smoothing**, so many entries will be zero. In the next
lesson we'll add Laplace smoothing to remove zeros.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Tuple
import random
FALLBACK = "ababa\n"
# ---------- IO & Vocab ----------
def load_text() -> str:
""" Load and normalize text from train.txt or data.txt, else fallback.
Returns:
UTF-8 text with only LF ('\\n') newlines.
>>> isinstance(load_text(), str)
True
"""
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
txt = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
return txt.replace("\r\n", "\n").replace("\r", "\n")
def build_vocab(text: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]:
"""Build sorted vocabulary and mapping dicts.
Args:
text: Corpus string.
Returns:
(chars, stoi, itos)
>>> chars, stoi, itos = build_vocab("ba\\n")
>>> chars == ['\\n', 'a', 'b']
True
>>> stoi['a'], itos[2]
(1, 'b')
"""
chars = sorted(set(text))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
return chars, stoi, itos
# ---------- counts -> probabilities ----------
def bigram_counts(text: str, stoi: Dict[str, int]) -> List[List[int]]:
"""Return bigram counts matrix of shape [V, V].
Args:
text: Corpus.
stoi: Char-to-index mapping.
Returns:
VxV list of ints, where M[i][j] = count of (i->j).
>>> M = bigram_counts("aba", {'a':0, 'b':1})
>>> M[0][1], M[1][0] # a->b, b->a
(1, 1)
"""
V = len(stoi)
M: List[List[int]] = [[0 for _ in range(V)] for _ in range(V)]
ids = [stoi[c] for c in text if c in stoi]
for a, b in zip(ids[:-1], ids[1:]):
M[a][b] += 1
return M
def row_to_probs(row: List[int]) -> List[float]:
"""Normalize a count row into probabilites (no smoothing).
Args:
row: List of non-negative counts.
Returns:
Probabilities that sum to 1.0 if the row has any mass; otherwise all zeros.
>>> row_to_probs([0, 2, 2])
[0.0, 0.5, 0.5]
>>> row_to_probs([0, 0, 0])
[0.0, 0.0, 0.0]
"""
s = sum(row)
if s == 0:
return [0.0 for _ in row]
return [c / s for c in row]
def counts_to_probs(M: List[List[int]]) -> List[List[float]]:
"""Convert a counts matrix to a matrix of row-wise probabilities.
>>> P = counts_to_probs([[0, 1], [2, 0]])
>>> sum(P[0]), sum(P[1])
(1.0, 1.0)
"""
return [row_to_probs(r) for r in M]
# ---------- sampling ----------
def _categorical_sample(probs: List[float], rng: random.Random) -> int:
"""Sample an index from a probability vector.
Args:
probs: Probabilities (should sum to ~1.0).
rng: Random number generator.
Returns:
Chosen index.
>>> rng = random.Random(0)
>>> _categorical_sample([0.0, 1.0, 0.0], rng)
1
"""
r = rng.random()
acc = 0.0
for i, p in enumerate(probs):
acc += p
if r <= acc:
return i
# Numerical edge case: if due to randoming we didn't return, pick last max-prob.
return max(range(len(probs)), key=lambda i: probs[i])
def sample_text(P: List[List[float]], itos: Dict[int, str], length: int = 400, seed: int = 123, start_char: str | None = None) -> str:
"""Generate text by walking the bigram model.
Args:
P: Row-stochastic matrix [V, V] of P(next|current).
itos: Index-to-char map.
length: Number of characters to produce.
seed: RNG seed for reproducibility.
start_char: If provided, start from this character; else pick randomly.
Returns:
Generated string of length 'length'.
>>> chars, stoi, itos = build_vocab("aba")
>>> M = bigram_counts("aba", stoi)
>>> P = counts_to_probs(M)
>>> out = sample_text(P, itos, length=5, seed=0, start_char='a')
>>> len(out) == 5
True
"""
rng = random.Random(seed)
V = len(itos)
# choose starting index
if start_char is not None:
start_idx = {v: k for k, v in itos.items()}[start_char]
else:
start_idx = rng.randrange(V)
cur = start_idx
out = [itos[cur]]
for _ in range(length - 1):
row = P[cur]
if not row or sum(row) == 0.0:
# dead: jump to random char
cur = rng.randrange(V)
else:
cur = _categorical_sample(row, rng)
out.append(itos[cur])
return "".join(out)
# ---------- main ----------
def main() -> None:
text = load_text()
chars, stoi, itos = build_vocab(text)
M = bigram_counts(text, stoi)
P = counts_to_probs(M)
print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}")
print("\n=== Sample (no smoothing) ===")
print(sample_text(P, itos, length=500, seed=123))
if __name__ == "__main__":
main()