26 lines
645 B
Python
26 lines
645 B
Python
import torch
|
|
from sklearn.cluster import KMeans
|
|
from model.tokenizer import Tokenizer
|
|
|
|
tokenizer = Tokenizer()
|
|
|
|
|
|
def cluster_vocab(n_clusters=10):
|
|
vocab_items = list(tokenizer.vocab.items())
|
|
|
|
if not vocab_items:
|
|
return [] # If vocab is empty, just return empty clusters safely
|
|
|
|
words, ids = zip(*vocab_items)
|
|
ids = torch.tensor(ids, dtype=torch.float32).unsqueeze(1)
|
|
|
|
kmeans = KMeans(n_clusters=min(n_clusters, len(words)))
|
|
labels = kmeans.fit_predict(ids)
|
|
|
|
clusters = [[] for _ in range(max(labels) + 1)]
|
|
for word, label in zip(words, labels):
|
|
clusters[label].append(word)
|
|
|
|
return clusters
|
|
|