diff --git a/tensorcircuit/timeevol.py b/tensorcircuit/timeevol.py index b473f2a8..c480bcdc 100644 --- a/tensorcircuit/timeevol.py +++ b/tensorcircuit/timeevol.py @@ -2,8 +2,9 @@ Analog time evolution engines """ -from typing import Any, Tuple, Optional, Callable, List, Sequence +from typing import Any, Tuple, Optional, Callable, List, Sequence, Dict from functools import partial +import warnings import numpy as np @@ -427,37 +428,58 @@ def _evol(t: Tensor) -> Tensor: ed_evol = hamiltonian_evol -@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]}) -def evol_local( - c: Circuit, - index: Sequence[int], - h_fun: Callable[..., Tensor], - t: float, - *args: Any, - **solver_kws: Any, -) -> Circuit: - """ - ode evolution of time dependent Hamiltonian on circuit of given indices - [only jax backend support for now] +def _solve_ode( + f: Callable[..., Tensor], + s: Tensor, + times: Tensor, + args: Any, + solver_kws: Dict[str, Any], +) -> Tensor: + rtol = solver_kws.get("rtol", 1e-12) + atol = solver_kws.get("atol", 1e-12) + ode_backend = solver_kws.get("ode_backend", "jaxode") + max_steps = solver_kws.get("max_steps", 10000) - :param c: _description_ - :type c: Circuit - :param index: qubit sites to evolve - :type index: Sequence[int] - :param h_fun: h_fun should return a dense Hamiltonian matrix - with input arguments time and *args - :type h_fun: Callable[..., Tensor] - :param t: evolution time - :type t: float - :return: _description_ - :rtype: Circuit - """ - s = c.state() - n = int(np.log2(s.shape[-1]) + 1e-7) - if isinstance(t, float): - t = backend.stack([0.0, t]) - s1 = ode_evol_local(h_fun, s, t, index, None, *args, **solver_kws) - return type(c)(n, inputs=s1[-1]) + ts = backend.convert_to_tensor(times) + ts = backend.cast(ts, dtype=rdtypestr) + + if ode_backend == "jaxode": + from jax.experimental.ode import odeint + + s1 = odeint(f, s, ts, rtol=rtol, atol=atol, mxstep=max_steps, *args) + return s1 + + import diffrax + + # Ignore complex warning + warnings.simplefilter("ignore", category=UserWarning, append=True) + + solver = solver_kws.get("solver", "Tsit5") + dt0 = solver_kws.get("dt0", 0.01) + all_solvers = { + "Dopri5": diffrax.Dopri5, + "Tsit5": diffrax.Tsit5, + "Dopri8": diffrax.Dopri8, + "Kvaerno5": diffrax.Kvaerno5, + } + + # ODE + term = diffrax.ODETerm(lambda t, y, args: f(y, t, *args)) + + # solve ODE + s1 = diffrax.diffeqsolve( + terms=term, + solver=all_solvers[solver](), + t0=times[0], + t1=times[-1], + dt0=dt0, + y0=s, + saveat=diffrax.SaveAt(ts=times), + args=args, + stepsize_controller=diffrax.PIDController(rtol=rtol, atol=atol), + max_steps=max_steps, + ).ys + return s1 def ode_evol_local( @@ -475,6 +497,9 @@ def ode_evol_local( This function solves the time-dependent Schrodinger equation using numerical ODE integration. The Hamiltonian is applied only to a specific subset of qubits (indices) in the system. + The ode_backend parameter defaults to 'jaxode' (which uses `jax.experimental.ode.odeint` with a default solver + of 'Dopri5');if set to 'diffrax', it uses `diffrax.diffeqsolve` instead (with a default solver of 'Tsit5'). + Note: This function currently only supports the JAX backend. :param hamiltonian: A function that returns a dense Hamiltonian matrix for the specified @@ -490,13 +515,20 @@ def ode_evol_local( :type callback: Optional[Callable[..., Tensor]] :param args: Additional arguments to pass to the Hamiltonian function. :param solver_kws: Additional keyword arguments to pass to the ODE solver. + ode_backend='jaxode'(default) uses `jax.experimental.ode.odeint`; ode_backend='diffrax' + uses `diffrax.diffeqsolve`. + rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would + like the numerical approximation to your equation. + The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'} + and only works when ode_backend='diffrax'. + dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'. + max_steps (default: 10000) The maximum number of steps to take before quitting the computation + unconditionally and only works when ode_backend='diffrax'. :return: Evolved quantum states at the specified time points. If callback is provided, returns the callback results; otherwise returns the state vectors. :rtype: Tensor """ - from jax.experimental.ode import odeint - s = initial_state n = int(np.log2(backend.shape_tuple(initial_state)[-1]) + 1e-7) l = len(index) @@ -517,38 +549,11 @@ def f(y: Tensor, t: Tensor, *args: Any) -> Tensor: y = contractor([y, h], output_edge_order=edges) return backend.reshape(y.tensor, [-1]) - ts = backend.convert_to_tensor(times) - ts = backend.cast(ts, dtype=rdtypestr) - s1 = odeint(f, s, ts, *args, **solver_kws) - if not callback: - return s1 - return backend.stack([callback(s1[i]) for i in range(len(s1))]) + s1 = _solve_ode(f, initial_state, times, args, solver_kws) - -@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]}) -def evol_global( - c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any -) -> Circuit: - """ - ode evolution of time dependent Hamiltonian on circuit of all qubits - [only jax backend support for now] - - :param c: _description_ - :type c: Circuit - :param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix - with input arguments time and *args - :type h_fun: Callable[..., Tensor] - :param t: _description_ - :type t: float - :return: _description_ - :rtype: Circuit - """ - s = c.state() - n = c._nqubits - if isinstance(t, float): - t = backend.stack([0.0, t]) - s1 = ode_evol_global(h_fun, s, t, None, *args, **solver_kws) - return type(c)(n, inputs=s1[-1]) + if callback is None: + return s1 + return backend.stack([callback(a_state) for a_state in s1]) def ode_evol_global( @@ -564,7 +569,10 @@ def ode_evol_global( This function solves the time-dependent Schrodinger equation using numerical ODE integration. The Hamiltonian is applied to the full system and should be provided in sparse matrix format - for efficiency. + for efficiency. + + The ode_backend parameter defaults to 'jaxode' (which uses `jax.experimental.ode.odeint` with a default solver + of 'Dopri5');if set to 'diffrax', it uses `diffrax.diffeqsolve` instead (with a default solver of 'Tsit5'). Note: This function currently only supports the JAX backend. @@ -578,25 +586,91 @@ def ode_evol_global( :param callback: Optional function to apply to the state at each time step. :type callback: Optional[Callable[..., Tensor]] :param args: Additional arguments to pass to the Hamiltonian function. + :type args: tuple | list :param solver_kws: Additional keyword arguments to pass to the ODE solver. + ode_backend='jaxode'(default) uses `jax.experimental.ode.odeint`; ode_backend='diffrax' + uses `diffrax.diffeqsolve`. + rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would + like the numerical approximation to your equation. + The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'} + and only works when ode_backend='diffrax'. + dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'. + max_steps (default: 10000) The maximum number of steps to take before quitting the computation + unconditionally and only works when ode_backend='diffrax'. + :type solver_kws: dict :return: Evolved quantum states at the specified time points. If callback is provided, returns the callback results; otherwise returns the state vectors. :rtype: Tensor """ - from jax.experimental.ode import odeint - - s = initial_state - ts = backend.convert_to_tensor(times) - ts = backend.cast(ts, dtype=rdtypestr) def f(y: Tensor, t: Tensor, *args: Any) -> Tensor: h = -1.0j * hamiltonian(t, *args) return backend.sparse_dense_matmul(h, y) - s1 = odeint(f, s, ts, *args, **solver_kws) - if not callback: + s1 = _solve_ode(f, initial_state, times, args, solver_kws) + + if callback is None: return s1 - return backend.stack([callback(s1[i]) for i in range(len(s1))]) + return backend.stack([callback(a_state) for a_state in s1]) + + +@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]}) +def evol_local( + c: Circuit, + index: Sequence[int], + h_fun: Callable[..., Tensor], + t: float, + *args: Any, + **solver_kws: Any, +) -> Circuit: + """ + ode evolution of time dependent Hamiltonian on circuit of given indices + [only jax backend support for now] + + :param c: _description_ + :type c: Circuit + :param index: qubit sites to evolve + :type index: Sequence[int] + :param h_fun: h_fun should return a dense Hamiltonian matrix + with input arguments time and *args + :type h_fun: Callable[..., Tensor] + :param t: evolution time + :type t: float + :return: _description_ + :rtype: Circuit + """ + s = c.state() + n = int(np.log2(s.shape[-1]) + 1e-7) + if isinstance(t, float): + t = backend.stack([0.0, t]) + s1 = ode_evol_local(h_fun, s, t, index, None, *args, **solver_kws) + return type(c)(n, inputs=s1[-1]) + + +@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]}) +def evol_global( + c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any +) -> Circuit: + """ + ode evolution of time dependent Hamiltonian on circuit of all qubits + [only jax backend support for now] + + :param c: _description_ + :type c: Circuit + :param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix + with input arguments time and *args + :type h_fun: Callable[..., Tensor] + :param t: _description_ + :type t: float + :return: _description_ + :rtype: Circuit + """ + s = c.state() + n = c._nqubits + if isinstance(t, float): + t = backend.stack([0.0, t]) + s1 = ode_evol_global(h_fun, s, t, None, *args, **solver_kws) + return type(c)(n, inputs=s1[-1]) def chebyshev_evol( diff --git a/tests/test_timeevol.py b/tests/test_timeevol.py index 854cb85a..676c3893 100644 --- a/tests/test_timeevol.py +++ b/tests/test_timeevol.py @@ -11,7 +11,12 @@ import tensorcircuit as tc -def test_circuit_ode_evol(jaxb): +def test_circuit_ode_evol(highp, jaxb): + try: + import diffrax # pylint: disable=unused-import + except ImportError: + pytest.skip("diffrax not installed, skipping test") + def h_square(t, b): return (tc.backend.sign(t - 1.0) + 1) / 2 * b * tc.gates.x().tensor @@ -33,13 +38,22 @@ def h_square_sparse(t, b): c.cx(0, 1) c.h(2) c = tc.timeevol.evol_global( - c, h_square_sparse, 2.0, tc.backend.convert_to_tensor(0.2) + c, + h_square_sparse, + 2.0, + tc.backend.convert_to_tensor(0.2), + ode_backend="diffrax", ) c.rx(1, theta=np.pi - 0.4) np.testing.assert_allclose(c.expectation_ps(z=[1]), 1.0, atol=1e-5) -def test_ode_evol_local(jaxb): +def test_ode_evol_local(highp, jaxb): + try: + import diffrax # pylint: disable=unused-import + except ImportError: + pytest.skip("diffrax not installed, skipping test") + def local_hamiltonian(t, Omega, phi): angle = phi * t coeff = Omega * tc.backend.cos(2.0 * t) # Amplitude modulation @@ -60,7 +74,18 @@ def local_hamiltonian(t, Omega, phi): times = tc.backend.arange(0.0, 3.0, 0.1) # Evolve with local Hamiltonian acting on qubit 1 - states = tc.timeevol.ode_evol_local( + states0 = tc.timeevol.ode_evol_local( + local_hamiltonian, + psi0, + times, + [1], # Apply to qubit 1 + None, + 1.0, + 2.0, # Omega=1.0, phi=2.0 + solver="Tsit5", + ode_backend="diffrax", + ) + states1 = tc.timeevol.ode_evol_local( local_hamiltonian, psi0, times, @@ -68,11 +93,34 @@ def local_hamiltonian(t, Omega, phi): None, 1.0, 2.0, # Omega=1.0, phi=2.0 + atol=1.0e-13, + rtol=1.0e-15, ) - assert tc.backend.shape_tuple(states) == (30, 16) + states2 = tc.timeevol.ode_evol_local( + local_hamiltonian, + psi0, + times, + [1], # Apply to qubit 1 + None, + 1.0, + 2.0, # Omega=1.0, phi=2.0 + solver="Dopri8", + atol=1.0e-13, + rtol=1.0e-13, + ode_backend="diffrax", + dt0=0.005, + ) + + np.testing.assert_allclose(states2, states1, atol=1e-10, rtol=0.0) + np.testing.assert_allclose(states0, states1, atol=1e-10, rtol=0.0) + +def test_ode_evol_global(highp, jaxb): + try: + import diffrax # pylint: disable=unused-import + except ImportError: + pytest.skip("diffrax not installed, skipping test") -def test_ode_evol_global(jaxb): # Create a time-dependent transverse field Hamiltonian # H(t) = -∑ᵢ Jᵢ(t) ZᵢZᵢ₊₁ - ∑ᵢ hᵢ(t) Xᵢ @@ -121,7 +169,7 @@ def zobs(state): x_ham = tc.quantum.PauliStringSum2COO([[1, 0, 0, 0], [0, 1, 0, 0]], [1, 1]) # Example with parameterized Hamiltonian and optimization - def parametrized_hamiltonian(t, params): + def parametrized_hamiltonian(t, *params): # params = [J0, J1, h0, h1] - parameters to optimize J_t = params[0] + params[1] * tc.backend.sin(2.0 * t) h_t = params[2] + params[3] * tc.backend.cos(1.5 * t) @@ -142,13 +190,17 @@ def objective_function(params): psi0, tc.backend.convert_to_tensor([0, 1.0]), None, - params, + *params, + atol=1.0e-15, + rtol=1.0e-15, + solver="Kvaerno5", + ode_backend="diffrax", ) # Measure ZZ correlation at final time final_state = states[-1] return tc.backend.real(zz_correlation(final_state)) - print(objective_function(tc.backend.ones([4]))) + print(objective_function(tc.backend.ones(4))) @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])