23 lines
631 B
Python
23 lines
631 B
Python
import torch
|
|
from sklearn.cluster import KMeans
|
|
from model.tokenizer import Tokenizer
|
|
|
|
tokenizer = Tokenizer()
|
|
|
|
|
|
def cluster_vocab(n_clusters=10):
|
|
vocab_items = list(tokenizer.vocab.items())
|
|
words, ids = zip(*vocab_items)
|
|
|
|
# Instead of embeddings, let's make random small vectors manually
|
|
vectors = torch.randn(len(ids), 64) # 64d random vectors for clustering
|
|
|
|
kmeans = KMeans(n_clusters=min(n_clusters, len(ids)))
|
|
labels = kmeans.fit_predict(vectors.cpu().numpy())
|
|
|
|
clusters = {}
|
|
for idx, label in enumerate(labels):
|
|
clusters.setdefault(label, []).append(words[idx])
|
|
|
|
return clusters
|