RubyOld/tokenizer.py
2025-04-08 19:52:01 -04:00

43 lines
1.3 KiB
Python

import torch
import logging
logger = logging.getLogger("tokenizer")
logger.setLevel(logging.INFO)
fh = logging.FileHandler("learned_chars.log")
formatter = logging.Formatter('%(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
class ChildTokenizer:
def __init__(self):
self.char_to_id = {'<pad>': 0, '<unk>': 1}
self.id_to_char = {0: '<pad>', 1: '<unk>'}
self.next_id = 2
# 🔤 Bootstrap with common characters
for ch in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,!? ':;":
self.char_to_id[ch] = self.next_id
self.id_to_char[self.next_id] = ch
self.next_id += 1
def encode(self, text, return_tensors=False, freeze=False):
ids = []
for ch in text:
if ch not in self.char_to_id:
if freeze:
ids.append(self.char_to_id.get('<unk>', 1))
continue
self.char_to_id[ch] = self.next_id
self.id_to_char[self.next_id] = ch
self.next_id += 1
ids.append(self.char_to_id[ch])
return torch.tensor([ids], dtype=torch.long) if return_tensors else ids
def decode(self, ids):
return ''.join([self.id_to_char.get(i, '<unk>') for i in ids])
def vocab_size(self):
return self.next_id