""" SentencePiece tokenizer trainer """ import sentencepiece as spm from pathlib import Path from typing import List, Optional import tempfile def train_tokenizer( input_files: List[str], model_prefix: str, vocab_size: int = 32000, model_type: str = "bpe", # or "unigram" character_coverage: float = 0.9995, num_threads: int = 4, user_defined_symbols: Optional[List[str]] = None, max_sentence_length: int = 16384, shuffle_input_sentence: bool = True, seed_sentencepiece_size: int = 1000000, **kwargs ) -> str: """ Train a SentencePiece tokenizer Args: input_files: List of text file paths for training model_prefix: Output model path prefix (will create .model and .vocab files) vocab_size: Target vocabulary size model_type: 'bpe' or 'unigram' character_coverage: Character coverage (0.9995 for multilingual, 1.0 for single language) num_threads: Number of threads for training user_defined_symbols: Optional list of user-defined symbols to add max_sentence_length: Maximum sentence length shuffle_input_sentence: Whether to shuffle input sentences seed_sentencepiece_size: Number of sentences to use for initial seed **kwargs: Additional arguments to pass to SentencePiece trainer Returns: Path to trained model file """ # Validate input files for f in input_files: if not Path(f).exists(): raise FileNotFoundError(f"Input file not found: {f}") # Prepare training arguments train_args = { 'input': ','.join(input_files), 'model_prefix': model_prefix, 'vocab_size': vocab_size, 'model_type': model_type, 'character_coverage': character_coverage, 'num_threads': num_threads, 'max_sentence_length': max_sentence_length, 'shuffle_input_sentence': shuffle_input_sentence, 'seed_sentencepiece_size': seed_sentencepiece_size, # Special tokens 'pad_id': 0, 'unk_id': 1, 'bos_id': 2, 'eos_id': 3, 'pad_piece': '', 'unk_piece': '', 'bos_piece': '', 'eos_piece': '', # User-defined symbols (e.g., for special control tokens) 'user_defined_symbols': user_defined_symbols or [], # Normalization 'normalization_rule_name': 'nmt_nfkc_cf', # Standard normalization 'remove_extra_whitespaces': True, 'split_by_unicode_script': True, 'split_by_whitespace': True, 'split_by_number': True, 'split_digits': True, 'byte_fallback': True, # Handle unknown bytes } # Add any additional kwargs train_args.update(kwargs) # Train the model print(f"Training {model_type.upper()} tokenizer with vocab size {vocab_size}...") print(f"Input files: {len(input_files)} file(s)") print(f"Output: {model_prefix}.model") spm.SentencePieceTrainer.Train(**{k: str(v) if isinstance(v, list) else v for k, v in train_args.items()}) model_path = f"{model_prefix}.model" # Verify the model was created if not Path(model_path).exists(): raise RuntimeError(f"Model training failed - {model_path} not created") # Print vocab info sp = spm.SentencePieceProcessor() sp.Load(model_path) print(f"✓ Tokenizer trained successfully!") print(f" Vocabulary size: {sp.vocab_size()}") print(f" BOS token: {sp.IdToPiece(sp.bos_id())} (ID: {sp.bos_id()})") print(f" EOS token: {sp.IdToPiece(sp.eos_id())} (ID: {sp.eos_id()})") print(f" PAD token: {sp.IdToPiece(sp.pad_id())} (ID: {sp.pad_id()})") print(f" UNK token: {sp.IdToPiece(sp.unk_id())} (ID: {sp.unk_id()})") return model_path def train_from_text( texts: List[str], model_prefix: str, vocab_size: int = 32000, model_type: str = "bpe", **kwargs ) -> str: """ Train tokenizer directly from list of texts (without needing files) Args: texts: List of text strings model_prefix: Output model path prefix vocab_size: Target vocabulary size model_type: 'bpe' or 'unigram' **kwargs: Additional arguments Returns: Path to trained model file """ # Write texts to temporary file with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8') as f: for text in texts: f.write(text.strip() + '\n') temp_file = f.name try: # Train using the temporary file model_path = train_tokenizer( input_files=[temp_file], model_prefix=model_prefix, vocab_size=vocab_size, model_type=model_type, **kwargs ) finally: # Clean up temp file Path(temp_file).unlink(missing_ok=True) return model_path