Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 1 addition & 28 deletions src/dualip/run_solver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import fields
from typing import Optional

import torch
Expand All @@ -11,36 +10,10 @@
from dualip.objectives.miplib import MIPLIB2017ObjectiveFunction
from dualip.optimizers.agd import AcceleratedGradientDescent, SolverResult
from dualip.types import ComputeArgs, ObjectiveArgs, SolverArgs
from dualip.utils.dist_utils import transfer_tensors_to_device
from dualip.utils.mlflow_utils import MLflowConfig, log_hyperparameters, mlflow_run_context


def transfer_tensors_to_device(input_args: BaseInputArgs, device: str):
"""
Transfer all tensor fields in input_args to the specified device.

Args:
input_args: The input arguments object
device: The target device (e.g., 'cuda:0', 'cpu')

Returns:
A new instance of input_args with all tensors transferred to device
"""
# Get all field names from the dataclass
field_names = [field.name for field in fields(input_args)]

# Create a dictionary of field values with tensors transferred to device
field_values = {}
for field_name in field_names:
value = getattr(input_args, field_name)
if isinstance(value, torch.Tensor):
field_values[field_name] = value.to(device)
else:
field_values[field_name] = value

# Create a new instance of the same class with transferred tensors
return type(input_args)(**field_values)


def build_objective(
input_args: BaseInputArgs, solver_args: SolverArgs, compute_args: ComputeArgs, objective_args: ObjectiveArgs
):
Expand Down
23 changes: 23 additions & 0 deletions src/dualip/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
from collections import defaultdict
from dataclasses import fields

import torch

from dualip.objectives.base import BaseInputArgs
from dualip.projections.base import ProjectionEntry
from dualip.utils.sparse_utils import split_csc_by_cols


def transfer_tensors_to_device(input_args: BaseInputArgs, device: str):
"""
Transfer all tensor fields in input_args to the specified device.

Args:
input_args: The input arguments dataclass.
device: The target device (e.g., 'cuda:0', 'cpu').

Returns:
A new instance of the same dataclass with all tensors transferred to device.
"""
field_values = {}
for field in fields(input_args):
value = getattr(input_args, field.name)
if isinstance(value, torch.Tensor):
field_values[field.name] = value.to(device)
else:
field_values[field.name] = value
return type(input_args)(**field_values)


def global_to_local_projection_map(global_map: dict[str, ProjectionEntry], local_cols: list[int]) -> dict[str, dict]:
"""
Given a global projection_map and the list of global col indices for a split,
Expand Down
Loading