From d9f67b9dd81c7db51b37728a3629ee5e1d6e83af Mon Sep 17 00:00:00 2001 From: Aida Date: Mon, 9 Mar 2026 11:30:35 -0700 Subject: [PATCH] fix OOM error originating from data splitting --- src/dualip/utils/sparse_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dualip/utils/sparse_utils.py b/src/dualip/utils/sparse_utils.py index 36a0948..a09287d 100644 --- a/src/dualip/utils/sparse_utils.py +++ b/src/dualip/utils/sparse_utils.py @@ -285,9 +285,9 @@ def split_csc_by_cols(M: torch.Tensor, split_sizes: List[int]) -> List[torch.Ten end_nnz = int(col_ptr[end_col].item()) # slice out the per-block CSC arrays - sub_col_ptr = col_ptr[start_col : (end_col + 1)] - col_ptr[start_col] - sub_row_idx = row_idx[start_nnz:end_nnz] - sub_vals = vals[start_nnz:end_nnz] + sub_col_ptr = (col_ptr[start_col : (end_col + 1)] - col_ptr[start_col]).clone() + sub_row_idx = row_idx[start_nnz:end_nnz].clone() + sub_vals = vals[start_nnz:end_nnz].clone() # build the sub‐matrix M_block = torch.sparse_csc_tensor(sub_col_ptr, sub_row_idx, sub_vals, size=(m, width))