- Added 03_train_val_split.py to create deterministic train/validation splits from data.txt or fallback text - Updated .gitignore to un-comment .vscode/ directory exclusion - Changed data.txt pattern to *.txt for better file matching in gitignore - Script handles UTF-8 text loading with newline normalization and writes train.txt/val.txt files - Includes doctest examples and proper type hints
80 lines
2.1 KiB
Python
80 lines
2.1 KiB
Python
# 03_train_val_split.py
|
|
"""
|
|
Create a train/val split from data.txt or a small fallback, and save to disk.
|
|
|
|
What it does:
|
|
- Loads UTF-8 text, normalizes newlines to "\\n".
|
|
- Splits by character index (default 90/10) deterministically.
|
|
- Writes 'train.txt' and 'val.txt'
|
|
|
|
How to run:
|
|
python 03_train_val_split.py
|
|
# optional:
|
|
python -m doctest -v 03_train_val_split.py
|
|
|
|
Notes:
|
|
- Keep 'data.txt' next to this script. If it's missing, small fallback is used
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
from pathlib import Path
|
|
from typing import Tuple
|
|
|
|
FALLBACK = "abcdefg\nhijklmn\nopqrst\nuvwx\nyz\n"
|
|
|
|
|
|
def load_and_normalize(path: Path) -> str:
|
|
"""Load text from path if it exists; else use FALLBACK. Normalize to LF.
|
|
|
|
Args:
|
|
path: Path to 'data.txt'.
|
|
|
|
Returns:
|
|
Text with only '\\n' line endings.
|
|
|
|
>>> load_and_normalize.__doc__ is not None
|
|
True
|
|
"""
|
|
text = path.read_text(encoding="utf-8") if path.exists() else FALLBACK
|
|
return text.replace("\r\n", "\n").replace("\r", "\n")
|
|
|
|
|
|
def split_indices(n: int, train_ratio: float = 0.9) -> Tuple[int, int]:
|
|
"""Return (train_end, val_start) indices for a 1D split of length n.
|
|
|
|
Args:
|
|
n: Total number of characters.
|
|
train_ratio: Fraction for train portion (0.0 < train_ratio < 1.0)
|
|
|
|
Returns:
|
|
(train_end, val_start) where val_start == train_end.
|
|
|
|
>>> split_indices(100, 0.8)
|
|
(80, 80)
|
|
>>> split_indices(5, 0.6)
|
|
(3, 3)
|
|
"""
|
|
assert 0.0 < train_ratio < 1.0, "train_ratio must be between 0 and 1"
|
|
train_end = int(n * train_ratio)
|
|
return train_end, train_end
|
|
|
|
|
|
def main() -> None:
|
|
data_path = Path("data.txt")
|
|
text = load_and_normalize(data_path)
|
|
n = len(text)
|
|
tr_end, va_start = split_indices(n, 0.9)
|
|
train, val = text[:tr_end], text[va_start:]
|
|
|
|
Path("train.txt").write_text(train, encoding="utf-8")
|
|
Path("val.txt").write_text(val, encoding="utf-8")
|
|
|
|
print(f"Total chars: {n}")
|
|
print(f"Train chars: {len(train)}")
|
|
print(f"Val chars: {len(val)}")
|
|
print("Wrote train.txt and val.txt")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|