diff --git a/src/dualip/utils/sparse_utils.py b/src/dualip/utils/sparse_utils.py index 36a0948..a09287d 100644 --- a/src/dualip/utils/sparse_utils.py +++ b/src/dualip/utils/sparse_utils.py @@ -285,9 +285,9 @@ def split_csc_by_cols(M: torch.Tensor, split_sizes: List[int]) -> List[torch.Ten end_nnz = int(col_ptr[end_col].item()) # slice out the per-block CSC arrays - sub_col_ptr = col_ptr[start_col : (end_col + 1)] - col_ptr[start_col] - sub_row_idx = row_idx[start_nnz:end_nnz] - sub_vals = vals[start_nnz:end_nnz] + sub_col_ptr = (col_ptr[start_col : (end_col + 1)] - col_ptr[start_col]).clone() + sub_row_idx = row_idx[start_nnz:end_nnz].clone() + sub_vals = vals[start_nnz:end_nnz].clone() # build the sub‐matrix M_block = torch.sparse_csc_tensor(sub_col_ptr, sub_row_idx, sub_vals, size=(m, width))