diff --git a/src/dualip/run_solver.py b/src/dualip/run_solver.py index 31a3c1e..09889fa 100644 --- a/src/dualip/run_solver.py +++ b/src/dualip/run_solver.py @@ -1,4 +1,3 @@ -from dataclasses import fields from typing import Optional import torch @@ -11,36 +10,10 @@ from dualip.objectives.miplib import MIPLIB2017ObjectiveFunction from dualip.optimizers.agd import AcceleratedGradientDescent, SolverResult from dualip.types import ComputeArgs, ObjectiveArgs, SolverArgs +from dualip.utils.dist_utils import transfer_tensors_to_device from dualip.utils.mlflow_utils import MLflowConfig, log_hyperparameters, mlflow_run_context -def transfer_tensors_to_device(input_args: BaseInputArgs, device: str): - """ - Transfer all tensor fields in input_args to the specified device. - - Args: - input_args: The input arguments object - device: The target device (e.g., 'cuda:0', 'cpu') - - Returns: - A new instance of input_args with all tensors transferred to device - """ - # Get all field names from the dataclass - field_names = [field.name for field in fields(input_args)] - - # Create a dictionary of field values with tensors transferred to device - field_values = {} - for field_name in field_names: - value = getattr(input_args, field_name) - if isinstance(value, torch.Tensor): - field_values[field_name] = value.to(device) - else: - field_values[field_name] = value - - # Create a new instance of the same class with transferred tensors - return type(input_args)(**field_values) - - def build_objective( input_args: BaseInputArgs, solver_args: SolverArgs, compute_args: ComputeArgs, objective_args: ObjectiveArgs ): diff --git a/src/dualip/utils/dist_utils.py b/src/dualip/utils/dist_utils.py index 2249813..d3ef970 100644 --- a/src/dualip/utils/dist_utils.py +++ b/src/dualip/utils/dist_utils.py @@ -1,11 +1,34 @@ from collections import defaultdict +from dataclasses import fields import torch +from dualip.objectives.base import BaseInputArgs from dualip.projections.base import ProjectionEntry from dualip.utils.sparse_utils import split_csc_by_cols +def transfer_tensors_to_device(input_args: BaseInputArgs, device: str): + """ + Transfer all tensor fields in input_args to the specified device. + + Args: + input_args: The input arguments dataclass. + device: The target device (e.g., 'cuda:0', 'cpu'). + + Returns: + A new instance of the same dataclass with all tensors transferred to device. + """ + field_values = {} + for field in fields(input_args): + value = getattr(input_args, field.name) + if isinstance(value, torch.Tensor): + field_values[field.name] = value.to(device) + else: + field_values[field.name] = value + return type(input_args)(**field_values) + + def global_to_local_projection_map(global_map: dict[str, ProjectionEntry], local_cols: list[int]) -> dict[str, dict]: """ Given a global projection_map and the list of global col indices for a split,