Ruby/utils/tokenizer.py

107 lines
3.4 KiB
Python

import json
import os
import re
import unicodedata
class HybridTokenizer:
"""Hybrid word/character tokenizer with vocab persistence."""
def __init__(
self,
vocab_file,
min_word_freq=5,
max_vocab_size=10000
):
self.vocab_file = vocab_file
if os.path.exists(vocab_file):
with open(vocab_file, 'r', encoding='utf-8') as f:
data = json.load(f)
self.word_to_id = data.get('word_to_id', {})
self.char_to_id = data.get('char_to_id', {})
else:
self.word_to_id = {'<unk>': 0}
self.char_to_id = {}
self.min_word_freq = min_word_freq
self.max_vocab_size = max_vocab_size
@staticmethod
def _clean_text(text):
text = unicodedata.normalize('NFKC', text)
text = re.sub(r'[\r\n\t]+', ' ', text)
text = ''.join(ch for ch in text if ch.isprintable())
return text
def build_vocab(self, texts):
"""Build word and character vocabs from a list of texts."""
word_freq = {}
char_set = set()
for text in texts:
text = self._clean_text(text)
for word in text.split():
# Preserve Title-case words, lowercase everything else
if word[0].isupper() and word[1:].islower():
norm = word
else:
norm = word.lower()
word_freq[norm] = word_freq.get(norm, 0) + 1
char_set.update(norm)
# Pick top words by freq
words = [
w for w, f in sorted(
word_freq.items(),
key=lambda x: x[1],
reverse=True
) if f >= self.min_word_freq
]
avail = self.max_vocab_size - len(self.word_to_id)
for w in words[:avail]:
if w not in self.word_to_id:
self.word_to_id[w] = len(self.word_to_id)
# Now assign chars after all words
idx = len(self.word_to_id)
for ch in sorted(char_set):
if ch not in self.char_to_id:
self.char_to_id[ch] = idx
idx += 1
os.makedirs(os.path.dirname(self.vocab_file), exist_ok=True)
with open(self.vocab_file, 'w', encoding='utf-8') as f:
json.dump({
'word_to_id': self.word_to_id,
'char_to_id': self.char_to_id
}, f, ensure_ascii=False, indent=2)
def encode(self, text):
"""Convert text into a list of token IDs."""
text = self._clean_text(text)
ids = []
for word in text.split():
if word[0].isupper() and word[1:].islower():
norm = word
else:
norm = word.lower()
if norm in self.word_to_id:
ids.append(self.word_to_id[norm])
else:
for ch in norm:
ids.append(
self.char_to_id.get(ch, self.word_to_id['<unk>'])
)
return ids
def decode(self, ids):
"""Convert a list of token IDs back into text."""
inv_word = {v: k for k, v in self.word_to_id.items()}
inv_char = {v: k for k, v in self.char_to_id.items()}
tokens = []
for i in ids:
if i in inv_word:
tokens.append(inv_word[i])
else:
tokens.append(inv_char.get(i, '<unk>'))
return ' '.join(tokens)