49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
import os
|
|
import json
|
|
|
|
|
|
class Sensory:
|
|
"""Dynamic whitespace tokenizer that can grow (or not) its vocab."""
|
|
|
|
def __init__(self):
|
|
self.stoi = {"<pad>": 0, "<unk>": 1}
|
|
self.itos = {0: "<pad>", 1: "<unk>"}
|
|
|
|
def encode(self, text: str, grow: bool = True) -> list[int]:
|
|
ids: list[int] = []
|
|
for tok in text.strip().split():
|
|
if tok not in self.stoi:
|
|
if grow:
|
|
idx = len(self.stoi)
|
|
self.stoi[tok] = idx
|
|
self.itos[idx] = tok
|
|
else:
|
|
idx = self.stoi["<unk>"]
|
|
else:
|
|
idx = self.stoi[tok]
|
|
ids.append(idx)
|
|
return ids
|
|
|
|
def decode(self, ids: list[int]) -> str:
|
|
return " ".join(self.itos.get(i, "<unk>") for i in ids)
|
|
|
|
def save_vocab(self, path: str = "vocab.json") -> None:
|
|
"""Dump stoi+itos to disk."""
|
|
data = {
|
|
"stoi": self.stoi,
|
|
# JSON keys must be strings
|
|
"itos": {str(k): v for k, v in self.itos.items()}
|
|
}
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
def load_vocab(self, path: str = "vocab.json") -> None:
|
|
"""Load stoi+itos if it exists."""
|
|
if not os.path.isfile(path):
|
|
return
|
|
with open(path, encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
self.stoi = data["stoi"]
|
|
# convert itos keys back to int
|
|
self.itos = {int(k): v for k, v in data["itos"].items()}
|