Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions src/dualip/utils/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,22 +196,15 @@ def apply_F_to_columns(
# This is the highest number of non-zeroes of columns in the bucket
L = int(lengths.max().item())

# fixed prefix/flat-index logic:
prefix = torch.cat(
[
torch.tensor([0], device=device, dtype=lengths.dtype),
torch.cumsum(lengths[:-1], dim=0),
]
)

prefix_rep = prefix.repeat_interleave(lengths) # shape (total,)
idx_in_col = torch.arange(total, device=device) - prefix_rep
offs = starts.repeat_interleave(lengths)
flat_indices = offs + idx_in_col
# Compute cols_rep once, then derive all other indices via indexing
# (avoids 2 extra repeat_interleave calls and a torch.cat)
cols_rep = torch.arange(K, device=device).repeat_interleave(lengths) # (total,)
prefix = lengths.cumsum(0) - lengths # shape (K,), avoids torch.cat
idx_in_col = torch.arange(total, device=device) - prefix[cols_rep]
flat_indices = starts[cols_rep] + idx_in_col

# 2) build padded [L × K] block
block = torch.zeros((L, K), device=device, dtype=dtype)
cols_rep = torch.arange(K, device=device).repeat_interleave(lengths) # (total,)
block[idx_in_col, cols_rep] = vals[flat_indices]

# 3) apply the batched projection
Expand Down
120 changes: 119 additions & 1 deletion tests/test_sparse_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch

from dualip.utils.sparse_utils import hstack_csc, left_multiply_sparse, right_multiply_sparse, vstack_csc
from dualip.utils.sparse_utils import (
apply_F_to_columns,
hstack_csc,
left_multiply_sparse,
right_multiply_sparse,
vstack_csc,
)


def test_vstack_csc():
Expand Down Expand Up @@ -86,6 +92,118 @@ def test_right_multiply_sparse():
assert result_sparse.layout == torch.sparse_csc


class TestApplyFToColumns:
"""Tests for apply_F_to_columns with various bucket configurations."""

@staticmethod
def _apply_dense_columnwise(M_dense, F_batch):
"""Reference: apply F_batch to each column independently via dense ops."""
result = M_dense.clone()
for j in range(M_dense.shape[1]):
col = M_dense[:, j]
nonzero_mask = col != 0
if nonzero_mask.any():
block = col[nonzero_mask].unsqueeze(1)
proj = F_batch(block).squeeze(1)
result[:, j] = 0.0
result[:, j][nonzero_mask] = proj
return result

def test_identity_function(self):
"""F = identity should return the same matrix."""
M_dense = torch.tensor([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 5.0]])
M_sparse = M_dense.to_sparse_csc()
buckets = [torch.arange(M_dense.shape[1])]

result = apply_F_to_columns(M_sparse, lambda x: x, buckets)
assert torch.allclose(result.to_dense(), M_dense)

def test_scaling_function(self):
"""F = 2x should double all values."""
M_dense = torch.tensor([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 5.0]])
M_sparse = M_dense.to_sparse_csc()
buckets = [torch.arange(M_dense.shape[1])]

result = apply_F_to_columns(M_sparse, lambda x: 2 * x, buckets)
assert torch.allclose(result.to_dense(), 2 * M_dense)

def test_multiple_buckets(self):
"""Splitting columns across multiple buckets should give the same result."""
M_dense = torch.tensor(
[
[1.0, 0.0, 3.0, 0.0, 7.0],
[0.0, 2.0, 0.0, 4.0, 0.0],
[5.0, 0.0, 6.0, 0.0, 8.0],
]
)
M_sparse = M_dense.to_sparse_csc()

def f(x):
return x * 0.5

expected = M_dense * 0.5

single = apply_F_to_columns(M_sparse, f, [torch.arange(5)])
multi = apply_F_to_columns(M_sparse, f, [torch.tensor([0, 2, 4]), torch.tensor([1, 3])])

assert torch.allclose(single.to_dense(), expected)
assert torch.allclose(multi.to_dense(), expected)

def test_varying_column_lengths(self):
"""Columns with different numbers of nonzeros in the same bucket."""
M_dense = torch.tensor(
[
[1.0, 0.0, 3.0],
[2.0, 0.0, 0.0],
[3.0, 4.0, 0.0],
[4.0, 0.0, 0.0],
]
)
M_sparse = M_dense.to_sparse_csc()

def f(x):
return x**2

buckets = [torch.arange(3)]

result = apply_F_to_columns(M_sparse, f, buckets)
expected = self._apply_dense_columnwise(M_dense, f)
assert torch.allclose(result.to_dense(), expected)

def test_output_tensor(self):
"""Writing into a pre-allocated output tensor."""
M_dense = torch.tensor([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 5.0]])
M_sparse = M_dense.to_sparse_csc()
output = M_sparse.clone()
buckets = [torch.arange(3)]

apply_F_to_columns(M_sparse, lambda x: x * 3, buckets, output_tensor=output)
assert torch.allclose(output.to_dense(), 3 * M_dense)

def test_empty_bucket_skipped(self):
"""Empty buckets should be harmlessly skipped."""
M_dense = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
M_sparse = M_dense.to_sparse_csc()
buckets = [torch.tensor([], dtype=torch.long), torch.arange(2)]

result = apply_F_to_columns(M_sparse, lambda x: -x, buckets)
assert torch.allclose(result.to_dense(), -M_dense)

def test_clamp_projection(self):
"""Column-wise clamp to [0, inf) (ReLU-like) preserves sparsity pattern."""
M_dense = torch.tensor([[1.0, 0.0, -3.0], [0.0, -2.0, 0.0], [-4.0, 0.0, 5.0]])
M_sparse = M_dense.to_sparse_csc()

def f(x):
return x.clamp(min=0)

buckets = [torch.arange(3)]

result = apply_F_to_columns(M_sparse, f, buckets)
expected = self._apply_dense_columnwise(M_dense, f)
assert torch.allclose(result.to_dense(), expected)


def test_left_multiply_sparse():
"""Test left multiplication diag(v) @ M using dense tensor reference."""
# Create a sparse matrix and diagonal vector
Expand Down
Loading