24 lines
614 B
Python
24 lines
614 B
Python
import torch
|
|
from sklearn.cluster import KMeans
|
|
from model.tokenizer import Tokenizer
|
|
|
|
tokenizer = Tokenizer()
|
|
|
|
|
|
def cluster_vocab(n_clusters=10):
|
|
vocab_items = list(tokenizer.vocab.items())
|
|
words, ids = zip(*vocab_items)
|
|
|
|
embeds = torch.nn.Embedding(len(ids), 256) # Same as model size
|
|
with torch.no_grad():
|
|
vectors = embeds(torch.tensor(list(ids)))
|
|
|
|
kmeans = KMeans(n_clusters=n_clusters)
|
|
labels = kmeans.fit_predict(vectors.cpu().numpy())
|
|
|
|
clusters = {}
|
|
for idx, label in enumerate(labels):
|
|
clusters.setdefault(label, []).append(words[idx])
|
|
|
|
return clusters
|