feat: add vocabulary encoding/decoding script with character-level tokenization
This commit introduces a new script that implements character-level vocabulary building and text encoding/decoding functionality. The script loads text from train.txt or falls back to data.txt, normalizes line endings, builds character-to-id mappings, and includes round-trip encoding/decoding validation. It's designed for CPU-only operation using only Python standard library modules and provides clear error handling for unseen characters during encoding.
This commit is contained in:
125
04_vocab_encode_decode.py
Normal file
125
04_vocab_encode_decode.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# 04_vocab_encode_decode.py
|
||||||
|
"""
|
||||||
|
Build a character vocabulary and encode/decode text.
|
||||||
|
|
||||||
|
What it does:
|
||||||
|
- Loads UTF-8 text from train.txt (preferred) or data.txt (fallback included).
|
||||||
|
- Normalizes newlines to "\\n".
|
||||||
|
- Builds char<->id mappings, then encodes and decodes a sample for a round-trip check.
|
||||||
|
|
||||||
|
How to run:
|
||||||
|
python 04_vocab_encode_decode.py
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The vocabulary is built from the loaded text. Encoding a string with unseen characters
|
||||||
|
will raise a KeyError (expected at this stage).
|
||||||
|
- Everything is CPU-only and uses Python stdlib.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
FALLBACK = "hello\\nworld\\n"
|
||||||
|
|
||||||
|
|
||||||
|
def load_source() -> str:
|
||||||
|
"""Load source text from train.txt, then data.txt, else fallback; normalize newlines.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text with only '\\n' newlines.
|
||||||
|
|
||||||
|
>>> isinstance(load_source(), str)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
p = Path("train.txt") if Path("train.txt").exists() else Path("data.txt")
|
||||||
|
text = p.read_text(encoding="utf-8") if p.exists() else FALLBACK
|
||||||
|
return text.replace("\\r\\n", "\\n").replace("\\r", "\\n")
|
||||||
|
|
||||||
|
|
||||||
|
def build_vocab(text: str) -> Dict[str, int]:
|
||||||
|
"""Build a sorted character vocabulary mapping char->index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Source text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
stoi dict mapping each unique character to an integer id.
|
||||||
|
|
||||||
|
>>> stoi = build_vocab("ab\\n")
|
||||||
|
>>> sorted(stoi) == ['\\n','a','b']
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
chars = sorted(set(text))
|
||||||
|
return {ch: i for i, ch in enumerate(chars)}
|
||||||
|
|
||||||
|
|
||||||
|
def invert(stoi: Dict[str, int]) -> Dict[int, str]:
|
||||||
|
"""Invert a char->id mapping to id->char.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stoi: Mapping from characters to integer ids.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
itos dict mapping integer ids back to characters.
|
||||||
|
|
||||||
|
>>> invert({'a': 0, 'b': 1})[1]
|
||||||
|
'b'
|
||||||
|
"""
|
||||||
|
return {v: k for k, v in stoi.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def encode(text: str, stoi: Dict[str, int]) -> List[int]:
|
||||||
|
"""Encode a string into a list of integer ids using the provided vocabulary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text.
|
||||||
|
stoi: Character-to-index mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of integer ids.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If a character is not in the vocabulary.
|
||||||
|
|
||||||
|
>>> encode("ab", {'a':0, 'b':1})
|
||||||
|
[0, 1]
|
||||||
|
"""
|
||||||
|
return [stoi[ch] for ch in text]
|
||||||
|
|
||||||
|
|
||||||
|
def decode(ids: List[int], itos: Dict[int, str]) -> str:
|
||||||
|
"""Decode a list of integer ids back into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: Sequence of token ids.
|
||||||
|
itos: Index-to-character mapping.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded string.
|
||||||
|
|
||||||
|
>>> decode([0, 1], {0:'a', 1:'b'})
|
||||||
|
'ab'
|
||||||
|
"""
|
||||||
|
return "".join(itos[i] for i in ids)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
text = load_source()
|
||||||
|
stoi = build_vocab(text)
|
||||||
|
itos = invert(stoi)
|
||||||
|
|
||||||
|
# Take a small sample to test round-trip
|
||||||
|
sample = text[:200]
|
||||||
|
ids = encode(sample, stoi)
|
||||||
|
roundtrip = decode(ids, itos)
|
||||||
|
|
||||||
|
print(f"Vocab size: {len(stoi)}")
|
||||||
|
print("Sample (first 200 chars, \\n shown literally):")
|
||||||
|
print(sample.replace("\\n", "\\\\n"))
|
||||||
|
print("\nEncoded (first 40 ids):", ids[:40], "...")
|
||||||
|
print("Roundtrip OK:", roundtrip == sample)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user