Skip to content

Commit 7e39e4a

Browse files
committed
autograd config section
1 parent f058da8 commit 7e39e4a

File tree

16 files changed

+344
-156
lines changed

16 files changed

+344
-156
lines changed

tests/config/test_manager.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import numpy as np
34
import pytest
45

56

@@ -39,3 +40,38 @@ def test_builtin_profiles(profile, config_manager):
3940
config_manager.switch_profile(profile)
4041
web = config_manager.get_section("web")
4142
assert web.s3_region is not None
43+
44+
45+
def test_autograd_defaults(config_manager):
46+
autograd = config_manager.get_section("autograd")
47+
assert autograd.min_wvl_fraction == pytest.approx(5e-2)
48+
assert autograd.points_per_wavelength == 10
49+
assert autograd.monitor_interval_poly == (1, 1, 1)
50+
assert autograd.quadrature_sample_fraction == pytest.approx(0.4)
51+
assert autograd.gradient_precision == "single"
52+
assert autograd.max_traced_structures == 500
53+
assert autograd.max_adjoint_per_fwd == 10
54+
55+
56+
def test_autograd_update_section(config_manager):
57+
config_manager.update_section(
58+
"autograd",
59+
min_wvl_fraction=0.08,
60+
points_per_wavelength=12,
61+
solver_freq_chunk_size=3,
62+
gradient_precision="double",
63+
max_traced_structures=600,
64+
max_adjoint_per_fwd=7,
65+
)
66+
autograd = config_manager.get_section("autograd")
67+
assert autograd.min_wvl_fraction == pytest.approx(0.08)
68+
assert autograd.points_per_wavelength == 12
69+
assert autograd.solver_freq_chunk_size == 3
70+
assert autograd.gradient_precision == "double"
71+
assert autograd.max_traced_structures == 600
72+
assert autograd.max_adjoint_per_fwd == 7
73+
74+
from tidy3d.components.autograd import derivative_utils as autograd_utils
75+
76+
assert autograd_utils.GRADIENT_PRECISION == "double"
77+
assert autograd_utils.GRADIENT_DTYPE_FLOAT is np.float64

tests/test_components/autograd/test_autograd.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919

2020
import tidy3d as td
2121
import tidy3d.web as web
22-
from tidy3d.components.autograd.constants import (
23-
MAX_NUM_TRACED_STRUCTURES,
24-
MIN_WVL_FRACTION_CYLINDER_DISCRETIZE,
22+
from tidy3d.components.autograd.derivative_utils import (
2523
MINIMUM_SPACING_FRACTION,
24+
DerivativeInfo,
2625
)
27-
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
2826
from tidy3d.components.autograd.utils import is_tidy_box
2927
from tidy3d.components.data.data_array import DataArray
28+
from tidy3d.config import config
3029
from tidy3d.exceptions import AdjointError
3130
from tidy3d.plugins.polyslab import ComplexPolySlab
3231
from tidy3d.web import run, run_async
@@ -1251,7 +1250,8 @@ def test_too_many_traced_structures(monkeypatch, use_emulated_run):
12511250
def make_sim(*args):
12521251
structure = make_structures(*args)[structure_key]
12531252
return SIM_BASE.updated_copy(
1254-
structures=(MAX_NUM_TRACED_STRUCTURES + 1) * [structure], monitors=[monitor]
1253+
structures=(config.autograd.max_traced_structures + 1) * [structure],
1254+
monitors=[monitor],
12551255
)
12561256

12571257
def objective(*args):
@@ -1808,7 +1808,7 @@ def test_cylinder_discretization(eps_real):
18081808
):
18091809
cylinder = td.Cylinder(axis=2, length=info.wavelength_min, radius=2 * info.wavelength_min)
18101810

1811-
expected_wvl_mat = info.wavelength_min * MIN_WVL_FRACTION_CYLINDER_DISCRETIZE
1811+
expected_wvl_mat = info.wavelength_min * config.autograd.min_wvl_fraction
18121812
wvl_mat = cylinder._discretization_wavelength(derivative_info=info)
18131813

