feat: add vocabulary encoding/decoding script with character-level tokenization

This commit introduces a new script that implements character-level vocabulary building and text encoding/decoding functionality. The script loads text from train.txt or falls back to data.txt, normalizes line endings, builds character-to-id mappings, and includes round-trip encoding/decoding validation. It's designed for CPU-only operation using only Python standard library modules and provides clear error handling for unseen characters during encoding.
This commit is contained in:
2025-09-23 20:57:48 -04:00
parent abba60a798
commit feecf05ee3

125
04_vocab_encode_decode.py Normal file
View File

@@ -0,0 +1,125 @@
# 04_vocab_encode_decode.py
"""
Build a character vocabulary and encode/decode text.
What it does:
- Loads UTF-8 text from train.txt (preferred) or data.txt (fallback included).
- Normalizes newlines to "\\n".
- Builds char<->id mappings, then encodes and decodes a sample for a round-trip check.
How to run:
python 04_vocab_encode_decode.py
Notes:
- The vocabulary is built from the loaded text. Encoding a string with unseen characters
will raise a KeyError (expected at this stage).
- Everything is CPU-only and uses Python stdlib.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List
FALLBACK = "hello\\nworld\\n"
def load_source() -> str:
"""Load source text from train.txt, then data.txt, else fallback; normalize newlines.
Returns:
Text with only '\\n' newlines.
>>> isinstance(load_source(), str)
True
"""
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
text = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
return text.replace("\\r\\n", "\\n").replace("\\r", "\\n")
def build_vocab(text: str) -> Dict[str, int]:
"""Build a sorted character vocabulary mapping char->index.
Args:
text: Source text.
Returns:
stoi dict mapping each unique character to an integer id.
>>> stoi = build_vocab("ab\\n")
>>> sorted(stoi) == ['\\n','a','b']
True
"""
chars = sorted(set(text))
return {ch: i for i, ch in enumerate(chars)}
def invert(stoi: Dict[str, int]) -> Dict[int, str]:
"""Invert a char->id mapping to id->char.
Args:
stoi: Mapping from characters to integer ids.
Returns:
itos dict mapping integer ids back to characters.
>>> invert({'a': 0, 'b': 1})[1]
'b'
"""
return {v: k for k, v in stoi.items()}
def encode(text: str, stoi: Dict[str, int]) -> List[int]:
"""Encode a string into a list of integer ids using the provided vocabulary.
Args:
text: Input text.
stoi: Character-to-index mapping.
Returns:
List of integer ids.
Raises:
KeyError: If a character is not in the vocabulary.
>>> encode("ab", {'a':0, 'b':1})
[0, 1]
"""
return [stoi[ch] for ch in text]
def decode(ids: List[int], itos: Dict[int, str]) -> str:
"""Decode a list of integer ids back into a string.
Args:
ids: Sequence of token ids.
itos: Index-to-character mapping.
Returns:
Decoded string.
>>> decode([0, 1], {0:'a', 1:'b'})
'ab'
"""
return "".join(itos[i] for i in ids)
def main() -> None:
text = load_source()
stoi = build_vocab(text)
itos = invert(stoi)
# Take a small sample to test round-trip
sample = text[:200]
ids = encode(sample, stoi)
roundtrip = decode(ids, itos)
print(f"Vocab size: {len(stoi)}")
print("Sample (first 200 chars, \\n shown literally):")
print(sample.replace("\\n", "\\\\n"))
print("\nEncoded (first 40 ids):", ids[:40], "...")
print("Roundtrip OK:", roundtrip == sample)
if __name__ == "__main__":
main()