import torch from sklearn.cluster import KMeans from model.tokenizer import Tokenizer tokenizer = Tokenizer() def cluster_vocab(n_clusters=10): vocab_items = list(tokenizer.vocab.items()) words, ids = zip(*vocab_items) embeds = torch.nn.Embedding(len(ids), 256) # Same as model size with torch.no_grad(): vectors = embeds(torch.tensor(list(ids))) kmeans = KMeans(n_clusters=n_clusters) labels = kmeans.fit_predict(vectors.cpu().numpy()) clusters = {} for idx, label in enumerate(labels): clusters.setdefault(label, []).append(words[idx]) return clusters