29 lines
856 B
Python
29 lines
856 B
Python
import torch
|
|
from sklearn.cluster import KMeans
|
|
from model.tokenizer import Tokenizer
|
|
|
|
tokenizer = Tokenizer()
|
|
|
|
SPECIAL_TOKENS = {"<pad>", "<unk>", "<start>", "<end>", "<sep>"}
|
|
|
|
|
|
def cluster_vocab(n_clusters=10):
|
|
vocab_items = [(word, idx) for word, idx in tokenizer.vocab.items() if word not in SPECIAL_TOKENS]
|
|
|
|
if len(vocab_items) < 2:
|
|
return [] # Not enough real words to cluster
|
|
|
|
words, ids = zip(*vocab_items)
|
|
|
|
# Use 1D embedding: you can expand this to real model vectors later
|
|
vectors = torch.eye(len(words), dtype=torch.float32) # fake embeddings
|
|
|
|
kmeans = KMeans(n_clusters=min(n_clusters, len(words)), n_init="auto")
|
|
labels = kmeans.fit_predict(vectors)
|
|
|
|
clusters = [[] for _ in range(max(labels) + 1)]
|
|
for word, label in zip(words, labels):
|
|
clusters[label].append(word)
|
|
|
|
return clusters
|