Ruby/utils/abstraction.py

29 lines
854 B
Python

import torch
from sklearn.cluster import KMeans
from ego.tokenizer import Tokenizer
tokenizer = Tokenizer()
SPECIAL_TOKENS = {"<pad>", "<unk>", "<start>", "<end>", "<sep>"}
def cluster_vocab(n_clusters=10):
vocab_items = [(word, idx) for word, idx in tokenizer.vocab.items() if word not in SPECIAL_TOKENS]
if len(vocab_items) < 2:
return [] # Not enough real words to cluster
words, ids = zip(*vocab_items)
# Use 1D embedding: you can expand this to real model vectors later
vectors = torch.eye(len(words), dtype=torch.float32) # fake embeddings
kmeans = KMeans(n_clusters=min(n_clusters, len(words)), n_init="auto")
labels = kmeans.fit_predict(vectors)
clusters = [[] for _ in range(max(labels) + 1)]
for word, label in zip(words, labels):
clusters[label].append(word)
return clusters