43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
import torch
|
|
import logging
|
|
|
|
logger = logging.getLogger("tokenizer")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
fh = logging.FileHandler("learned_chars.log")
|
|
formatter = logging.Formatter('%(message)s')
|
|
fh.setFormatter(formatter)
|
|
logger.addHandler(fh)
|
|
|
|
|
|
class ChildTokenizer:
|
|
def __init__(self):
|
|
self.char_to_id = {'<pad>': 0, '<unk>': 1}
|
|
self.id_to_char = {0: '<pad>', 1: '<unk>'}
|
|
self.next_id = 2
|
|
|
|
# 🔤 Bootstrap with common characters
|
|
for ch in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,!? ':;":
|
|
self.char_to_id[ch] = self.next_id
|
|
self.id_to_char[self.next_id] = ch
|
|
self.next_id += 1
|
|
|
|
def encode(self, text, return_tensors=False, freeze=False):
|
|
ids = []
|
|
for ch in text:
|
|
if ch not in self.char_to_id:
|
|
if freeze:
|
|
ids.append(self.char_to_id.get('<unk>', 1))
|
|
continue
|
|
self.char_to_id[ch] = self.next_id
|
|
self.id_to_char[self.next_id] = ch
|
|
self.next_id += 1
|
|
ids.append(self.char_to_id[ch])
|
|
return torch.tensor([ids], dtype=torch.long) if return_tensors else ids
|
|
|
|
def decode(self, ids):
|
|
return ''.join([self.id_to_char.get(i, '<unk>') for i in ids])
|
|
|
|
def vocab_size(self):
|
|
return self.next_id
|