feat: add temperature and top-k sampling with NumPy implementation and update gitignore to exclude requirements.txt from text file ignore pattern
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -216,4 +216,5 @@ __marimo__/
|
|||||||
.streamlit/secrets.toml
|
.streamlit/secrets.toml
|
||||||
|
|
||||||
# Data/Material that should not be synced
|
# Data/Material that should not be synced
|
||||||
*.txt
|
*.txt
|
||||||
|
!requirements.txt
|
180
08_temp_topk_numpy.py
Normal file
180
08_temp_topk_numpy.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
# 08_temp_topk_numpy.py
|
||||||
|
"""
|
||||||
|
Temperature and top-k sampling on a laplace-smoothed bigram model (NumPy).
|
||||||
|
|
||||||
|
What it does:
|
||||||
|
- Builds a bigram probability matrix (with Laplace smoothing).
|
||||||
|
- Applies temperature scaling and optional top-k truncation per step.
|
||||||
|
- Generates sample text for a few temperature settings.
|
||||||
|
|
||||||
|
How to run:
|
||||||
|
python 08_temp_topk_numpy.py
|
||||||
|
# Optional flags:
|
||||||
|
# --length 300 --seed 123 --temperature 0.8 --top_k 50
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Temperature < 1.0 shrapens (more deterministic); > 1.0 flattens (more random).
|
||||||
|
- Top-k keeps only the k most likely candidates each step (others set to prob=0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
FALLBACK = "to be, or not to be\n"
|
||||||
|
|
||||||
|
# ---------- IO & Vocab ----------
|
||||||
|
|
||||||
|
|
||||||
|
def load_text() -> str:
|
||||||
|
"""Load and normalize text from train.txt or data.txt (fallback if missing).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UTF-8 text with only LF ('\\n') newlines.
|
||||||
|
|
||||||
|
>>> isinstance(load_text(), str)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
|
||||||
|
t = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
|
||||||
|
return t.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def vocab_maps(text: str) -> Tuple[list[str], Dict[str, int], Dict[int, str]]:
|
||||||
|
"""Build sorted vocabulary and mapping dicts.
|
||||||
|
|
||||||
|
>>> chars, stoi, itos = vocab_maps("ba\\n")
|
||||||
|
>>> chars == ['\\n', 'a', 'b']
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
chars = sorted(set(text))
|
||||||
|
stoi = {c: i for i, c in enumerate(chars)}
|
||||||
|
itos = {i: c for c, i in stoi.items()}
|
||||||
|
return chars, stoi, itos
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- counts -> probabilities ----------
|
||||||
|
def bigram_counts(text: str, stoi: Dict[str, int]) -> np.ndarray:
|
||||||
|
"""Return bigram counts matrix [V,V]
|
||||||
|
|
||||||
|
>>> import numpy as _np
|
||||||
|
>>> M = bigram_counts("aba", {'a':0, 'b':1})
|
||||||
|
>>> int(M[0,1]), int(M[1,0]) # a->b, b->a
|
||||||
|
(1, 1)
|
||||||
|
"""
|
||||||
|
ids = [stoi[c] for c in text if c in stoi]
|
||||||
|
V = len(stoi)
|
||||||
|
M = np.zeros((V, V), dtype=np.int64)
|
||||||
|
if len(ids) > 1:
|
||||||
|
a = np.array(ids[:-1], dtype=np.int64)
|
||||||
|
b = np.array(ids[1:], dtype=np.int64)
|
||||||
|
for i, j in zip(a, b):
|
||||||
|
M[i, j] += 1
|
||||||
|
return M
|
||||||
|
|
||||||
|
|
||||||
|
def laplace_probs(M: np.ndarray, alpha: float = 1.0) -> np.ndarray:
|
||||||
|
"""Add-alpha smoothing per row (returns probabilities).
|
||||||
|
|
||||||
|
>>> import numpy as _np
|
||||||
|
>>> P = laplace_probs(_np.arry([[0, 2], [0, 0]]), alpha=1.0)
|
||||||
|
>>> _np.allclose(P.sum(axis=1), 1.0)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
P = M.astype(np.float64) + alpha
|
||||||
|
P /= P.sum(axis=1, keepdims=True)
|
||||||
|
return P
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- temperature + top-k ----------
|
||||||
|
def temp_topk_row(row: np.ndarray, temperature: float = 1.0, k: int | None = None) -> np.ndarray:
|
||||||
|
"""Apply temperature and top-k to a probability row; return new probs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: 1-D probability vector (sums to ~1).
|
||||||
|
temperature: >0; lower = sharper, higher = flatter.
|
||||||
|
k: If set and k < V, keep only top-k by (logit); others set to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A re-normalized probability vector.
|
||||||
|
|
||||||
|
>>> import numpy as _np
|
||||||
|
>>> r = _np.array([0.1, 0.2, 0.7])
|
||||||
|
>>> out = temp_topk__row(r, temperature=0.7, k=2)
|
||||||
|
>>> _np.isclose(out.sum(), 1.0)
|
||||||
|
True
|
||||||
|
>>> (_np.count_nonzero(out) == 2)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
assert temperature > 0.0, "temperature must be > 0"
|
||||||
|
# Convert probs -> logits for temperature & masking
|
||||||
|
logits = np.log(np.clip(row, 1e-12, 1.0))
|
||||||
|
logits = logits / temperature
|
||||||
|
if k is not None and 0 < k < logits.size:
|
||||||
|
top_idx = np.argpartition(-logits, k)[:k]
|
||||||
|
mask = np.full_like(logits, -np.inf)
|
||||||
|
mask[top_idx] = logits[top_idx]
|
||||||
|
logits = mask
|
||||||
|
# Back to probs (stable)
|
||||||
|
logits = logits - np.nanmax(logits) # subtract max for stability
|
||||||
|
p = np.exp(logits)
|
||||||
|
p_sum = p.sum()
|
||||||
|
if not np.isfinite(p_sum) or p_sum <= 0.0:
|
||||||
|
# Fallback to uniform if things went sideways numerically
|
||||||
|
p = np.ones_like(p) / p.size
|
||||||
|
else:
|
||||||
|
p /= p_sum
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- sampling ----------
|
||||||
|
|
||||||
|
def sample_text(P: np.ndarray, itos: Dict[int, str], length: int = 300, seed: int = 123,
|
||||||
|
temperature: float = 1.0, top_k: int | None = None) -> str:
|
||||||
|
"""Sample text from a row-stochastic matrix with temp/top-k.
|
||||||
|
|
||||||
|
>>> import numpy as _np
|
||||||
|
>>> itos = {0: 'a',1:'b', 2:'c'}
|
||||||
|
>>> P = _np.array([[0.0, 1.0, 0.0],[0.5, 0.0, 0.5],[1.0, 0.0, 0.0]])
|
||||||
|
>>> s = sample_text(P, itos, length=5, seed=0, temperature=1.0, top_k=2)
|
||||||
|
>>> len(s) == 5
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
V = P.shape[0]
|
||||||
|
cur = int(rng.integers(low=0, high=V))
|
||||||
|
out = [itos[cur]]
|
||||||
|
for _ in range(length - 1):
|
||||||
|
row = P[cur]
|
||||||
|
p = temp_topk_row(row, temperature=temperature, k=top_k)
|
||||||
|
cur = int(rng.choice(V, p=p))
|
||||||
|
out.append(itos[cur])
|
||||||
|
return "".join(out)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- main ----------
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--length", type=int, default=300)
|
||||||
|
parser.add_argument("--seed", type=int, default=123)
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--top_k", type=int, default=50)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
text = load_text()
|
||||||
|
chars, stoi, itos = vocab_maps(text)
|
||||||
|
M = bigram_counts(text, stoi)
|
||||||
|
P = laplace_probs(M, alpha=1.0)
|
||||||
|
|
||||||
|
print(f"Vocab size: {len(chars)} | corpus chars: {len(text)}")
|
||||||
|
for T in (max(0.3, args.temperature/2), args.temperature, max(0.3, args.temperature*1.5)):
|
||||||
|
print(f"\n=== temperature={T:.2f} | top_k={args.top_k} ===")
|
||||||
|
s = sample_text(P, itos, length=args. length, seed=args.seed, temperature=T, top_k=min(args.top_k, len(chars)))
|
||||||
|
print(s)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
numpy
|
Reference in New Issue
Block a user