diff --git a/src/dualip/utils/sparse_utils.py b/src/dualip/utils/sparse_utils.py index a09287d..f8c8af8 100644 --- a/src/dualip/utils/sparse_utils.py +++ b/src/dualip/utils/sparse_utils.py @@ -196,22 +196,15 @@ def apply_F_to_columns( # This is the highest number of non-zeroes of columns in the bucket L = int(lengths.max().item()) - # fixed prefix/flat-index logic: - prefix = torch.cat( - [ - torch.tensor([0], device=device, dtype=lengths.dtype), - torch.cumsum(lengths[:-1], dim=0), - ] - ) - - prefix_rep = prefix.repeat_interleave(lengths) # shape (total,) - idx_in_col = torch.arange(total, device=device) - prefix_rep - offs = starts.repeat_interleave(lengths) - flat_indices = offs + idx_in_col + # Compute cols_rep once, then derive all other indices via indexing + # (avoids 2 extra repeat_interleave calls and a torch.cat) + cols_rep = torch.arange(K, device=device).repeat_interleave(lengths) # (total,) + prefix = lengths.cumsum(0) - lengths # shape (K,), avoids torch.cat + idx_in_col = torch.arange(total, device=device) - prefix[cols_rep] + flat_indices = starts[cols_rep] + idx_in_col # 2) build padded [L × K] block block = torch.zeros((L, K), device=device, dtype=dtype) - cols_rep = torch.arange(K, device=device).repeat_interleave(lengths) # (total,) block[idx_in_col, cols_rep] = vals[flat_indices] # 3) apply the batched projection diff --git a/tests/test_sparse_utils.py b/tests/test_sparse_utils.py index 3100d3c..531237d 100644 --- a/tests/test_sparse_utils.py +++ b/tests/test_sparse_utils.py @@ -1,6 +1,12 @@ import torch -from dualip.utils.sparse_utils import hstack_csc, left_multiply_sparse, right_multiply_sparse, vstack_csc +from dualip.utils.sparse_utils import ( + apply_F_to_columns, + hstack_csc, + left_multiply_sparse, + right_multiply_sparse, + vstack_csc, +) def test_vstack_csc(): @@ -86,6 +92,118 @@ def test_right_multiply_sparse(): assert result_sparse.layout == torch.sparse_csc +class TestApplyFToColumns: + """Tests for apply_F_to_columns with various bucket configurations.""" + + @staticmethod + def _apply_dense_columnwise(M_dense, F_batch): + """Reference: apply F_batch to each column independently via dense ops.""" + result = M_dense.clone() + for j in range(M_dense.shape[1]): + col = M_dense[:, j] + nonzero_mask = col != 0 + if nonzero_mask.any(): + block = col[nonzero_mask].unsqueeze(1) + proj = F_batch(block).squeeze(1) + result[:, j] = 0.0 + result[:, j][nonzero_mask] = proj + return result + + def test_identity_function(self): + """F = identity should return the same matrix.""" + M_dense = torch.tensor([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 5.0]]) + M_sparse = M_dense.to_sparse_csc() + buckets = [torch.arange(M_dense.shape[1])] + + result = apply_F_to_columns(M_sparse, lambda x: x, buckets) + assert torch.allclose(result.to_dense(), M_dense) + + def test_scaling_function(self): + """F = 2x should double all values.""" + M_dense = torch.tensor([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 5.0]]) + M_sparse = M_dense.to_sparse_csc() + buckets = [torch.arange(M_dense.shape[1])] + + result = apply_F_to_columns(M_sparse, lambda x: 2 * x, buckets) + assert torch.allclose(result.to_dense(), 2 * M_dense) + + def test_multiple_buckets(self): + """Splitting columns across multiple buckets should give the same result.""" + M_dense = torch.tensor( + [ + [1.0, 0.0, 3.0, 0.0, 7.0], + [0.0, 2.0, 0.0, 4.0, 0.0], + [5.0, 0.0, 6.0, 0.0, 8.0], + ] + ) + M_sparse = M_dense.to_sparse_csc() + + def f(x): + return x * 0.5 + + expected = M_dense * 0.5 + + single = apply_F_to_columns(M_sparse, f, [torch.arange(5)]) + multi = apply_F_to_columns(M_sparse, f, [torch.tensor([0, 2, 4]), torch.tensor([1, 3])]) + + assert torch.allclose(single.to_dense(), expected) + assert torch.allclose(multi.to_dense(), expected) + + def test_varying_column_lengths(self): + """Columns with different numbers of nonzeros in the same bucket.""" + M_dense = torch.tensor( + [ + [1.0, 0.0, 3.0], + [2.0, 0.0, 0.0], + [3.0, 4.0, 0.0], + [4.0, 0.0, 0.0], + ] + ) + M_sparse = M_dense.to_sparse_csc() + + def f(x): + return x**2 + + buckets = [torch.arange(3)] + + result = apply_F_to_columns(M_sparse, f, buckets) + expected = self._apply_dense_columnwise(M_dense, f) + assert torch.allclose(result.to_dense(), expected) + + def test_output_tensor(self): + """Writing into a pre-allocated output tensor.""" + M_dense = torch.tensor([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 5.0]]) + M_sparse = M_dense.to_sparse_csc() + output = M_sparse.clone() + buckets = [torch.arange(3)] + + apply_F_to_columns(M_sparse, lambda x: x * 3, buckets, output_tensor=output) + assert torch.allclose(output.to_dense(), 3 * M_dense) + + def test_empty_bucket_skipped(self): + """Empty buckets should be harmlessly skipped.""" + M_dense = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + M_sparse = M_dense.to_sparse_csc() + buckets = [torch.tensor([], dtype=torch.long), torch.arange(2)] + + result = apply_F_to_columns(M_sparse, lambda x: -x, buckets) + assert torch.allclose(result.to_dense(), -M_dense) + + def test_clamp_projection(self): + """Column-wise clamp to [0, inf) (ReLU-like) preserves sparsity pattern.""" + M_dense = torch.tensor([[1.0, 0.0, -3.0], [0.0, -2.0, 0.0], [-4.0, 0.0, 5.0]]) + M_sparse = M_dense.to_sparse_csc() + + def f(x): + return x.clamp(min=0) + + buckets = [torch.arange(3)] + + result = apply_F_to_columns(M_sparse, f, buckets) + expected = self._apply_dense_columnwise(M_dense, f) + assert torch.allclose(result.to_dense(), expected) + + def test_left_multiply_sparse(): """Test left multiplication diag(v) @ M using dense tensor reference.""" # Create a sparse matrix and diagonal vector