-
Notifications
You must be signed in to change notification settings - Fork 11
Update timeevol.py #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
change two functions
please also add tests in tests/test_timeevol.py accordingly |
also please ensure the PR can pass local check: black, mypy, pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall a nice work, some minor issues require addressing
tensorcircuit/timeevol.py
Outdated
|
||
max_steps = solver_kws.get("max_steps", 10000) | ||
|
||
if (solver := solver_kws.get("solver", "Dopri5")) == "Dopri5": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why dopri5 corresponds the original odeint in jax, doesn't diffrax provides dopri5?
tensorcircuit/timeevol.py
Outdated
|
||
else: | ||
import diffrax | ||
import warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warnings can be imported globally at the beginning of the file
tensorcircuit/timeevol.py
Outdated
|
||
|
||
def ode_evol_global( | ||
hamiltonian: Callable[..., Tensor], | ||
initial_state: Tensor, | ||
times: Tensor, | ||
callback: Optional[Callable[..., Tensor]] = None, | ||
*args: Any, | ||
**solver_kws: Any, | ||
args: Optional[Sequence[int | float | complex]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
tensorcircuit/timeevol.py
Outdated
return type(c)(n, inputs=s1[-1]) | ||
|
||
|
||
# def ode_evol_local( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete instead of comment
tensorcircuit/timeevol.py
Outdated
return type(c)(n, inputs=s1[-1]) | ||
|
||
|
||
# def ode_evol_global( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete instead of comment
tests/test_timeevol.py
Outdated
|
||
assert tc.backend.shape_tuple(states0) == (30, 16) | ||
assert tc.backend.shape_tuple(states1) == (30, 16) | ||
assert tc.backend.shape_tuple(states2) == (30, 16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better also np assert_allclose between them?
@@ -60,16 +60,37 @@ def local_hamiltonian(t, Omega, phi): | |||
times = tc.backend.arange(0.0, 3.0, 0.1) | |||
|
|||
# Evolve with local Hamiltonian acting on qubit 1 | |||
states = tc.timeevol.ode_evol_local( | |||
states0 = tc.timeevol.ode_evol_local( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for both test functions that use dependence of diffrax, you should first try import diffrax, if error then skip the test. Another way is to add diffrax in the requirements/requirements-extra.txt
tensorcircuit/timeevol.py
Outdated
@@ -466,7 +488,7 @@ def ode_evol_local( | |||
times: Tensor, | |||
index: Sequence[int], | |||
callback: Optional[Callable[..., Tensor]] = None, | |||
*args: Any, | |||
*args: int | float | complex, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better still use Any as args can be tensors
tests/test_timeevol.py
Outdated
except ImportError: | ||
pytest.skip("diffrax not installed, skipping test") | ||
|
||
tc.set_dtype("complex128") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will pollute the precision of other tests, use highp
pytest fixture instead
tests/test_timeevol.py
Outdated
except ImportError: | ||
pytest.skip("diffrax not installed, skipping test") | ||
|
||
tc.set_dtype("complex128") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use highp
pytest fixture
tests/test_timeevol.py
Outdated
@@ -40,6 +40,13 @@ def h_square_sparse(t, b): | |||
|
|||
|
|||
def test_ode_evol_local(jaxb): | |||
try: | |||
import diffrax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add pylint ignore comment to ensure legal pylint check
change two functions