Ruby/model/abstraction.py

24 lines
614 B
Python

import torch
from sklearn.cluster import KMeans
from model.tokenizer import Tokenizer
tokenizer = Tokenizer()
def cluster_vocab(n_clusters=10):
vocab_items = list(tokenizer.vocab.items())
words, ids = zip(*vocab_items)
embeds = torch.nn.Embedding(len(ids), 256) # Same as model size
with torch.no_grad():
vectors = embeds(torch.tensor(list(ids)))
kmeans = KMeans(n_clusters=n_clusters)
labels = kmeans.fit_predict(vectors.cpu().numpy())
clusters = {}
for idx, label in enumerate(labels):
clusters.setdefault(label, []).append(words[idx])
return clusters