diff --git a/.gitignore b/.gitignore index 82ec0a7..c8a733f 100644 --- a/.gitignore +++ b/.gitignore @@ -199,7 +199,7 @@ cython_debug/ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore # and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder -# .vscode/ +.vscode/ # Ruff stuff: .ruff_cache/ @@ -216,4 +216,4 @@ __marimo__/ .streamlit/secrets.toml # Data/Material that should not be synced -data.txt \ No newline at end of file +*.txt \ No newline at end of file diff --git a/03_train_val_split.py b/03_train_val_split.py new file mode 100644 index 0000000..ead5142 --- /dev/null +++ b/03_train_val_split.py @@ -0,0 +1,79 @@ +# 03_train_val_split.py +""" +Create a train/val split from data.txt or a small fallback, and save to disk. + +What it does: +- Loads UTF-8 text, normalizes newlines to "\\n". +- Splits by character index (default 90/10) deterministically. +- Writes 'train.txt' and 'val.txt' + +How to run: + python 03_train_val_split.py + # optional: + python -m doctest -v 03_train_val_split.py + +Notes: +- Keep 'data.txt' next to this script. If it's missing, small fallback is used +""" + +from __future__ import annotations +from pathlib import Path +from typing import Tuple + +FALLBACK = "abcdefg\nhijklmn\nopqrst\nuvwx\nyz\n" + + +def load_and_normalize(path: Path) -> str: + """Load text from path if it exists; else use FALLBACK. Normalize to LF. + + Args: + path: Path to 'data.txt'. + + Returns: + Text with only '\\n' line endings. + + >>> load_and_normalize.__doc__ is not None + True + """ + text = path.read_text(encoding="utf-8") if path.exists() else FALLBACK + return text.replace("\r\n", "\n").replace("\r", "\n") + + +def split_indices(n: int, train_ratio: float = 0.9) -> Tuple[int, int]: + """Return (train_end, val_start) indices for a 1D split of length n. + + Args: + n: Total number of characters. + train_ratio: Fraction for train portion (0.0 < train_ratio < 1.0) + + Returns: + (train_end, val_start) where val_start == train_end. + + >>> split_indices(100, 0.8) + (80, 80) + >>> split_indices(5, 0.6) + (3, 3) + """ + assert 0.0 < train_ratio < 1.0, "train_ratio must be between 0 and 1" + train_end = int(n * train_ratio) + return train_end, train_end + + +def main() -> None: + data_path = Path("data.txt") + text = load_and_normalize(data_path) + n = len(text) + tr_end, va_start = split_indices(n, 0.9) + train, val = text[:tr_end], text[va_start:] + + Path("train.txt").write_text(train, encoding="utf-8") + Path("val.txt").write_text(val, encoding="utf-8") + + print(f"Total chars: {n}") + print(f"Train chars: {len(train)}") + print(f"Val chars: {len(val)}") + print("Wrote train.txt and val.txt") + + +if __name__ == "__main__": + main()