203 lines
5.3 KiB
Python
203 lines
5.3 KiB
Python
# 06_bigram_counts.py
|
|
"""
|
|
Build a bigram counts model (no smoothing) and sample text.
|
|
|
|
What it does:
|
|
- Loads UTF-8 text from train.txt (preferred) or data.txt; normalizes to "\\n".
|
|
- Builds a sorted vocabulary.
|
|
- Fills a VxV integer bigram counts matrix.
|
|
- Converts to per-row probabilities (no smoothing).
|
|
- Samples text via a sdimple categorical draw per step.
|
|
|
|
How to run:
|
|
python 06_bigram_counts.py
|
|
# optional doctests
|
|
python -m doctest -v 06_bigram_counts.py
|
|
|
|
Notes:
|
|
- This version has **no smoothing**, so many entries will be zero. In the next
|
|
lesson we'll add Laplace smoothing to remove zeros.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
import random
|
|
|
|
FALLBACK = "ababa\n"
|
|
|
|
# ---------- IO & Vocab ----------
|
|
|
|
|
|
def load_text() -> str:
|
|
""" Load and normalize text from train.txt or data.txt, else fallback.
|
|
|
|
Returns:
|
|
UTF-8 text with only LF ('\\n') newlines.
|
|
|
|
>>> isinstance(load_text(), str)
|
|
True
|
|
"""
|
|
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
|
|
txt = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
|
|
return txt.replace("\r\n", "\n").replace("\r", "\n")
|
|
|
|
|
|
def build_vocab(text: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]:
|
|
"""Build sorted vocabulary and mapping dicts.
|
|
|
|
Args:
|
|
text: Corpus string.
|
|
|
|
Returns:
|
|
(chars, stoi, itos)
|
|
|
|
>>> chars, stoi, itos = build_vocab("ba\\n")
|
|
>>> chars == ['\\n', 'a', 'b']
|
|
True
|
|
>>> stoi['a'], itos[2]
|
|
(1, 'b')
|
|
"""
|
|
chars = sorted(set(text))
|
|
stoi = {ch: i for i, ch in enumerate(chars)}
|
|
itos = {i: ch for ch, i in stoi.items()}
|
|
return chars, stoi, itos
|
|
|
|
|
|
# ---------- counts -> probabilities ----------
|
|
|
|
def bigram_counts(text: str, stoi: Dict[str, int]) -> List[List[int]]:
|
|
"""Return bigram counts matrix of shape [V, V].
|
|
|
|
Args:
|
|
text: Corpus.
|
|
stoi: Char-to-index mapping.
|
|
|
|
Returns:
|
|
VxV list of ints, where M[i][j] = count of (i->j).
|
|
|
|
>>> M = bigram_counts("aba", {'a':0, 'b':1})
|
|
>>> M[0][1], M[1][0] # a->b, b->a
|
|
(1, 1)
|
|
"""
|
|
V = len(stoi)
|
|
M: List[List[int]] = [[0 for _ in range(V)] for _ in range(V)]
|
|
ids = [stoi[c] for c in text if c in stoi]
|
|
for a, b in zip(ids[:-1], ids[1:]):
|
|
M[a][b] += 1
|
|
return M
|
|
|
|
|
|
def row_to_probs(row: List[int]) -> List[float]:
|
|
"""Normalize a count row into probabilites (no smoothing).
|
|
|
|
Args:
|
|
row: List of non-negative counts.
|
|
|
|
Returns:
|
|
Probabilities that sum to 1.0 if the row has any mass; otherwise all zeros.
|
|
|
|
>>> row_to_probs([0, 2, 2])
|
|
[0.0, 0.5, 0.5]
|
|
>>> row_to_probs([0, 0, 0])
|
|
[0.0, 0.0, 0.0]
|
|
"""
|
|
s = sum(row)
|
|
if s == 0:
|
|
return [0.0 for _ in row]
|
|
return [c / s for c in row]
|
|
|
|
|
|
def counts_to_probs(M: List[List[int]]) -> List[List[float]]:
|
|
"""Convert a counts matrix to a matrix of row-wise probabilities.
|
|
|
|
>>> P = counts_to_probs([[0, 1], [2, 0]])
|
|
>>> sum(P[0]), sum(P[1])
|
|
(1.0, 1.0)
|
|
"""
|
|
return [row_to_probs(r) for r in M]
|
|
|
|
|
|
# ---------- sampling ----------
|
|
|
|
def _categorical_sample(probs: List[float], rng: random.Random) -> int:
|
|
"""Sample an index from a probability vector.
|
|
|
|
Args:
|
|
probs: Probabilities (should sum to ~1.0).
|
|
rng: Random number generator.
|
|
|
|
Returns:
|
|
Chosen index.
|
|
|
|
>>> rng = random.Random(0)
|
|
>>> _categorical_sample([0.0, 1.0, 0.0], rng)
|
|
1
|
|
"""
|
|
r = rng.random()
|
|
acc = 0.0
|
|
for i, p in enumerate(probs):
|
|
acc += p
|
|
if r <= acc:
|
|
return i
|
|
# Numerical edge case: if due to randoming we didn't return, pick last max-prob.
|
|
return max(range(len(probs)), key=lambda i: probs[i])
|
|
|
|
|
|
def sample_text(P: List[List[float]], itos: Dict[int, str], length: int = 400, seed: int = 123, start_char: str | None = None) -> str:
|
|
"""Generate text by walking the bigram model.
|
|
|
|
Args:
|
|
P: Row-stochastic matrix [V, V] of P(next|current).
|
|
itos: Index-to-char map.
|
|
length: Number of characters to produce.
|
|
seed: RNG seed for reproducibility.
|
|
start_char: If provided, start from this character; else pick randomly.
|
|
|
|
Returns:
|
|
Generated string of length 'length'.
|
|
|
|
>>> chars, stoi, itos = build_vocab("aba")
|
|
>>> M = bigram_counts("aba", stoi)
|
|
>>> P = counts_to_probs(M)
|
|
>>> out = sample_text(P, itos, length=5, seed=0, start_char='a')
|
|
>>> len(out) == 5
|
|
True
|
|
"""
|
|
rng = random.Random(seed)
|
|
V = len(itos)
|
|
# choose starting index
|
|
if start_char is not None:
|
|
start_idx = {v: k for k, v in itos.items()}[start_char]
|
|
else:
|
|
start_idx = rng.randrange(V)
|
|
|
|
cur = start_idx
|
|
out = [itos[cur]]
|
|
for _ in range(length - 1):
|
|
row = P[cur]
|
|
if not row or sum(row) == 0.0:
|
|
# dead: jump to random char
|
|
cur = rng.randrange(V)
|
|
else:
|
|
cur = _categorical_sample(row, rng)
|
|
out.append(itos[cur])
|
|
return "".join(out)
|
|
|
|
|
|
# ---------- main ----------
|
|
|
|
def main() -> None:
|
|
text = load_text()
|
|
chars, stoi, itos = build_vocab(text)
|
|
M = bigram_counts(text, stoi)
|
|
P = counts_to_probs(M)
|
|
|
|
print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}")
|
|
print("\n=== Sample (no smoothing) ===")
|
|
print(sample_text(P, itos, length=500, seed=123))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|