Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 147 additions & 73 deletions tensorcircuit/timeevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Analog time evolution engines
"""

from typing import Any, Tuple, Optional, Callable, List, Sequence
from typing import Any, Tuple, Optional, Callable, List, Sequence, Dict
from functools import partial
import warnings

import numpy as np

Expand Down Expand Up @@ -427,37 +428,58 @@ def _evol(t: Tensor) -> Tensor:
ed_evol = hamiltonian_evol


@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
def evol_local(
c: Circuit,
index: Sequence[int],
h_fun: Callable[..., Tensor],
t: float,
*args: Any,
**solver_kws: Any,
) -> Circuit:
"""
ode evolution of time dependent Hamiltonian on circuit of given indices
[only jax backend support for now]
def _solve_ode(
f: Callable[..., Tensor],
s: Tensor,
times: Tensor,
args: Any,
solver_kws: Dict[str, Any],
) -> Tensor:
rtol = solver_kws.get("rtol", 1e-12)
atol = solver_kws.get("atol", 1e-12)
ode_backend = solver_kws.get("ode_backend", "jaxode")
max_steps = solver_kws.get("max_steps", 10000)

:param c: _description_
:type c: Circuit
:param index: qubit sites to evolve
:type index: Sequence[int]
:param h_fun: h_fun should return a dense Hamiltonian matrix
with input arguments time and *args
:type h_fun: Callable[..., Tensor]
:param t: evolution time
:type t: float
:return: _description_
:rtype: Circuit
"""
s = c.state()
n = int(np.log2(s.shape[-1]) + 1e-7)
if isinstance(t, float):
t = backend.stack([0.0, t])
s1 = ode_evol_local(h_fun, s, t, index, None, *args, **solver_kws)
return type(c)(n, inputs=s1[-1])
ts = backend.convert_to_tensor(times)
ts = backend.cast(ts, dtype=rdtypestr)

if ode_backend == "jaxode":
from jax.experimental.ode import odeint

s1 = odeint(f, s, ts, rtol=rtol, atol=atol, mxstep=max_steps, *args)
return s1

import diffrax

# Ignore complex warning
warnings.simplefilter("ignore", category=UserWarning, append=True)

solver = solver_kws.get("solver", "Tsit5")
dt0 = solver_kws.get("dt0", 0.01)
all_solvers = {
"Dopri5": diffrax.Dopri5,
"Tsit5": diffrax.Tsit5,
"Dopri8": diffrax.Dopri8,
"Kvaerno5": diffrax.Kvaerno5,
}

# ODE
term = diffrax.ODETerm(lambda t, y, args: f(y, t, *args))

# solve ODE
s1 = diffrax.diffeqsolve(
terms=term,
solver=all_solvers[solver](),
t0=times[0],
t1=times[-1],
dt0=dt0,
y0=s,
saveat=diffrax.SaveAt(ts=times),
args=args,
stepsize_controller=diffrax.PIDController(rtol=rtol, atol=atol),
max_steps=max_steps,
).ys
return s1


def ode_evol_local(
Expand All @@ -475,6 +497,9 @@ def ode_evol_local(
This function solves the time-dependent Schrodinger equation using numerical ODE integration.
The Hamiltonian is applied only to a specific subset of qubits (indices) in the system.

The ode_backend parameter defaults to 'jaxode' (which uses `jax.experimental.ode.odeint` with a default solver
of 'Dopri5');if set to 'diffrax', it uses `diffrax.diffeqsolve` instead (with a default solver of 'Tsit5').

Note: This function currently only supports the JAX backend.

:param hamiltonian: A function that returns a dense Hamiltonian matrix for the specified
Expand All @@ -490,13 +515,20 @@ def ode_evol_local(
:type callback: Optional[Callable[..., Tensor]]
:param args: Additional arguments to pass to the Hamiltonian function.
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
ode_backend='jaxode'(default) uses `jax.experimental.ode.odeint`; ode_backend='diffrax'
uses `diffrax.diffeqsolve`.
rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.
:return: Evolved quantum states at the specified time points. If callback is provided,
returns the callback results; otherwise returns the state vectors.
:rtype: Tensor
"""
from jax.experimental.ode import odeint

s = initial_state
n = int(np.log2(backend.shape_tuple(initial_state)[-1]) + 1e-7)
l = len(index)

Expand All @@ -517,38 +549,11 @@ def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
y = contractor([y, h], output_edge_order=edges)
return backend.reshape(y.tensor, [-1])

