107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
import json
|
|
import os
|
|
import re
|
|
import unicodedata
|
|
|
|
|
|
class HybridTokenizer:
|
|
"""Hybrid word/character tokenizer with vocab persistence."""
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_file,
|
|
min_word_freq=5,
|
|
max_vocab_size=10000
|
|
):
|
|
self.vocab_file = vocab_file
|
|
if os.path.exists(vocab_file):
|
|
with open(vocab_file, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
self.word_to_id = data.get('word_to_id', {})
|
|
self.char_to_id = data.get('char_to_id', {})
|
|
else:
|
|
self.word_to_id = {'<unk>': 0}
|
|
self.char_to_id = {}
|
|
self.min_word_freq = min_word_freq
|
|
self.max_vocab_size = max_vocab_size
|
|
|
|
@staticmethod
|
|
def _clean_text(text):
|
|
text = unicodedata.normalize('NFKC', text)
|
|
text = re.sub(r'[\r\n\t]+', ' ', text)
|
|
text = ''.join(ch for ch in text if ch.isprintable())
|
|
return text
|
|
|
|
def build_vocab(self, texts):
|
|
"""Build word and character vocabs from a list of texts."""
|
|
word_freq = {}
|
|
char_set = set()
|
|
|
|
for text in texts:
|
|
text = self._clean_text(text)
|
|
for word in text.split():
|
|
# Preserve Title-case words, lowercase everything else
|
|
if word[0].isupper() and word[1:].islower():
|
|
norm = word
|
|
else:
|
|
norm = word.lower()
|
|
word_freq[norm] = word_freq.get(norm, 0) + 1
|
|
char_set.update(norm)
|
|
|
|
# Pick top words by freq
|
|
words = [
|
|
w for w, f in sorted(
|
|
word_freq.items(),
|
|
key=lambda x: x[1],
|
|
reverse=True
|
|
) if f >= self.min_word_freq
|
|
]
|
|
avail = self.max_vocab_size - len(self.word_to_id)
|
|
for w in words[:avail]:
|
|
if w not in self.word_to_id:
|
|
self.word_to_id[w] = len(self.word_to_id)
|
|
|
|
# Now assign chars after all words
|
|
idx = len(self.word_to_id)
|
|
for ch in sorted(char_set):
|
|
if ch not in self.char_to_id:
|
|
self.char_to_id[ch] = idx
|
|
idx += 1
|
|
|
|
os.makedirs(os.path.dirname(self.vocab_file), exist_ok=True)
|
|
with open(self.vocab_file, 'w', encoding='utf-8') as f:
|
|
json.dump({
|
|
'word_to_id': self.word_to_id,
|
|
'char_to_id': self.char_to_id
|
|
}, f, ensure_ascii=False, indent=2)
|
|
|
|
def encode(self, text):
|
|
"""Convert text into a list of token IDs."""
|
|
text = self._clean_text(text)
|
|
ids = []
|
|
for word in text.split():
|
|
if word[0].isupper() and word[1:].islower():
|
|
norm = word
|
|
else:
|
|
norm = word.lower()
|
|
if norm in self.word_to_id:
|
|
ids.append(self.word_to_id[norm])
|
|
else:
|
|
for ch in norm:
|
|
ids.append(
|
|
self.char_to_id.get(ch, self.word_to_id['<unk>'])
|
|
)
|
|
return ids
|
|
|
|
def decode(self, ids):
|
|
"""Convert a list of token IDs back into text."""
|
|
inv_word = {v: k for k, v in self.word_to_id.items()}
|
|
inv_char = {v: k for k, v in self.char_to_id.items()}
|
|
tokens = []
|
|
for i in ids:
|
|
if i in inv_word:
|
|
tokens.append(inv_word[i])
|
|
else:
|
|
tokens.append(inv_char.get(i, '<unk>'))
|
|
return ' '.join(tokens)
|