43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
# tokenizers/word_tokenizer.py
|
|
|
|
import re
|
|
from collections import Counter
|
|
import pickle
|
|
|
|
|
|
class WordTokenizer:
|
|
def __init__(self, vocab_size=50000):
|
|
self.vocab_size = vocab_size
|
|
self.word_to_id = {"<PAD>": 0, "<UNK>": 1}
|
|
self.id_to_word = {0: "<PAD>", 1: "<UNK>"}
|
|
|
|
def fit(self, texts):
|
|
words = re.findall(r"\b\w+\b", texts.lower())
|
|
freq = Counter(words).most_common(self.vocab_size - 2)
|
|
for idx, (word, _) in enumerate(freq, start=2):
|
|
self.word_to_id[word] = idx
|
|
self.id_to_word[idx] = word
|
|
|
|
def encode(self, text):
|
|
return [self.word_to_id.get(word, 1) for word in re.findall(r"\b\w+\b", text.lower())]
|
|
|
|
def decode(self, tokens):
|
|
return " ".join([self.id_to_word.get(token, "<UNK>") for token in tokens])
|
|
|
|
def save(self, path):
|
|
with open(path, "wb") as f:
|
|
pickle.dump({
|
|
"vocab_size": self.vocab_size,
|
|
"word_to_id": self.word_to_id,
|
|
"id_to_word": self.id_to_word
|
|
}, f)
|
|
|
|
@classmethod
|
|
def load(cls, path):
|
|
with open(path, "rb") as f:
|
|
data = pickle.load(f)
|
|
obj = cls(data["vocab_size"])
|
|
obj.word_to_id = data["word_to_id"]
|
|
obj.id_to_word = data["id_to_word"]
|
|
return obj
|