ts = backend.convert_to_tensor(times)
ts = backend.cast(ts, dtype=rdtypestr)
s1 = odeint(f, s, ts, *args, **solver_kws)
if not callback:
return s1
return backend.stack([callback(s1[i]) for i in range(len(s1))])
s1 = _solve_ode(f, initial_state, times, args, solver_kws)


@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
def evol_global(
c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
) -> Circuit:
"""
ode evolution of time dependent Hamiltonian on circuit of all qubits
[only jax backend support for now]

:param c: _description_
:type c: Circuit
:param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
with input arguments time and *args
:type h_fun: Callable[..., Tensor]
:param t: _description_
:type t: float
:return: _description_
:rtype: Circuit
"""
s = c.state()
n = c._nqubits
if isinstance(t, float):
t = backend.stack([0.0, t])
s1 = ode_evol_global(h_fun, s, t, None, *args, **solver_kws)
return type(c)(n, inputs=s1[-1])
if callback is None:
return s1
return backend.stack([callback(a_state) for a_state in s1])


def ode_evol_global(
Expand All @@ -564,7 +569,10 @@ def ode_evol_global(

This function solves the time-dependent Schrodinger equation using numerical ODE integration.
The Hamiltonian is applied to the full system and should be provided in sparse matrix format
for efficiency.
for efficiency.

The ode_backend parameter defaults to 'jaxode' (which uses `jax.experimental.ode.odeint` with a default solver
of 'Dopri5');if set to 'diffrax', it uses `diffrax.diffeqsolve` instead (with a default solver of 'Tsit5').

Note: This function currently only supports the JAX backend.

Expand All @@ -578,25 +586,91 @@ def ode_evol_global(
:param callback: Optional function to apply to the state at each time step.
:type callback: Optional[Callable[..., Tensor]]
:param args: Additional arguments to pass to the Hamiltonian function.
:type args: tuple | list
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
ode_backend='jaxode'(default) uses `jax.experimental.ode.odeint`; ode_backend='diffrax'
uses `diffrax.diffeqsolve`.
rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.
:type solver_kws: dict
:return: Evolved quantum states at the specified time points. If callback is provided,
returns the callback results; otherwise returns the state vectors.
:rtype: Tensor
"""
from jax.experimental.ode import odeint

s = initial_state
ts = backend.convert_to_tensor(times)
ts = backend.cast(ts, dtype=rdtypestr)

def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
h = -1.0j * hamiltonian(t, *args)
return backend.sparse_dense_matmul(h, y)

s1 = odeint(f, s, ts, *args, **solver_kws)
if not callback:
s1 = _solve_ode(f, initial_state, times, args, solver_kws)

if callback is None:
return s1
return backend.stack([callback(s1[i]) for i in range(len(s1))])
return backend.stack([callback(a_state) for a_state in s1])


@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
def evol_local(
c: Circuit,
index: Sequence[int],
h_fun: Callable[..., Tensor],
t: float,
*args: Any,
**solver_kws: Any,
) -> Circuit:
"""
ode evolution of time dependent Hamiltonian on circuit of given indices
[only jax backend support for now]

:param c: _description_
:type c: Circuit
:param index: qubit sites to evolve
:type index: Sequence[int]
:param h_fun: h_fun should return a dense Hamiltonian matrix
with input arguments time and *args
:type h_fun: Callable[..., Tensor]
:param t: evolution time
:type t: float
:return: _description_
:rtype: Circuit
"""
s = c.state()
n = int(np.log2(s.shape[-1]) + 1e-7)
if isinstance(t, float):
t = backend.stack([0.0, t])
s1 = ode_evol_local(h_fun, s, t, index, None, *args, **solver_kws)
return type(c)(n, inputs=s1[-1])


@partial(arg_alias, alias_dict={"h_fun": ["hamiltonian"], "t": ["times"]})
def evol_global(
c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
) -> Circuit:
"""
ode evolution of time dependent Hamiltonian on circuit of all qubits
[only jax backend support for now]

:param c: _description_
:type c: Circuit
:param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
with input arguments time and *args
:type h_fun: Callable[..., Tensor]
:param t: _description_
:type t: float
:return: _description_
:rtype: Circuit
"""
s = c.state()
n = c._nqubits
if isinstance(t, float):
t = backend.stack([0.0, t])
s1 = ode_evol_global(h_fun, s, t, None, *args, **solver_kws)
return type(c)(n, inputs=s1[-1])


def chebyshev_evol(
Expand Down
Loading