feat: Add train/val split script and update gitignore
- Added 03_train_val_split.py to create deterministic train/validation splits from data.txt or fallback text - Updated .gitignore to un-comment .vscode/ directory exclusion - Changed data.txt pattern to *.txt for better file matching in gitignore - Script handles UTF-8 text loading with newline normalization and writes train.txt/val.txt files - Includes doctest examples and proper type hints
This commit is contained in:
79
03_train_val_split.py
Normal file
79
03_train_val_split.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# 03_train_val_split.py
|
||||
"""
|
||||
Create a train/val split from data.txt or a small fallback, and save to disk.
|
||||
|
||||
What it does:
|
||||
- Loads UTF-8 text, normalizes newlines to "\\n".
|
||||
- Splits by character index (default 90/10) deterministically.
|
||||
- Writes 'train.txt' and 'val.txt'
|
||||
|
||||
How to run:
|
||||
python 03_train_val_split.py
|
||||
# optional:
|
||||
python -m doctest -v 03_train_val_split.py
|
||||
|
||||
Notes:
|
||||
- Keep 'data.txt' next to this script. If it's missing, small fallback is used
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
FALLBACK = "abcdefg\nhijklmn\nopqrst\nuvwx\nyz\n"
|
||||
|
||||
|
||||
def load_and_normalize(path: Path) -> str:
|
||||
"""Load text from path if it exists; else use FALLBACK. Normalize to LF.
|
||||
|
||||
Args:
|
||||
path: Path to 'data.txt'.
|
||||
|
||||
Returns:
|
||||
Text with only '\\n' line endings.
|
||||
|
||||
>>> load_and_normalize.__doc__ is not None
|
||||
True
|
||||
"""
|
||||
text = path.read_text(encoding="utf-8") if path.exists() else FALLBACK
|
||||
return text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
|
||||
def split_indices(n: int, train_ratio: float = 0.9) -> Tuple[int, int]:
|
||||
"""Return (train_end, val_start) indices for a 1D split of length n.
|
||||
|
||||
Args:
|
||||
n: Total number of characters.
|
||||
train_ratio: Fraction for train portion (0.0 < train_ratio < 1.0)
|
||||
|
||||
Returns:
|
||||
(train_end, val_start) where val_start == train_end.
|
||||
|
||||
>>> split_indices(100, 0.8)
|
||||
(80, 80)
|
||||
>>> split_indices(5, 0.6)
|
||||
(3, 3)
|
||||
"""
|
||||
assert 0.0 < train_ratio < 1.0, "train_ratio must be between 0 and 1"
|
||||
train_end = int(n * train_ratio)
|
||||
return train_end, train_end
|
||||
|
||||
|
||||
def main() -> None:
|
||||
data_path = Path("data.txt")
|
||||
text = load_and_normalize(data_path)
|
||||
n = len(text)
|
||||
tr_end, va_start = split_indices(n, 0.9)
|
||||
train, val = text[:tr_end], text[va_start:]
|
||||
|
||||
Path("train.txt").write_text(train, encoding="utf-8")
|
||||
Path("val.txt").write_text(val, encoding="utf-8")
|
||||
|
||||
print(f"Total chars: {n}")
|
||||
print(f"Train chars: {len(train)}")
|
||||
print(f"Val chars: {len(val)}")
|
||||
print("Wrote train.txt and val.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user