18141814
assert np.isclose(expected_wvl_mat, wvl_mat), (

tidy3d/components/autograd/constants.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tidy3d/components/autograd/derivative_utils.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
from tidy3d.constants import C_0, EPSILON_0, LARGE_NUMBER, MU_0
1414
from tidy3d.log import log
1515

16-
from .constants import (
17-
DEFAULT_WAVELENGTH_FRACTION,
18-
GRADIENT_DTYPE_COMPLEX,
19-
GRADIENT_DTYPE_FLOAT,
20-
MINIMUM_SPACING_FRACTION,
21-
)
2216
from .types import PathType
2317
from .utils import get_static
2418

@@ -27,6 +21,33 @@
2721
EpsType = Union[tidycomplex, FreqDataArray]
2822

2923

24+
GRADIENT_PRECISION = "single" # Options: "single", "double"
25+
GRADIENT_DTYPE_FLOAT = np.float32
26+
GRADIENT_DTYPE_COMPLEX = np.complex64
27+
28+
29+
def set_gradient_precision(precision: str) -> None:
30+
"""Update global gradient precision and derived dtypes."""
31+
32+
if precision not in {"single", "double"}:
33+
raise ValueError("gradient_precision must be 'single' or 'double'")
34+
35+
global GRADIENT_PRECISION, GRADIENT_DTYPE_FLOAT, GRADIENT_DTYPE_COMPLEX
36+
GRADIENT_PRECISION = precision
37+
if precision == "single":
38+
GRADIENT_DTYPE_FLOAT = np.float32
39+
GRADIENT_DTYPE_COMPLEX = np.complex64
40+
else:
41+
GRADIENT_DTYPE_FLOAT = np.float64
42+
GRADIENT_DTYPE_COMPLEX = np.complex128
43+
44+
45+
MINIMUM_SPACING_FRACTION = 1e-2
46+
47+
48+
set_gradient_precision("single")
49+
50+
3051
class LazyInterpolator:
3152
"""Lazy wrapper for interpolators that creates them on first access."""
3253

@@ -711,7 +732,7 @@ def _project_in_basis(
711732

712733
def adaptive_vjp_spacing(
713734
self,
714-
wl_fraction: float = DEFAULT_WAVELENGTH_FRACTION,
735+
wl_fraction: Optional[float] = None,
715736
min_allowed_spacing_fraction: float = MINIMUM_SPACING_FRACTION,
716737
) -> float:
717738
"""Compute adaptive spacing for finite-difference gradient evaluation.
@@ -721,8 +742,9 @@ def adaptive_vjp_spacing(
721742
722743
Parameters
723744
----------
724-
wl_fraction : float = 0.1
725-
Fraction of wavelength/skin depth to use as spacing.
745+
wl_fraction : float, optional
746+
Fraction of wavelength/skin depth to use as spacing. Defaults to the configured
747+
``autograd.default_wavelength_fraction`` when ``None``.
726748
min_allowed_spacing_fraction : float = 1e-2
727749
Minimum allowed spacing fraction of free space wavelength to
728750
prevent numerical issues.
@@ -732,6 +754,11 @@ def adaptive_vjp_spacing(
732754
float
733755
Adaptive spacing value for gradient evaluation.
734756
"""
757+
if wl_fraction is None:
758+
from tidy3d.config import config
759+
760+
wl_fraction = config.autograd.default_wavelength_fraction
761+
735762
# handle FreqDataArray or scalar eps_in
736763
if isinstance(self.eps_in, FreqDataArray):
737764
eps_real = np.asarray(self.eps_in.values, dtype=np.complex128).real

tidy3d/components/geometry/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
except ImportError:
1818
pass
1919

20+
import tidy3d.components.autograd.derivative_utils as autograd_utils
2021
from tidy3d.compat import _shapely_is_older_than
2122
from tidy3d.components.autograd import (
2223
AutogradFieldMap,
@@ -25,7 +26,6 @@
2526
TracedSize,
2627
get_static,
2728
)
28-
from tidy3d.components.autograd.constants import GRADIENT_DTYPE_FLOAT
2929
from tidy3d.components.autograd.derivative_utils import (
3030
DerivativeInfo,
3131
FieldData,
@@ -2684,7 +2684,7 @@ def _derivative_face_dielectric(
26842684
26852685
Parameters
26862686
----------
2687-
dtype : np.dtype = GRADIENT_DTYPE_FLOAT
2687+
dtype : np.dtype = autograd_utils.GRADIENT_DTYPE_FLOAT
26882688
Data type for interpolation coordinates and values.
26892689
26902690
dim_normal : str
@@ -2757,7 +2757,7 @@ def _derivative_face_pec(
27572757
27582758
Parameters
27592759
----------
2760-
dtype : np.dtype = GRADIENT_DTYPE_FLOAT
2760+
dtype : np.dtype = autograd_utils.GRADIENT_DTYPE_FLOAT
27612761
Data type for interpolation coordinates and values.
27622762
27632763
dim_normal : str
@@ -3613,7 +3613,7 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
36133613

36143614
# create interpolators once for all geometries to avoid redundant field data conversions
36153615
interpolators = derivative_info.interpolators or derivative_info.create_interpolators(
3616-
dtype=GRADIENT_DTYPE_FLOAT
3616+
dtype=autograd_utils.GRADIENT_DTYPE_FLOAT
36173617
)
36183618

36193619
for field_path in derivative_info.paths:

0 commit comments

Comments
 (0)