Ruby/brain_map.py
2025-05-04 17:32:25 -04:00

93 lines
2.5 KiB
Python

import numpy as np
from flask import Blueprint, render_template, jsonify, request
bp = Blueprint(
'brain_map',
__name__,
template_folder='templates',
static_folder='static',
)
# Will be injected from body.py
system = None
@bp.route('/graph')
def graph():
return render_template('graph.html')
@bp.route('/data')
def data():
if system is None:
return jsonify({"nodes": [], "edges": []})
# 1) embeddings → cosine sims
emb = system.brain.token_emb.weight.detach().cpu().numpy()
N = emb.shape[0]
norms = np.linalg.norm(emb, axis=1, keepdims=True)
emb_norm = emb / (norms + 1e-8)
sim = emb_norm.dot(emb_norm.T)
# 2) filters
min_degree = int(request.args.get('min_degree', 1))
max_nodes = int(request.args.get('max_nodes', 200))
# 3) valid tokens
items = [(tok, idx) for tok, idx in system.sensory.stoi.items() if 0 <= idx < N]
# 4) build undirected unique pairs & degree counts
deg = {int(idx): 0 for _, idx in items}
unique_pairs = set()
for _, raw_i in items:
i = int(raw_i)
count = 0
for raw_j in np.argsort(-sim[i]):
j = int(raw_j)
if j == i or j not in deg:
continue
pair = (min(i, j), max(i, j))
if pair in unique_pairs:
continue
unique_pairs.add(pair)
deg[i] += 1
deg[j] += 1
count += 1
if count >= 3:
break
# 5) filter & cap nodes
filtered = [(tok, idx) for tok, idx in items if deg[int(idx)] >= min_degree]
filtered.sort(key=lambda x: (-deg[int(x[1])], int(x[1])))
subset = filtered[:max_nodes]
subset_ids = {int(idx) for _, idx in subset}
# 6) build nodes with HSL coloring
max_deg = max((deg[idx] for _, idx in subset), default=1)
nodes = []
for tok, raw_idx in subset:
idx = int(raw_idx)
d = deg[idx]
hue = int((1 - d / max_deg) * 240)
nodes.append({
'id': idx,
'label': tok,
'color': {
'background': f'hsl({hue},80%,40%)',
'border': f'hsl({hue},60%,30%)',
'highlight': {
'background': f'hsl({hue},100%,50%)',
'border': f'hsl({hue},80%,40%)'
}
}
})
# 7) build edges
edges = [
{'from': a, 'to': b}
for (a, b) in unique_pairs
if a in subset_ids and b in subset_ids
]
return jsonify({'nodes': nodes, 'edges': edges})