diff --git a/04_vocab_encode_decode.py b/04_vocab_encode_decode.py new file mode 100644 index 0000000..8137812 --- /dev/null +++ b/04_vocab_encode_decode.py @@ -0,0 +1,125 @@ +# 04_vocab_encode_decode.py +""" +Build a character vocabulary and encode/decode text. + +What it does: +- Loads UTF-8 text from train.txt (preferred) or data.txt (fallback included). +- Normalizes newlines to "\\n". +- Builds char<->id mappings, then encodes and decodes a sample for a round-trip check. + +How to run: + python 04_vocab_encode_decode.py + +Notes: +- The vocabulary is built from the loaded text. Encoding a string with unseen characters + will raise a KeyError (expected at this stage). +- Everything is CPU-only and uses Python stdlib. +""" + +from __future__ import annotations +from pathlib import Path +from typing import Dict, List + +FALLBACK = "hello\\nworld\\n" + + +def load_source() -> str: + """Load source text from train.txt, then data.txt, else fallback; normalize newlines. + + Returns: + Text with only '\\n' newlines. + + >>> isinstance(load_source(), str) + True + """ + p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt") + text = p.read_text(encoding="utf-8") if p.exists() else FALLBACK + return text.replace("\\r\\n", "\\n").replace("\\r", "\\n") + + +def build_vocab(text: str) -> Dict[str, int]: + """Build a sorted character vocabulary mapping char->index. + + Args: + text: Source text. + + Returns: + stoi dict mapping each unique character to an integer id. + + >>> stoi = build_vocab("ab\\n") + >>> sorted(stoi) == ['\\n','a','b'] + True + """ + chars = sorted(set(text)) + return {ch: i for i, ch in enumerate(chars)} + + +def invert(stoi: Dict[str, int]) -> Dict[int, str]: + """Invert a char->id mapping to id->char. + + Args: + stoi: Mapping from characters to integer ids. + + Returns: + itos dict mapping integer ids back to characters. + + >>> invert({'a': 0, 'b': 1})[1] + 'b' + """ + return {v: k for k, v in stoi.items()} + + +def encode(text: str, stoi: Dict[str, int]) -> List[int]: + """Encode a string into a list of integer ids using the provided vocabulary. + + Args: + text: Input text. + stoi: Character-to-index mapping. + + Returns: + List of integer ids. + + Raises: + KeyError: If a character is not in the vocabulary. + + >>> encode("ab", {'a':0, 'b':1}) + [0, 1] + """ + return [stoi[ch] for ch in text] + + +def decode(ids: List[int], itos: Dict[int, str]) -> str: + """Decode a list of integer ids back into a string. + + Args: + ids: Sequence of token ids. + itos: Index-to-character mapping. + + Returns: + Decoded string. + + >>> decode([0, 1], {0:'a', 1:'b'}) + 'ab' + """ + return "".join(itos[i] for i in ids) + + +def main() -> None: + text = load_source() + stoi = build_vocab(text) + itos = invert(stoi) + + # Take a small sample to test round-trip + sample = text[:200] + ids = encode(sample, stoi) + roundtrip = decode(ids, itos) + + print(f"Vocab size: {len(stoi)}") + print("Sample (first 200 chars, \\n shown literally):") + print(sample.replace("\\n", "\\\\n")) + print("\nEncoded (first 40 ids):", ids[:40], "...") + print("Roundtrip OK:", roundtrip == sample) + + +if __name__ == "__main__": + main()