Skip to content

Commit 70af873

Browse files
committed
split web/api/autograd/ into more modules
1 parent 5c073c6 commit 70af873

File tree

9 files changed

+966
-473
lines changed

9 files changed

+966
-473
lines changed

tidy3d/web/api/autograd/autograd.py

Lines changed: 72 additions & 473 deletions
Large diffs are not rendered by default.
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
from __future__ import annotations
2+
3+
from collections import defaultdict
4+
5+
import numpy as np
6+
import xarray as xr
7+
8+
import tidy3d as td
9+
from tidy3d.components.autograd import AutogradFieldMap, get_static
10+
from tidy3d.components.autograd.constants import ADJOINT_FREQ_CHUNK_SIZE
11+
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
12+
from tidy3d.exceptions import AdjointError
13+
14+
from .utils import E_to_D, get_derivative_maps
15+
16+
17+
def setup_adj(
18+
data_fields_vjp: AutogradFieldMap,
19+
sim_data_orig: td.SimulationData,
20+
sim_fields_keys: list[tuple],
21+
max_num_adjoint_per_fwd: int,
22+
) -> list[td.Simulation]:
23+
"""Construct an adjoint simulation from a set of data_fields for the VJP."""
24+
25+
td.log.info("Running custom vjp (adjoint) pipeline.")
26+
27+
# filter out any data_fields_vjp with all 0's
28+
data_fields_vjp = {
29+
k: get_static(v) for k, v in data_fields_vjp.items() if not np.allclose(v, 0)
30+
}
31+
32+
for k, v in data_fields_vjp.items():
33+
if np.any(np.isnan(v)):
34+
raise AdjointError(
35+
f"NaN values detected for data field {k} in the adjoint pipeline. This may be "
36+
f"due to NaN values in the simulation data or the computed value of your "
37+
f"objective function."
38+
)
39+
40+
# if all entries are zero, there is no adjoint sim to run
41+
if not data_fields_vjp:
42+
return []
43+
44+
# start with the full simulation data structure and either zero out the fields
45+
# that have no tracer data for them or insert the tracer data
46+
full_sim_data_dict = sim_data_orig._strip_traced_fields(
47+
include_untraced_data_arrays=True, starting_path=("data",)
48+
)
49+
for path in full_sim_data_dict.keys():
50+
if path in data_fields_vjp:
51+
full_sim_data_dict[path] = data_fields_vjp[path]
52+
else:
53+
full_sim_data_dict[path] *= 0
54+
55+
# insert the raw VJP data into the .data of the original SimulationData
56+
sim_data_vjp = sim_data_orig._insert_traced_fields(field_mapping=full_sim_data_dict)
57+
58+
# make adjoint simulation from that SimulationData
59+
data_vjp_paths = set(data_fields_vjp.keys())
60+
61+
num_monitors = len(sim_data_orig.simulation.monitors)
62+
adjoint_monitors = sim_data_orig.simulation._with_adjoint_monitors(sim_fields_keys).monitors[
63+
num_monitors:
64+
]
65+
66+
sims_adj = sim_data_vjp._make_adjoint_sims(
67+
data_vjp_paths=data_vjp_paths,
68+
adjoint_monitors=adjoint_monitors,
69+
)
70+
71+
if len(sims_adj) > max_num_adjoint_per_fwd:
72+
raise AdjointError(
73+
f"Number of adjoint simulations ({len(sims_adj)}) exceeds the maximum allowed "
74+
f"({max_num_adjoint_per_fwd}) per forward simulation. This typically means that "
75+
"there are many frequencies and monitors in the simulation that are being differentiated "
76+
"w.r.t. in the objective function. To proceed, please double-check the simulation "
77+
"setup, increase the 'max_num_adjoint_per_fwd' parameter in the run function, and re-run."
78+
)
79+
80+
return sims_adj
81+
82+
83+
def _compute_eps_array(medium, frequencies):
84+
"""Compute permittivity array for all frequencies."""
85+
eps_data = [np.mean(medium.eps_model(f)) for f in frequencies]
86+
return td.components.data.data_array.DataArray(
87+
data=np.array(eps_data), dims=("f",), coords={"f": frequencies}
88+
)
89+
90+
91+
def _slice_field_data(
92+
field_data: dict, freqs: np.ndarray, component_indicator: str | None = None
93+
) -> dict:
94+
"""Slice field data dictionary along frequency dimension."""
95+
if component_indicator:
96+
return {k: v.sel(f=freqs) for k, v in field_data.items() if component_indicator in k}
97+
else:
98+
return {k: v.sel(f=freqs) for k, v in field_data.items()}
99+
100+
101+
def postprocess_adj(
102+
sim_data_adj: td.SimulationData,
103+
sim_data_orig: td.SimulationData,
104+
sim_data_fwd: td.SimulationData,
105+
sim_fields_keys: list[tuple],
106+
) -> AutogradFieldMap:
107+
"""Postprocess some data from the adjoint simulation into the VJP for the original sim flds."""
108+
109+
# map of index into 'structures' to the list of paths we need vjps for
110+
sim_vjp_map = defaultdict(list)
111+
for _, structure_index, *structure_path in sim_fields_keys:
112+
structure_path = tuple(structure_path)
113+
sim_vjp_map[structure_index].append(structure_path)
114+
115+
# store the derivative values given the forward and adjoint data
116+
sim_fields_vjp = {}
117+
for structure_index, structure_paths in sim_vjp_map.items():
118+
# grab the forward and adjoint data
119+
fld_fwd = sim_data_fwd._get_adjoint_data(structure_index, data_type="fld")
120+
eps_fwd = sim_data_fwd._get_adjoint_data(structure_index, data_type="eps")
121+
fld_adj = sim_data_adj._get_adjoint_data(structure_index, data_type="fld")
122+
eps_adj = sim_data_adj._get_adjoint_data(structure_index, data_type="eps")
123+
124+
# post normalize the adjoint fields if a single, broadband source
125+
fwd_flds_adj_normed = {}
126+
for key, val in fld_adj.field_components.items():
127+
fwd_flds_adj_normed[key] = val * sim_data_adj.simulation.post_norm
128+
129+
fld_adj = fld_adj.updated_copy(**fwd_flds_adj_normed)
130+
131+
# maps of the E_fwd * E_adj and D_fwd * D_adj, each as as td.FieldData & 'Ex', 'Ey', 'Ez'
132+
der_maps = get_derivative_maps(
133+
fld_fwd=fld_fwd,
134+
eps_fwd=eps_fwd,
135+
fld_adj=fld_adj,
136+
eps_adj=eps_adj,
137+
)
138+
E_der_map = der_maps["E"]
139+
D_der_map = der_maps["D"]
140+
H_der_map = der_maps["H"]
141+
142+
H_info_exists = H_der_map is not None
143+
144+
D_fwd = E_to_D(fld_fwd, eps_fwd)
145+
D_adj = E_to_D(fld_adj, eps_fwd)
146+
147+
# compute the derivatives for this structure
148+
structure = sim_data_fwd.simulation.structures[structure_index]
149+
150+
# compute epsilon arrays for all frequencies
151+
adjoint_frequencies = np.array(fld_adj.monitor.freqs)
152+
153+
eps_in = _compute_eps_array(structure.medium, adjoint_frequencies)
154+
eps_out = _compute_eps_array(sim_data_orig.simulation.medium, adjoint_frequencies)
155+
156+
# handle background medium if present
157+
if structure.background_medium:
158+
eps_background = _compute_eps_array(structure.background_medium, adjoint_frequencies)
159+
else:
160+
eps_background = None
161+
162+
# auto permittivity detection for non-box geometries
163+
if not isinstance(structure.geometry, td.Box):
164+
sim_orig = sim_data_orig.simulation
165+
plane_eps = eps_fwd.monitor.geometry
166+
167+
sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid)
168+
169+
# permittivity without this structure
170+
structs_no_struct = list(sim_orig.structures)
171+
structs_no_struct.pop(structure_index)
172+
sim_no_structure = sim_orig.updated_copy(
173+
structures=structs_no_struct, monitors=[], sources=[], grid_spec=sim_orig_grid_spec
174+
)
175+
176+
eps_no_structure_data = [
177+
sim_no_structure.epsilon(box=plane_eps, coord_key="centers", freq=f)
178+
for f in adjoint_frequencies
179+
]
180+
181+
eps_no_structure = xr.concat(eps_no_structure_data, dim="f").assign_coords(
182+
f=adjoint_frequencies
183+
)
184+
185+
if structure.medium.is_pec:
186+
eps_inf_structure = None
187+
else:
188+
# permittivity with infinite structure
189+
structs_inf_struct = list(sim_orig.structures)[structure_index + 1 :]
190+
sim_inf_structure = sim_orig.updated_copy(
191+
structures=structs_inf_struct,
192+
medium=structure.medium,
193+
monitors=[],
194+
sources=[],
195+
grid_spec=sim_orig_grid_spec,
196+
)
197+
198+
eps_inf_structure_data = [
199+
sim_inf_structure.epsilon(box=plane_eps, coord_key="centers", freq=f)
200+
for f in adjoint_frequencies
201+
]
202+
203+
eps_inf_structure = xr.concat(eps_inf_structure_data, dim="f").assign_coords(
204+
f=adjoint_frequencies
205+
)
206+
else:
207+
eps_no_structure = eps_inf_structure = None
208+
209+
# compute bounds intersection
210+
struct_bounds = rmin_struct, rmax_struct = structure.geometry.bounds
211+
rmin_sim, rmax_sim = sim_data_orig.simulation.bounds
212+
rmin_intersect = tuple([max(a, b) for a, b in zip(rmin_sim, rmin_struct)])
213+
rmax_intersect = tuple([min(a, b) for a, b in zip(rmax_sim, rmax_struct)])
214+
bounds_intersect = (rmin_intersect, rmax_intersect)
215+
216+
# get chunk size - if None, process all frequencies as one chunk
217+
freq_chunk_size = ADJOINT_FREQ_CHUNK_SIZE
218+
n_freqs = len(adjoint_frequencies)
219+
if freq_chunk_size is None:
220+
freq_chunk_size = n_freqs
221+
222+
# process in chunks
223+
vjp_value_map = {}
224+
225+
for chunk_start in range(0, n_freqs, freq_chunk_size):
226+
chunk_end = min(chunk_start + freq_chunk_size, n_freqs)
227+
freq_slice = slice(chunk_start, chunk_end)
228+
229+
select_adjoint_freqs = adjoint_frequencies[freq_slice]
230+
231+
# slice field data for current chunk
232+
E_der_map_chunk = _slice_field_data(E_der_map.field_components, select_adjoint_freqs)
233+
D_der_map_chunk = _slice_field_data(D_der_map.field_components, select_adjoint_freqs)
234+
E_fwd_chunk = _slice_field_data(
235+
fld_fwd.field_components, select_adjoint_freqs, component_indicator="E"
236+
)
237+
E_adj_chunk = _slice_field_data(
238+
fld_adj.field_components, select_adjoint_freqs, component_indicator="E"
239+
)
240+
D_fwd_chunk = _slice_field_data(D_fwd.field_components, select_adjoint_freqs)
241+
D_adj_chunk = _slice_field_data(D_adj.field_components, select_adjoint_freqs)
242+
eps_data_chunk = _slice_field_data(eps_fwd.field_components, select_adjoint_freqs)
243+
244+
H_der_map_chunk = None
245+
H_fwd_chunk = None
246+
H_adj_chunk = None
247+
248+
if H_info_exists:
249+
H_der_map_chunk = _slice_field_data(
250+
H_der_map.field_components, select_adjoint_freqs
251+
)
252+
H_fwd_chunk = _slice_field_data(
253+
fld_fwd.field_components, select_adjoint_freqs, component_indicator="H"
254+
)
255+
H_adj_chunk = _slice_field_data(
256+
fld_adj.field_components, select_adjoint_freqs, component_indicator="H"
257+
)
258+
259+
# slice epsilon arrays
260+
eps_in_chunk = eps_in.sel(f=select_adjoint_freqs)
261+
eps_out_chunk = eps_out.sel(f=select_adjoint_freqs)
262+
eps_background_chunk = (
263+
eps_background.sel(f=select_adjoint_freqs) if eps_background is not None else None
264+
)
265+
eps_no_structure_chunk = (
266+
eps_no_structure.sel(f=select_adjoint_freqs)
267+
if eps_no_structure is not None
268+
else None
269+
)
270+
eps_inf_structure_chunk = (
271+
eps_inf_structure.sel(f=select_adjoint_freqs)
272+
if eps_inf_structure is not None
273+
else None
274+
)
275+
276+
# create derivative info with sliced data
277+
derivative_info = DerivativeInfo(
278+
paths=structure_paths,
279+
E_der_map=E_der_map_chunk,
280+
D_der_map=D_der_map_chunk,
281+
H_der_map=H_der_map_chunk,
282+
E_fwd=E_fwd_chunk,
283+
E_adj=E_adj_chunk,
284+
D_fwd=D_fwd_chunk,
285+
D_adj=D_adj_chunk,
286+
H_fwd=H_fwd_chunk,
287+
H_adj=H_adj_chunk,
288+
eps_data=eps_data_chunk,
289+
eps_in=eps_in_chunk,
290+
eps_out=eps_out_chunk,
291+
eps_background=eps_background_chunk,
292+
frequencies=select_adjoint_freqs, # only chunk frequencies
293+
eps_no_structure=eps_no_structure_chunk,
294+
eps_inf_structure=eps_inf_structure_chunk,
295+
bounds=struct_bounds,
296+
bounds_intersect=bounds_intersect,
297+
simulation_bounds=sim_data_orig.simulation.bounds,
298+
is_medium_pec=structure.medium.is_pec,
299+
)
300+
301+
# compute derivatives for chunk
302+
vjp_chunk = structure._compute_derivatives(derivative_info)
303+
304+
# accumulate results
305+
for path, value in vjp_chunk.items():
306+
if path in vjp_value_map:
307+
val = vjp_value_map[path]
308+
if isinstance(val, (list, tuple)) and isinstance(value, (list, tuple)):
309+
vjp_value_map[path] = type(val)(x + y for x, y in zip(val, value))
310+
else:
311+
vjp_value_map[path] += value
312+
else:
313+
vjp_value_map[path] = value
314+
315+
# store vjps in output map
316+
for structure_path, vjp_value in vjp_value_map.items():
317+
sim_path = ("structures", structure_index, *list(structure_path))
318+
sim_fields_vjp[sim_path] = vjp_value
319+
320+
return sim_fields_vjp
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
# keys for data into auxiliary dictionary (re-exported in autograd.py for tests)
4+
AUX_KEY_SIM_DATA_ORIGINAL = "sim_data"
5+
AUX_KEY_SIM_DATA_FWD = "sim_data_fwd_adjoint"
6+
AUX_KEY_FWD_TASK_ID = "task_id_fwd"
7+
AUX_KEY_SIM_ORIGINAL = "sim_original"
8+
9+
# server-side auxiliary files to upload/download
10+
SIM_VJP_FILE = "output/autograd_sim_vjp.hdf5"
11+
SIM_FIELDS_KEYS_FILE = "autograd_sim_fields_keys.hdf5"
12+
13+
# default behaviors
14+
LOCAL_GRADIENT = False
15+
16+
# directory to store adjoint data for local gradient calculation relative to run path
17+
LOCAL_ADJOINT_DIR = "adjoint_data"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
# keys for data into auxiliary dictionary (re-exported in autograd.py for tests)
4+
AUX_KEY_SIM_DATA_ORIGINAL = "sim_data"
5+
AUX_KEY_SIM_DATA_FWD = "sim_data_fwd_adjoint"
6+
AUX_KEY_FWD_TASK_ID = "task_id_fwd"
7+
AUX_KEY_SIM_ORIGINAL = "sim_original"
8+
9+
# server-side auxiliary files to upload/download
10+
SIM_VJP_FILE = "output/autograd_sim_vjp.hdf5"
11+
SIM_FIELDS_KEYS_FILE = "autograd_sim_fields_keys.hdf5"
12+
13+
# default behaviors
14+
LOCAL_GRADIENT = False
15+
16+
# directory to store adjoint data for local gradient calculation relative to run path
17+
LOCAL_ADJOINT_DIR = "adjoint_data"

0 commit comments

Comments
 (0)