From 76d62837a8cd1a7fba989180e5310cc3209a0a93 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Thu, 13 Nov 2025 14:35:50 +0100 Subject: [PATCH 1/3] Add checkpoint/resume functionality to EMD solver Implements pause and resume capabilities for the EMD (Earth Mover's Distance) solver, allowing long-running optimizations to be interrupted and continued from their exact state. Changes: - Modified EMD_wrap() signature to accept checkpoint parameters for saving and restoring complete internal solver state (flow, potentials, tree structure) - Added saveCheckpoint(), restoreCheckpoint(), and runFromCheckpoint() methods to NetworkSimplexSimple class in network_simplex_simple.h - Extended emd_c() Cython wrapper to handle checkpoint dictionaries with 12 fields (10 arrays + 2 scalar arc counts) - Added 'checkpoint' and 'return_checkpoint' parameters to emd() Python API - Includes search_arc_num and all_arc_num scalars in checkpoint to preserve initialization state required by start() method Tests: - Added test_emd_checkpoint() for basic save/resume functionality - Added test_emd_checkpoint_multiple() for multiple pause/resume cycles - Added test_emd_checkpoint_structure() to verify checkpoint field integrity --- ot/lp/EMD.h | 11 ++- ot/lp/EMD_wrapper.cpp | 46 +++++++-- ot/lp/_network_simplex.py | 68 ++++++++++++- ot/lp/emd_wrap.pyx | 169 ++++++++++++++++++++++++++++----- ot/lp/network_simplex_simple.h | 110 +++++++++++++++++++++ test/test_ot.py | 88 +++++++++++++++++ 6 files changed, 458 insertions(+), 34 deletions(-) diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index b56f0601b..0bd655236 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -29,7 +29,16 @@ enum ProblemType { MAX_ITER_REACHED }; -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); +int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, uint64_t maxIter, + int resume_mode=0, int return_checkpoint=0, + double* flow_state=nullptr, double* pi_state=nullptr, + signed char* state_state=nullptr, int* parent_state=nullptr, + int64_t* pred_state=nullptr, int* thread_state=nullptr, + int* rev_thread_state=nullptr, int* succ_num_state=nullptr, + int* last_succ_state=nullptr, signed char* forward_state=nullptr, + int64_t* search_arc_num_out=nullptr, int64_t* all_arc_num_out=nullptr); + int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 4aa5a6e72..e54dd0776 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -20,7 +20,14 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, - double* alpha, double* beta, double *cost, uint64_t maxIter) { + double* alpha, double* beta, double *cost, uint64_t maxIter, + int resume_mode, int return_checkpoint, + double* flow_state, double* pi_state, signed char* state_state, + int* parent_state, int64_t* pred_state, + int* thread_state, int* rev_thread_state, + int* succ_num_state, int* last_succ_state, + signed char* forward_state, + int64_t* search_arc_num_out, int64_t* all_arc_num_out) { // beware M and C are stored in row major C style!!! using namespace lemon; @@ -93,8 +100,29 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, // Solve the problem with the network simplex algorithm - - int ret=net.run(); + // If resume_mode=1 and checkpoint data provided, resume from checkpoint + // Otherwise do normal run + + int64_t search_arc_num_in = 0, all_arc_num_in = 0; + if (resume_mode == 1 && search_arc_num_out != nullptr && all_arc_num_out != nullptr) { + search_arc_num_in = *search_arc_num_out; + all_arc_num_in = *all_arc_num_out; + } + + int ret; + if (resume_mode == 1 && flow_state != nullptr) { + // Resume from checkpoint + ret = net.runFromCheckpoint( + flow_state, pi_state, state_state, + parent_state, pred_state, + thread_state, rev_thread_state, + succ_num_state, last_succ_state, forward_state, + search_arc_num_in, all_arc_num_in); + } else { + // Normal run + ret = net.run(); + } + uint64_t i, j; if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { *cost = 0; @@ -111,6 +139,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, } + // Save checkpoint if requested and arrays provided + if (return_checkpoint == 1 && flow_state != nullptr) { + net.saveCheckpoint( + flow_state, pi_state, state_state, + parent_state, pred_state, + thread_state, rev_thread_state, + succ_num_state, last_succ_state, forward_state, + search_arc_num_out, all_arc_num_out); + } return ret; } @@ -118,9 +155,6 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, - - - int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) { // beware M and C are stored in row major C style!!! diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 492e4c7ac..effc03486 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -172,6 +172,8 @@ def emd( center_dual=True, numThreads=1, check_marginals=True, + checkpoint=None, + return_checkpoint=False, ): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -232,6 +234,15 @@ def emd( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. + checkpoint: dict, optional (default=None) + Checkpoint data from a previous emd() call to resume computation. + The checkpoint must contain internal solver state including flow, + potentials, and tree structure. Obtain by calling emd() with + return_checkpoint=True. + return_checkpoint: bool, optional (default=False) + If True and log=True, includes complete internal solver state in the + returned log dictionary for checkpointing. This enables pausing and + resuming the optimization. Returns @@ -241,7 +252,8 @@ def emd( parameters log: dict, optional If input log is true, a dictionary containing the - cost and dual variables and exit status + cost and dual variables and exit status. If return_checkpoint=True, + also contains internal solver state for resuming computation. Examples @@ -321,7 +333,43 @@ def emd( numThreads = check_number_threads(numThreads) - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + checkpoint_data = None + if checkpoint is not None: + # Extract checkpoint arrays and convert to numpy (strip leading underscore) + checkpoint_data = { + "flow": nx.to_numpy(checkpoint["_flow"]) if "_flow" in checkpoint else None, + "pi": nx.to_numpy(checkpoint["_pi"]) if "_pi" in checkpoint else None, + "state": nx.to_numpy(checkpoint["_state"]) + if "_state" in checkpoint + else None, + "parent": nx.to_numpy(checkpoint["_parent"]) + if "_parent" in checkpoint + else None, + "pred": nx.to_numpy(checkpoint["_pred"]) if "_pred" in checkpoint else None, + "thread": nx.to_numpy(checkpoint["_thread"]) + if "_thread" in checkpoint + else None, + "rev_thread": nx.to_numpy(checkpoint["_rev_thread"]) + if "_rev_thread" in checkpoint + else None, + "succ_num": nx.to_numpy(checkpoint["_succ_num"]) + if "_succ_num" in checkpoint + else None, + "last_succ": nx.to_numpy(checkpoint["_last_succ"]) + if "_last_succ" in checkpoint + else None, + "forward": nx.to_numpy(checkpoint["_forward"]) + if "_forward" in checkpoint + else None, + "search_arc_num": int(checkpoint.get("search_arc_num", 0)), + "all_arc_num": int(checkpoint.get("all_arc_num", 0)), + } + # Filter out None values + checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None} + + G, cost, u, v, result_code, checkpoint_out = emd_c( + a, b, M, numItermax, numThreads, checkpoint_data, int(return_checkpoint) + ) if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -345,6 +393,22 @@ def emd( log["v"] = nx.from_numpy(v, type_as=type_as) log["warning"] = result_code_string log["result_code"] = result_code + + # Add checkpoint data if requested (preserve original dtypes, don't cast) + if return_checkpoint and checkpoint_out is not None: + log["_flow"] = checkpoint_out["flow"] + log["_pi"] = checkpoint_out["pi"] + log["_state"] = checkpoint_out["state"] + log["_parent"] = checkpoint_out["parent"] + log["_pred"] = checkpoint_out["pred"] + log["_thread"] = checkpoint_out["thread"] + log["_rev_thread"] = checkpoint_out["rev_thread"] + log["_succ_num"] = checkpoint_out["succ_num"] + log["_last_succ"] = checkpoint_out["last_succ"] + log["_forward"] = checkpoint_out["forward"] + log["search_arc_num"] = int(checkpoint_out["search_arc_num"]) + log["all_arc_num"] = int(checkpoint_out["all_arc_num"]) + return nx.from_numpy(G, type_as=type_as), log return nx.from_numpy(G, type_as=type_as) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 53df54fc3..c99cdc011 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -14,13 +14,21 @@ from ..utils import dist cimport cython cimport libc.math as math -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, int64_t import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, uint64_t maxIter, + int resume_mode, int return_checkpoint, + double* flow_state, double* pi_state, signed char* state_state, + int* parent_state, int64_t* pred_state, + int* thread_state, int* rev_thread_state, + int* succ_num_state, int* last_succ_state, + signed char* forward_state, + int64_t* search_arc_num_out, int64_t* all_arc_num_out) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -40,9 +48,16 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, + np.ndarray[double, ndim=1, mode="c"] b, + np.ndarray[double, ndim=2, mode="c"] M, + uint64_t max_iter, + int numThreads, + checkpoint_in=None, + int return_checkpoint=0): """ Solves the Earth Movers distance problem and returns the optimal transport matrix + with optional checkpoint support for pause/resume. gamm=emd(a,b,M) @@ -79,43 +94,147 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod max_iter : uint64_t The maximum number of iterations before stopping the optimization algorithm if it has not converged. + numThreads : int + Number of threads for parallel computation (1 = no OpenMP) + checkpoint_in : dict or None + Checkpoint data to resume from. Should contain flow, pi, state, parent, + pred, thread, rev_thread, succ_num, last_succ, forward arrays. + return_checkpoint : int + If 1, returns checkpoint data; if 0, returns None for checkpoint. Returns ------- gamma: (ns x nt) numpy.ndarray Optimal transportation matrix for the given parameters + cost : float + Optimal transport cost + alpha : (ns,) numpy.ndarray + Source dual potentials + beta : (nt,) numpy.ndarray + Target dual potentials + result_code : int + Result code (OPTIMAL, INFEASIBLE, UNBOUNDED, MAX_ITER_REACHED) + checkpoint_out : dict or None + Checkpoint data if return_checkpoint=1, None otherwise """ - cdef int n1= M.shape[0] - cdef int n2= M.shape[1] - cdef int nmax=n1+n2-1 + cdef int n1 = M.shape[0] + cdef int n2 = M.shape[1] + cdef int all_nodes = n1 + n2 + 1 + cdef int64_t max_arcs = n1 * n2 + 2 * (n1 + n2) cdef int result_code = 0 - cdef int nG=0 - - cdef double cost=0 - cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) - cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) - - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) - - cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) + cdef double cost = 0 + cdef int64_t search_arc_num = 0 + cdef int64_t all_arc_num = 0 + + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2]) + + # Checkpoint arrays (for both input and output) + cdef np.ndarray[double, ndim=1, mode="c"] flow_state + cdef np.ndarray[double, ndim=1, mode="c"] pi_state + cdef np.ndarray[signed char, ndim=1, mode="c"] state_state + cdef np.ndarray[int, ndim=1, mode="c"] parent_state + cdef np.ndarray[int64_t, ndim=1, mode="c"] pred_state + cdef np.ndarray[int, ndim=1, mode="c"] thread_state + cdef np.ndarray[int, ndim=1, mode="c"] rev_thread_state + cdef np.ndarray[int, ndim=1, mode="c"] succ_num_state + cdef np.ndarray[int, ndim=1, mode="c"] last_succ_state + cdef np.ndarray[signed char, ndim=1, mode="c"] forward_state + + cdef int resume_mode = 0 if not len(a): - a=np.ones((n1,))/n1 + a = np.ones((n1,)) / n1 if not len(b): - b=np.ones((n2,))/n2 - - # init OT matrix - G=np.zeros([n1, n2]) - - # calling the function + b = np.ones((n2,)) / n2 + + # Prepare checkpoint arrays + if checkpoint_in is not None: + resume_mode = 1 + flow_state = np.asarray(checkpoint_in['flow'], dtype=np.float64, order='C') + pi_state = np.asarray(checkpoint_in['pi'], dtype=np.float64, order='C') + state_state = np.asarray(checkpoint_in['state'], dtype=np.int8, order='C') + parent_state = np.asarray(checkpoint_in['parent'], dtype=np.int32, order='C') + pred_state = np.asarray(checkpoint_in['pred'], dtype=np.int64, order='C') + thread_state = np.asarray(checkpoint_in['thread'], dtype=np.int32, order='C') + rev_thread_state = np.asarray(checkpoint_in['rev_thread'], dtype=np.int32, order='C') + + # Sanity check: array sizes must match expected sizes + if flow_state.shape[0] != max_arcs or pi_state.shape[0] != all_nodes: + raise ValueError( + f"Checkpoint size mismatch: expected flow={max_arcs}, pi={all_nodes}, " + f"got flow={flow_state.shape[0]}, pi={pi_state.shape[0]}" + ) + succ_num_state = np.asarray(checkpoint_in['succ_num'], dtype=np.int32, order='C') + last_succ_state = np.asarray(checkpoint_in['last_succ'], dtype=np.int32, order='C') + forward_state = np.asarray(checkpoint_in['forward'], dtype=np.int8, order='C') + + # Extract the arc counts + search_arc_num = checkpoint_in['search_arc_num'] + all_arc_num = checkpoint_in['all_arc_num'] + else: + # Allocate empty arrays (will be filled if return_checkpoint=1) + flow_state = np.zeros(max_arcs, dtype=np.float64) + pi_state = np.zeros(all_nodes, dtype=np.float64) + state_state = np.zeros(max_arcs, dtype=np.int8) + parent_state = np.zeros(all_nodes, dtype=np.int32) + pred_state = np.zeros(all_nodes, dtype=np.int64) + thread_state = np.zeros(all_nodes, dtype=np.int32) + rev_thread_state = np.zeros(all_nodes, dtype=np.int32) + succ_num_state = np.zeros(all_nodes, dtype=np.int32) + last_succ_state = np.zeros(all_nodes, dtype=np.int32) + forward_state = np.zeros(all_nodes, dtype=np.int8) + + # Call C++ function with checkpoint support with nogil: if numThreads == 1: - result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + result_code = EMD_wrap( + n1, n2, + a.data, b.data, M.data, + G.data, alpha.data, beta.data, + &cost, max_iter, + resume_mode, return_checkpoint, + flow_state.data, + pi_state.data, + state_state.data, + parent_state.data, + pred_state.data, + thread_state.data, + rev_thread_state.data, + succ_num_state.data, + last_succ_state.data, + forward_state.data, + &search_arc_num, + &all_arc_num + ) else: - result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, numThreads) - return G, cost, alpha, beta, result_code + # For now, OpenMP version falls back to regular (not implemented yet) + result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, + G.data, alpha.data, beta.data, + &cost, max_iter, numThreads) + + # Build checkpoint output dict if requested + checkpoint_out = None + if return_checkpoint: + checkpoint_out = { + 'flow': flow_state, + 'pi': pi_state, + 'state': state_state, + 'parent': parent_state, + 'pred': pred_state, + 'thread': thread_state, + 'rev_thread': rev_thread_state, + 'succ_num': succ_num_state, + 'last_succ': last_succ_state, + 'forward': forward_state, + 'search_arc_num': search_arc_num, + 'all_arc_num': all_arc_num, + } + + return G, cost, alpha, beta, result_code, checkpoint_out @cython.boundscheck(False) diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 9612a8a24..388851f88 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -941,6 +941,116 @@ namespace lemon { } } + + /// This function saves the complete internal state of the solver, + /// including flow values, dual potentials, arc states, and the + /// spanning tree structure. This allows pausing and resuming + /// the optimization later. + + void saveCheckpoint( + double* flow_out, + double* pi_out, + signed char* state_out, + int* parent_out, + ArcsType* pred_out, + int* thread_out, + int* rev_thread_out, + int* succ_num_out, + int* last_succ_out, + signed char* forward_out, + ArcsType* search_arc_num_out, + ArcsType* all_arc_num_out) + { + // Copy internal state to output arrays + std::copy(_flow.begin(), _flow.end(), flow_out); + std::copy(_pi.begin(), _pi.end(), pi_out); + std::copy(_state.begin(), _state.end(), state_out); + std::copy(_parent.begin(), _parent.end(), parent_out); + std::copy(_pred.begin(), _pred.end(), pred_out); + std::copy(_thread.begin(), _thread.end(), thread_out); + std::copy(_rev_thread.begin(), _rev_thread.end(), rev_thread_out); + std::copy(_succ_num.begin(), _succ_num.end(), succ_num_out); + std::copy(_last_succ.begin(), _last_succ.end(), last_succ_out); + + // Convert bool vector to signed char + for (size_t i = 0; i < _forward.size(); i++) { + forward_out[i] = _forward[i] ? 1 : 0; + } + + // Save arc counts needed for start() + *search_arc_num_out = _search_arc_num; + *all_arc_num_out = _all_arc_num; + } + + + /// This function restores the complete internal state of the solver + /// from a previously saved checkpoint. + + void restoreCheckpoint( + double* flow_in, + double* pi_in, + signed char* state_in, + int* parent_in, + ArcsType* pred_in, + int* thread_in, + int* rev_thread_in, + int* succ_num_in, + int* last_succ_in, + signed char* forward_in, + ArcsType search_arc_num_in, + ArcsType all_arc_num_in) + { + // Copy from input arrays to internal state + std::copy(flow_in, flow_in + _flow.size(), _flow.begin()); + std::copy(pi_in, pi_in + _pi.size(), _pi.begin()); + std::copy(state_in, state_in + _state.size(), _state.begin()); + std::copy(parent_in, parent_in + _parent.size(), _parent.begin()); + std::copy(pred_in, pred_in + _pred.size(), _pred.begin()); + std::copy(thread_in, thread_in + _thread.size(), _thread.begin()); + std::copy(rev_thread_in, rev_thread_in + _rev_thread.size(), _rev_thread.begin()); + std::copy(succ_num_in, succ_num_in + _succ_num.size(), _succ_num.begin()); + std::copy(last_succ_in, last_succ_in + _last_succ.size(), _last_succ.begin()); + + // Convert signed char to bool vector + for (size_t i = 0; i < _forward.size(); i++) { + _forward[i] = (forward_in[i] != 0); + } + + // Restore root (it's always _node_num) + _root = _node_num; + + // Restore arc counts needed by start() + _search_arc_num = search_arc_num_in; + _all_arc_num = all_arc_num_in; + } + + + /// This function restores the solver state from a checkpoint and + /// continues the optimization from that point. It skips the normal + /// initialization phase and goes directly to the simplex iterations. + + ProblemType runFromCheckpoint( + double* flow_in, + double* pi_in, + signed char* state_in, + int* parent_in, + ArcsType* pred_in, + int* thread_in, + int* rev_thread_in, + int* succ_num_in, + int* last_succ_in, + signed char* forward_in, + ArcsType search_arc_num_in, + ArcsType all_arc_num_in) + { + // Restore state from checkpoint + restoreCheckpoint(flow_in, pi_in, state_in, parent_in, pred_in, + thread_in, rev_thread_in, succ_num_in, last_succ_in, forward_in, + search_arc_num_in, all_arc_num_in); + + return start(); + } + /// @} private: diff --git a/test/test_ot.py b/test/test_ot.py index e8217d54d..b32f97ec4 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -914,6 +914,94 @@ def test_dual_variables(): assert constraint_violation.max() < 1e-8 +def test_emd_checkpoint(): + # test checkpoint save and resume + n = 50 + rng = np.random.RandomState(42) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = rng.rand(n, n) + + G_ref, log_ref = ot.emd(a, b, M, numItermax=10000, log=True) + + G1, log1 = ot.emd(a, b, M, numItermax=500, log=True, return_checkpoint=True) + + if log1["result_code"] == 3: # MAX_ITER_REACHED ? + G2, log2 = ot.emd( + a, b, M, numItermax=10000, log=True, checkpoint=log1, return_checkpoint=True + ) + + np.testing.assert_allclose(log2["cost"], log_ref["cost"], rtol=1e-6) + np.testing.assert_allclose(G2, G_ref, rtol=1e-6) + + +def test_emd_checkpoint_multiple(): + # test multiple checkpoint cycles + n = 100 + rng = np.random.RandomState(123) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = rng.rand(n, n) + + G_ref, log_ref = ot.emd(a, b, M, numItermax=50000, log=True) + + # multiple checkpoint phases with increasing iteration budgets + max_iters = [100, 300, 600, 1000] + checkpoint = None + costs = [] + + for max_iter in max_iters: + G, log = ot.emd( + a, + b, + M, + numItermax=max_iter, + log=True, + checkpoint=checkpoint, + return_checkpoint=True, + ) + costs.append(log["cost"]) + + if log["result_code"] != 3: # converged + break + checkpoint = log + + # check cost decreases monotonically + for i in range(len(costs) - 1): + assert costs[i + 1] <= costs[i] + + # check final result matches reference + np.testing.assert_allclose(log["cost"], log_ref["cost"], rtol=1e-5) + + +def test_emd_checkpoint_structure(): + # test that checkpoint contains all required fields + n = 10 + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = np.random.rand(n, n) + + G, log = ot.emd(a, b, M, numItermax=10, log=True, return_checkpoint=True) + + required_fields = [ + "_flow", + "_pi", + "_state", + "_parent", + "_pred", + "_thread", + "_rev_thread", + "_succ_num", + "_last_succ", + "_forward", + "_search_arc_num", + "_all_arc_num", + ] + + for field in required_fields: + assert field in log, f"Missing checkpoint field: {field}" + + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) # Check that dual and primal cost are equal From 31a382d92e35f4d21e46bc02a064f231c2d87278 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 17 Nov 2025 10:12:02 +0100 Subject: [PATCH 2/3] Add checkpoint/resume functionality to EMD solver : added parameter declaration to only have one now --- ot/lp/_network_simplex.py | 95 +++++++++++++++++++++++++-------------- test/test_ot.py | 44 ++++++++++-------- 2 files changed, 88 insertions(+), 51 deletions(-) diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index effc03486..6e93f2279 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -172,8 +172,7 @@ def emd( center_dual=True, numThreads=1, check_marginals=True, - checkpoint=None, - return_checkpoint=False, + warm_start=False, ): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -234,15 +233,10 @@ def emd( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. - checkpoint: dict, optional (default=None) - Checkpoint data from a previous emd() call to resume computation. - The checkpoint must contain internal solver state including flow, - potentials, and tree structure. Obtain by calling emd() with - return_checkpoint=True. - return_checkpoint: bool, optional (default=False) - If True and log=True, includes complete internal solver state in the - returned log dictionary for checkpointing. This enables pausing and - resuming the optimization. + warm_start: bool or dict, optional (default=False) + If True, returns warm start data in the log for resuming computation. + If dict (from previous call with warm_start=True), resumes optimization + from the provided state. Requires log=True when saving state. Returns @@ -252,7 +246,7 @@ def emd( parameters log: dict, optional If input log is true, a dictionary containing the - cost and dual variables and exit status. If return_checkpoint=True, + cost and dual variables and exit status. If warm_start=True, also contains internal solver state for resuming computation. @@ -270,6 +264,13 @@ def emd( array([[0.5, 0. ], [0. , 0.5]]) + Warm start example for resuming optimization: + + >>> # First call - save warm start data + >>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True) + >>> # Resume from warm start + >>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log) + .. _references-emd: References @@ -333,39 +334,67 @@ def emd( numThreads = check_number_threads(numThreads) + # Handle warm_start parameter checkpoint_data = None - if checkpoint is not None: - # Extract checkpoint arrays and convert to numpy (strip leading underscore) + return_checkpoint = False + + if isinstance(warm_start, dict): + # Resume from previous warm_start dict checkpoint_data = { - "flow": nx.to_numpy(checkpoint["_flow"]) if "_flow" in checkpoint else None, - "pi": nx.to_numpy(checkpoint["_pi"]) if "_pi" in checkpoint else None, - "state": nx.to_numpy(checkpoint["_state"]) - if "_state" in checkpoint + "flow": nx.to_numpy(warm_start.get("_flow", warm_start.get("flow"))) + if ("_flow" in warm_start or "flow" in warm_start) else None, - "parent": nx.to_numpy(checkpoint["_parent"]) - if "_parent" in checkpoint + "pi": nx.to_numpy(warm_start.get("_pi", warm_start.get("pi"))) + if ("_pi" in warm_start or "pi" in warm_start) else None, - "pred": nx.to_numpy(checkpoint["_pred"]) if "_pred" in checkpoint else None, - "thread": nx.to_numpy(checkpoint["_thread"]) - if "_thread" in checkpoint + "state": nx.to_numpy(warm_start.get("_state", warm_start.get("state"))) + if ("_state" in warm_start or "state" in warm_start) else None, - "rev_thread": nx.to_numpy(checkpoint["_rev_thread"]) - if "_rev_thread" in checkpoint + "parent": nx.to_numpy(warm_start.get("_parent", warm_start.get("parent"))) + if ("_parent" in warm_start or "parent" in warm_start) else None, - "succ_num": nx.to_numpy(checkpoint["_succ_num"]) - if "_succ_num" in checkpoint + "pred": nx.to_numpy(warm_start.get("_pred", warm_start.get("pred"))) + if ("_pred" in warm_start or "pred" in warm_start) else None, - "last_succ": nx.to_numpy(checkpoint["_last_succ"]) - if "_last_succ" in checkpoint + "thread": nx.to_numpy(warm_start.get("_thread", warm_start.get("thread"))) + if ("_thread" in warm_start or "thread" in warm_start) else None, - "forward": nx.to_numpy(checkpoint["_forward"]) - if "_forward" in checkpoint + "rev_thread": nx.to_numpy( + warm_start.get("_rev_thread", warm_start.get("rev_thread")) + ) + if ("_rev_thread" in warm_start or "rev_thread" in warm_start) + else None, + "succ_num": nx.to_numpy( + warm_start.get("_succ_num", warm_start.get("succ_num")) + ) + if ("_succ_num" in warm_start or "succ_num" in warm_start) else None, - "search_arc_num": int(checkpoint.get("search_arc_num", 0)), - "all_arc_num": int(checkpoint.get("all_arc_num", 0)), + "last_succ": nx.to_numpy( + warm_start.get("_last_succ", warm_start.get("last_succ")) + ) + if ("_last_succ" in warm_start or "last_succ" in warm_start) + else None, + "forward": nx.to_numpy( + warm_start.get("_forward", warm_start.get("forward")) + ) + if ("_forward" in warm_start or "forward" in warm_start) + else None, + "search_arc_num": int( + warm_start.get("search_arc_num", warm_start.get("_search_arc_num", 0)) + ), + "all_arc_num": int( + warm_start.get("all_arc_num", warm_start.get("_all_arc_num", 0)) + ), } # Filter out None values checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None} + elif warm_start is True: + # Save warm_start data - requires log=True + if not log: + raise ValueError( + "warm_start=True requires log=True to return the warm start data" + ) + return_checkpoint = True G, cost, u, v, result_code, checkpoint_out = emd_c( a, b, M, numItermax, numThreads, checkpoint_data, int(return_checkpoint) diff --git a/test/test_ot.py b/test/test_ot.py index b32f97ec4..1a37dffbf 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -924,12 +924,10 @@ def test_emd_checkpoint(): G_ref, log_ref = ot.emd(a, b, M, numItermax=10000, log=True) - G1, log1 = ot.emd(a, b, M, numItermax=500, log=True, return_checkpoint=True) + G1, log1 = ot.emd(a, b, M, numItermax=500, log=True, warm_start=True) if log1["result_code"] == 3: # MAX_ITER_REACHED ? - G2, log2 = ot.emd( - a, b, M, numItermax=10000, log=True, checkpoint=log1, return_checkpoint=True - ) + G2, log2 = ot.emd(a, b, M, numItermax=10000, log=True, warm_start=log1) np.testing.assert_allclose(log2["cost"], log_ref["cost"], rtol=1e-6) np.testing.assert_allclose(G2, G_ref, rtol=1e-6) @@ -947,24 +945,34 @@ def test_emd_checkpoint_multiple(): # multiple checkpoint phases with increasing iteration budgets max_iters = [100, 300, 600, 1000] - checkpoint = None + warm_start_data = None costs = [] for max_iter in max_iters: - G, log = ot.emd( - a, - b, - M, - numItermax=max_iter, - log=True, - checkpoint=checkpoint, - return_checkpoint=True, - ) + if warm_start_data is None: + G, log = ot.emd( + a, + b, + M, + numItermax=max_iter, + log=True, + warm_start=True, + ) + else: + G, log = ot.emd( + a, + b, + M, + numItermax=max_iter, + log=True, + warm_start=warm_start_data, + ) costs.append(log["cost"]) if log["result_code"] != 3: # converged break - checkpoint = log + # Only use warm_start if checkpoint fields are present + warm_start_data = log if "_flow" in log else None # check cost decreases monotonically for i in range(len(costs) - 1): @@ -981,7 +989,7 @@ def test_emd_checkpoint_structure(): b = ot.utils.unif(n) M = np.random.rand(n, n) - G, log = ot.emd(a, b, M, numItermax=10, log=True, return_checkpoint=True) + G, log = ot.emd(a, b, M, numItermax=10, log=True, warm_start=True) required_fields = [ "_flow", @@ -994,8 +1002,8 @@ def test_emd_checkpoint_structure(): "_succ_num", "_last_succ", "_forward", - "_search_arc_num", - "_all_arc_num", + "search_arc_num", # scalars don't have underscore prefix + "all_arc_num", ] for field in required_fields: From b6736b54de15cbbd20f18fe00cb795e19e3b23fc Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 17 Nov 2025 10:22:11 +0100 Subject: [PATCH 3/3] One field for checkpoint, cleaner interface --- ot/lp/_network_simplex.py | 90 +++++++++++++++++++-------------------- test/test_ot.py | 33 ++++++++------ 2 files changed, 64 insertions(+), 59 deletions(-) diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 6e93f2279..9ee21bbcc 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -247,7 +247,8 @@ def emd( log: dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status. If warm_start=True, - also contains internal solver state for resuming computation. + also contains a "checkpoint" key with the internal solver state + for resuming computation. Examples @@ -268,6 +269,7 @@ def emd( >>> # First call - save warm start data >>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True) + >>> # log["checkpoint"] contains the solver state >>> # Resume from warm start >>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log) @@ -340,51 +342,47 @@ def emd( if isinstance(warm_start, dict): # Resume from previous warm_start dict + # Check if checkpoint is nested under "checkpoint" key or at top level + if "checkpoint" in warm_start: + chkpt = warm_start["checkpoint"] + else: + chkpt = warm_start + checkpoint_data = { - "flow": nx.to_numpy(warm_start.get("_flow", warm_start.get("flow"))) - if ("_flow" in warm_start or "flow" in warm_start) + "flow": nx.to_numpy(chkpt.get("flow", chkpt.get("_flow"))) + if ("flow" in chkpt or "_flow" in chkpt) else None, - "pi": nx.to_numpy(warm_start.get("_pi", warm_start.get("pi"))) - if ("_pi" in warm_start or "pi" in warm_start) + "pi": nx.to_numpy(chkpt.get("pi", chkpt.get("_pi"))) + if ("pi" in chkpt or "_pi" in chkpt) else None, - "state": nx.to_numpy(warm_start.get("_state", warm_start.get("state"))) - if ("_state" in warm_start or "state" in warm_start) + "state": nx.to_numpy(chkpt.get("state", chkpt.get("_state"))) + if ("state" in chkpt or "_state" in chkpt) else None, - "parent": nx.to_numpy(warm_start.get("_parent", warm_start.get("parent"))) - if ("_parent" in warm_start or "parent" in warm_start) + "parent": nx.to_numpy(chkpt.get("parent", chkpt.get("_parent"))) + if ("parent" in chkpt or "_parent" in chkpt) else None, - "pred": nx.to_numpy(warm_start.get("_pred", warm_start.get("pred"))) - if ("_pred" in warm_start or "pred" in warm_start) + "pred": nx.to_numpy(chkpt.get("pred", chkpt.get("_pred"))) + if ("pred" in chkpt or "_pred" in chkpt) else None, - "thread": nx.to_numpy(warm_start.get("_thread", warm_start.get("thread"))) - if ("_thread" in warm_start or "thread" in warm_start) + "thread": nx.to_numpy(chkpt.get("thread", chkpt.get("_thread"))) + if ("thread" in chkpt or "_thread" in chkpt) else None, - "rev_thread": nx.to_numpy( - warm_start.get("_rev_thread", warm_start.get("rev_thread")) - ) - if ("_rev_thread" in warm_start or "rev_thread" in warm_start) + "rev_thread": nx.to_numpy(chkpt.get("rev_thread", chkpt.get("_rev_thread"))) + if ("rev_thread" in chkpt or "_rev_thread" in chkpt) else None, - "succ_num": nx.to_numpy( - warm_start.get("_succ_num", warm_start.get("succ_num")) - ) - if ("_succ_num" in warm_start or "succ_num" in warm_start) + "succ_num": nx.to_numpy(chkpt.get("succ_num", chkpt.get("_succ_num"))) + if ("succ_num" in chkpt or "_succ_num" in chkpt) else None, - "last_succ": nx.to_numpy( - warm_start.get("_last_succ", warm_start.get("last_succ")) - ) - if ("_last_succ" in warm_start or "last_succ" in warm_start) + "last_succ": nx.to_numpy(chkpt.get("last_succ", chkpt.get("_last_succ"))) + if ("last_succ" in chkpt or "_last_succ" in chkpt) else None, - "forward": nx.to_numpy( - warm_start.get("_forward", warm_start.get("forward")) - ) - if ("_forward" in warm_start or "forward" in warm_start) + "forward": nx.to_numpy(chkpt.get("forward", chkpt.get("_forward"))) + if ("forward" in chkpt or "_forward" in chkpt) else None, "search_arc_num": int( - warm_start.get("search_arc_num", warm_start.get("_search_arc_num", 0)) - ), - "all_arc_num": int( - warm_start.get("all_arc_num", warm_start.get("_all_arc_num", 0)) + chkpt.get("search_arc_num", chkpt.get("_search_arc_num", 0)) ), + "all_arc_num": int(chkpt.get("all_arc_num", chkpt.get("_all_arc_num", 0))), } # Filter out None values checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None} @@ -425,18 +423,20 @@ def emd( # Add checkpoint data if requested (preserve original dtypes, don't cast) if return_checkpoint and checkpoint_out is not None: - log["_flow"] = checkpoint_out["flow"] - log["_pi"] = checkpoint_out["pi"] - log["_state"] = checkpoint_out["state"] - log["_parent"] = checkpoint_out["parent"] - log["_pred"] = checkpoint_out["pred"] - log["_thread"] = checkpoint_out["thread"] - log["_rev_thread"] = checkpoint_out["rev_thread"] - log["_succ_num"] = checkpoint_out["succ_num"] - log["_last_succ"] = checkpoint_out["last_succ"] - log["_forward"] = checkpoint_out["forward"] - log["search_arc_num"] = int(checkpoint_out["search_arc_num"]) - log["all_arc_num"] = int(checkpoint_out["all_arc_num"]) + log["checkpoint"] = { + "flow": checkpoint_out["flow"], + "pi": checkpoint_out["pi"], + "state": checkpoint_out["state"], + "parent": checkpoint_out["parent"], + "pred": checkpoint_out["pred"], + "thread": checkpoint_out["thread"], + "rev_thread": checkpoint_out["rev_thread"], + "succ_num": checkpoint_out["succ_num"], + "last_succ": checkpoint_out["last_succ"], + "forward": checkpoint_out["forward"], + "search_arc_num": int(checkpoint_out["search_arc_num"]), + "all_arc_num": int(checkpoint_out["all_arc_num"]), + } return nx.from_numpy(G, type_as=type_as), log return nx.from_numpy(G, type_as=type_as) diff --git a/test/test_ot.py b/test/test_ot.py index 1a37dffbf..f5c7976d1 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -971,8 +971,8 @@ def test_emd_checkpoint_multiple(): if log["result_code"] != 3: # converged break - # Only use warm_start if checkpoint fields are present - warm_start_data = log if "_flow" in log else None + # Only use warm_start if checkpoint is present + warm_start_data = log if "checkpoint" in log else None # check cost decreases monotonically for i in range(len(costs) - 1): @@ -991,23 +991,28 @@ def test_emd_checkpoint_structure(): G, log = ot.emd(a, b, M, numItermax=10, log=True, warm_start=True) + # Check that checkpoint key exists + assert "checkpoint" in log, "Missing checkpoint key in log" + + checkpoint = log["checkpoint"] + required_fields = [ - "_flow", - "_pi", - "_state", - "_parent", - "_pred", - "_thread", - "_rev_thread", - "_succ_num", - "_last_succ", - "_forward", - "search_arc_num", # scalars don't have underscore prefix + "flow", + "pi", + "state", + "parent", + "pred", + "thread", + "rev_thread", + "succ_num", + "last_succ", + "forward", + "search_arc_num", "all_arc_num", ] for field in required_fields: - assert field in log, f"Missing checkpoint field: {field}" + assert field in checkpoint, f"Missing checkpoint field: {field}" def check_duality_gap(a, b, M, G, u, v, cost):