From 56775d486ca7d7dbf18ce051727e8e7f30da0716 Mon Sep 17 00:00:00 2001 From: Oli Clive-Griffin Date: Fri, 5 Dec 2025 13:29:24 +0000 Subject: [PATCH 001/500] wip: Inline edge normalization and add backend user tracking --- spd/app/backend/lib/edge_normalization.py | 48 -------- spd/app/backend/routers/graphs.py | 25 +++- spd/app/backend/routers/runs.py | 9 ++ spd/app/backend/schemas.py | 1 + spd/app/frontend/src/App.svelte | 116 ++++++++++-------- .../local-attr/PromptCardTabs.svelte | 2 +- spd/app/frontend/src/lib/api.ts | 7 ++ spd/attributions/compute.py | 0 spd/data.py | 46 +++++++ 9 files changed, 153 insertions(+), 101 deletions(-) delete mode 100644 spd/app/backend/lib/edge_normalization.py delete mode 100644 spd/attributions/compute.py diff --git a/spd/app/backend/lib/edge_normalization.py b/spd/app/backend/lib/edge_normalization.py deleted file mode 100644 index a382bf5e8..000000000 --- a/spd/app/backend/lib/edge_normalization.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Edge normalization utilities for attribution graphs.""" - -import copy -from collections import defaultdict - -from spd.app.backend.compute import Edge - - -def normalize_edges_by_target(edges: list[Edge]) -> list[Edge]: - """Normalize edges so incoming edges to each target node sum to 1. - - Groups edges by target node (target:s_out:c_out_idx) and normalizes - the absolute values of incoming edges to sum to 1, preserving signs. - - Args: - edges: List of Edge dataclasses. - - Returns: - New list of edges with normalized values. - """ - if not edges: - return edges - - # Group edges by target node - edges_by_target: dict[str, list[tuple[int, Edge]]] = defaultdict(list) - for i, edge in enumerate(edges): - edges_by_target[str(edge.target)].append((i, edge)) - - # Normalize each group - normalized = copy.copy(edges) # Shallow copy of list - for _target_key, edge_group in edges_by_target.items(): - # Sum of absolute values - total_abs = sum(edge.strength**2 for _, edge in edge_group) - if total_abs == 0: - continue - - # Normalize: val -> val / total_abs (preserves sign) - for idx, edge in edge_group: - new_val = edge.strength / total_abs - # Create new Edge with updated strength - normalized[idx] = Edge( - source=edge.source, - target=edge.target, - strength=new_val, - is_cross_seq=edge.is_cross_seq, - ) - - return normalized diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 6557afe31..335a09592 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -1,9 +1,11 @@ """Graph computation endpoints for tokenization and attribution graphs.""" import json +import math import queue import threading from collections.abc import Generator +from itertools import groupby from typing import Annotated, Any import torch @@ -11,13 +13,13 @@ from fastapi.responses import JSONResponse, StreamingResponse from spd.app.backend.compute import ( + Edge, LocalAttributionResult, OptimizedLocalAttributionResult, compute_local_attributions, compute_local_attributions_optimized, ) from spd.app.backend.dependencies import DepLoadedRun -from spd.app.backend.lib.edge_normalization import normalize_edges_by_target from spd.app.backend.optim_cis.run_optim_cis import OptimCIConfig from spd.app.backend.schemas import ( EdgeData, @@ -117,7 +119,7 @@ def generate() -> Generator[str]: edges = edges[:GLOBAL_EDGE_LIMIT] if normalize: - edges = normalize_edges_by_target(edges) + edges = _normalize_edges_by_target(edges) edges_typed = [ EdgeData(src=str(e.source), tgt=str(e.target), val=e.strength) for e in edges @@ -152,6 +154,23 @@ def generate() -> Generator[str]: return StreamingResponse(generate(), media_type="text/event-stream") +def _normalize_edges_by_target(edges: list[Edge]) -> list[Edge]: + out_edges = [] + for _, incoming_edges in groupby(edges, key=lambda e: e.target): + incoming_edges = list(incoming_edges) # list() because we iterate over the group twice + incoming_strength = math.sqrt(sum(edge.strength**2 for edge in incoming_edges)) + for edge in incoming_edges: + out_edges.append( + Edge( + source=edge.source, + target=edge.target, + is_cross_seq=edge.is_cross_seq, + strength=edge.strength / incoming_strength, + ) + ) + return out_edges + + @router.post("/optimized/stream") @log_errors def compute_graph_optimized_stream( @@ -240,7 +259,7 @@ def generate() -> Generator[str]: edges = edges[:GLOBAL_EDGE_LIMIT] if normalize: - edges = normalize_edges_by_target(edges) + edges = _normalize_edges_by_target(edges) edges_typed = [ EdgeData(src=str(e.source), tgt=str(e.target), val=e.strength) for e in edges diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index f2ca720a2..60100f9e5 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -1,5 +1,6 @@ """Run management endpoints.""" +import getpass from urllib.parse import unquote import torch @@ -149,6 +150,7 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: has_prompts=prompt_count > 0, prompt_count=prompt_count, context_length=context_length, + backend_user=getpass.getuser(), ) @@ -157,3 +159,10 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: def health_check() -> dict[str, str]: """Health check endpoint.""" return {"status": "ok"} + + +@router.get("/whoami") +@log_errors +def whoami() -> dict[str, str]: + """Return the current backend user.""" + return {"user": getpass.getuser()} diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index 436c224b8..f919e97aa 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -185,6 +185,7 @@ class LoadedRun(BaseModel): has_prompts: bool prompt_count: int context_length: int + backend_user: str class SubcomponentMetadata(BaseModel): diff --git a/spd/app/frontend/src/App.svelte b/spd/app/frontend/src/App.svelte index bc2fffd26..3acb3cf71 100644 --- a/spd/app/frontend/src/App.svelte +++ b/spd/app/frontend/src/App.svelte @@ -16,6 +16,7 @@ let loadedRun = $state(null); let backendError = $state(null); + let backendUser = $state(null); async function loadStatus() { if (loadingTrainRun) return; @@ -64,7 +65,10 @@ } } - onMount(loadStatus); + onMount(() => { + loadStatus(); + api.getWhoami().then((user) => (backendUser = user)); + }); let activeTab = $state<"activation-contexts" | "local-attributions" | null>(null); let showConfig = $state(false); @@ -73,6 +77,7 @@
+ user: {backendUser ?? "..."}
+ {#if loadedRun} +
(showConfig = true)} + onmouseleave={() => (showConfig = false)} + > + + {#if showConfig} +
+
{loadedRun.config_yaml}
+
+ {/if} +
+ {/if} -
- {#if loadedRun} - - -
(showConfig = true)} - onmouseleave={() => (showConfig = false)} - > - - {#if showConfig} -
-
{loadedRun.config_yaml}
-
- {/if} -
- {/if}
+ {#if loadedRun} + + {/if} +
{#if backendError}
@@ -163,11 +170,12 @@ .top-bar { display: flex; align-items: center; + justify-content: flex-start; gap: var(--space-4); padding: var(--space-2) var(--space-3); background: var(--bg-surface); border-bottom: 1px solid var(--border-default); - flex-shrink: 0; + /* flex-shrink: 0; */ } .run-input { @@ -216,29 +224,24 @@ border-color: var(--accent-primary-dim); } - .run-input button { + .load-button { padding: var(--space-1) var(--space-3); - background: var(--accent-primary); - color: white; - border: none; + color: var(--text-primary); + border: 1px solid var(--border-default); font-weight: 500; white-space: nowrap; } - .run-input button:hover:not(:disabled) { - background: var(--accent-primary-dim); + .load-button:hover:not(:disabled) { + background: var(--accent-primary); + color: white; } - .run-input button:disabled { + .load-button:disabled { background: var(--border-default); color: var(--text-muted); } - .tab-navigation { - display: flex; - gap: var(--space-1); - } - .tab-button { padding: var(--space-1) var(--space-3); background: var(--bg-elevated); @@ -261,7 +264,6 @@ .config-wrapper { position: relative; - margin-left: auto; } .config-button { @@ -277,7 +279,7 @@ } .config-button:hover { - border-color: var(--border-strong); + background: var(--bg-inset); color: var(--text-primary); } @@ -306,6 +308,22 @@ word-wrap: break-word; } + .backend-user { + font-size: var(--text-sm); + font-family: var(--font-mono); + color: var(--text-muted); + white-space: nowrap; + } + + .tab-bar { + display: flex; + gap: var(--space-2); + padding: var(--space-3); + background: var(--bg-surface); + border-bottom: 1px solid var(--border-default); + flex-shrink: 0; + } + .main-content { flex: 1; min-width: 0; diff --git a/spd/app/frontend/src/components/local-attr/PromptCardTabs.svelte b/spd/app/frontend/src/components/local-attr/PromptCardTabs.svelte index 221642241..a3a8d785f 100644 --- a/spd/app/frontend/src/components/local-attr/PromptCardTabs.svelte +++ b/spd/app/frontend/src/components/local-attr/PromptCardTabs.svelte @@ -27,7 +27,7 @@
{/each} - +
diff --git a/spd/app/frontend/src/components/LocalAttributionsGraph.svelte b/spd/app/frontend/src/components/LocalAttributionsGraph.svelte index e68bace5a..2d55c7309 100644 --- a/spd/app/frontend/src/components/LocalAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsGraph.svelte @@ -37,6 +37,7 @@ onPinnedNodesChange: (nodes: PinnedNode[]) => void; onLoadComponentDetail: (layer: string, cIdx: number) => void; onEdgeCountChange?: (count: number) => void; + onStageNode?: (layer: string, seqPos: number, componentIdx: number) => void; }; let { @@ -52,6 +53,7 @@ onPinnedNodesChange, onLoadComponentDetail, onEdgeCountChange, + onStageNode, }: Props = $props(); // UI state @@ -753,7 +755,21 @@ handleNodeMouseLeave(); }} > -

{hoveredNode.layer}:{hoveredNode.cIdx}

+
+

{hoveredNode.layer}:{hoveredNode.cIdx}

+ {#if onStageNode && hoveredNode.layer !== "output" && hoveredNode.layer !== "wte"} + + {/if} +
diff --git a/spd/app/frontend/src/components/LocalAttributionsTab.svelte b/spd/app/frontend/src/components/LocalAttributionsTab.svelte index 6b4341e95..9834b3c13 100644 --- a/spd/app/frontend/src/components/LocalAttributionsTab.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsTab.svelte @@ -1,6 +1,7 @@ @@ -157,7 +163,7 @@ class:active={activeTab === "intervention"} onclick={() => (activeTab = "intervention")} > - Intervention {#if stagedNodes.length > 0}{stagedNodes.length}{/if} + Intervention {#if pinnedNodes.length > 0}{pinnedNodes.length}{/if} + Staged Nodes ({pinnedNodes.length}) +
+ + +
+
- {#each pinnedNodes as pinned (`${pinned.layer}:${pinned.cIdx}`)} + {#each pinnedNodes as pinned, idx (`${pinned.layer}:${pinned.seqIdx}:${pinned.cIdx}-${idx}`)} {@const detail = componentDetailsCache[`${pinned.layer}:${pinned.cIdx}`]} {@const isLoading = !detail && pinned.layer !== "output"} + {@const token = getTokenAtPosition(pinned.seqIdx)}
- {pinned.layer}:{pinned.cIdx} +
+ {pinned.layer}:{pinned.seqIdx}:{pinned.cIdx} + "{token}" +
@@ -69,18 +94,38 @@ margin-bottom: var(--space-2); } + .header-actions { + display: flex; + gap: var(--space-2); + } + .pinned-container-header button { background: var(--bg-elevated); border: 1px solid var(--border-default); color: var(--text-secondary); } - .pinned-container-header button:hover { + .pinned-container-header button:hover:not(:disabled) { background: var(--bg-inset); color: var(--text-primary); border-color: var(--border-strong); } + .run-btn { + background: var(--accent-primary) !important; + color: white !important; + border-color: var(--accent-primary) !important; + } + + .run-btn:hover:not(:disabled) { + background: var(--accent-primary-dim) !important; + } + + .run-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } + .pinned-items { display: flex; flex-direction: row; @@ -105,6 +150,12 @@ border-bottom: 1px solid var(--border-subtle); } + .node-info { + display: flex; + flex-direction: column; + gap: var(--space-1); + } + .pinned-header strong { font-family: var(--font-mono); font-size: var(--text-base); @@ -112,6 +163,12 @@ font-weight: 600; } + .token-preview { + font-family: var(--font-mono); + font-size: var(--text-sm); + color: var(--text-muted); + } + .unpin-btn { background: var(--status-negative); color: white; diff --git a/spd/app/frontend/src/lib/localAttributionsTypes.ts b/spd/app/frontend/src/lib/localAttributionsTypes.ts index ddf791366..82f93d58e 100644 --- a/spd/app/frontend/src/lib/localAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/localAttributionsTypes.ts @@ -90,6 +90,7 @@ export type NodePosition = { export type PinnedNode = { layer: string; + seqIdx: number; cIdx: number; }; From bca6883bc67dc475efe12d5b644badb4ee287002 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 8 Dec 2025 17:41:14 +0000 Subject: [PATCH 017/500] interventions is nice now, filtering cleaner --- spd/app/backend/routers/graphs.py | 80 +- spd/app/frontend/src/App.svelte | 80 +- .../src/components/InterventionTab.svelte | 486 ------------ .../components/LocalAttributionsGraph.svelte | 22 +- .../components/LocalAttributionsTab.svelte | 357 ++++++--- .../local-attr/InterventionsView.svelte | 698 ++++++++++++++++++ .../components/local-attr/PromptPicker.svelte | 18 +- ...tsPanel.svelte => StagedNodesPanel.svelte} | 87 ++- .../src/components/local-attr/types.ts | 13 +- 9 files changed, 1083 insertions(+), 758 deletions(-) delete mode 100644 spd/app/frontend/src/components/InterventionTab.svelte create mode 100644 spd/app/frontend/src/components/local-attr/InterventionsView.svelte rename spd/app/frontend/src/components/local-attr/{PinnedComponentsPanel.svelte => StagedNodesPanel.svelte} (63%) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 9f1c21d10..582626391 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -168,19 +168,15 @@ def generate() -> Generator[str]: ), ) - # Apply normalization for response - edges = _normalize_edges(raw_edges, normalize) - if len(edges) > GLOBAL_EDGE_LIMIT: - edges.sort(key=lambda e: abs(e.strength), reverse=True) - edges = edges[:GLOBAL_EDGE_LIMIT] - - edges_typed = [_edge_to_edge_data(e) for e in edges] - node_importance, max_abs_attr = compute_edge_stats(edges_typed) + # Process edges for response + edges_data, node_importance, max_abs_attr = process_edges_for_response( + raw_edges, normalize, num_tokens=len(token_ids), is_optimized=False + ) response_data = GraphData( id=prompt_id, tokens=token_strings, - edges=edges_typed, + edges=edges_data, outputProbs=raw_output_probs, nodeImportance=node_importance, maxAbsAttr=max_abs_attr, @@ -360,20 +356,15 @@ def generate() -> Generator[str]: ), ) - # Apply normalization for response - edges = _remove_non_final_output_nodes(raw_edges, len(token_ids)) - edges = _normalize_edges(edges, normalize) - if len(edges) > GLOBAL_EDGE_LIMIT: - edges.sort(key=lambda e: abs(e.strength), reverse=True) - edges = edges[:GLOBAL_EDGE_LIMIT] - - edges_typed = [_edge_to_edge_data(e) for e in edges] - node_importance, max_abs_attr = compute_edge_stats(edges_typed) + # Process edges for response + edges_data, node_importance, max_abs_attr = process_edges_for_response( + raw_edges, normalize, num_tokens=len(token_ids), is_optimized=True + ) response_data = GraphDataWithOptimization( id=prompt_id, tokens=token_strings, - edges=edges_typed, + edges=edges_data, outputProbs=raw_output_probs, nodeImportance=node_importance, maxAbsAttr=max_abs_attr, @@ -397,9 +388,38 @@ def generate() -> Generator[str]: return StreamingResponse(generate(), media_type="text/event-stream") -def _remove_non_final_output_nodes(edges: list[Edge], num_tokens: int) -> list[Edge]: - """Remove edges that are not final output nodes.""" - return [edge for edge in edges if edge.target.seq_pos == num_tokens - 1] +def process_edges_for_response( + edges: list[Edge], + normalize: NormalizeType, + num_tokens: int, + is_optimized: bool, + edge_limit: int = GLOBAL_EDGE_LIMIT, +) -> tuple[list[EdgeData], dict[str, float], float]: + """Single source of truth for edge processing pipeline. + + Applies filtering, normalization, limiting, and computes stats. + Guarantees identical processing for compute and retrieval paths. + + Args: + edges: Raw edges from computation or database + normalize: Normalization type ("none", "target", "layer") + num_tokens: Number of tokens in the prompt (for filtering) + is_optimized: Whether this is an optimized graph (applies additional filtering) + edge_limit: Maximum number of edges to return + + Returns: + (edges_data, node_importance, max_abs_attr) + """ + if is_optimized: + final_seq_pos = num_tokens - 1 + edges = [edge for edge in edges if edge.target.seq_pos == final_seq_pos] + edges = _normalize_edges(edges, normalize) + if len(edges) > edge_limit: + print(f"[WARNING] Edge limit {edge_limit} exceeded ({len(edges)} edges), truncating") + edges = sorted(edges, key=lambda e: abs(e.strength), reverse=True)[:edge_limit] + edges_data = [_edge_to_edge_data(e) for e in edges] + node_importance, max_abs_attr = compute_edge_stats(edges_data) + return edges_data, node_importance, max_abs_attr @router.get("/{prompt_id}") @@ -423,18 +443,15 @@ def get_graphs( token_strings = [loaded.token_strings[t] for t in prompt.token_ids] stored_graphs = db.get_graphs(prompt_id) + num_tokens = len(prompt.token_ids) results: list[GraphData | GraphDataWithOptimization] = [] for graph in stored_graphs: - # Normalize and convert to API format - edges = _normalize_edges(graph.edges, normalize) - if len(edges) > GLOBAL_EDGE_LIMIT: - edges.sort(key=lambda e: abs(e.strength), reverse=True) - edges = edges[:GLOBAL_EDGE_LIMIT] - edges_data = [_edge_to_edge_data(e) for e in edges] - - node_importance, max_abs_attr = compute_edge_stats(edges_data) + is_optimized = graph.optimization_params is not None + edges_data, node_importance, max_abs_attr = process_edges_for_response( + graph.edges, normalize, num_tokens, is_optimized + ) - if graph.optimization_params is None: + if not is_optimized: # Standard graph results.append( GraphData( @@ -448,6 +465,7 @@ def get_graphs( ) else: # Optimized graph + assert graph.optimization_params is not None assert graph.optimization_stats is not None results.append( GraphDataWithOptimization( diff --git a/spd/app/frontend/src/App.svelte b/spd/app/frontend/src/App.svelte index 131ccd617..cefffef41 100644 --- a/spd/app/frontend/src/App.svelte +++ b/spd/app/frontend/src/App.svelte @@ -2,10 +2,8 @@ // import { RenderScan } from "svelte-render-scan"; import type { LoadedRun } from "./lib/api"; import * as api from "./lib/api"; - import type { PinnedNode } from "./lib/localAttributionsTypes"; import ActivationContextsTab from "./components/ActivationContextsTab.svelte"; - import InterventionTab from "./components/InterventionTab.svelte"; import LocalAttributionsTab from "./components/LocalAttributionsTab.svelte"; import { onMount } from "svelte"; @@ -69,39 +67,8 @@ api.getWhoami().then((user) => (backendUser = user)); }); - let activeTab = $state<"activation-contexts" | "local-attributions" | "intervention" | null>(null); + let activeTab = $state<"prompts" | "activation-contexts" | null>(null); let showConfig = $state(false); - - // Pinned/staged nodes (shared between LocalAttributionsTab and InterventionTab) - let pinnedNodes = $state([]); - let interventionText = $state(""); - - function handlePinnedNodesChange(nodes: PinnedNode[]) { - pinnedNodes = nodes; - } - - function clearPinnedNodes() { - pinnedNodes = []; - } - - function addPinnedNode(node: PinnedNode) { - // Avoid duplicates - const exists = pinnedNodes.some( - (n) => n.layer === node.layer && n.seqIdx === node.seqIdx && n.cIdx === node.cIdx - ); - if (!exists) { - pinnedNodes = [...pinnedNodes, node]; - } - } - - function removePinnedNode(index: number) { - pinnedNodes = pinnedNodes.filter((_, i) => i !== index); - } - - function goToIntervention(text: string) { - interventionText = text; - activeTab = "intervention"; - } @@ -153,17 +120,10 @@
- - {/if} - -
- e.key === "Enter" && parseAndAddNode()} - /> - - {#if nodeEntryError} - {nodeEntryError} - {/if} -
-
- -
- - -
- -
- -
- - - - {#if error} -
{error}
- {/if} - - - {#if result} -
-

Results

-
- - - - - - - - - - {#each result.input_tokens as inputToken, pos (pos)} - - - - - - {/each} - -
PosInput TokenTop Predictions (next token)
{pos} - {inputToken} - - {#each result.predictions_per_position[pos] as pred, i (pred.token_id)} - - {pred.token} - {formatProb(pred.prob)} - - {/each} -
-
-
- {/if} - - - diff --git a/spd/app/frontend/src/components/LocalAttributionsGraph.svelte b/spd/app/frontend/src/components/LocalAttributionsGraph.svelte index f96864ad1..c3e24c659 100644 --- a/spd/app/frontend/src/components/LocalAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsGraph.svelte @@ -31,10 +31,10 @@ componentGap: number; layerGap: number; activationContextsSummary: ActivationContextsSummary | null; - pinnedNodes: PinnedNode[]; + stagedNodes: PinnedNode[]; componentDetailsCache: Record; componentDetailsLoading: Record; - onPinnedNodesChange: (nodes: PinnedNode[]) => void; + onStagedNodesChange: (nodes: PinnedNode[]) => void; onLoadComponentDetail: (layer: string, cIdx: number) => void; onEdgeCountChange?: (count: number) => void; }; @@ -46,10 +46,10 @@ componentGap, layerGap, activationContextsSummary, - pinnedNodes, + stagedNodes, componentDetailsCache, componentDetailsLoading, - onPinnedNodesChange, + onStagedNodesChange, onLoadComponentDetail, onEdgeCountChange, }: Props = $props(); @@ -344,10 +344,10 @@ // Check if an edge is connected to any pinned node (exact match including seqIdx) function isEdgeConnectedToPinnedNode(src: string, tgt: string): boolean { - if (pinnedNodes.length === 0) return false; + if (stagedNodes.length === 0) return false; const [srcLayer, srcSeqIdx, srcCIdx] = src.split(":"); const [tgtLayer, tgtSeqIdx, tgtCIdx] = tgt.split(":"); - for (const pinned of pinnedNodes) { + for (const pinned of stagedNodes) { if ( (srcLayer === pinned.layer && +srcSeqIdx === pinned.seqIdx && +srcCIdx === pinned.cIdx) || (tgtLayer === pinned.layer && +tgtSeqIdx === pinned.seqIdx && +tgtCIdx === pinned.cIdx) @@ -406,7 +406,7 @@ }); function isNodePinned(layer: string, seqIdx: number, cIdx: number): boolean { - return pinnedNodes.some((p) => p.layer === layer && p.seqIdx === seqIdx && p.cIdx === cIdx); + return stagedNodes.some((p) => p.layer === layer && p.seqIdx === seqIdx && p.cIdx === cIdx); } // Check if a node key should be highlighted @@ -415,7 +415,7 @@ function isKeyHighlighted(key: string): boolean { const [layer, seqIdx, cIdx] = key.split(":"); // Exact match for pinned nodes - if (pinnedNodes.some((p) => p.layer === layer && p.seqIdx === +seqIdx && p.cIdx === +cIdx)) { + if (stagedNodes.some((p) => p.layer === layer && p.seqIdx === +seqIdx && p.cIdx === +cIdx)) { return true; } // For hover: highlight all nodes with same component (across all positions) @@ -501,11 +501,11 @@ } function handleNodeClick(layer: string, seqIdx: number, cIdx: number) { - const idx = pinnedNodes.findIndex((p) => p.layer === layer && p.seqIdx === seqIdx && p.cIdx === cIdx); + const idx = stagedNodes.findIndex((p) => p.layer === layer && p.seqIdx === seqIdx && p.cIdx === cIdx); if (idx >= 0) { - onPinnedNodesChange(pinnedNodes.filter((_, i) => i !== idx)); + onStagedNodesChange(stagedNodes.filter((_, i) => i !== idx)); } else { - onPinnedNodesChange([...pinnedNodes, { layer, seqIdx, cIdx }]); + onStagedNodesChange([...stagedNodes, { layer, seqIdx, cIdx }]); } hoveredNode = null; } diff --git a/spd/app/frontend/src/components/LocalAttributionsTab.svelte b/spd/app/frontend/src/components/LocalAttributionsTab.svelte index 52c5e2346..4a239ca3d 100644 --- a/spd/app/frontend/src/components/LocalAttributionsTab.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsTab.svelte @@ -8,23 +8,18 @@ PinnedNode, PromptPreview, } from "../lib/localAttributionsTypes"; - import PinnedComponentsPanel from "./local-attr/PinnedComponentsPanel.svelte"; + import type { InterventionResponse } from "../lib/interventionTypes"; + import StagedNodesPanel from "./local-attr/StagedNodesPanel.svelte"; import ComputeProgressOverlay from "./local-attr/ComputeProgressOverlay.svelte"; + import InterventionsView from "./local-attr/InterventionsView.svelte"; import PromptCardHeader from "./local-attr/PromptCardHeader.svelte"; import PromptCardTabs from "./local-attr/PromptCardTabs.svelte"; import PromptPicker from "./local-attr/PromptPicker.svelte"; - import type { StoredGraph, ComputeOptions, LoadingState, OptimizeConfig, PromptCard } from "./local-attr/types"; + import type { StoredGraph, ComputeOptions, LoadingState, OptimizeConfig, PromptCard, Intervention } from "./local-attr/types"; import ViewControls from "./local-attr/ViewControls.svelte"; import LocalAttributionsGraph from "./LocalAttributionsGraph.svelte"; - // Props - pinnedNodes are managed in App.svelte (shared with InterventionTab) - type Props = { - pinnedNodes: PinnedNode[]; - onPinnedNodesChange: (nodes: PinnedNode[]) => void; - onGoToIntervention: (text: string) => void; - }; - - let { pinnedNodes, onPinnedNodesChange, onGoToIntervention }: Props = $props(); + // No props - all state managed internally now // Server state let loadedRun = $state(null); @@ -39,7 +34,7 @@ // Prompt picker state let showPromptPicker = $state(false); - let filterByPinned = $state(false); + let filterByStaged = $state(false); let filteredPrompts = $state([]); let filterLoading = $state(false); let isAddingCustomPrompt = $state(false); @@ -49,6 +44,9 @@ let loadingState = $state(null); let computeError = $state(null); + // Intervention loading state + let runningIntervention = $state(false); + // Graph generation state let generatingGraphs = $state(false); let generateProgress = $state(0); @@ -69,7 +67,7 @@ // Compute options let computeOptions = $state({ maxMeanCI: 1.0, - normalizeEdges: "layer", // kept for compute, but view uses normalizeEdges state + normalizeEdges: "layer", useOptimized: false, optimizeConfig: { labelTokenText: "", @@ -82,7 +80,7 @@ }, }); - // Component details cache (shared across graphs, persists across graph switches) + // Component details cache (shared across graphs) let componentDetailsCache = $state>({}); let componentDetailsLoading = $state>({}); @@ -211,6 +209,7 @@ id: `graph-${idx}-${Date.now()}`, label, data, + stagedNodes: [], }; }); } catch (e) { @@ -225,6 +224,8 @@ isCustom, graphs, activeGraphId: graphs.length > 0 ? graphs[0].id : null, + interventions: [], + activeView: "graph", }; promptCards = [...promptCards, newCard]; activeCardId = cardId; @@ -284,6 +285,78 @@ computeOptions.optimizeConfig = { ...computeOptions.optimizeConfig, ...partial }; } + // Update staged nodes for the active graph + function handleStagedNodesChange(nodes: PinnedNode[]) { + if (!activeCard || !activeGraph) return; + promptCards = promptCards.map((card) => { + if (card.id !== activeCard.id) return card; + return { + ...card, + graphs: card.graphs.map((g) => + g.id === activeGraph.id ? { ...g, stagedNodes: nodes } : g + ), + }; + }); + + // Re-filter if needed + if (filterByStaged) { + filterPromptsByStaged(); + } + } + + // Switch between graph and interventions view + function handleViewChange(view: "graph" | "interventions") { + if (!activeCard) return; + promptCards = promptCards.map((card) => + card.id === activeCard.id ? { ...card, activeView: view } : card, + ); + } + + // Run intervention and add to prompt's list + async function handleRunIntervention() { + if (!activeCard || !activeGraph || activeGraph.stagedNodes.length === 0) return; + + runningIntervention = true; + try { + const nodes = activeGraph.stagedNodes.map((n) => ({ + layer: n.layer, + seq_pos: n.seqIdx, + component_idx: n.cIdx, + })); + const text = activeCard.tokens.join(""); + const result: InterventionResponse = await mainApi.runIntervention(text, nodes); + + const intervention: Intervention = { + id: `${Date.now()}-${Math.random().toString(36).slice(2, 6)}`, + timestamp: Date.now(), + nodes: [...activeGraph.stagedNodes], + result, + }; + + // Add intervention and switch to interventions view + promptCards = promptCards.map((card) => { + if (card.id !== activeCard.id) return card; + return { + ...card, + interventions: [...card.interventions, intervention], + activeView: "interventions", + }; + }); + } catch (e) { + console.error("Intervention failed:", e); + } finally { + runningIntervention = false; + } + } + + // Clear interventions for the active card + function handleClearInterventions() { + if (!activeCard) return; + promptCards = promptCards.map((card) => + card.id === activeCard.id ? { ...card, interventions: [] } : card, + ); + } + async function computeGraphForCard() { if (!activeCard || !activeCard.tokenIds || loadingCardId) return; @@ -293,7 +366,6 @@ const optConfig = computeOptions.optimizeConfig; const isOptimized = computeOptions.useOptimized; - // Set up stages - optimized has 2 stages, standard has 1 if (isOptimized) { loadingState = { stages: [ @@ -328,11 +400,9 @@ (progress) => { if (!loadingState) return; if (progress.stage === "graph") { - // Graph computation stage loadingState.currentStage = 1; loadingState.stages[1].progress = progress.current / progress.total; } else { - // Optimization stage loadingState.stages[0].progress = progress.current / progress.total; } }, @@ -357,7 +427,7 @@ if (card.id !== activeCard.id) return card; return { ...card, - graphs: [...card.graphs, { id: graphId, label, data }], + graphs: [...card.graphs, { id: graphId, label, data, stagedNodes: [] }], activeGraphId: graphId, }; }); @@ -369,15 +439,16 @@ } } - async function filterPromptsByPinned() { - if (pinnedNodes.length === 0) { + async function filterPromptsByStaged() { + const stagedNodes = activeGraph?.stagedNodes ?? []; + if (stagedNodes.length === 0) { filteredPrompts = []; return; } filterLoading = true; try { - const components = pinnedNodes.map((p) => `${p.layer}:${p.cIdx}`); + const components = stagedNodes.map((p) => `${p.layer}:${p.cIdx}`); const result = await attrApi.searchPrompts(components, "all"); filteredPrompts = result.results; } catch { @@ -388,25 +459,17 @@ } function handleFilterToggle() { - filterByPinned = !filterByPinned; - if (filterByPinned && pinnedNodes.length > 0) { - filterPromptsByPinned(); - } - } - - function handlePinnedNodesChangeWithFilter(nodes: PinnedNode[]) { - onPinnedNodesChange(nodes); - if (filterByPinned) { - filterPromptsByPinned(); + filterByStaged = !filterByStaged; + const stagedNodes = activeGraph?.stagedNodes ?? []; + if (filterByStaged && stagedNodes.length > 0) { + filterPromptsByStaged(); } } async function handleNormalizeChange(value: attrApi.NormalizeType) { normalizeEdges = value; - // Also sync to compute options so new computes use the same normalization computeOptions.normalizeEdges = value; - // Re-fetch all graphs for all cards with new normalization const updatedCards = await Promise.all( promptCards.map(async (card) => { if (card.graphs.length === 0) return card; @@ -422,6 +485,7 @@ id: `graph-${idx}-${Date.now()}`, label, data, + stagedNodes: [] as PinnedNode[], }; }); return { @@ -485,8 +549,8 @@ {#if activeCard} - - - {#if activeGraph?.data.optimization} -
- Target: "{activeGraph.data.optimization.label_str}" @ {( - activeGraph.data.optimization.label_prob * 100 - ).toFixed(1)}% - L0: - {activeGraph.data.optimization.l0_total.toFixed(0)} active -
- {/if} + +
+ + +
- {#if computeError} -
- {computeError} - -
- {/if} + {#if activeCard.activeView === "graph"} + + + {#if activeGraph?.data.optimization} +
+ Target: "{activeGraph.data.optimization.label_str}" @ {( + activeGraph.data.optimization.label_prob * 100 + ).toFixed(1)}% + L0: + {activeGraph.data.optimization.l0_total.toFixed(0)} active +
+ {/if} -
- {#if loadingCardId === activeCard.id && loadingState} - + {#if computeError} +
+ {computeError} + +
{/if} - {#if activeGraph} - (topK = v)} - onLayoutChange={(v) => (nodeLayout = v)} - onComponentGapChange={(v) => (componentGap = v)} - onLayerGapChange={(v) => (layerGap = v)} - onNormalizeChange={handleNormalizeChange} - /> - {#key activeGraph.id} - + {#if loadingCardId === activeCard.id && loadingState} + + {/if} + + {#if activeGraph} + (topK = v)} + onLayoutChange={(v) => (nodeLayout = v)} + onComponentGapChange={(v) => (componentGap = v)} + onLayerGapChange={(v) => (layerGap = v)} + onNormalizeChange={handleNormalizeChange} + /> + {#key activeGraph.id} + (filteredEdgeCount = count)} + /> + {/key} + (filteredEdgeCount = count)} + outputProbs={activeGraph.data.outputProbs} + tokens={activeGraph.data.tokens} + {runningIntervention} + onStagedNodesChange={handleStagedNodesChange} + onRunIntervention={handleRunIntervention} /> - {/key} - - {:else if !loadingCardId} -
-

Click Compute to generate the attribution graph

-
- {/if} -
+ {:else if !loadingCardId} +
+

Click Compute to generate the attribution graph

+
+ {/if} + + {:else} + + + {/if} {:else}

Click + Add Prompt to get started

@@ -639,6 +738,53 @@ background: var(--bg-inset); } + .view-tabs { + display: flex; + gap: var(--space-1); + margin-bottom: var(--space-2); + } + + .view-tab { + padding: var(--space-1) var(--space-3); + background: var(--bg-elevated); + border: 1px solid var(--border-default); + font-size: var(--text-sm); + font-weight: 500; + color: var(--text-secondary); + display: inline-flex; + align-items: center; + gap: var(--space-1); + } + + .view-tab:hover { + color: var(--text-primary); + border-color: var(--border-strong); + background: var(--bg-surface); + } + + .view-tab.active { + color: white; + background: var(--accent-primary); + border-color: var(--accent-primary); + } + + .view-tab .badge { + display: inline-flex; + align-items: center; + justify-content: center; + min-width: 16px; + height: 16px; + padding: 0 4px; + font-size: var(--text-xs); + font-weight: 600; + background: rgba(255, 255, 255, 0.2); + border-radius: 8px; + } + + .view-tab.active .badge { + background: rgba(255, 255, 255, 0.3); + } + .optim-results { display: flex; gap: var(--space-4); @@ -656,7 +802,6 @@ flex: 1; position: relative; min-height: 400px; - /* border: 1px solid var(--border-default); */ } .graph-area.loading { diff --git a/spd/app/frontend/src/components/local-attr/InterventionsView.svelte b/spd/app/frontend/src/components/local-attr/InterventionsView.svelte new file mode 100644 index 000000000..fabd2910d --- /dev/null +++ b/spd/app/frontend/src/components/local-attr/InterventionsView.svelte @@ -0,0 +1,698 @@ + + +
+ {#if interventions.length === 0} +
+

No interventions yet.

+

Stage nodes from the Graph view and click "Run Intervention"

+
+ {:else} +
+ {interventions.length} intervention{interventions.length === 1 ? "" : "s"} + +
+ +
+ {#each interventions as intervention, idx (intervention.id)} + {@const layout = computeLayout(intervention.nodes, intervention.result.input_tokens.length)} +
+
+ Intervention #{idx + 1} + {formatTime(intervention.timestamp)} +
+ + +
+ +
+ + + + Logits + + {#each Array(MAX_PREDICTIONS) as _, rank (rank)} + + #{rank + 1} + + {/each} + + + {#each Object.entries(layout.layerYPositions) as [layer, y] (layer)} + {@const yCenter = y + COMPONENT_SIZE / 2} + + {getRowLabel(layer)} + + {/each} + +
+ + +
+ + + + {#each intervention.result.input_tokens as token, pos (pos)} + {@const x = layout.seqXStarts[pos]} + {@const w = layout.seqWidths[pos]} + {@const cx = x + w / 2} + + + + [{pos}] + "{token}" + + + {#each Array(MAX_PREDICTIONS) as _, rank (rank)} + {@const pred = intervention.result.predictions_per_position[pos][rank]} + {@const cellY = LOGITS_HEADER_HEIGHT + rank * LOGITS_ROW_HEIGHT} + + {#if pred} + "{pred.token}" + {formatProb(pred.prob)} + {:else} + - + {/if} + {/each} + + + + + + {token} + [{pos}] + {/each} + + + + + + + {#each intervention.nodes as node (`${node.layer}:${node.seqIdx}:${node.cIdx}`)} + {@const pos = layout.nodePositions[`${node.layer}:${node.seqIdx}:${node.cIdx}`]} + {#if pos} + handleNodeMouseEnter(e, node.layer, node.seqIdx, node.cIdx)} + onmouseleave={handleNodeMouseLeave} + > + + + + {/if} + {/each} + + +
+
+
+ {/each} +
+ {/if} + + + {#if hoveredNode} + {@const summary = activationContextsSummary?.[hoveredNode.layer]?.find( + (s) => s.subcomponent_idx === hoveredNode?.cIdx + )} + {@const detail = componentDetailsCache[`${hoveredNode.layer}:${hoveredNode.cIdx}`]} + {@const isLoading = componentDetailsLoading[`${hoveredNode.layer}:${hoveredNode.cIdx}`] ?? false} + +
(isHoveringTooltip = true)} + onmouseleave={() => { + isHoveringTooltip = false; + handleNodeMouseLeave(); + }} + > +

{hoveredNode.layer}:{hoveredNode.seqIdx}:{hoveredNode.cIdx}

+ +
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/local-attr/PromptPicker.svelte b/spd/app/frontend/src/components/local-attr/PromptPicker.svelte index 9f0b01f6f..fcb4ceb6e 100644 --- a/spd/app/frontend/src/components/local-attr/PromptPicker.svelte +++ b/spd/app/frontend/src/components/local-attr/PromptPicker.svelte @@ -4,8 +4,8 @@ type Props = { prompts: PromptPreview[]; filteredPrompts: PromptPreview[]; - pinnedNodes: PinnedNode[]; - filterByPinned: boolean; + stagedNodes: PinnedNode[]; + filterByStaged: boolean; filterLoading: boolean; generatingGraphs: boolean; generateProgress: number; @@ -22,8 +22,8 @@ let { prompts, filteredPrompts, - pinnedNodes, - filterByPinned, + stagedNodes, + filterByStaged, filterLoading, generatingGraphs, generateProgress, @@ -39,7 +39,7 @@ let customText = $state(""); - const displayedPrompts = $derived(filterByPinned ? filteredPrompts : prompts); + const displayedPrompts = $derived(filterByStaged ? filteredPrompts : prompts); async function handleAddCustom() { if (!customText.trim() || isAddingCustomPrompt) return; @@ -75,11 +75,11 @@ {#if filterLoading} ... @@ -95,7 +95,7 @@ {/each} {#if displayedPrompts.length === 0}
- {filterByPinned ? "No matching prompts" : "No prompts yet"} + {filterByStaged ? "No matching prompts" : "No prompts yet"}
{/if}
diff --git a/spd/app/frontend/src/components/local-attr/PinnedComponentsPanel.svelte b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte similarity index 63% rename from spd/app/frontend/src/components/local-attr/PinnedComponentsPanel.svelte rename to spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte index c85397322..3ebd01921 100644 --- a/spd/app/frontend/src/components/local-attr/PinnedComponentsPanel.svelte +++ b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte @@ -3,23 +3,31 @@ import ComponentDetailCard from "./ComponentDetailCard.svelte"; type Props = { - pinnedNodes: PinnedNode[]; + stagedNodes: PinnedNode[]; componentDetailsCache: Record; outputProbs: Record; tokens: string[]; - onPinnedNodesChange: (nodes: PinnedNode[]) => void; - onGoToIntervention: (text: string) => void; + runningIntervention: boolean; + onStagedNodesChange: (nodes: PinnedNode[]) => void; + onRunIntervention: () => void; }; - let { pinnedNodes, componentDetailsCache, outputProbs, tokens, onPinnedNodesChange, onGoToIntervention }: Props = - $props(); + let { + stagedNodes, + componentDetailsCache, + outputProbs, + tokens, + runningIntervention, + onStagedNodesChange, + onRunIntervention, + }: Props = $props(); function clearAll() { - onPinnedNodesChange([]); + onStagedNodesChange([]); } - function unpinNode(pinned: PinnedNode) { - onPinnedNodesChange(pinnedNodes.filter((p) => p !== pinned)); + function unstageNode(node: PinnedNode) { + onStagedNodesChange(stagedNodes.filter((n) => n !== node)); } function getTokenAtPosition(seqIdx: number): string { @@ -28,42 +36,41 @@ } return "?"; } - - function handleGoToIntervention() { - const text = tokens.join(""); - onGoToIntervention(text); - } -{#if pinnedNodes.length > 0} -
-
- Staged Nodes ({pinnedNodes.length}) +{#if stagedNodes.length > 0} +
+
+ Staged Nodes ({stagedNodes.length})
-
-
- {#each pinnedNodes as pinned, idx (`${pinned.layer}:${pinned.seqIdx}:${pinned.cIdx}-${idx}`)} - {@const detail = componentDetailsCache[`${pinned.layer}:${pinned.cIdx}`]} - {@const isLoading = !detail && pinned.layer !== "output"} - {@const token = getTokenAtPosition(pinned.seqIdx)} -
-
+
+ {#each stagedNodes as node, idx (`${node.layer}:${node.seqIdx}:${node.cIdx}-${idx}`)} + {@const detail = componentDetailsCache[`${node.layer}:${node.cIdx}`]} + {@const isLoading = !detail && node.layer !== "output"} + {@const token = getTokenAtPosition(node.seqIdx)} +
+
- {pinned.layer}:{pinned.seqIdx}:{pinned.cIdx} + {node.layer}:{node.seqIdx}:{node.cIdx} "{token}"
- +
- .pinned-container { + .staged-container { margin-top: var(--space-4); background: var(--bg-surface); border: 1px solid var(--border-default); padding: var(--space-3); } - .pinned-container-header { + .staged-container-header { font-size: var(--text-sm); font-family: var(--font-sans); font-weight: 600; @@ -99,13 +106,13 @@ gap: var(--space-2); } - .pinned-container-header button { + .staged-container-header button { background: var(--bg-elevated); border: 1px solid var(--border-default); color: var(--text-secondary); } - .pinned-container-header button:hover:not(:disabled) { + .staged-container-header button:hover:not(:disabled) { background: var(--bg-inset); color: var(--text-primary); border-color: var(--border-strong); @@ -126,14 +133,14 @@ cursor: not-allowed; } - .pinned-items { + .staged-items { display: flex; flex-direction: row; gap: var(--space-3); overflow-x: auto; } - .pinned-item { + .staged-item { flex-shrink: 0; min-width: 300px; max-width: 400px; @@ -142,7 +149,7 @@ background: var(--bg-elevated); } - .pinned-header { + .staged-header { display: flex; justify-content: space-between; align-items: center; @@ -156,7 +163,7 @@ gap: var(--space-1); } - .pinned-header strong { + .staged-header strong { font-family: var(--font-mono); font-size: var(--text-base); color: var(--accent-primary); @@ -169,14 +176,14 @@ color: var(--text-muted); } - .unpin-btn { + .unstage-btn { background: var(--status-negative); color: white; border: none; padding: var(--space-1) var(--space-2); } - .unpin-btn:hover { + .unstage-btn:hover { background: var(--status-negative-bright); } diff --git a/spd/app/frontend/src/components/local-attr/types.ts b/spd/app/frontend/src/components/local-attr/types.ts index 06c04634b..a80a397e8 100644 --- a/spd/app/frontend/src/components/local-attr/types.ts +++ b/spd/app/frontend/src/components/local-attr/types.ts @@ -1,10 +1,19 @@ -import type { GraphData } from "../../lib/localAttributionsTypes"; +import type { GraphData, PinnedNode } from "../../lib/localAttributionsTypes"; +import type { InterventionResponse } from "../../lib/interventionTypes"; import type { NormalizeType } from "../../lib/localAttributionsApi"; export type StoredGraph = { id: string; label: string; data: GraphData; + stagedNodes: PinnedNode[]; +}; + +export type Intervention = { + id: string; + timestamp: number; + nodes: PinnedNode[]; // snapshot of nodes used + result: InterventionResponse; }; export type PromptCard = { @@ -15,6 +24,8 @@ export type PromptCard = { isCustom: boolean; graphs: StoredGraph[]; activeGraphId: string | null; + interventions: Intervention[]; + activeView: "graph" | "interventions"; }; export type OptimizeConfig = { From 844f3ce3496b31acff0ef80474d7e954159d64a9 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 8 Dec 2025 17:55:25 +0000 Subject: [PATCH 018/500] token uplift --- spd/app/backend/lib/activation_contexts.py | 145 ++++++++++++++---- .../backend/routers/activation_contexts.py | 1 + spd/app/backend/routers/graphs.py | 17 +- spd/app/backend/routers/runs.py | 11 ++ spd/app/backend/schemas.py | 8 +- spd/app/backend/state.py | 1 + .../local-attr/ComponentDetailCard.svelte | 33 +++- .../local-attr/StagedNodesPanel.svelte | 16 +- .../src/lib/localAttributionsTypes.ts | 6 +- tests/app/test_server_api.py | 1 + 10 files changed, 190 insertions(+), 49 deletions(-) diff --git a/spd/app/backend/lib/activation_contexts.py b/spd/app/backend/lib/activation_contexts.py index a2938a408..4d4039dde 100644 --- a/spd/app/backend/lib/activation_contexts.py +++ b/spd/app/backend/lib/activation_contexts.py @@ -34,6 +34,7 @@ def get_activations_data( tokenizer: PreTrainedTokenizerBase, train_loader: DataLoader[Int[Tensor, "B S"]], token_strings: dict[int, str], + token_base_rates: dict[int, float], importance_threshold: float, n_batches: int, n_tokens_either_side: int, @@ -64,9 +65,9 @@ def get_activations_data( component_activation_tokens = defaultdict[str, defaultdict[int, dict[int, int]]]( lambda: defaultdict(lambda: defaultdict(int)) ) - # - the number of times each token is predicted when component fires - component_predicted_tokens = defaultdict[str, defaultdict[int, dict[int, int]]]( - lambda: defaultdict(lambda: defaultdict(int)) + # - accumulated probability mass for each predicted token when component fires + component_predicted_probs = defaultdict[str, defaultdict[int, dict[int, float]]]( + lambda: defaultdict(lambda: defaultdict(float)) ) # - the sum of causal importances C = cm.C @@ -111,8 +112,8 @@ def get_activations_data( sampling=config.sampling, ).lower_leaky - # Get predicted tokens (argmax of logits at each position) - predicted_token_ids: Int[Tensor, "B S"] = logits.argmax(dim=-1) + # Get softmax probabilities for predicted token lift calculation + pred_probs: Float[Tensor, "B S V"] = torch.softmax(logits, dim=-1) for module_idx, (module_name, ci_val) in enumerate(ci_vals.items()): pbar.update(1) @@ -192,8 +193,8 @@ def get_activations_data( # Get token IDs at active position for token counting active_token_ids = window_token_ids_np[:, n_tokens_either_side] - # Get predicted tokens at each firing position - firing_predicted_tokens = predicted_token_ids[batch_idx, seq_idx].cpu().numpy() + # Get prediction probabilities at each firing position (full vocab) + firing_pred_probs: Float[Tensor, "n_firings V"] = pred_probs[batch_idx, seq_idx] # Process by component - group firings and use batch add unique_components = np.unique(comp_idx_np) @@ -211,13 +212,16 @@ def get_activations_data( for tok_id, count in zip(unique_tokens, token_counts, strict=True): component_activation_tokens[module_name][c_idx_int][int(tok_id)] += int(count) - # Update predicted token counts for this component - predicted_for_component = firing_predicted_tokens[mask_c] - unique_predicted, predicted_counts = np.unique( - predicted_for_component, return_counts=True - ) - for tok_id, count in zip(unique_predicted, predicted_counts, strict=True): - component_predicted_tokens[module_name][c_idx_int][int(tok_id)] += int(count) + # Accumulate predicted token probability mass for this component + probs_for_component: Float[Tensor, "n_c V"] = firing_pred_probs[ + torch.from_numpy(mask_c).to(firing_pred_probs.device) + ] + prob_sums: Float[Tensor, " V"] = probs_for_component.sum(dim=0) + prob_sums_cpu = prob_sums.cpu() + for tok_id in range(prob_sums_cpu.shape[0]): + prob = float(prob_sums_cpu[tok_id]) + if prob > 1e-6: # Skip negligible probabilities + component_predicted_probs[module_name][c_idx_int][tok_id] += prob # Apply position separation filter for example diversity only if separation_tokens > 0: @@ -252,7 +256,7 @@ def get_activations_data( model_ctxs: dict[str, list[SubcomponentActivationContexts]] = {} for module_name in component_activation_tokens: module_acts = component_activation_tokens[module_name] - module_predicted = component_predicted_tokens[module_name] + module_predicted_probs = component_predicted_probs[module_name] module_examples = examples[module_name] module_activation_counts = component_activation_counts[module_name] module_mean_cis = (component_sum_cis[module_name] / n_toks_seen).tolist() @@ -264,10 +268,13 @@ def get_activations_data( token_strings=token_strings, component_activation_count=module_activation_counts[component_idx], ) - predicted_tokens, predicted_probs = _get_component_predicted_tokens( - component_predicted_counts=module_predicted[component_idx], - token_strings=token_strings, - component_activation_count=module_activation_counts[component_idx], + predicted_tokens, predicted_lifts, predicted_firing_probs, predicted_base_probs = ( + _get_component_predicted_tokens( + component_prob_sums=module_predicted_probs[component_idx], + token_strings=token_strings, + token_base_rates=token_base_rates, + component_activation_count=module_activation_counts[component_idx], + ) ) example_tokens, example_ci, example_active_pos, example_active_ci = module_examples[ component_idx @@ -283,7 +290,9 @@ def get_activations_data( pr_recalls=pr_recalls, pr_precisions=pr_precisions, predicted_tokens=predicted_tokens, - predicted_probs=predicted_probs, + predicted_lifts=predicted_lifts, + predicted_firing_probs=predicted_firing_probs, + predicted_base_probs=predicted_base_probs, ) module_subcomponent_ctxs.append(subcomponent_ctx) module_subcomponent_ctxs.sort(key=lambda x: x.mean_ci, reverse=True) @@ -293,6 +302,58 @@ def get_activations_data( return ModelActivationContexts(layers=model_ctxs) +def compute_token_base_rates( + cm: ComponentModel, + train_loader: DataLoader[Int[Tensor, "B S"]], + n_batches: int, + onprogress: Callable[[float], None] | None = None, +) -> dict[int, float]: + """Compute E[P(token)] across the dataset - the base rate probability for each token. + + For each position in the dataset, we compute softmax(logits) and accumulate + the probability mass for each token. The result is normalized by total positions. + + Returns: + Dict mapping token_id -> mean probability across all positions + """ + logger.info(f"Computing token base rates over {n_batches} batches") + device = next(cm.parameters()).device + + # Accumulate probability mass per token + token_prob_sums: dict[int, float] = {} + n_positions = 0 + + train_iter = iter(train_loader) + for i in tqdm.tqdm(range(n_batches), desc="Computing base rates"): + batch: Int[Tensor, "B S"] = extract_batch_data(next(train_iter)).to(device) + B, S = batch.shape + n_positions += B * S + + with torch.no_grad(): + output_with_cache = cm(batch, cache_type="input") + logits = output_with_cache.output + probs: Float[Tensor, "B S V"] = torch.softmax(logits, dim=-1) + + # Sum probabilities across batch and sequence + prob_sums: Float[Tensor, " V"] = probs.sum(dim=(0, 1)) + + # Accumulate into dict (move to CPU once) + prob_sums_cpu = prob_sums.cpu() + for token_id in range(prob_sums_cpu.shape[0]): + prob = float(prob_sums_cpu[token_id]) + if prob > 0: + token_prob_sums[token_id] = token_prob_sums.get(token_id, 0.0) + prob + + if onprogress: + onprogress((i + 1) / n_batches) + + # Normalize by total positions to get mean probability + base_rates = {tok_id: prob_sum / n_positions for tok_id, prob_sum in token_prob_sums.items()} + + logger.info(f"Computed base rates for {len(base_rates)} tokens over {n_positions} positions") + return base_rates + + def _apply_position_separation( seq_positions: NDArray[np.int64], batch_indices: NDArray[np.int64], @@ -468,25 +529,45 @@ def _get_component_token_pr( def _get_component_predicted_tokens( - component_predicted_counts: dict[int, int], + component_prob_sums: dict[int, float], token_strings: dict[int, str], + token_base_rates: dict[int, float], component_activation_count: int, -) -> tuple[list[str], list[float]]: - """Return columnar data: (tokens, probs) sorted by probability descending. +) -> tuple[list[str], list[float], list[float], list[float]]: + """Return columnar data: (tokens, lifts, firing_probs, base_probs) sorted by lift descending. - prob = P(predicted_token = X | component fires) + firing_prob = E[P(token) | component fires] + base_prob = E[P(token)] (from token_base_rates) + lift = firing_prob / base_prob """ tokens: list[str] = [] - probs: list[float] = [] + lifts: list[float] = [] + firing_probs: list[float] = [] + base_probs: list[float] = [] + + for token_id, prob_sum in component_prob_sums.items(): + firing_prob = prob_sum / component_activation_count + base_prob = token_base_rates.get(token_id, 0.0) + + if base_prob < 1e-9: + # Avoid division by zero; skip tokens with essentially no base rate + logger.warning( + f"Token {token_id} ({token_strings.get(token_id, '?')}) has near-zero base rate, skipping" + ) + continue + + lift = firing_prob / base_prob - for token_id, count in component_predicted_counts.items(): - prob = round(count / component_activation_count, 3) tokens.append(token_strings[token_id]) - probs.append(prob) + lifts.append(round(lift, 2)) + firing_probs.append(round(firing_prob, 4)) + base_probs.append(round(base_prob, 4)) - # Sort by probability descending - sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) + # Sort by lift descending + sorted_indices = sorted(range(len(lifts)), key=lambda i: lifts[i], reverse=True) tokens = [tokens[i] for i in sorted_indices] - probs = [probs[i] for i in sorted_indices] + lifts = [lifts[i] for i in sorted_indices] + firing_probs = [firing_probs[i] for i in sorted_indices] + base_probs = [base_probs[i] for i in sorted_indices] - return tokens, probs + return tokens, lifts, firing_probs, base_probs diff --git a/spd/app/backend/routers/activation_contexts.py b/spd/app/backend/routers/activation_contexts.py index d49800bfd..b37a9b9c6 100644 --- a/spd/app/backend/routers/activation_contexts.py +++ b/spd/app/backend/routers/activation_contexts.py @@ -137,6 +137,7 @@ def compute_thread() -> None: tokenizer=loaded.tokenizer, train_loader=train_loader, token_strings=loaded.token_strings, + token_base_rates=loaded.token_base_rates, importance_threshold=importance_threshold, n_batches=n_batches, n_tokens_either_side=n_tokens_either_side, diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 582626391..aa8920532 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -63,7 +63,7 @@ def tokenize_text(text: str, loaded: DepLoadedRun) -> TokenizeResponse: NormalizeType = Literal["none", "target", "layer"] -def compute_edge_stats(edges: list[EdgeData]) -> tuple[dict[str, float], float]: +def compute_edge_stats(edges: list[Edge]) -> tuple[dict[str, float], float]: """Compute node importance and max absolute edge value. Returns: @@ -72,10 +72,12 @@ def compute_edge_stats(edges: list[EdgeData]) -> tuple[dict[str, float], float]: importance: dict[str, float] = {} max_abs_attr = 0.0 for edge in edges: - val_sq = edge.val * edge.val - importance[edge.src] = importance.get(edge.src, 0.0) + val_sq - importance[edge.tgt] = importance.get(edge.tgt, 0.0) + val_sq - abs_val = abs(edge.val) + val_sq = edge.strength * edge.strength + src_key = str(edge.source) + tgt_key = str(edge.target) + importance[src_key] = importance.get(src_key, 0.0) + val_sq + importance[tgt_key] = importance.get(tgt_key, 0.0) + val_sq + abs_val = abs(edge.strength) if abs_val > max_abs_attr: max_abs_attr = abs_val return importance, max_abs_attr @@ -414,11 +416,12 @@ def process_edges_for_response( final_seq_pos = num_tokens - 1 edges = [edge for edge in edges if edge.target.seq_pos == final_seq_pos] edges = _normalize_edges(edges, normalize) + node_importance, max_abs_attr = compute_edge_stats(edges) + # Clip to edge limit for response if len(edges) > edge_limit: print(f"[WARNING] Edge limit {edge_limit} exceeded ({len(edges)} edges), truncating") - edges = sorted(edges, key=lambda e: abs(e.strength), reverse=True)[:edge_limit] + edges = sorted(edges, key=lambda e: abs(e.strength), reverse=True)[:edge_limit] edges_data = [_edge_to_edge_data(e) for e in edges] - node_importance, max_abs_attr = compute_edge_stats(edges_data) return edges_data, node_importance, max_abs_attr diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index 1ca9a77e2..25d44a4fa 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -11,6 +11,7 @@ from spd.app.backend.compute import get_sources_by_target from spd.app.backend.dependencies import DepStateManager +from spd.app.backend.lib.activation_contexts import compute_token_base_rates from spd.app.backend.schemas import LoadedRun from spd.app.backend.state import RunState from spd.app.backend.utils import build_token_lookup, log_errors @@ -115,6 +116,15 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): buffer_size=task_config.buffer_size, global_seed=spd_config.seed, ) + + # Compute token base rates for predicted token lift calculation + logger.info(f"[API] Computing token base rates for run {run.id}") + token_base_rates = compute_token_base_rates( + cm=model, + train_loader=train_loader, + n_batches=100, + ) + manager.run_state = RunState( run=run, model=model, @@ -124,6 +134,7 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): token_strings=token_strings, train_loader=train_loader, context_length=context_length, + token_base_rates=token_base_rates, ) logger.info(f"[API] Run {run.id} loaded on {DEVICE}") diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index 291a24b84..5a48cc995 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -170,10 +170,12 @@ class SubcomponentActivationContexts(BaseModel): pr_recalls: list[float] # [n_unique_tokens] pr_precisions: list[float] # [n_unique_tokens] - # Predicted token stats - P(predicted_token | component fires) - # Sorted by probability descending + # Predicted token stats - lift = E[P(token)|fires] / E[P(token)] + # Sorted by lift descending predicted_tokens: list[str] # [n_unique_predicted] - predicted_probs: list[float] # [n_unique_predicted] - P(token predicted | component fires) + predicted_lifts: list[float] # [n_unique_predicted] - lift (firing_prob / base_prob) + predicted_firing_probs: list[float] # [n_unique_predicted] - E[P(token) | component fires] + predicted_base_probs: list[float] # [n_unique_predicted] - E[P(token)] base rate class ModelActivationContexts(BaseModel): diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index 5a5cac985..f9e301725 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -30,6 +30,7 @@ class RunState: token_strings: dict[int, str] train_loader: DataLoader[Any] context_length: int + token_base_rates: dict[int, float] # token_id -> E[P(token)] across dataset activation_contexts_cache: ModelActivationContexts | None = None diff --git a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte index 8405cde75..d5a2621aa 100644 --- a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte +++ b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte @@ -141,7 +141,10 @@
{#if detail.pr_tokens?.length > 0}
-

Top Input Tokens

+

+ Token Precision + P(firing | token) +

{#each detail.pr_tokens.slice(0, 10) as token, i (i)} @@ -155,15 +158,22 @@ {/if} - {#if detail.predicted_tokens?.length} + {#if detail.predicted_tokens.length > 0}
-

Top Predicted

+

Prediction uplift

{#each detail.predicted_tokens.slice(0, 10) as token, i (i)} - + {/each} @@ -303,6 +313,21 @@ text-align: right; } + .lift-cell { + white-space: nowrap; + } + + .lift-value { + color: var(--text-primary); + font-weight: 500; + } + + .lift-detail { + color: var(--text-muted); + font-size: var(--text-xs); + margin-left: var(--space-1); + } + .data-table code { color: var(--text-primary); font-family: var(--font-mono); diff --git a/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte index 3ebd01921..ec4ca28f5 100644 --- a/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte +++ b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte @@ -36,6 +36,10 @@ } return "?"; } + + // Validation: can't run intervention with embedding (wte) or output nodes + const hasInvalidNodes = $derived(stagedNodes.some((n) => n.layer === "wte" || n.layer === "output")); + const canRunIntervention = $derived(stagedNodes.length > 0 && !hasInvalidNodes && !runningIntervention); {#if stagedNodes.length > 0} @@ -43,10 +47,13 @@
Staged Nodes ({stagedNodes.length})
+ {#if hasInvalidNodes} + Remove wte/output nodes to run + {/if} @@ -103,9 +110,16 @@ .header-actions { display: flex; + align-items: center; gap: var(--space-2); } + .validation-warning { + font-size: var(--text-xs); + font-family: var(--font-mono); + color: var(--status-negative-bright); + } + .staged-container-header button { background: var(--bg-elevated); border: 1px solid var(--border-default); diff --git a/spd/app/frontend/src/lib/localAttributionsTypes.ts b/spd/app/frontend/src/lib/localAttributionsTypes.ts index 82f93d58e..c6c4949b3 100644 --- a/spd/app/frontend/src/lib/localAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/localAttributionsTypes.ts @@ -58,8 +58,10 @@ export type ComponentDetail = { pr_tokens: string[]; pr_recalls: number[]; pr_precisions: number[]; - predicted_tokens?: string[]; - predicted_probs?: number[]; + predicted_tokens: string[]; + predicted_lifts: number[]; + predicted_firing_probs: number[]; + predicted_base_probs: number[]; }; export type SearchResult = { diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 0ea3d038c..a7121f97f 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -157,6 +157,7 @@ def app_with_state(): config=config, token_strings=token_strings, train_loader=train_loader, + token_base_rates={}, ) manager = StateManager.get() From 52c3c904b7b3ba00310e870b4cd986a1987babe3 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 9 Dec 2025 11:17:04 +0000 Subject: [PATCH 019/500] hypothesis input --- .../backend/routers/activation_contexts.py | 38 +++++ spd/app/backend/schemas.py | 20 +++ .../local-attr/ComponentDetailCard.svelte | 138 +++++++++++++++++- .../frontend/src/lib/localAttributionsApi.ts | 15 ++ .../src/lib/localAttributionsTypes.ts | 6 + 5 files changed, 215 insertions(+), 2 deletions(-) diff --git a/spd/app/backend/routers/activation_contexts.py b/spd/app/backend/routers/activation_contexts.py index b37a9b9c6..a7f86a587 100644 --- a/spd/app/backend/routers/activation_contexts.py +++ b/spd/app/backend/routers/activation_contexts.py @@ -14,6 +14,8 @@ from spd.app.backend.lib.activation_contexts import get_activations_data from spd.app.backend.schemas import ( ActivationContextsGenerationConfig, + ComponentProbeRequest, + ComponentProbeResponse, HarvestMetadata, ModelActivationContexts, SubcomponentActivationContexts, @@ -187,3 +189,39 @@ def generate() -> Generator[str]: thread.join() return StreamingResponse(generate(), media_type="text/event-stream") + + +@router.post("/probe") +@log_errors +def probe_component( + request: ComponentProbeRequest, + loaded: DepLoadedRun, +) -> ComponentProbeResponse: + """Probe a component's CI values on custom text. + + Fast endpoint for testing hypotheses about component activation. + Only requires a single forward pass. + """ + import torch + + from spd.app.backend.compute import compute_ci_only + from spd.utils.distributed_utils import get_device + + device = get_device() + + token_ids = loaded.tokenizer.encode(request.text, add_special_tokens=False) + assert len(token_ids) > 0, "Text produced no tokens" + + tokens_tensor = torch.tensor([token_ids], device=device) + + result = compute_ci_only( + model=loaded.model, + tokens=tokens_tensor, + sampling=loaded.config.sampling, + ) + + ci_tensor = result.ci_lower_leaky[request.layer] + ci_values = ci_tensor[0, :, request.component_idx].tolist() + token_strings = [loaded.token_strings[t] for t in token_ids] + + return ComponentProbeResponse(tokens=token_strings, ci_values=ci_values) diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index 5a48cc995..675b86fa7 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -243,3 +243,23 @@ class InterventionResponse(BaseModel): input_tokens: list[str] predictions_per_position: list[list[TokenPrediction]] + + +# ============================================================================= +# Component Probe Models +# ============================================================================= + + +class ComponentProbeRequest(BaseModel): + """Request to probe a component's CI on custom text.""" + + text: str + layer: str + component_idx: int + + +class ComponentProbeResponse(BaseModel): + """Response with CI values for a component on custom text.""" + + tokens: list[str] + ci_values: list[float] diff --git a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte index d5a2621aa..d2f33bffe 100644 --- a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte +++ b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte @@ -1,6 +1,12 @@ + +
+
Test Custom Text
+ + {#if probeLoading} +

Loading...

+ {:else if probeError} +

{probeError}

+ {:else if probeResult && probeResult.tokens.length > 0} +
+ {#each probeResult.tokens as tok, i (i)}{tok}{/each} +
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/LocalAttributionsTab.svelte b/spd/app/frontend/src/components/LocalAttributionsTab.svelte index 4a239ca3d..8dfa7672c 100644 --- a/spd/app/frontend/src/components/LocalAttributionsTab.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsTab.svelte @@ -5,17 +5,14 @@ ActivationContextsSummary, ComponentDetail, GraphData, - PinnedNode, PromptPreview, } from "../lib/localAttributionsTypes"; - import type { InterventionResponse } from "../lib/interventionTypes"; - import StagedNodesPanel from "./local-attr/StagedNodesPanel.svelte"; import ComputeProgressOverlay from "./local-attr/ComputeProgressOverlay.svelte"; import InterventionsView from "./local-attr/InterventionsView.svelte"; import PromptCardHeader from "./local-attr/PromptCardHeader.svelte"; import PromptCardTabs from "./local-attr/PromptCardTabs.svelte"; import PromptPicker from "./local-attr/PromptPicker.svelte"; - import type { StoredGraph, ComputeOptions, LoadingState, OptimizeConfig, PromptCard, Intervention } from "./local-attr/types"; + import type { StoredGraph, ComputeOptions, LoadingState, OptimizeConfig, PromptCard } from "./local-attr/types"; import ViewControls from "./local-attr/ViewControls.svelte"; import LocalAttributionsGraph from "./LocalAttributionsGraph.svelte"; @@ -34,7 +31,6 @@ // Prompt picker state let showPromptPicker = $state(false); - let filterByStaged = $state(false); let filteredPrompts = $state([]); let filterLoading = $state(false); let isAddingCustomPrompt = $state(false); @@ -196,22 +192,35 @@ async function addPromptCard(promptId: number, tokens: string[], tokenIds: number[], isCustom: boolean) { const cardId = `${Date.now()}-${Math.random().toString(36).slice(2, 6)}`; - // Fetch stored graphs for this prompt + // Fetch stored graphs for this prompt (includes composer selection and intervention runs) let graphs: StoredGraph[] = []; try { const storedGraphs = await attrApi.getGraphs(promptId, normalizeEdges); - graphs = storedGraphs.map((data, idx) => { + graphs = await Promise.all(storedGraphs.map(async (data, idx) => { const isOptimized = !!data.optimization; const label = isOptimized ? `Optimized (${data.optimization!.steps} steps)` : "Standard"; + + // Load intervention runs for this graph + const runs = await mainApi.getInterventionRuns(data.id); + + // Initialize composer selection: from DB or default to all nodes + const allNodeKeys = Object.keys(data.nodeImportance); + const composerSelection = data.composerSelection + ? new Set(data.composerSelection) + : new Set(allNodeKeys); + return { id: `graph-${idx}-${Date.now()}`, + dbId: data.id, label, data, - stagedNodes: [], + composerSelection, + interventionRuns: runs, + activeRunId: null, }; - }); + })); } catch (e) { console.warn("Failed to fetch graphs:", e); } @@ -224,7 +233,6 @@ isCustom, graphs, activeGraphId: graphs.length > 0 ? graphs[0].id : null, - interventions: [], activeView: "graph", }; promptCards = [...promptCards, newCard]; @@ -285,61 +293,62 @@ computeOptions.optimizeConfig = { ...computeOptions.optimizeConfig, ...partial }; } - // Update staged nodes for the active graph - function handleStagedNodesChange(nodes: PinnedNode[]) { + // Switch between graph and interventions view + function handleViewChange(view: "graph" | "interventions") { + if (!activeCard) return; + promptCards = promptCards.map((card) => + card.id === activeCard.id ? { ...card, activeView: view } : card, + ); + } + + // Update composer selection for the active graph + async function handleComposerSelectionChange(selection: Set) { if (!activeCard || !activeGraph) return; + + // Update local state immediately promptCards = promptCards.map((card) => { if (card.id !== activeCard.id) return card; return { ...card, graphs: card.graphs.map((g) => - g.id === activeGraph.id ? { ...g, stagedNodes: nodes } : g + g.id === activeGraph.id ? { ...g, composerSelection: selection, activeRunId: null } : g ), }; }); - // Re-filter if needed - if (filterByStaged) { - filterPromptsByStaged(); + // Persist to backend (fire and forget with error handling) + try { + await mainApi.updateComposerSelection(activeGraph.dbId, Array.from(selection)); + } catch (e) { + console.error("Failed to save composer selection:", e); } } - // Switch between graph and interventions view - function handleViewChange(view: "graph" | "interventions") { - if (!activeCard) return; - promptCards = promptCards.map((card) => - card.id === activeCard.id ? { ...card, activeView: view } : card, - ); - } - - // Run intervention and add to prompt's list + // Run intervention and save to DB async function handleRunIntervention() { - if (!activeCard || !activeGraph || activeGraph.stagedNodes.length === 0) return; + if (!activeCard || !activeGraph || activeGraph.composerSelection.size === 0) return; runningIntervention = true; try { - const nodes = activeGraph.stagedNodes.map((n) => ({ - layer: n.layer, - seq_pos: n.seqIdx, - component_idx: n.cIdx, - })); const text = activeCard.tokens.join(""); - const result: InterventionResponse = await mainApi.runIntervention(text, nodes); + const selectedNodes = Array.from(activeGraph.composerSelection); - const intervention: Intervention = { - id: `${Date.now()}-${Math.random().toString(36).slice(2, 6)}`, - timestamp: Date.now(), - nodes: [...activeGraph.stagedNodes], - result, - }; + const run = await mainApi.runAndSaveIntervention({ + graph_id: activeGraph.dbId, + text, + selected_nodes: selectedNodes, + }); - // Add intervention and switch to interventions view + // Add run to local state and select it promptCards = promptCards.map((card) => { if (card.id !== activeCard.id) return card; return { ...card, - interventions: [...card.interventions, intervention], - activeView: "interventions", + graphs: card.graphs.map((g) => + g.id === activeGraph.id + ? { ...g, interventionRuns: [...g.interventionRuns, run], activeRunId: run.id } + : g + ), }; }); } catch (e) { @@ -349,12 +358,52 @@ } } - // Clear interventions for the active card - function handleClearInterventions() { - if (!activeCard) return; - promptCards = promptCards.map((card) => - card.id === activeCard.id ? { ...card, interventions: [] } : card, - ); + // Select a run and restore its selection state + function handleSelectRun(runId: number) { + if (!activeCard || !activeGraph) return; + + const run = activeGraph.interventionRuns.find((r) => r.id === runId); + if (!run) return; + + // Restore selection from the run + promptCards = promptCards.map((card) => { + if (card.id !== activeCard.id) return card; + return { + ...card, + graphs: card.graphs.map((g) => + g.id === activeGraph.id + ? { ...g, composerSelection: new Set(run.selected_nodes), activeRunId: runId } + : g + ), + }; + }); + } + + // Delete an intervention run + async function handleDeleteRun(runId: number) { + if (!activeCard || !activeGraph) return; + + try { + await mainApi.deleteInterventionRun(runId); + + promptCards = promptCards.map((card) => { + if (card.id !== activeCard.id) return card; + return { + ...card, + graphs: card.graphs.map((g) => { + if (g.id !== activeGraph.id) return g; + const newRuns = g.interventionRuns.filter((r) => r.id !== runId); + return { + ...g, + interventionRuns: newRuns, + activeRunId: g.activeRunId === runId ? null : g.activeRunId, + }; + }), + }; + }); + } catch (e) { + console.error("Failed to delete run:", e); + } } async function computeGraphForCard() { @@ -423,11 +472,22 @@ const graphId = `${Date.now()}-${Math.random().toString(36).slice(2, 6)}`; const label = isOptimized ? `Optimized (${optConfig.steps} steps)` : "Standard"; + // Initialize composer selection to all nodes + const allNodeKeys = Object.keys(data.nodeImportance); + promptCards = promptCards.map((card) => { if (card.id !== activeCard.id) return card; return { ...card, - graphs: [...card.graphs, { id: graphId, label, data, stagedNodes: [] }], + graphs: [...card.graphs, { + id: graphId, + dbId: data.id, + label, + data, + composerSelection: new Set(allNodeKeys), + interventionRuns: [], + activeRunId: null, + }], activeGraphId: graphId, }; }); @@ -439,33 +499,6 @@ } } - async function filterPromptsByStaged() { - const stagedNodes = activeGraph?.stagedNodes ?? []; - if (stagedNodes.length === 0) { - filteredPrompts = []; - return; - } - - filterLoading = true; - try { - const components = stagedNodes.map((p) => `${p.layer}:${p.cIdx}`); - const result = await attrApi.searchPrompts(components, "all"); - filteredPrompts = result.results; - } catch { - filteredPrompts = []; - } finally { - filterLoading = false; - } - } - - function handleFilterToggle() { - filterByStaged = !filterByStaged; - const stagedNodes = activeGraph?.stagedNodes ?? []; - if (filterByStaged && stagedNodes.length > 0) { - filterPromptsByStaged(); - } - } - async function handleNormalizeChange(value: attrApi.NormalizeType) { normalizeEdges = value; computeOptions.normalizeEdges = value; @@ -476,18 +509,31 @@ try { const storedGraphs = await attrApi.getGraphs(card.promptId, normalizeEdges); - const graphs = storedGraphs.map((data, idx) => { + const graphs = await Promise.all(storedGraphs.map(async (data, idx) => { const isOptimized = !!data.optimization; const label = isOptimized ? `Optimized (${data.optimization!.steps} steps)` : "Standard"; + + // Load intervention runs + const runs = await mainApi.getInterventionRuns(data.id); + + // Initialize composer selection + const allNodeKeys = Object.keys(data.nodeImportance); + const composerSelection = data.composerSelection + ? new Set(data.composerSelection) + : new Set(allNodeKeys); + return { id: `graph-${idx}-${Date.now()}`, + dbId: data.id, label, data, - stagedNodes: [] as PinnedNode[], + composerSelection, + interventionRuns: runs, + activeRunId: null, }; - }); + })); return { ...card, graphs, @@ -549,8 +595,8 @@ {}} onGenerate={handleGeneratePrompts} onClose={() => (showPromptPicker = false)} /> @@ -580,10 +626,11 @@ class="view-tab" class:active={activeCard.activeView === "interventions"} onclick={() => handleViewChange("interventions")} + disabled={!activeGraph} > Interventions - {#if activeCard.interventions.length > 0} - {activeCard.interventions.length} + {#if activeGraph && activeGraph.interventionRuns.length > 0} + {activeGraph.interventionRuns.length} {/if}
@@ -648,39 +695,35 @@ {componentGap} {layerGap} {activationContextsSummary} - stagedNodes={activeGraph.stagedNodes} + stagedNodes={[]} {componentDetailsCache} {componentDetailsLoading} - onStagedNodesChange={handleStagedNodesChange} + onStagedNodesChange={() => {}} onLoadComponentDetail={loadComponentDetail} onEdgeCountChange={(count) => (filteredEdgeCount = count)} /> {/key} - {:else if !loadingCardId}

Click Compute to generate the attribution graph

{/if}
- {:else} + {:else if activeGraph} {/if} {:else} diff --git a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte index d2f33bffe..89463fe23 100644 --- a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte +++ b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte @@ -1,14 +1,9 @@
- {#if interventions.length === 0} -
-

No interventions yet.

-

Stage nodes from the Graph view and click "Run Intervention"

-
- {:else} -
- {interventions.length} intervention{interventions.length === 1 ? "" : "s"} - + +
+
+ Composer + {selectedCount} / {activeNodeCount} nodes selected
-
- {#each interventions as intervention, idx (intervention.id)} - {@const layout = computeLayout(intervention.nodes, intervention.result.input_tokens.length)} -
-
- Intervention #{idx + 1} - {formatTime(intervention.timestamp)} -
+
+
+ + + {topK} +
+
+ + + +
+
- -
- -
- - - - Logits - - {#each Array(MAX_PREDICTIONS) as _, rank (rank)} - - #{rank + 1} - - {/each} - - - {#each Object.entries(layout.layerYPositions) as [layer, y] (layer)} - {@const yCenter = y + COMPONENT_SIZE / 2} - - {getRowLabel(layer)} - - {/each} - -
+
+ Click to toggle • Shift+click to solo +
- -
+
+ + + {#each tokens as token, pos (pos)} + {@const x = layout.seqXStarts[pos]} + {@const w = layout.seqWidths[pos]} + {@const cx = x + w / 2} + "{token}" + [{pos}] + {/each} + + + {#each Object.entries(layout.layerYPositions) as [layer, y] (layer)} + {getRowLabel(layer)} + {/each} + + + + {#each filteredEdges as edge (`${edge.src}-${edge.tgt}`)} + {@const path = getEdgePath(edge.src, edge.tgt)} + {#if path} + + {/if} + {/each} + + + + + {#each activeNodes as nodeKey (nodeKey)} + {@const pos = layout.nodePositions[nodeKey]} + {@const [layer, seqIdx, cIdx] = nodeKey.split(":")} + {@const selected = isNodeSelected(nodeKey)} + {@const isOutput = layer === "output"} + {@const outputEntry = isOutput ? graph.data.outputProbs[`${seqIdx}:${cIdx}`] : null} + {#if pos} + - - - {#each intervention.result.input_tokens as token, pos (pos)} - {@const x = layout.seqXStarts[pos]} - {@const w = layout.seqWidths[pos]} - {@const cx = x + w / 2} - - - + handleNodeMouseEnter(e, layer, +seqIdx, +cIdx)} + onmouseleave={handleNodeMouseLeave} + onclick={(e) => handleNodeClick(e, nodeKey)} + > + + + {#if isOutput && outputEntry} [{pos}] - "{token}" - - - {#each Array(MAX_PREDICTIONS) as _, rank (rank)} - {@const pred = intervention.result.predictions_per_position[pos][rank]} - {@const cellY = LOGITS_HEADER_HEIGHT + rank * LOGITS_ROW_HEIGHT} - - {#if pred} - "{pred.token}" - {formatProb(pred.prob)} - {:else} - - - {/if} - {/each} + >"{outputEntry.token}" + {/if} + + {/if} + {/each} + + +
+
- - - - - {token} - [{pos}] - {/each} + +
+
+ Run History + {graph.interventionRuns.length} runs +
- - + {#if graph.interventionRuns.length === 0} +
+

No runs yet

+

Select nodes and click Run

+
+ {:else} +
+ {#each graph.interventionRuns.slice().reverse() as run (run.id)} + {@const isActive = graph.activeRunId === run.id} +
onSelectRun(run.id)} + onkeydown={(e) => e.key === "Enter" && onSelectRun(run.id)} + > +
+ {formatTime(run.created_at)} + {run.selected_nodes.length} nodes + +
- - - {#each intervention.nodes as node (`${node.layer}:${node.seqIdx}:${node.cIdx}`)} - {@const pos = layout.nodePositions[`${node.layer}:${node.seqIdx}:${node.cIdx}`]} - {#if pos} - handleNodeMouseEnter(e, node.layer, node.seqIdx, node.cIdx)} - onmouseleave={handleNodeMouseLeave} - > - - - - {/if} + +
+
{token}{detail.predicted_probs?.[i]?.toFixed(3)} + {detail.predicted_lifts[i].toFixed(1)}x + + ({(detail.predicted_firing_probs[i] * 100).toFixed(1)}% vs {( + detail.predicted_base_probs[i] * 100 + ).toFixed(1)}%) + +
+ + + {#each run.result.input_tokens as token, idx (idx)} + + {/each} + + + + {#each Array(Math.min(3, MAX_PREDICTIONS)) as _, rank (rank)} + + {#each run.result.predictions_per_position as preds, idx (idx)} + {@const pred = preds[rank]} + + {/each} + {/each} - - + +
+ "{token}" +
+ {#if pred} + "{pred.token}" + {formatProb(pred.prob)} + {:else} + - + {/if} +
-
- {/each} -
- {/if} + {/each} +
+ {/if} +
{#if hoveredNode} @@ -528,10 +594,7 @@ class="node-tooltip" style="left: {tooltipPos.x}px; top: {tooltipPos.y}px;" onmouseenter={() => (isHoveringTooltip = true)} - onmouseleave={() => { - isHoveringTooltip = false; - handleNodeMouseLeave(); - }} + onmouseleave={() => { isHoveringTooltip = false; handleNodeMouseLeave(); }} >

{hoveredNode.layer}:{hoveredNode.seqIdx}:{hoveredNode.cIdx}

.interventions-view { - flex: 1; display: flex; - flex-direction: column; + flex: 1; min-height: 0; + gap: var(--space-4); } - .empty-state { + /* Composer Panel */ + .composer-panel { + flex: 2; display: flex; - flex: 1; flex-direction: column; - align-items: center; - justify-content: center; - color: var(--text-muted); - text-align: center; - padding: var(--space-4); - font-family: var(--font-sans); + min-width: 0; background: var(--bg-surface); + border: 1px solid var(--border-default); + padding: var(--space-3); } - .empty-state p { - margin: var(--space-1) 0; + .composer-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: var(--space-2); + padding-bottom: var(--space-2); + border-bottom: 1px solid var(--border-subtle); } - .empty-state .hint { + .composer-header .title { + font-weight: 600; + font-family: var(--font-sans); + color: var(--text-primary); + } + + .node-count { font-size: var(--text-sm); font-family: var(--font-mono); + color: var(--text-muted); } - .interventions-header { + .composer-controls { + display: flex; + flex-direction: column; + gap: var(--space-2); + margin-bottom: var(--space-2); + } + + .topk-control { display: flex; - justify-content: space-between; align-items: center; - padding: var(--space-2) 0; + gap: var(--space-2); font-size: var(--text-sm); - font-family: var(--font-sans); + font-family: var(--font-mono); color: var(--text-secondary); } - .clear-btn { + .topk-control label { + white-space: nowrap; + } + + .topk-control input[type="range"] { + flex: 1; + min-width: 100px; + max-width: 200px; + } + + .topk-value { + min-width: 40px; + text-align: right; + color: var(--text-primary); + } + + .button-group { + display: flex; + gap: var(--space-2); + } + + .button-group button { padding: var(--space-1) var(--space-2); background: var(--bg-elevated); border: 1px solid var(--border-default); @@ -596,82 +696,213 @@ font-size: var(--text-sm); } - .clear-btn:hover { - background: var(--status-negative); - color: white; - border-color: var(--status-negative); + .button-group button:hover:not(:disabled) { + background: var(--bg-inset); + border-color: var(--border-strong); } - .interventions-list { + .run-btn { + background: var(--accent-primary) !important; + color: white !important; + border-color: var(--accent-primary) !important; + } + + .run-btn:hover:not(:disabled) { + background: var(--accent-primary-dim) !important; + } + + .run-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + .composer-hint { + font-size: var(--text-xs); + font-family: var(--font-mono); + color: var(--text-muted); + margin-bottom: var(--space-2); + } + + .composer-graph { flex: 1; - overflow-y: auto; - display: flex; - flex-direction: column; - gap: var(--space-3); + overflow: auto; + background: var(--bg-inset); + border: 1px solid var(--border-subtle); + } + + .node-group { + cursor: pointer; + } + + .node-group .node { + transition: opacity 0.1s, fill 0.1s; + } + + .node-group:hover .node { + opacity: 1 !important; + filter: brightness(1.2); } - .intervention-card { + /* History Panel */ + .history-panel { + flex: 1; + display: flex; + flex-direction: column; + min-width: 300px; + max-width: 400px; background: var(--bg-surface); border: 1px solid var(--border-default); padding: var(--space-3); } - .intervention-header { + .history-header { display: flex; justify-content: space-between; align-items: center; - margin-bottom: var(--space-3); + margin-bottom: var(--space-2); padding-bottom: var(--space-2); border-bottom: 1px solid var(--border-subtle); } - .intervention-title { + .history-header .title { font-weight: 600; font-family: var(--font-sans); color: var(--text-primary); } - .intervention-time { + .run-count { font-size: var(--text-sm); font-family: var(--font-mono); color: var(--text-muted); } - /* Unified visualization */ - .unified-viz-wrapper { + .empty-history { + flex: 1; display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + color: var(--text-muted); + text-align: center; + } + + .empty-history p { + margin: var(--space-1) 0; + } + + .empty-history .hint { + font-size: var(--text-sm); + font-family: var(--font-mono); + } + + .runs-list { + flex: 1; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: var(--space-2); + } + + .run-card { + background: var(--bg-elevated); border: 1px solid var(--border-default); - background: var(--bg-surface); - overflow: hidden; + padding: var(--space-2); + cursor: pointer; + transition: border-color 0.1s; } - .labels-column { - flex-shrink: 0; - background: var(--bg-surface); - border-right: 1px solid var(--border-default); + .run-card:hover { + border-color: var(--border-strong); + } + + .run-card.active { + border-color: var(--accent-primary); + background: var(--bg-inset); + } + + .run-header { + display: flex; + align-items: center; + gap: var(--space-2); + margin-bottom: var(--space-2); + font-size: var(--text-sm); + font-family: var(--font-mono); + } + + .run-time { + color: var(--text-secondary); + } + + .run-nodes { + color: var(--text-muted); + margin-left: auto; + } + + .delete-btn { + padding: 2px 6px; + background: transparent; + border: none; + color: var(--text-muted); + font-size: var(--text-xs); } - .viz-content { + .delete-btn:hover { + color: var(--status-negative); + } + + /* Mini logits table */ + .logits-mini { overflow-x: auto; - flex: 1; + } + + .logits-mini table { + width: 100%; + border-collapse: collapse; + font-size: var(--text-xs); + font-family: var(--font-mono); + } + + .logits-mini th, + .logits-mini td { + padding: 2px 4px; + text-align: center; + border: 1px solid var(--border-subtle); + max-width: 60px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } + + .logits-mini th { background: var(--bg-surface); + color: var(--text-secondary); } - .node-group { - cursor: pointer; + .logits-mini .token-text { + font-size: 9px; } - .node { - transition: opacity 0.1s; + .logits-mini td { + background: var(--bg-inset); + color: var(--text-muted); } - .node.highlighted { - stroke: var(--accent-primary); - stroke-width: 2px; - filter: brightness(1.2); - opacity: 1 !important; + .logits-mini td.has-pred { + background: var(--bg-surface); + } + + .pred-token { + display: block; + color: var(--text-primary); + } + + .pred-prob { + display: block; + font-size: 8px; + color: var(--text-muted); } + /* Tooltip */ .node-tooltip { position: fixed; padding: var(--space-3); @@ -691,7 +922,6 @@ font-family: var(--font-mono); color: var(--accent-primary); font-weight: 600; - letter-spacing: 0.02em; border-bottom: 1px solid var(--border-subtle); padding-bottom: var(--space-2); } diff --git a/spd/app/frontend/src/components/local-attr/types.ts b/spd/app/frontend/src/components/local-attr/types.ts index a80a397e8..27ef55f94 100644 --- a/spd/app/frontend/src/components/local-attr/types.ts +++ b/spd/app/frontend/src/components/local-attr/types.ts @@ -1,19 +1,16 @@ -import type { GraphData, PinnedNode } from "../../lib/localAttributionsTypes"; -import type { InterventionResponse } from "../../lib/interventionTypes"; +import type { GraphData } from "../../lib/localAttributionsTypes"; +import type { InterventionRunSummary } from "../../lib/interventionTypes"; import type { NormalizeType } from "../../lib/localAttributionsApi"; export type StoredGraph = { id: string; + dbId: number; // database ID for API calls label: string; data: GraphData; - stagedNodes: PinnedNode[]; -}; - -export type Intervention = { - id: string; - timestamp: number; - nodes: PinnedNode[]; // snapshot of nodes used - result: InterventionResponse; + // Composer state for interventions + composerSelection: Set; // currently selected node keys + interventionRuns: InterventionRunSummary[]; // persisted runs + activeRunId: number | null; // which run is selected (for restoring selection) }; export type PromptCard = { @@ -24,7 +21,6 @@ export type PromptCard = { isCustom: boolean; graphs: StoredGraph[]; activeGraphId: string | null; - interventions: Intervention[]; activeView: "graph" | "interventions"; }; diff --git a/spd/app/frontend/src/lib/api.ts b/spd/app/frontend/src/lib/api.ts index 89a7f4766..c6a3b7f14 100644 --- a/spd/app/frontend/src/lib/api.ts +++ b/spd/app/frontend/src/lib/api.ts @@ -154,7 +154,12 @@ export async function getComponentDetail(layer: string, componentIdx: number): P } // Intervention types -import type { InterventionNode, InterventionResponse } from "./interventionTypes"; +import type { + InterventionNode, + InterventionResponse, + InterventionRunSummary, + RunInterventionRequest, +} from "./interventionTypes"; export async function runIntervention( text: string, @@ -172,3 +177,56 @@ export async function runIntervention( } return (await response.json()) as InterventionResponse; } + +/** Run an intervention and save the result */ +export async function runAndSaveIntervention( + request: RunInterventionRequest, +): Promise { + const response = await fetch(`${API_URL}/api/intervention/run`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(request), + }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to run intervention"); + } + return (await response.json()) as InterventionRunSummary; +} + +/** Get all intervention runs for a graph */ +export async function getInterventionRuns(graphId: number): Promise { + const response = await fetch(`${API_URL}/api/intervention/runs/${graphId}`); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to get intervention runs"); + } + return (await response.json()) as InterventionRunSummary[]; +} + +/** Delete an intervention run */ +export async function deleteInterventionRun(runId: number): Promise { + const response = await fetch(`${API_URL}/api/intervention/runs/${runId}`, { + method: "DELETE", + }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to delete intervention run"); + } +} + +/** Update composer selection state for a graph */ +export async function updateComposerSelection( + graphId: number, + selection: string[] | null, +): Promise { + const response = await fetch(`${API_URL}/api/intervention/composer`, { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ graph_id: graphId, selection }), + }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to update composer selection"); + } +} diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index c2e1c6537..1e4bb8238 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -17,3 +17,19 @@ export type InterventionResponse = { input_tokens: string[]; predictions_per_position: TokenPrediction[][]; }; + +/** Persisted intervention run from the server */ +export type InterventionRunSummary = { + id: number; + selected_nodes: string[]; // node keys (layer:seq:cIdx) + result: InterventionResponse; + created_at: string; +}; + +/** Request to run and save an intervention */ +export type RunInterventionRequest = { + graph_id: number; + text: string; + selected_nodes: string[]; + top_k?: number; +}; diff --git a/spd/app/frontend/src/lib/localAttributionsTypes.ts b/spd/app/frontend/src/lib/localAttributionsTypes.ts index b32fdc971..38c3b4379 100644 --- a/spd/app/frontend/src/lib/localAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/localAttributionsTypes.ts @@ -28,6 +28,7 @@ export type GraphData = { nodeImportance: Record; // node key -> sum of squared edge values maxAbsAttr: number; // max absolute edge value optimization?: OptimizationResult; + composerSelection?: string[] | null; // selected node keys, null = all selected }; export type OptimizationResult = { @@ -58,10 +59,11 @@ export type ComponentDetail = { pr_tokens: string[]; pr_recalls: number[]; pr_precisions: number[]; - predicted_tokens: string[]; - predicted_lifts: number[]; - predicted_firing_probs: number[]; - predicted_base_probs: number[]; + // TODO: Re-enable token uplift after performance optimization + // predicted_tokens: string[]; + // predicted_lifts: number[]; + // predicted_firing_probs: number[]; + // predicted_base_probs: number[]; }; export type SearchResult = { diff --git a/spd/app/frontend/vite.config.ts b/spd/app/frontend/vite.config.ts index bfec3a4ab..e9c93e7fd 100644 --- a/spd/app/frontend/vite.config.ts +++ b/spd/app/frontend/vite.config.ts @@ -4,7 +4,7 @@ import { svelte } from "@sveltejs/vite-plugin-svelte"; // https://vite.dev/config/ export default defineConfig({ plugins: [svelte()], - // server: { - // hmr: false, - // }, + server: { + hmr: false, + }, }); diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index a7121f97f..f9ea201c4 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -157,7 +157,8 @@ def app_with_state(): config=config, token_strings=token_strings, train_loader=train_loader, - token_base_rates={}, + # TODO: Re-enable token uplift after performance optimization + # token_base_rates={}, ) manager = StateManager.get() From 1f1594fc7c75b64c41e6e03134fca05fc74702ac Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 9 Dec 2025 12:04:19 +0000 Subject: [PATCH 022/500] improve act harvest defaults --- .../components/ActivationContextsTab.svelte | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/spd/app/frontend/src/components/ActivationContextsTab.svelte b/spd/app/frontend/src/components/ActivationContextsTab.svelte index f3618bcf7..475de3688 100644 --- a/spd/app/frontend/src/components/ActivationContextsTab.svelte +++ b/spd/app/frontend/src/components/ActivationContextsTab.svelte @@ -7,12 +7,12 @@ let progress = $state(null); // Configuration parameters - let nBatches = $state(1); - let batchSize = $state(1); - let nTokensEitherSide = $state(10); + let nBatches = $state(10); + let batchSize = $state(16); + let nTokensEitherSide = $state(8); + let separationTokens = $state(8); let importanceThreshold = $state(0.0); - let topkExamples = $state(1000); - let separationTokens = $state(0); + let topkExamples = $state(200); async function loadContexts() { loading = true; @@ -89,6 +89,18 @@ />
+
+ + +
+
- -
- - -
- {#if loadedRun}
-
- - {#if loading && progress} -
-
- Processing... {(progress.progress * 100).toFixed(1)}% +
+
+ + +
+
+ + +
+
+ + +
+
+ +
-
-
+
+ + +
+
+ +
- {:else if loading} -
Loading...
- {/if} + {#if loading && progress} +
+
+
+
+ {(progress.progress * 100).toFixed(0)}% +
+ {/if} +
- {#if harvestMetadata} - - {/if} +
+ {#if loading && !progress} +
Loading...
+ {:else if harvestMetadata} + + {:else} +
+

No activation contexts loaded

+

Click Harvest to generate contexts from training data

+
+ {/if} +
diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index 2b1144712..b9f7f1e6d 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -110,51 +110,51 @@ }); -
- - -
- -
- - - of {totalPages} - -
+
+
+
+ + +
+ + +
-{#if loadingComponent} -
Loading component data...
-{:else if currentComponent && currentMetadata} -
-

- Subcomponent {currentMetadata.subcomponent_idx} (Mean CI: {currentMetadata.mean_ci < 0.001 - ? currentMetadata.mean_ci.toExponential(2) - : currentMetadata.mean_ci.toFixed(3)}) -

- - - - {#if densities != null} -
-
-
- Tokens - {currentComponent.pr_tokens.length > N_TOKENS_TO_DISPLAY - ? `(top ${N_TOKENS_TO_DISPLAY} of ${currentComponent.pr_tokens.length})` - : ""} -
-
+ {#if loadingComponent} +
Loading component data...
+ {:else if currentComponent && currentMetadata} +
+

+ Subcomponent {currentMetadata.subcomponent_idx} (Mean CI: {currentMetadata.mean_ci < 0.001 + ? currentMetadata.mean_ci.toExponential(2) + : currentMetadata.mean_ci.toFixed(3)}) +

+ + {#if densities != null} +
+
+
+ Tokens + {currentComponent.pr_tokens.length > N_TOKENS_TO_DISPLAY + ? `(top ${N_TOKENS_TO_DISPLAY} of ${currentComponent.pr_tokens.length})` + : ""} +
-
-
- {#each densities as { token, recall, precision } (`${token}-${recall}-${precision}`)} - {@const value = metricMode === "recall" ? recall : precision} -
- {token} -
-
+
+ {#each densities as { token, recall, precision } (`${token}-${recall}-${precision}`)} + {@const value = metricMode === "recall" ? recall : precision} +
+ {token} +
+
+
+ {(value * 100).toFixed(1)}%
- {(value * 100).toFixed(1)}% -
- {/each} + {/each} +
-
- {/if} - - -
-{/if} + {/if} + + + + +
+ {/if} +
diff --git a/spd/app/frontend/src/components/LocalAttributionsTab.svelte b/spd/app/frontend/src/components/LocalAttributionsTab.svelte index c24345ecb..f3b26651a 100644 --- a/spd/app/frontend/src/components/LocalAttributionsTab.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsTab.svelte @@ -194,17 +194,12 @@ // Load intervention runs for this graph const runs = await mainApi.getInterventionRuns(data.id); - // Initialize composer selection: from DB or default to all interventable nodes - const composerSelection = data.composerSelection - ? filterInterventableNodes(data.composerSelection) - : filterInterventableNodes(Object.keys(data.nodeImportance)); - return { id: `graph-${idx}-${Date.now()}`, dbId: data.id, label, data, - composerSelection, + composerSelection: filterInterventableNodes(Object.keys(data.nodeImportance)), interventionRuns: runs, activeRunId: null, }; @@ -289,10 +284,9 @@ } // Update composer selection for the active graph - async function handleComposerSelectionChange(selection: SvelteSet) { + function handleComposerSelectionChange(selection: SvelteSet) { if (!activeCard || !activeGraph) return; - // Update local state immediately promptCards = promptCards.map((card) => { if (card.id !== activeCard.id) return card; return { @@ -302,13 +296,6 @@ ), }; }); - - // Persist to backend (fire and forget with error handling) - try { - await mainApi.updateComposerSelection(activeGraph.dbId, Array.from(selection)); - } catch (e) { - console.error("Failed to save composer selection:", e); - } } // Run intervention and save to DB @@ -504,17 +491,12 @@ // Load intervention runs const runs = await mainApi.getInterventionRuns(data.id); - // Initialize composer selection (only interventable nodes) - const composerSelection = data.composerSelection - ? filterInterventableNodes(data.composerSelection) - : filterInterventableNodes(Object.keys(data.nodeImportance)); - return { id: `graph-${idx}-${Date.now()}`, dbId: data.id, label, data, - composerSelection, + composerSelection: filterInterventableNodes(Object.keys(data.nodeImportance)), interventionRuns: runs, activeRunId: null, }; diff --git a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte index 8631f4bd3..2e0e62835 100644 --- a/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte +++ b/spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte @@ -262,6 +262,8 @@ } .examples-scroll-container { + overflow-x: auto; + overflow-y: clip; scrollbar-width: none; } diff --git a/spd/app/frontend/src/components/local-attr/InterventionsView.svelte b/spd/app/frontend/src/components/local-attr/InterventionsView.svelte index a402c801e..f3d16e8e9 100644 --- a/spd/app/frontend/src/components/local-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/local-attr/InterventionsView.svelte @@ -10,7 +10,7 @@ } from "../../lib/localAttributionsTypes"; import { colors, getEdgeColor, getOutputHeaderColor } from "../../lib/colors"; import { lerp } from "./graphUtils"; - import ComponentDetailCard from "./ComponentDetailCard.svelte"; + import NodeTooltip from "./NodeTooltip.svelte"; // Layout constants const COMPONENT_SIZE = 8; @@ -688,33 +688,20 @@ {#if hoveredNode} - {@const summary = activationContextsSummary?.[hoveredNode.layer]?.find( - (s) => s.subcomponent_idx === hoveredNode?.cIdx, - )} - {@const detail = componentDetailsCache[`${hoveredNode.layer}:${hoveredNode.cIdx}`]} - {@const isLoading = componentDetailsLoading[`${hoveredNode.layer}:${hoveredNode.cIdx}`] ?? false} - -
(isHoveringTooltip = true)} - onmouseleave={() => { + (isHoveringTooltip = true)} + onMouseLeave={() => { isHoveringTooltip = false; handleNodeMouseLeave(); }} - > -

{hoveredNode.layer}:{hoveredNode.seqIdx}:{hoveredNode.cIdx}

- -
+ /> {/if}
@@ -1022,28 +1009,4 @@ font-size: 8px; color: var(--text-muted); } - - /* Tooltip */ - .node-tooltip { - position: fixed; - padding: var(--space-3); - background: var(--bg-elevated); - border: 1px solid var(--border-strong); - z-index: 1000; - pointer-events: auto; - font-family: var(--font-mono); - max-width: 400px; - max-height: 400px; - overflow-y: auto; - } - - .node-tooltip h3 { - margin: 0 0 var(--space-2) 0; - font-size: var(--text-base); - font-family: var(--font-mono); - color: var(--accent-primary); - font-weight: 600; - border-bottom: 1px solid var(--border-subtle); - padding-bottom: var(--space-2); - } diff --git a/spd/app/frontend/src/lib/api.ts b/spd/app/frontend/src/lib/api.ts index 066ccfd4e..da4332421 100644 --- a/spd/app/frontend/src/lib/api.ts +++ b/spd/app/frontend/src/lib/api.ts @@ -213,15 +213,3 @@ export async function deleteInterventionRun(runId: number): Promise { } } -/** Update composer selection state for a graph */ -export async function updateComposerSelection(graphId: number, selection: string[] | null): Promise { - const response = await fetch(`${API_URL}/api/intervention/composer`, { - method: "PUT", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ graph_id: graphId, selection }), - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to update composer selection"); - } -} diff --git a/spd/app/frontend/src/lib/localAttributionsTypes.ts b/spd/app/frontend/src/lib/localAttributionsTypes.ts index 38c8681ec..461a56746 100644 --- a/spd/app/frontend/src/lib/localAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/localAttributionsTypes.ts @@ -28,7 +28,6 @@ export type GraphData = { nodeImportance: Record; // node key -> sum of squared edge values maxAbsAttr: number; // max absolute edge value optimization?: OptimizationResult; - composerSelection?: string[] | null; // node keys, null = never set (defaults to all interventable) }; export type OptimizationResult = { From 845d585696dc4c0b6f9b2709aa4b5d034723f4d9 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 9 Dec 2025 14:38:17 +0000 Subject: [PATCH 031/500] reintroduce node pinning --- .../components/LocalAttributionsTab.svelte | 30 +++++++- .../local-attr/StagedNodesPanel.svelte | 76 ++++++------------- 2 files changed, 49 insertions(+), 57 deletions(-) diff --git a/spd/app/frontend/src/components/LocalAttributionsTab.svelte b/spd/app/frontend/src/components/LocalAttributionsTab.svelte index f3b26651a..f7ba37ab1 100644 --- a/spd/app/frontend/src/components/LocalAttributionsTab.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsTab.svelte @@ -7,6 +7,7 @@ type ActivationContextsSummary, type ComponentDetail, type GraphData, + type PinnedNode, type PromptPreview, } from "../lib/localAttributionsTypes"; import ComputeProgressOverlay from "./local-attr/ComputeProgressOverlay.svelte"; @@ -14,6 +15,7 @@ import PromptCardHeader from "./local-attr/PromptCardHeader.svelte"; import PromptCardTabs from "./local-attr/PromptCardTabs.svelte"; import PromptPicker from "./local-attr/PromptPicker.svelte"; + import StagedNodesPanel from "./local-attr/StagedNodesPanel.svelte"; import type { StoredGraph, ComputeOptions, LoadingState, OptimizeConfig, PromptCard } from "./local-attr/types"; import ViewControls from "./local-attr/ViewControls.svelte"; import LocalAttributionsGraph from "./LocalAttributionsGraph.svelte"; @@ -86,6 +88,9 @@ let componentDetailsCache = $state>({}); let componentDetailsLoading = $state>({}); + // Pinned nodes for attributions graph + let pinnedNodes = $state([]); + async function loadComponentDetail(layer: string, cIdx: number) { const cacheKey = `${layer}:${cIdx}`; if (componentDetailsCache[cacheKey] || componentDetailsLoading[cacheKey]) return; @@ -101,6 +106,16 @@ } } + function handlePinnedNodesChange(nodes: PinnedNode[]) { + pinnedNodes = nodes; + // Load component details for any newly pinned nodes + for (const node of nodes) { + if (node.layer !== "wte" && node.layer !== "output") { + loadComponentDetail(node.layer, node.cIdx); + } + } + } + // Tokenize label text when it changes let labelTokenizeTimeout: ReturnType | null = null; $effect(() => { @@ -563,7 +578,7 @@ {}} + onStagedNodesChange={handlePinnedNodesChange} onLoadComponentDetail={loadComponentDetail} onEdgeCountChange={(count) => (filteredEdgeCount = count)} /> {/key} +
{:else} diff --git a/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte index 3e9e47c0c..afc611d2b 100644 --- a/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte +++ b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte @@ -1,25 +1,25 @@ {#if stagedNodes.length > 0}
- Staged Nodes ({stagedNodes.length}) -
- {#if hasInvalidNodes} - Remove wte/output nodes to run - {/if} - - -
+ Pinned Components ({stagedNodes.length}) +
{#each stagedNodes as node, idx (`${node.layer}:${node.seqIdx}:${node.cIdx}-${idx}`)} - {@const detail = componentDetailsCache[`${node.layer}:${node.cIdx}`]} - {@const isLoading = !detail && node.layer !== "output"} + {@const cacheKey = `${node.layer}:${node.cIdx}`} + {@const detail = componentDetailsCache[cacheKey]} + {@const isLoading = componentDetailsLoading[cacheKey] ?? false} + {@const summary = activationContextsSummary?.[node.layer]?.find((s) => s.subcomponent_idx === node.cIdx)} {@const token = getTokenAtPosition(node.seqIdx)}
@@ -74,9 +64,11 @@
@@ -87,9 +79,9 @@ From 65e09a936c6a08b9463caea55f4650335187dc37 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 9 Dec 2025 15:07:47 +0000 Subject: [PATCH 032/500] tidy up component hover window handling --- .../ActivationContextsPagedTable.svelte | 5 - .../src/components/ComponentProbeInput.svelte | 3 +- .../local-attr/ComponentDetailCard.svelte | 332 ------------------ .../local-attr/ComponentNodeCard.svelte | 234 ++++++++++++ .../components/local-attr/NodeTooltip.svelte | 148 ++++++++ .../local-attr/OutputNodeCard.svelte | 128 +++++++ .../local-attr/StagedNodesPanel.svelte | 72 +++- 7 files changed, 564 insertions(+), 358 deletions(-) delete mode 100644 spd/app/frontend/src/components/local-attr/ComponentDetailCard.svelte create mode 100644 spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte create mode 100644 spd/app/frontend/src/components/local-attr/NodeTooltip.svelte create mode 100644 spd/app/frontend/src/components/local-attr/OutputNodeCard.svelte diff --git a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte index bab07c06b..29c018a1b 100644 --- a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte +++ b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte @@ -255,10 +255,5 @@ line-height: 1.8; color: var(--text-primary); white-space: nowrap; - border-bottom: 1px solid var(--border-subtle); - } - - .example-item:last-child { - border-bottom: none; } diff --git a/spd/app/frontend/src/components/ComponentProbeInput.svelte b/spd/app/frontend/src/components/ComponentProbeInput.svelte index c96e1a21d..df71e670e 100644 --- a/spd/app/frontend/src/components/ComponentProbeInput.svelte +++ b/spd/app/frontend/src/components/ComponentProbeInput.svelte @@ -85,10 +85,9 @@ diff --git a/spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte b/spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte new file mode 100644 index 000000000..875028bc0 --- /dev/null +++ b/spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte @@ -0,0 +1,234 @@ + + +
+

+ {#if seqIdx !== undefined} + Position: + {seqIdx} + {/if} + {#if summary} + {#if seqIdx !== undefined}|{/if} + Mean CI: + {summary.mean_ci.toFixed(4)} + {/if} +

+ + + + {#if detail} + {#if detail.example_tokens.length > 0} + {#if showFullTable} + +
+

Activating Examples ({detail.example_tokens.length})

+ {#if compact} + + {/if} +
+ + {:else} + +

Top Activating Examples

+
+ {#each detail.example_tokens.slice(0, COMPACT_MAX_EXAMPLES) as tokens, i (i)} +
+ +
+ {/each} +
+ {#if detail.example_tokens.length > COMPACT_MAX_EXAMPLES} + + {/if} + {/if} + {/if} + +
+ {#if tokenPrecisionsSorted.length > 0} +
+

+ Token Precision + P(firing | token) +

+ + + {#each tokenPrecisionsSorted.slice(0, 10) as { token, precision } (token)} + + + + + {/each} + +
{token}{precision.toFixed(3)}
+
+ {/if} +
+ {:else if isLoading} +

Loading details...

+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/local-attr/NodeTooltip.svelte b/spd/app/frontend/src/components/local-attr/NodeTooltip.svelte new file mode 100644 index 000000000..591213e49 --- /dev/null +++ b/spd/app/frontend/src/components/local-attr/NodeTooltip.svelte @@ -0,0 +1,148 @@ + + + +
+

{hoveredNode.layer}:{hoveredNode.seqIdx}:{hoveredNode.cIdx}

+ {#if isWte} +
+
"{inputToken}"
+

+ Position: {hoveredNode.seqIdx} +

+
+ {:else if isOutput} + + {:else} + {@const cacheKey = `${hoveredNode.layer}:${hoveredNode.cIdx}`} + {@const detail = componentDetailsCache[cacheKey] ?? null} + {@const isLoading = componentDetailsLoading[cacheKey] ?? false} + {@const summary = findComponentSummary(hoveredNode.layer, hoveredNode.cIdx)} + {#if detail} + + {:else} + + {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/local-attr/OutputNodeCard.svelte b/spd/app/frontend/src/components/local-attr/OutputNodeCard.svelte new file mode 100644 index 000000000..bee47f226 --- /dev/null +++ b/spd/app/frontend/src/components/local-attr/OutputNodeCard.svelte @@ -0,0 +1,128 @@ + + +
+ {#if singlePosEntry} +
+
"{escapeHtml(singlePosEntry.token)}"
+
{(singlePosEntry.prob * 100).toFixed(1)}% probability
+
+

+ Position: + {seqIdx} | + Vocab ID: + {cIdx} +

+ {:else if allPositions} +

"{allPositions[0].token}"

+ + + {#each allPositions as pos (pos.seqIdx)} + + + + + {/each} + +
Pos {pos.seqIdx}{(pos.prob * 100).toFixed(2)}%
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte index afc611d2b..bda4c49c0 100644 --- a/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte +++ b/spd/app/frontend/src/components/local-attr/StagedNodesPanel.svelte @@ -1,6 +1,7 @@ @@ -47,11 +56,9 @@
{#each stagedNodes as node, idx (`${node.layer}:${node.seqIdx}:${node.cIdx}-${idx}`)} - {@const cacheKey = `${node.layer}:${node.cIdx}`} - {@const detail = componentDetailsCache[cacheKey]} - {@const isLoading = componentDetailsLoading[cacheKey] ?? false} - {@const summary = activationContextsSummary?.[node.layer]?.find((s) => s.subcomponent_idx === node.cIdx)} {@const token = getTokenAtPosition(node.seqIdx)} + {@const isOutput = node.layer === "output"} + {@const isWte = node.layer === "wte"}
@@ -61,16 +68,36 @@
- + {#if isWte} +

Input embedding at position {node.seqIdx}

+ {:else if isOutput} + + {:else} + {@const cacheKey = `${node.layer}:${node.cIdx}`} + {@const detail = componentDetailsCache[cacheKey] ?? null} + {@const isLoading = componentDetailsLoading[cacheKey] ?? false} + {@const summary = findComponentSummary(node.layer, node.cIdx)} + {#if detail} + + {:else} + + {/if} + {/if}
{/each}
@@ -164,4 +191,11 @@ color: var(--text-primary); border-color: var(--border-strong); } + + .wte-info { + margin: var(--space-2) 0 0 0; + font-size: var(--text-sm); + font-family: var(--font-mono); + color: var(--text-secondary); + } From 6eb5000f97c01802c5807b5ce61c8d34e6d84d15 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 9 Dec 2025 15:12:12 +0000 Subject: [PATCH 033/500] wip: Make seqIdx required in ComponentNodeCard props --- .../src/components/local-attr/ComponentNodeCard.svelte | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte b/spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte index 875028bc0..4198bbf6b 100644 --- a/spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte +++ b/spd/app/frontend/src/components/local-attr/ComponentNodeCard.svelte @@ -7,7 +7,7 @@ type Props = { layer: string; cIdx: number; - seqIdx?: number; + seqIdx: number; summary: ComponentSummary | null; compact: boolean; } & ( @@ -36,10 +36,8 @@

- {#if seqIdx !== undefined} - Position: - {seqIdx} - {/if} + Position: + {seqIdx} {#if summary} {#if seqIdx !== undefined}|{/if} Mean CI: From 2c7efa36d021fa1cebb82019fd215e8ed1ac2833 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 10 Dec 2025 12:11:59 +0000 Subject: [PATCH 034/500] Do post-hoc ci-threshold filtering (#291) * merge origin/feature/app-fixes-1 to local * Simplify * Fix tests * Fix lint * Fix bug * Add test for ci filtering * Fix ci threshold defaults * Fix >= to > ci_threshold * Grey out when loading * Fix lint * fix lint * Fix lint * Use SvelteSet instead of Set everywhere * Remove unintended 'if ci_threshold <= 0' * Raise sql IntegrityError instead of ignoring conflicts * Apply ci_threshold to optimized only when desired * Minor pr fixes * Fix lint * Remove unused 'Max CI' * Simplify db * Remove l0s from db and fix graph structure * Make ViewSettings which is specific to a graph * Address PR comments --- spd/app/backend/compute.py | 83 ++++---- spd/app/backend/db/database.py | 170 +++++++-------- spd/app/backend/optim_cis/run_optim_cis.py | 19 +- spd/app/backend/routers/graphs.py | 137 ++++++++++--- .../components/LocalAttributionsTab.svelte | 194 ++++++++++++------ .../local-attr/ComponentNodeCard.svelte | 5 +- .../components/local-attr/NodeTooltip.svelte | 14 +- .../local-attr/PromptCardHeader.svelte | 20 +- .../local-attr/StagedNodesPanel.svelte | 8 +- .../components/local-attr/ViewControls.svelte | 50 +++++ .../src/components/local-attr/types.ts | 16 +- spd/app/frontend/src/lib/api.ts | 1 - .../frontend/src/lib/localAttributionsApi.ts | 7 +- .../src/lib/localAttributionsTypes.ts | 6 +- tests/app/test_server_api.py | 111 +++++++++- 15 files changed, 567 insertions(+), 274 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 91b79d11a..141caec5b 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -33,12 +33,15 @@ def compute_layer_alive_info( layer_name: str, ci_lower_leaky: dict[str, Tensor], output_probs: Float[Tensor, "1 seq vocab"] | None, - ci_threshold: float, output_prob_threshold: float, n_seq: int, device: str, ) -> LayerAliveInfo: - """Compute alive info for a layer. Handles regular, wte, and output layers.""" + """Compute alive info for a layer. Handles regular, wte, and output layers. + + For CI layers, all components with CI > 0 are considered alive. + Filtering by CI threshold is done at display time, not computation time. + """ if layer_name == "wte": # WTE: single pseudo-component, always alive at all positions alive_mask = torch.ones(n_seq, 1, device=device, dtype=torch.bool) @@ -46,12 +49,12 @@ def compute_layer_alive_info( elif layer_name == "output": assert output_probs is not None assert output_probs.shape[0] == 1 - alive_mask = output_probs[0] >= output_prob_threshold + alive_mask = output_probs[0] > output_prob_threshold alive_c_idxs = torch.where(alive_mask.any(dim=0))[0].tolist() else: ci = ci_lower_leaky[layer_name] assert ci.shape[0] == 1 - alive_mask = ci[0] >= ci_threshold + alive_mask = ci[0] > 0.0 alive_c_idxs = torch.where(alive_mask.any(dim=0))[0].tolist() return LayerAliveInfo(alive_mask, alive_c_idxs) @@ -84,15 +87,7 @@ class LocalAttributionResult: edges: list[Edge] output_probs: Float[Tensor, "seq vocab"] # Softmax probabilities for output logits - - -@dataclass -class OptimizationStats: - """Statistics from CI optimization.""" - - label_prob: float # P(label_token) with optimized CI mask - l0_total: float # Total L0 across all layers - l0_per_layer: dict[str, float] # L0 per layer + node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val @dataclass @@ -101,7 +96,8 @@ class OptimizedLocalAttributionResult: edges: list[Edge] output_probs: Float[Tensor, "seq vocab"] - stats: OptimizationStats + label_prob: float # P(label_token) with optimized CI mask + node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: @@ -226,7 +222,6 @@ def compute_edges_from_ci( tokens: Float[Tensor, "1 seq"], ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], sources_by_target: dict[str, list[str]], - ci_threshold: float, output_prob_threshold: float, device: str, show_progress: bool, @@ -235,7 +230,8 @@ def compute_edges_from_ci( """Core edge computation from pre-computed CI values. Computes gradient-based attribution edges between components using the - provided CI values for masking and thresholding. + provided CI values for masking. All components with CI > 0 are included; + filtering by CI threshold is done at display time. Use compute_local_attributions() for automatic CI computation, or compute_local_attributions_optimized() for optimized sparse CI values. @@ -283,7 +279,12 @@ def wte_hook( alive_info: dict[str, LayerAliveInfo] = {} for layer in all_layers: alive_info[layer] = compute_layer_alive_info( - layer, ci_lower_leaky, output_probs, ci_threshold, output_prob_threshold, n_seq, device + layer_name=layer, + ci_lower_leaky=ci_lower_leaky, + output_probs=output_probs, + output_prob_threshold=output_prob_threshold, + n_seq=n_seq, + device=device, ) edges: list[Edge] = [] @@ -365,14 +366,14 @@ def wte_hook( if pbar is not None: pbar.close() - return LocalAttributionResult(edges=edges, output_probs=output_probs) + node_ci_vals = extract_node_ci_vals(ci_lower_leaky) + return LocalAttributionResult(edges=edges, output_probs=output_probs, node_ci_vals=node_ci_vals) def compute_local_attributions( model: ComponentModel, tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], - ci_threshold: float, output_prob_threshold: float, sampling: SamplingType, device: str, @@ -397,7 +398,6 @@ def compute_local_attributions( tokens=tokens, ci_lower_leaky=ci.lower_leaky, sources_by_target=sources_by_target, - ci_threshold=ci_threshold, output_prob_threshold=output_prob_threshold, device=device, show_progress=show_progress, @@ -411,7 +411,6 @@ def compute_local_attributions_optimized( label_token: int, sources_by_target: dict[str, list[str]], optim_config: OptimCIConfig, - ci_threshold: float, output_prob_threshold: float, device: str, show_progress: bool, @@ -421,6 +420,9 @@ def compute_local_attributions_optimized( Runs CI optimization to find a minimal sparse mask that preserves the model's prediction of label_token, then computes edges. + + L0 stats are computed dynamically at display time from node_ci_vals, + not here at computation time. """ ci_params = optimize_ci_values( model=model, @@ -432,13 +434,6 @@ def compute_local_attributions_optimized( ) ci_outputs = ci_params.create_ci_outputs(model, device) - # Compute optimization stats - l0_per_layer: dict[str, float] = {} - for layer_name, ci_tensor in ci_outputs.lower_leaky.items(): - # L0 = count of components with CI > threshold, averaged over sequence - l0_per_layer[layer_name] = float((ci_tensor > ci_threshold).float().sum().item()) - l0_total = sum(l0_per_layer.values()) - # Get label probability with optimized CI mask with torch.no_grad(): mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") @@ -446,12 +441,6 @@ def compute_local_attributions_optimized( probs = torch.softmax(logits[0, -1, :], dim=-1) label_prob = float(probs[label_token].item()) - stats = OptimizationStats( - label_prob=label_prob, - l0_total=l0_total, - l0_per_layer=l0_per_layer, - ) - # Signal transition to graph computation stage if on_progress is not None: on_progress(0, 1, "graph") @@ -461,7 +450,6 @@ def compute_local_attributions_optimized( tokens=tokens, ci_lower_leaky=ci_outputs.lower_leaky, sources_by_target=sources_by_target, - ci_threshold=ci_threshold, output_prob_threshold=output_prob_threshold, device=device, show_progress=show_progress, @@ -471,7 +459,8 @@ def compute_local_attributions_optimized( return OptimizedLocalAttributionResult( edges=result.edges, output_probs=result.output_probs, - stats=stats, + label_prob=label_prob, + node_ci_vals=result.node_ci_vals, ) @@ -513,6 +502,28 @@ def compute_ci_only( return CIOnlyResult(ci_lower_leaky=ci.lower_leaky, output_probs=output_probs) +def extract_node_ci_vals( + ci_lower_leaky: dict[str, Float[Tensor, "1 seq n_components"]], +) -> dict[str, float]: + """Extract per-node CI values from CI tensors. + + Args: + ci_lower_leaky: Dict mapping layer name to CI tensor [1, seq, n_components]. + + Returns: + Dict mapping "layer:seq:c_idx" to CI value. + """ + node_ci_vals: dict[str, float] = {} + for layer_name, ci_tensor in ci_lower_leaky.items(): + n_seq = ci_tensor.shape[1] + n_components = ci_tensor.shape[2] + for seq_pos in range(n_seq): + for c_idx in range(n_components): + key = f"{layer_name}:{seq_pos}:{c_idx}" + node_ci_vals[key] = float(ci_tensor[0, seq_pos, c_idx].item()) + return node_ci_vals + + def extract_active_from_ci( ci_lower_leaky: dict[str, Float[Tensor, "1 seq n_components"]], output_probs: Float[Tensor, "1 seq vocab"], diff --git a/spd/app/backend/db/database.py b/spd/app/backend/db/database.py index 3647ef065..ceea35c21 100644 --- a/spd/app/backend/db/database.py +++ b/spd/app/backend/db/database.py @@ -49,14 +49,6 @@ class OptimizationParams(BaseModel): pnorm: float -class OptimizationStats(BaseModel): - """Statistics from optimized graph computation.""" - - label_prob: float - l0_total: float - l0_per_layer: dict[str, float] - - class StoredGraph(BaseModel): """A stored attribution graph.""" @@ -65,15 +57,18 @@ class StoredGraph(BaseModel): id: int = -1 # -1 for unsaved graphs, set by DB on save edges: list[Edge] output_probs: dict[str, OutputProbability] + node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val (required for all graphs) optimization_params: OptimizationParams | None = None - optimization_stats: OptimizationStats | None = None + label_prob: float | None = ( + None # P(label_token) with optimized CI mask, only for optimized graphs + ) class InterventionRunRecord(BaseModel): """A stored intervention run.""" id: int - cached_graph_id: int + graph_id: int selected_nodes: list[str] # node keys that were selected result_json: str # JSON-encoded InterventionResponse created_at: str @@ -86,7 +81,8 @@ class LocalAttrDB: - runs: One row per SPD run (keyed by wandb_path) - activation_contexts: Component metadata + generation config, 1:1 with runs - prompts: One row per stored prompt (token sequence), keyed by run_id - - component_activations: Inverted index mapping components to prompts + - original_component_seq_max_activations: Inverted index mapping components to prompts by a + component's max activation for that prompt Attribution graphs (edges) are computed on-demand at serve time, not stored. """ @@ -146,7 +142,7 @@ def init_schema(self) -> None: is_custom INTEGER NOT NULL DEFAULT 0 ); - CREATE TABLE IF NOT EXISTS component_activations ( + CREATE TABLE IF NOT EXISTS original_component_seq_max_activations ( prompt_id INTEGER NOT NULL REFERENCES prompts(id), component_key TEXT NOT NULL, max_ci REAL NOT NULL, @@ -156,11 +152,11 @@ def init_schema(self) -> None: CREATE INDEX IF NOT EXISTS idx_prompts_run_id ON prompts(run_id); CREATE INDEX IF NOT EXISTS idx_component_key - ON component_activations(component_key); + ON original_component_seq_max_activations(component_key); CREATE INDEX IF NOT EXISTS idx_prompt_id - ON component_activations(prompt_id); + ON original_component_seq_max_activations(prompt_id); - CREATE TABLE IF NOT EXISTS cached_graphs ( + CREATE TABLE IF NOT EXISTS graphs ( id INTEGER PRIMARY KEY AUTOINCREMENT, prompt_id INTEGER NOT NULL REFERENCES prompts(id), is_optimized INTEGER NOT NULL, @@ -176,35 +172,36 @@ def init_schema(self) -> None: edges_data BLOB NOT NULL, output_probs_data BLOB NOT NULL, + -- Node CI values: "layer:seq:c_idx" -> ci_val (required for all graphs) + node_ci_vals TEXT NOT NULL, + -- Optimization stats (NULL for standard graphs) label_prob REAL, - l0_total REAL, - l0_per_layer TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); - CREATE UNIQUE INDEX IF NOT EXISTS idx_cached_graphs_standard - ON cached_graphs(prompt_id) + CREATE UNIQUE INDEX IF NOT EXISTS idx_graphs_standard + ON graphs(prompt_id) WHERE is_optimized = 0; - CREATE UNIQUE INDEX IF NOT EXISTS idx_cached_graphs_optimized - ON cached_graphs(prompt_id, label_token, imp_min_coeff, ce_loss_coeff, steps, pnorm) + CREATE UNIQUE INDEX IF NOT EXISTS idx_graphs_optimized + ON graphs(prompt_id, label_token, imp_min_coeff, ce_loss_coeff, steps, pnorm) WHERE is_optimized = 1; - CREATE INDEX IF NOT EXISTS idx_cached_graphs_prompt - ON cached_graphs(prompt_id); + CREATE INDEX IF NOT EXISTS idx_graphs_prompt + ON graphs(prompt_id); CREATE TABLE IF NOT EXISTS intervention_runs ( id INTEGER PRIMARY KEY AUTOINCREMENT, - cached_graph_id INTEGER NOT NULL REFERENCES cached_graphs(id), + graph_id INTEGER NOT NULL REFERENCES graphs(id), selected_nodes TEXT NOT NULL, -- JSON array of node keys result TEXT NOT NULL, -- JSON InterventionResponse created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_intervention_runs_graph - ON intervention_runs(cached_graph_id); + ON intervention_runs(graph_id); """) conn.commit() @@ -337,7 +334,7 @@ def add_prompts( if component_rows: conn.executemany( - """INSERT INTO component_activations + """INSERT INTO original_component_seq_max_activations (prompt_id, component_key, max_ci, positions) VALUES (?, ?, ?, ?)""", component_rows, ) @@ -377,7 +374,7 @@ def add_custom_prompt( ] if component_rows: conn.executemany( - """INSERT INTO component_activations + """INSERT INTO original_component_seq_max_activations (prompt_id, component_key, max_ci, positions) VALUES (?, ?, ?, ?)""", component_rows, ) @@ -453,7 +450,7 @@ def find_prompts_with_components( if require_all: query = f""" SELECT ca.prompt_id - FROM component_activations ca + FROM original_component_seq_max_activations ca JOIN prompts p ON ca.prompt_id = p.id WHERE p.run_id = ? AND ca.component_key IN ({placeholders}) GROUP BY ca.prompt_id @@ -463,7 +460,7 @@ def find_prompts_with_components( else: query = f""" SELECT DISTINCT ca.prompt_id - FROM component_activations ca + FROM original_component_seq_max_activations ca JOIN prompts p ON ca.prompt_id = p.id WHERE p.run_id = ? AND ca.component_key IN ({placeholders}) """ @@ -472,19 +469,22 @@ def find_prompts_with_components( return [row["prompt_id"] for row in rows] # ------------------------------------------------------------------------- - # Cached graph operations + # Graph operations # ------------------------------------------------------------------------- def save_graph( self, prompt_id: int, graph: StoredGraph, - ) -> None: + ) -> int: """Save a computed graph for a prompt. Args: prompt_id: The prompt ID. graph: The graph to save. + + Returns: + The database ID of the saved graph. """ conn = self._get_conn() @@ -495,48 +495,57 @@ def save_graph( probs_json = json.dumps({k: v.model_dump() for k, v in graph.output_probs.items()}) probs_compressed = gzip.compress(probs_json.encode("utf-8")) + node_ci_vals_json = json.dumps(graph.node_ci_vals) is_optimized = 1 if graph.optimization_params else 0 + # Extract optimization-specific values (NULL for standard graphs) + label_token = None + imp_min_coeff = None + ce_loss_coeff = None + steps = None + pnorm = None + label_prob = None + if graph.optimization_params: - assert graph.optimization_stats is not None, ( - "optimization_stats required for optimized graphs" - ) - conn.execute( - """INSERT INTO cached_graphs + assert graph.label_prob is not None, "label_prob required for optimized graphs" + label_token = graph.optimization_params.label_token + imp_min_coeff = graph.optimization_params.imp_min_coeff + ce_loss_coeff = graph.optimization_params.ce_loss_coeff + steps = graph.optimization_params.steps + pnorm = graph.optimization_params.pnorm + label_prob = graph.label_prob + + try: + cursor = conn.execute( + """INSERT INTO graphs (prompt_id, is_optimized, label_token, imp_min_coeff, ce_loss_coeff, steps, pnorm, - edges_data, output_probs_data, - label_prob, l0_total, l0_per_layer) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - prompt_id, - is_optimized, - graph.optimization_params.label_token, - graph.optimization_params.imp_min_coeff, - graph.optimization_params.ce_loss_coeff, - graph.optimization_params.steps, - graph.optimization_params.pnorm, - edges_compressed, - probs_compressed, - graph.optimization_stats.label_prob, - graph.optimization_stats.l0_total, - json.dumps(graph.optimization_stats.l0_per_layer), - ), - ) - else: - conn.execute( - """INSERT INTO cached_graphs - (prompt_id, is_optimized, edges_data, output_probs_data) - VALUES (?, ?, ?, ?)""", + edges_data, output_probs_data, node_ci_vals, + label_prob) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( prompt_id, is_optimized, + label_token, + imp_min_coeff, + ce_loss_coeff, + steps, + pnorm, edges_compressed, probs_compressed, + node_ci_vals_json, + label_prob, ), ) - - conn.commit() + conn.commit() + graph_id = cursor.lastrowid + assert graph_id is not None + return graph_id + except sqlite3.IntegrityError as e: + raise ValueError( + f"Graph already exists for prompt_id={prompt_id}. " + "Use get_graphs() to retrieve existing graph or delete it first." + ) from e def get_graphs(self, prompt_id: int) -> list[StoredGraph]: """Retrieve all stored graphs for a prompt. @@ -550,10 +559,10 @@ def get_graphs(self, prompt_id: int) -> list[StoredGraph]: conn = self._get_conn() rows = conn.execute( - """SELECT id, is_optimized, edges_data, output_probs_data, + """SELECT id, is_optimized, edges_data, output_probs_data, node_ci_vals, label_token, imp_min_coeff, ce_loss_coeff, steps, pnorm, - label_prob, l0_total, l0_per_layer - FROM cached_graphs + label_prob + FROM graphs WHERE prompt_id = ? ORDER BY is_optimized, created_at""", (prompt_id,), @@ -575,8 +584,10 @@ def _edge_from_dict(d: dict[str, Any]) -> Edge: probs_json = json.loads(gzip.decompress(row["output_probs_data"]).decode("utf-8")) output_probs = {k: OutputProbability(**v) for k, v in probs_json.items()} + node_ci_vals: dict[str, float] = json.loads(row["node_ci_vals"]) + opt_params: OptimizationParams | None = None - opt_stats: OptimizationStats | None = None + label_prob: float | None = None if row["is_optimized"]: opt_params = OptimizationParams( @@ -586,19 +597,16 @@ def _edge_from_dict(d: dict[str, Any]) -> Edge: steps=row["steps"], pnorm=row["pnorm"], ) - opt_stats = OptimizationStats( - label_prob=row["label_prob"], - l0_total=row["l0_total"], - l0_per_layer=json.loads(row["l0_per_layer"]), - ) + label_prob = row["label_prob"] results.append( StoredGraph( id=row["id"], edges=edges, output_probs=output_probs, + node_ci_vals=node_ci_vals, optimization_params=opt_params, - optimization_stats=opt_stats, + label_prob=label_prob, ) ) @@ -607,7 +615,7 @@ def _edge_from_dict(d: dict[str, Any]) -> Edge: def delete_graphs_for_prompt(self, prompt_id: int) -> int: """Delete all graphs for a prompt. Returns the number of deleted rows.""" conn = self._get_conn() - cursor = conn.execute("DELETE FROM cached_graphs WHERE prompt_id = ?", (prompt_id,)) + cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) conn.commit() return cursor.rowcount @@ -615,7 +623,7 @@ def delete_graphs_for_run(self, run_id: int) -> int: """Delete all graphs for all prompts in a run. Returns the number of deleted rows.""" conn = self._get_conn() cursor = conn.execute( - """DELETE FROM cached_graphs + """DELETE FROM graphs WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", (run_id,), ) @@ -635,7 +643,7 @@ def save_intervention_run( """Save an intervention run. Args: - graph_id: The cached graph ID this run belongs to. + graph_id: The graph ID this run belongs to. selected_nodes: List of node keys that were selected. result_json: JSON-encoded InterventionResponse. @@ -644,7 +652,7 @@ def save_intervention_run( """ conn = self._get_conn() cursor = conn.execute( - """INSERT INTO intervention_runs (cached_graph_id, selected_nodes, result) + """INSERT INTO intervention_runs (graph_id, selected_nodes, result) VALUES (?, ?, ?)""", (graph_id, json.dumps(selected_nodes), result_json), ) @@ -657,16 +665,16 @@ def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: """Get all intervention runs for a graph. Args: - graph_id: The cached graph ID. + graph_id: The graph ID. Returns: List of intervention run records, ordered by creation time. """ conn = self._get_conn() rows = conn.execute( - """SELECT id, cached_graph_id, selected_nodes, result, created_at + """SELECT id, graph_id, selected_nodes, result, created_at FROM intervention_runs - WHERE cached_graph_id = ? + WHERE graph_id = ? ORDER BY created_at""", (graph_id,), ).fetchall() @@ -674,7 +682,7 @@ def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: return [ InterventionRunRecord( id=row["id"], - cached_graph_id=row["cached_graph_id"], + graph_id=row["graph_id"], selected_nodes=json.loads(row["selected_nodes"]), result_json=row["result"], created_at=row["created_at"], @@ -691,8 +699,6 @@ def delete_intervention_run(self, run_id: int) -> None: def delete_intervention_runs_for_graph(self, graph_id: int) -> int: """Delete all intervention runs for a graph. Returns count deleted.""" conn = self._get_conn() - cursor = conn.execute( - "DELETE FROM intervention_runs WHERE cached_graph_id = ?", (graph_id,) - ) + cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) conn.commit() return cursor.rowcount diff --git a/spd/app/backend/optim_cis/run_optim_cis.py b/spd/app/backend/optim_cis/run_optim_cis.py index dc21136d9..fe362b9cb 100644 --- a/spd/app/backend/optim_cis/run_optim_cis.py +++ b/spd/app/backend/optim_cis/run_optim_cis.py @@ -32,14 +32,13 @@ class AliveComponentInfo: def compute_alive_info( ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], - ci_threshold: float, ) -> AliveComponentInfo: - """Compute which (position, component) pairs are alive based on initial CI values.""" + """Compute which (position, component) pairs are alive (CI > 0).""" alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = {} alive_counts: dict[str, list[int]] = {} for layer_name, ci in ci_lower_leaky.items(): - mask = ci > ci_threshold + mask = ci > 0.0 alive_masks[layer_name] = mask # Count alive components per position: mask is [1, seq, C], sum over C counts_per_pos = mask[0].sum(dim=-1) # [seq] @@ -227,8 +226,6 @@ class OptimCIConfig: imp_min_config: ImportanceMinimalityLossConfig ce_loss_coeff: float - # CI thresholds and sampling - ci_threshold: float sampling: SamplingType ce_kl_rounding_threshold: float @@ -274,20 +271,12 @@ def optimize_ci_values( target_out = output_with_cache.output.detach() # Compute alive info and create optimizable parameters - alive_info = compute_alive_info(initial_ci_outputs.lower_leaky, config.ci_threshold) + alive_info = compute_alive_info(initial_ci_outputs.lower_leaky) ci_params = create_optimizable_ci_params( alive_info=alive_info, initial_pre_sigmoid=initial_ci_outputs.pre_sigmoid, ) - # Log initial alive counts - total_alive = sum(sum(counts) for counts in alive_info.alive_counts.values()) - print(f"\nAlive components (CI > {config.ci_threshold}):") - for layer_name, counts in alive_info.alive_counts.items(): - layer_total = sum(counts) - print(f" {layer_name}: {layer_total} total across {len(counts)} positions") - print(f" Total: {total_alive}") - weight_deltas = model.calc_weight_deltas() params = ci_params.get_parameters() @@ -328,7 +317,7 @@ def optimize_ci_values( label_prob = F.softmax(out[0, -1, :], dim=-1)[label_token] if step % config.log_freq == 0 or step == config.steps - 1: - l0_stats = compute_l0_stats(ci_outputs, config.ci_threshold) + l0_stats = compute_l0_stats(ci_outputs, ci_alive_threshold=0.0) # Compute CE/KL metrics for final token only with torch.no_grad(): diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 18626b7f1..427624a1f 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -21,7 +21,6 @@ ) from spd.app.backend.db.database import ( OptimizationParams, - OptimizationStats, StoredGraph, ) from spd.app.backend.dependencies import DepLoadedRun, DepStateManager @@ -63,6 +62,52 @@ def tokenize_text(text: str, loaded: DepLoadedRun) -> TokenizeResponse: NormalizeType = Literal["none", "target", "layer"] +def filter_edges_by_ci_threshold( + edges: list[Edge], + ci_threshold: float, + node_ci_vals: dict[str, float], +) -> list[Edge]: + """Filter edges by removing those where source or target has CI < ci_threshold. + + Args: + edges: List of edges to filter + ci_threshold: Threshold for filtering + node_ci_vals: CI values per node (layer:seq:c_idx -> ci_val) + + Returns: + Filtered list of edges + """ + return [ + edge + for edge in edges + if node_ci_vals.get(str(edge.source), 0.0) >= ci_threshold + and node_ci_vals.get(str(edge.target), 0.0) >= ci_threshold + ] + + +def compute_l0_from_node_ci_vals( + node_ci_vals: dict[str, float], + ci_threshold: float, +) -> tuple[float, dict[str, float]]: + """Compute L0 stats dynamically from node CI values. + + Args: + node_ci_vals: CI values per node (layer:seq:c_idx -> ci_val) + ci_threshold: Threshold for counting a component as active + + Returns: + (l0_total, l0_per_layer) where l0_per_layer maps layer name to count + """ + l0_per_layer: dict[str, float] = {} + for key, ci_val in node_ci_vals.items(): + if ci_val > ci_threshold: + # Key format: "layer:seq:c_idx" - extract layer name + layer = key.rsplit(":", 2)[0] + l0_per_layer[layer] = l0_per_layer.get(layer, 0.0) + 1.0 + l0_total = sum(l0_per_layer.values()) + return l0_total, l0_per_layer + + def compute_edge_stats(edges: list[Edge]) -> tuple[dict[str, float], float]: """Compute node importance and max absolute edge value. @@ -90,9 +135,9 @@ def compute_graph_stream( normalize: Annotated[NormalizeType, Query()], loaded: DepLoadedRun, manager: DepStateManager, + ci_threshold: Annotated[float, Query()], ): """Compute attribution graph for a prompt with streaming progress.""" - ci_threshold = 1e-6 output_prob_threshold = 0.01 db = manager.db @@ -115,7 +160,6 @@ def compute_thread() -> None: model=loaded.model, tokens=tokens_tensor, sources_by_target=loaded.sources_by_target, - ci_threshold=ci_threshold, output_prob_threshold=output_prob_threshold, sampling=loaded.config.sampling, device=DEVICE, @@ -161,22 +205,28 @@ def generate() -> Generator[str]: token=loaded.token_strings[c_idx], ) - # Save graph (raw, unnormalized edges) - db.save_graph( + # Store all edges (unfiltered, unnormalized) with CI values + graph_id = db.save_graph( prompt_id=prompt_id, graph=StoredGraph( edges=raw_edges, output_probs=raw_output_probs, + node_ci_vals=result.node_ci_vals, ), ) # Process edges for response edges_data, node_importance, max_abs_attr = process_edges_for_response( - raw_edges, normalize, num_tokens=len(token_ids), is_optimized=False + edges=raw_edges, + normalize=normalize, + num_tokens=len(token_ids), + ci_threshold=ci_threshold, + node_ci_vals=result.node_ci_vals, + is_optimized=False, ) response_data = GraphData( - id=prompt_id, + id=graph_id, tokens=token_strings, edges=edges_data, outputProbs=raw_output_probs, @@ -246,10 +296,10 @@ def compute_graph_optimized_stream( output_prob_threshold: Annotated[float, Query(ge=0, le=1)], loaded: DepLoadedRun, manager: DepStateManager, + ci_threshold: Annotated[float, Query()], ): """Compute optimized attribution graph for a prompt with streaming progress.""" lr = 1e-2 - ci_threshold = 1e-6 db = manager.db prompt = db.get_prompt(prompt_id) @@ -280,7 +330,6 @@ def compute_graph_optimized_stream( log_freq=max(1, steps // 4), imp_min_config=ImportanceMinimalityLossConfig(coeff=imp_min_coeff, pnorm=pnorm), ce_loss_coeff=ce_loss_coeff, - ci_threshold=ci_threshold, sampling=loaded.config.sampling, ce_kl_rounding_threshold=0.5, ) @@ -298,7 +347,6 @@ def compute_thread() -> None: label_token=label_token, sources_by_target=loaded.sources_by_target, optim_config=optim_config, - ci_threshold=ci_threshold, output_prob_threshold=output_prob_threshold, device=DEVICE, show_progress=False, @@ -343,28 +391,34 @@ def generate() -> Generator[str]: token=loaded.token_strings[c_idx], ) - # Save graph (raw, unnormalized edges) - db.save_graph( + # Store all edges (unfiltered, unnormalized) with CI values + graph_id = db.save_graph( prompt_id=prompt_id, graph=StoredGraph( edges=raw_edges, output_probs=raw_output_probs, + node_ci_vals=result.node_ci_vals, optimization_params=opt_params, - optimization_stats=OptimizationStats( - label_prob=result.stats.label_prob, - l0_total=result.stats.l0_total, - l0_per_layer=result.stats.l0_per_layer, - ), + label_prob=result.label_prob, ), ) - # Process edges for response edges_data, node_importance, max_abs_attr = process_edges_for_response( - raw_edges, normalize, num_tokens=len(token_ids), is_optimized=True + edges=raw_edges, + normalize=normalize, + num_tokens=len(token_ids), + ci_threshold=ci_threshold, + node_ci_vals=result.node_ci_vals, + is_optimized=True, + ) + + l0_total, l0_per_layer = compute_l0_from_node_ci_vals( + node_ci_vals=result.node_ci_vals, + ci_threshold=ci_threshold, ) response_data = GraphDataWithOptimization( - id=prompt_id, + id=graph_id, tokens=token_strings, edges=edges_data, outputProbs=raw_output_probs, @@ -376,9 +430,9 @@ def generate() -> Generator[str]: imp_min_coeff=imp_min_coeff, ce_loss_coeff=ce_loss_coeff, steps=steps, - label_prob=result.stats.label_prob, - l0_total=result.stats.l0_total, - l0_per_layer=result.stats.l0_per_layer, + label_prob=result.label_prob, + l0_total=l0_total, + l0_per_layer=l0_per_layer, ), ) complete_data = {"type": "complete", "data": response_data.model_dump()} @@ -394,6 +448,8 @@ def process_edges_for_response( edges: list[Edge], normalize: NormalizeType, num_tokens: int, + ci_threshold: float, + node_ci_vals: dict[str, float], is_optimized: bool, edge_limit: int = GLOBAL_EDGE_LIMIT, ) -> tuple[list[EdgeData], dict[str, float], float]: @@ -406,6 +462,8 @@ def process_edges_for_response( edges: Raw edges from computation or database normalize: Normalization type ("none", "target", "layer") num_tokens: Number of tokens in the prompt (for filtering) + ci_threshold: Threshold for filtering edges by CI + node_ci_vals: CI values per node (layer:seq:c_idx -> ci_val) is_optimized: Whether this is an optimized graph (applies additional filtering) edge_limit: Maximum number of edges to return @@ -415,8 +473,15 @@ def process_edges_for_response( if is_optimized: final_seq_pos = num_tokens - 1 edges = [edge for edge in edges if edge.target.seq_pos == final_seq_pos] - edges = _normalize_edges(edges, normalize) - node_importance, max_abs_attr = compute_edge_stats(edges) + + edges = filter_edges_by_ci_threshold( + edges=edges, + ci_threshold=ci_threshold, + node_ci_vals=node_ci_vals, + ) + + edges = _normalize_edges(edges=edges, normalize=normalize) + node_importance, max_abs_attr = compute_edge_stats(edges=edges) # Clip to edge limit for response if len(edges) > edge_limit: print(f"[WARNING] Edge limit {edge_limit} exceeded ({len(edges)} edges), truncating") @@ -430,6 +495,7 @@ def process_edges_for_response( def get_graphs( prompt_id: int, normalize: Annotated[NormalizeType, Query()], + ci_threshold: Annotated[float, Query(ge=0)], loaded: DepLoadedRun, manager: DepStateManager, ) -> list[GraphData | GraphDataWithOptimization]: @@ -451,7 +517,12 @@ def get_graphs( for graph in stored_graphs: is_optimized = graph.optimization_params is not None edges_data, node_importance, max_abs_attr = process_edges_for_response( - graph.edges, normalize, num_tokens, is_optimized + edges=graph.edges, + normalize=normalize, + num_tokens=num_tokens, + ci_threshold=ci_threshold, + node_ci_vals=graph.node_ci_vals, + is_optimized=is_optimized, ) if not is_optimized: @@ -469,7 +540,13 @@ def get_graphs( else: # Optimized graph assert graph.optimization_params is not None - assert graph.optimization_stats is not None + assert graph.label_prob is not None + + l0_total, l0_per_layer = compute_l0_from_node_ci_vals( + node_ci_vals=graph.node_ci_vals, + ci_threshold=ci_threshold, + ) + results.append( GraphDataWithOptimization( id=graph.id, @@ -484,9 +561,9 @@ def get_graphs( imp_min_coeff=graph.optimization_params.imp_min_coeff, ce_loss_coeff=graph.optimization_params.ce_loss_coeff, steps=graph.optimization_params.steps, - label_prob=graph.optimization_stats.label_prob, - l0_total=graph.optimization_stats.l0_total, - l0_per_layer=graph.optimization_stats.l0_per_layer, + label_prob=graph.label_prob, + l0_total=l0_total, + l0_per_layer=l0_per_layer, ), ) ) diff --git a/spd/app/frontend/src/components/LocalAttributionsTab.svelte b/spd/app/frontend/src/components/LocalAttributionsTab.svelte index f7ba37ab1..b6dd26bf1 100644 --- a/spd/app/frontend/src/components/LocalAttributionsTab.svelte +++ b/spd/app/frontend/src/components/LocalAttributionsTab.svelte @@ -16,7 +16,14 @@ import PromptCardTabs from "./local-attr/PromptCardTabs.svelte"; import PromptPicker from "./local-attr/PromptPicker.svelte"; import StagedNodesPanel from "./local-attr/StagedNodesPanel.svelte"; - import type { StoredGraph, ComputeOptions, LoadingState, OptimizeConfig, PromptCard } from "./local-attr/types"; + import type { + StoredGraph, + ComputeOptions, + LoadingState, + OptimizeConfig, + PromptCard, + ViewSettings, + } from "./local-attr/types"; import ViewControls from "./local-attr/ViewControls.svelte"; import LocalAttributionsGraph from "./LocalAttributionsGraph.svelte"; @@ -58,20 +65,25 @@ let generateProgress = $state(0); let generateCount = $state(0); - // Activation contexts - passed as props from App.svelte + // Refetching state (for CI threshold/normalize changes) - tracks which graph is being refetched + let refetchingGraphId = $state(null); - // View controls - let topK = $state(200); - let nodeLayout = $state<"importance" | "shuffled" | "jittered">("importance"); - let componentGap = $state(4); - let layerGap = $state(30); + // Default view settings for new graphs + const defaultViewSettings: ViewSettings = { + topK: 200, + nodeLayout: "importance", + componentGap: 4, + layerGap: 30, + normalizeEdges: "layer", + ciThreshold: 0, + }; + + // Edge count is derived from the graph rendering, not stored per-graph let filteredEdgeCount = $state(null); - let normalizeEdges = $state("layer"); // Compute options let computeOptions = $state({ - maxMeanCI: 1.0, - normalizeEdges: "layer", + ciThreshold: 0, useOptimized: false, optimizeConfig: { labelTokenText: "", @@ -200,7 +212,11 @@ // Fetch stored graphs for this prompt (includes composer selection and intervention runs) let graphs: StoredGraph[] = []; try { - const storedGraphs = await attrApi.getGraphs(promptId, normalizeEdges); + const storedGraphs = await attrApi.getGraphs( + promptId, + defaultViewSettings.normalizeEdges, + defaultViewSettings.ciThreshold, + ); graphs = await Promise.all( storedGraphs.map(async (data, idx) => { const isOptimized = !!data.optimization; @@ -214,6 +230,7 @@ dbId: data.id, label, data, + viewSettings: { ...defaultViewSettings }, composerSelection: filterInterventableNodes(Object.keys(data.nodeImportance)), interventionRuns: runs, activeRunId: null, @@ -428,12 +445,13 @@ { promptId: activeCard.promptId, labelToken: optConfig.labelTokenId, - normalize: computeOptions.normalizeEdges, + normalize: defaultViewSettings.normalizeEdges, impMinCoeff: optConfig.impMinCoeff, ceLossCoeff: optConfig.ceLossCoeff, steps: optConfig.steps, pnorm: optConfig.pnorm, outputProbThreshold: 0.01, + ciThreshold: defaultViewSettings.ciThreshold, }, (progress) => { if (!loadingState) return; @@ -449,7 +467,8 @@ data = await attrApi.computeGraphStreaming( { promptId: activeCard.promptId, - normalize: computeOptions.normalizeEdges, + normalize: defaultViewSettings.normalizeEdges, + ciThreshold: defaultViewSettings.ciThreshold, }, (progress) => { if (!loadingState) return; @@ -472,6 +491,7 @@ dbId: data.id, label, data, + viewSettings: { ...defaultViewSettings }, composerSelection: filterInterventableNodes(Object.keys(data.nodeImportance)), interventionRuns: [], activeRunId: null, @@ -488,47 +508,83 @@ } } - async function handleNormalizeChange(value: attrApi.NormalizeType) { - normalizeEdges = value; - computeOptions.normalizeEdges = value; - - const updatedCards = await Promise.all( - promptCards.map(async (card) => { - if (card.graphs.length === 0) return card; - - try { - const storedGraphs = await attrApi.getGraphs(card.promptId, normalizeEdges); - const graphs = await Promise.all( - storedGraphs.map(async (data, idx) => { - const isOptimized = !!data.optimization; - const label = isOptimized ? `Optimized (${data.optimization!.steps} steps)` : "Standard"; - - // Load intervention runs - const runs = await mainApi.getInterventionRuns(data.id); - - return { - id: `graph-${idx}-${Date.now()}`, - dbId: data.id, - label, - data, - composerSelection: filterInterventableNodes(Object.keys(data.nodeImportance)), - interventionRuns: runs, - activeRunId: null, - }; - }), - ); + // Refetch graph data when normalize or ciThreshold changes (these affect server-side filtering) + async function refetchActiveGraphData() { + if (!activeCard || !activeGraph) return; + + const { normalizeEdges, ciThreshold } = activeGraph.viewSettings; + refetchingGraphId = activeGraph.id; + try { + const storedGraphs = await attrApi.getGraphs(activeCard.promptId, normalizeEdges, ciThreshold); + const matchingData = storedGraphs.find((g) => g.id === activeGraph.dbId); + + if (!matchingData) { + throw new Error("Could not find matching graph data after refetch"); + } + + promptCards = promptCards.map((card) => { + if (card.id !== activeCard.id) return card; + return { + ...card, + graphs: card.graphs.map((g) => { + if (g.id !== activeGraph.id) return g; + return { + ...g, + data: matchingData, + composerSelection: filterInterventableNodes(Object.keys(matchingData.nodeImportance)), + }; + }), + }; + }); + } catch (e) { + console.warn("Failed to refetch graph:", e); + } finally { + refetchingGraphId = null; + } + } + + function updateActiveGraphViewSettings(partial: Partial) { + if (!activeCard || !activeGraph) return; + + promptCards = promptCards.map((card) => { + if (card.id !== activeCard.id) return card; + return { + ...card, + graphs: card.graphs.map((g) => { + if (g.id !== activeGraph.id) return g; return { - ...card, - graphs, - activeGraphId: graphs.length > 0 ? graphs[0].id : null, + ...g, + viewSettings: { ...g.viewSettings, ...partial }, }; - } catch (e) { - console.warn("Failed to re-fetch graphs for card:", card.id, e); - return card; - } - }), - ); - promptCards = updatedCards; + }), + }; + }); + } + + async function handleNormalizeChange(value: attrApi.NormalizeType) { + updateActiveGraphViewSettings({ normalizeEdges: value }); + await refetchActiveGraphData(); + } + + async function handleCiThresholdChange(value: number) { + updateActiveGraphViewSettings({ ciThreshold: value }); + await refetchActiveGraphData(); + } + + function handleTopKChange(value: number) { + updateActiveGraphViewSettings({ topK: value }); + } + + function handleLayoutChange(value: "importance" | "shuffled" | "jittered") { + updateActiveGraphViewSettings({ nodeLayout: value }); + } + + function handleComponentGapChange(value: number) { + updateActiveGraphViewSettings({ componentGap: value }); + } + + function handleLayerGapChange(value: number) { + updateActiveGraphViewSettings({ layerGap: value }); } async function handleGeneratePrompts(nPrompts: number) { @@ -640,7 +696,8 @@ > L0: - {activeGraph.data.optimization.l0_total.toFixed(0)} active

{/if} @@ -658,25 +715,28 @@ {/if} (topK = v)} - onLayoutChange={(v) => (nodeLayout = v)} - onComponentGapChange={(v) => (componentGap = v)} - onLayerGapChange={(v) => (layerGap = v)} + normalizeEdges={activeGraph.viewSettings.normalizeEdges} + ciThreshold={activeGraph.viewSettings.ciThreshold} + ciThresholdLoading={refetchingGraphId === activeGraph.id} + onTopKChange={handleTopKChange} + onLayoutChange={handleLayoutChange} + onComponentGapChange={handleComponentGapChange} + onLayerGapChange={handleLayerGapChange} onNormalizeChange={handleNormalizeChange} + onCiThresholdChange={handleCiThresholdChange} /> {#key activeGraph.id} - import type { ActivationContextsSummary, ComponentDetail, ComponentSummary, OutputProbEntry } from "../../lib/localAttributionsTypes"; + import type { + ActivationContextsSummary, + ComponentDetail, + ComponentSummary, + OutputProbEntry, + } from "../../lib/localAttributionsTypes"; import ComponentNodeCard from "./ComponentNodeCard.svelte"; import OutputNodeCard from "./OutputNodeCard.svelte"; @@ -47,7 +52,9 @@ const inputToken = $derived.by(() => { if (!isWte) return null; if (hoveredNode.seqIdx >= tokens.length) { - throw new Error(`NodeTooltip: seqIdx ${hoveredNode.seqIdx} out of bounds for tokens length ${tokens.length}`); + throw new Error( + `NodeTooltip: seqIdx ${hoveredNode.seqIdx} out of bounds for tokens length ${tokens.length}`, + ); } return tokens[hoveredNode.seqIdx]; }); @@ -65,7 +72,8 @@
"{inputToken}"

- Position: {hoveredNode.seqIdx} + Position: + {hoveredNode.seqIdx}

{:else if isOutput} diff --git a/spd/app/frontend/src/components/local-attr/PromptCardHeader.svelte b/spd/app/frontend/src/components/local-attr/PromptCardHeader.svelte index b6a42eb3e..d307b2c94 100644 --- a/spd/app/frontend/src/components/local-attr/PromptCardHeader.svelte +++ b/spd/app/frontend/src/components/local-attr/PromptCardHeader.svelte @@ -47,23 +47,7 @@
- - ? + ?