import os

class Tokenizer:
    def __init__(self, vocab_path="tokenizer_vocab.txt"):
        self.vocab_path = vocab_path
        self.vocab = {"<START>": 0, "<END>": 1}
        self.inv_vocab = {0: "<START>", 1: "<END>"}
        self.load_vocab()

    def load_vocab(self):
        if not os.path.exists(self.vocab_path):
            return
        with open(self.vocab_path, "r", encoding="utf-8") as f:
            for line in f:
                token, idx = line.strip().split("\t")
                self.vocab[token] = int(idx)
                if token not in self.vocab:
                    self.vocab[token] = idx
                    self.inv_vocab[idx] = token
        self.inv_vocab = {v: k for k, v in self.vocab.items()}

    def save_vocab(self):
        with open(self.vocab_path, "w", encoding="utf-8") as f:
            for token, idx in self.vocab.items():
                f.write(f"{token}\t{idx}\n")

    def tokenize(self, text):
        tokens = []
        for word in text.strip().split():
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
                self.inv_vocab[self.vocab[word]] = word
            tokens.append(self.vocab[word])
        self.save_vocab()
        return tokens

    def detokenize(self, tokens):
        return " ".join(self.inv_vocab.get(t, "<UNK>") for t in tokens)