Skip to content

Update timeevol.py #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

Huang-Xu-Yang
Copy link

change two functions

change two functions
@refraction-ray
Copy link
Member

please also add tests in tests/test_timeevol.py accordingly

@refraction-ray
Copy link
Member

also please ensure the PR can pass local check: black, mypy, pytest

Copy link
Member

@refraction-ray refraction-ray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall a nice work, some minor issues require addressing


max_steps = solver_kws.get("max_steps", 10000)

if (solver := solver_kws.get("solver", "Dopri5")) == "Dopri5":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why dopri5 corresponds the original odeint in jax, doesn't diffrax provides dopri5?


else:
import diffrax
import warnings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warnings can be imported globally at the beginning of the file



def ode_evol_global(
hamiltonian: Callable[..., Tensor],
initial_state: Tensor,
times: Tensor,
callback: Optional[Callable[..., Tensor]] = None,
*args: Any,
**solver_kws: Any,
args: Optional[Sequence[int | float | complex]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

return type(c)(n, inputs=s1[-1])


# def ode_evol_local(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete instead of comment

return type(c)(n, inputs=s1[-1])


# def ode_evol_global(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete instead of comment


assert tc.backend.shape_tuple(states0) == (30, 16)
assert tc.backend.shape_tuple(states1) == (30, 16)
assert tc.backend.shape_tuple(states2) == (30, 16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better also np assert_allclose between them?

@@ -60,16 +60,37 @@ def local_hamiltonian(t, Omega, phi):
times = tc.backend.arange(0.0, 3.0, 0.1)

# Evolve with local Hamiltonian acting on qubit 1
states = tc.timeevol.ode_evol_local(
states0 = tc.timeevol.ode_evol_local(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for both test functions that use dependence of diffrax, you should first try import diffrax, if error then skip the test. Another way is to add diffrax in the requirements/requirements-extra.txt

new test; do not change API compatibility; try import diffrax, if error then skip the test
@@ -466,7 +488,7 @@ def ode_evol_local(
times: Tensor,
index: Sequence[int],
callback: Optional[Callable[..., Tensor]] = None,
*args: Any,
*args: int | float | complex,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better still use Any as args can be tensors

except ImportError:
pytest.skip("diffrax not installed, skipping test")

tc.set_dtype("complex128")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will pollute the precision of other tests, use highp pytest fixture instead

except ImportError:
pytest.skip("diffrax not installed, skipping test")

tc.set_dtype("complex128")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use highp pytest fixture

@@ -40,6 +40,13 @@ def h_square_sparse(t, b):


def test_ode_evol_local(jaxb):
try:
import diffrax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add pylint ignore comment to ensure legal pylint check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants