-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_decoder.py
More file actions
121 lines (96 loc) · 4 KB
/
plot_decoder.py
File metadata and controls
121 lines (96 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import matplotlib.pyplot as plt
import numpy as np
import torch
from ST import SparseTransformer
from SAE import SparseAutoencoder
def plot_decoder_matrix_columns(model, X, num_cols=10, figsize=(15, 3), fixed_indices=None):
"""
Plot columns from the V matrix of a Sparse Transformer model or the decoder matrix of a Sparse Autoencoder model trained on MNIST.
Args:
model: The trained Sparse Transformer or Sparse Autoencoder model
num_cols: Number of columns to plot
figsize: Figure size for the plot
fixed_indices: Fixed indices to use for selecting data
"""
with torch.no_grad():
# Get fixed batch
if fixed_indices is None:
fixed_indices = np.arange(4096) # Default fixed indices if not provided
x = X[fixed_indices].to(model.device)
# Forward pass to get V matrix or decoder matrix
if isinstance(model, SparseTransformer):
_, _, _, V = model(x)
elif isinstance(model, SparseAutoencoder):
model(x)
V = model.W_d.weight.T
else:
raise ValueError("Unsupported model type")
# Move V to CPU and convert to numpy
V = V.cpu().numpy()
# Ensure num_cols does not exceed the number of columns in V
num_cols = min(num_cols, V.shape[0])
# Determine grid size
n_rows = int(np.ceil(np.sqrt(num_cols)))
n_cols = int(np.ceil(num_cols / n_rows))
# Create figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
fig.suptitle('Selected Columns from Decoder Matrix Reshaped as MNIST Images')
# Flatten axes array for easy iteration
axes = axes.flatten()
# Select evenly spaced columns to display
step = V.shape[0] // num_cols
selected_indices = np.arange(0, V.shape[0], step)[:num_cols]
# Plot each selected column
for i, idx in enumerate(selected_indices):
# Reshape column to 28x28 image
img = V[idx].reshape(28, 28)
# Normalize to [0, 1] for visualization
img = (img - img.min()) / (img.max() - img.min())
# Plot
axes[i].imshow(img, cmap='gray')
axes[i].axis('off')
# Hide any unused subplots
for j in range(i + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=0.5)
plt.show()
def plot_decoder_matrix_stats(model, X, fixed_indices=None):
"""
Plot statistics about the V matrix or decoder matrix.
Args:
model: The trained Sparse Transformer or Sparse Autoencoder model
fixed_indices: Fixed indices to use for selecting data
"""
with torch.no_grad():
# Get fixed batch
if fixed_indices is None:
fixed_indices = np.arange(64) # Default fixed indices if not provided
x = X[fixed_indices].to(model.device)
# Forward pass to get V matrix or decoder matrix
if isinstance(model, SparseTransformer):
_, _, _, V = model(x)
elif isinstance(model, SparseAutoencoder):
model(x)
V = model.W_d.weight.T
else:
raise ValueError("Unsupported model type")
# Move V to CPU and convert to numpy
V = V.cpu().numpy()
# Calculate statistics
norms = np.linalg.norm(V, axis=1)
sparsity = np.mean(np.abs(V) <= 1e-3)
# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Plot histogram of column norms
ax1.hist(norms, bins=50)
ax1.set_title('Distribution of V Matrix Column Norms')
ax1.set_xlabel('L2 Norm')
ax1.set_ylabel('Count')
# Plot heatmap of V matrix values with discrete colormap
im = ax2.imshow(np.abs(V), aspect='auto', cmap='coolwarm', interpolation='nearest')
ax2.set_title(f'V Matrix Values (Sparsity: {sparsity:.2%})')
ax2.set_xlabel('Input Dimension')
ax2.set_ylabel('Feature Index')
plt.colorbar(im, ax=ax2)
plt.tight_layout()
plt.show()