-
Notifications
You must be signed in to change notification settings - Fork 66
Add user_vjp hook and custom run functions to allow overriding the internal vjp #3015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
| >>> b = Sphere(center=(1,2,3), radius=2) | ||
| """ | ||
|
|
||
| radius: TracedSize1D = pydantic.Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marcorudolphflex - good point about losing the non-negativity constraint here. I noticed we also have this case in Cylinder. One thing is that if I try and specify a traced negative value, then it gets set to 0 and I am not sure where that is happening but figured @yaugenst-flex may know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just saw that TracedSize1D is also nonNegativeFloat. But shouldn't it be TracedPositiveFloat? But we also allow 0 radius currently for Circular...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traced type handling is pretty loose. It's handled a bit better on the pydantic v2 branch so I wouldn't worry about this for now because it's a bit involved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 6 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
tidy3d/web/api/autograd/types.py
Outdated
| path_key: typing.Optional[str] = None | ||
| """Path key this is relevant for. If not specified, assume the supplied function applies for all keys.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: docstring says "Path key this is relevant for" but doesn't explain what a path key is or provide examples. For user-facing API, add documentation explaining that path_key should be a tuple like ('geometry', 'radius') or ('medium', 'permittivity') to target specific fields, or None to apply to all paths for this structure.
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/types.py
Line: 19:20
Comment:
**style:** docstring says "Path key this is relevant for" but doesn't explain what a path key is or provide examples. For user-facing API, add documentation explaining that `path_key` should be a tuple like `('geometry', 'radius')` or `('medium', 'permittivity')` to target specific fields, or `None` to apply to all paths for this structure.
How can I resolve this? If you propose a fix, please make it concise.
tidy3d/web/api/autograd/backward.py
Outdated
| if path is None: | ||
| for match_path in get_all_paths(structure_index): | ||
| user_vjp_lookup.setdefault(structure_index, {})[match_path[0:2]] = vjp_fn | ||
| else: | ||
| user_vjp_lookup.setdefault(structure_index, {})[path] = vjp_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: when path is None, this loops through get_all_paths(structure_index) and sets the same vjp_fn for match_path[0:2] for each path. However, match_path[0:2] would be something like ('geometry', 'radius'), and this could result in the same key being set multiple times if there are multiple paths with the same first two elements. Consider using a set or checking if the key already exists to avoid redundant assignments.
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/backward.py
Line: 136:140
Comment:
**style:** when `path` is `None`, this loops through `get_all_paths(structure_index)` and sets the same `vjp_fn` for `match_path[0:2]` for each path. However, `match_path[0:2]` would be something like `('geometry', 'radius')`, and this could result in the same key being set multiple times if there are multiple paths with the same first two elements. Consider using a set or checking if the key already exists to avoid redundant assignments.
How can I resolve this? If you propose a fix, please make it concise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is ok because if we have the same path for the same structure, we would want the same vjp function to apply
tidy3d/components/structure.py
Outdated
| ) -> AutogradFieldMap: | ||
| """Compute adjoint gradients given the forward and adjoint fields provided in derivative_info.""" | ||
| """vjp_fns provide alternate derivative computation paths for the geometry or medium derivatives.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: docstrings use triple-quoted strings on the same line as code, which won't be recognized as docstrings by documentation tools. Move these to proper docstring format or convert to inline comments with #.
| ) -> AutogradFieldMap: | |
| """Compute adjoint gradients given the forward and adjoint fields provided in derivative_info.""" | |
| """vjp_fns provide alternate derivative computation paths for the geometry or medium derivatives.""" | |
| def _compute_derivatives( | |
| self, | |
| derivative_info: DerivativeInfo, | |
| vjp_fns: typing.Optional[dict[tuple[str, str], typing.Callable[..., typing.Any]]] = None, | |
| ) -> AutogradFieldMap: | |
| """Compute adjoint gradients given the forward and adjoint fields provided in derivative_info. | |
| Parameters | |
| ---------- | |
| derivative_info : DerivativeInfo | |
| Contains field data and parameters for gradient computation. | |
| vjp_fns : dict[tuple[str, str], Callable], optional | |
| Alternate derivative computation paths for geometry or medium derivatives. | |
| Keys are tuples like ('geometry', 'radius') or ('medium', 'permittivity'). | |
| """ |
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/components/structure.py
Line: 365:367
Comment:
**syntax:** docstrings use triple-quoted strings on the same line as code, which won't be recognized as docstrings by documentation tools. Move these to proper docstring format or convert to inline comments with `#`.
```suggestion
def _compute_derivatives(
self,
derivative_info: DerivativeInfo,
vjp_fns: typing.Optional[dict[tuple[str, str], typing.Callable[..., typing.Any]]] = None,
) -> AutogradFieldMap:
"""Compute adjoint gradients given the forward and adjoint fields provided in derivative_info.
Parameters
----------
derivative_info : DerivativeInfo
Contains field data and parameters for gradient computation.
vjp_fns : dict[tuple[str, str], Callable], optional
Alternate derivative computation paths for geometry or medium derivatives.
Keys are tuples like ('geometry', 'radius') or ('medium', 'permittivity').
"""
```
How can I resolve this? If you propose a fix, please make it concise.| structure_fields_map[med_or_geo].append(field_path) | ||
|
|
||
| # loop through sub fields, compute VJPs, and store in the derivative map {path -> vjp_value} | ||
| derivative_map = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: comment says "the first level of integration would be to" but is incomplete. Either complete the thought or remove the incomplete comment.
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/components/structure.py
Line: 385:385
Comment:
**style:** comment says "the first level of integration would be to" but is incomplete. Either complete the thought or remove the incomplete comment.
How can I resolve this? If you propose a fix, please make it concise.
tidy3d/web/api/autograd/backward.py
Outdated
| def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]: | ||
| """Get all the paths that may appear in autograd for this structure index. This allows a""" | ||
| """user_vjp to be called for all autograd paths for the structure.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: docstrings here use triple quotes on same line as code, similar issue. These won't be recognized as proper docstrings.
| def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]: | |
| """Get all the paths that may appear in autograd for this structure index. This allows a""" | |
| """user_vjp to be called for all autograd paths for the structure.""" | |
| def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]: | |
| """Get all the paths that may appear in autograd for this structure index. | |
| This allows a user_vjp to be called for all autograd paths for the structure. | |
| """ |
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/backward.py
Line: 118:120
Comment:
**syntax:** docstrings here use triple quotes on same line as code, similar issue. These won't be recognized as proper docstrings.
```suggestion
def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]:
"""Get all the paths that may appear in autograd for this structure index.
This allows a user_vjp to be called for all autograd paths for the structure.
"""
```
How can I resolve this? If you propose a fix, please make it concise.
tidy3d/web/api/autograd/backward.py
Outdated
| def updated_epsilon_full( | ||
| replacement_geometry: GeometryType, | ||
| adjoint_frequencies: typing.Optional[FreqDataArray] = adjoint_frequencies, | ||
| structure_index: typing.Optional[int] = structure_index, | ||
| eps_box: typing.Optional[Box] = eps_fwd.monitor.geometry, | ||
| ) -> ScalarFieldDataArray: | ||
| # Return the simulation permittivity for eps_box after replacing the geometry | ||
| # for this structure with a new geometry. This is helpful for carrying out finite |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: inline comment uses # style but spans multiple lines, making it look like a docstring. Since this is a nested function definition, either use a proper docstring or keep it as a single-line comment. Also, the default arguments capture closure variables which could be error-prone if these functions are called outside their intended scope.
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/backward.py
Line: 254:261
Comment:
**style:** inline comment uses `#` style but spans multiple lines, making it look like a docstring. Since this is a nested function definition, either use a proper docstring or keep it as a single-line comment. Also, the default arguments capture closure variables which could be error-prone if these functions are called outside their intended scope.
How can I resolve this? If you propose a fix, please make it concise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
used functools partial to bind the arguments instead
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/web/api/autograd/autograd.pyLines 472-485 472 if fn_arg is None:
473 return fn_arg
474
475 if isinstance(fn_arg, base_type):
! 476 expanded = dict.fromkeys(sim_dict.keys(), fn_arg)
! 477 return expanded
478
479 expanded = {}
480 if not isinstance(fn_arg, type(orig_sim_arg)):
! 481 raise AdjointError(
482 f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})"
483 )
484
485 if isinstance(orig_sim_arg, dict):Lines 485-493 485 if isinstance(orig_sim_arg, dict):
486 check_keys = fn_arg.keys() == sim_dict.keys()
487
488 if not check_keys:
! 489 raise AdjointError(f"{fn_arg_name} keys do not match simulations keys")
490
491 for key, val in fn_arg.items():
492 if isinstance(val, base_type):
493 expanded[key] = (val,)Lines 495-503 495 expanded[key] = val
496
497 elif isinstance(orig_sim_arg, (list, tuple)):
498 if not (len(fn_arg) == len(orig_sim_arg)):
! 499 raise AdjointError(
500 f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})"
501 )
502
503 for idx, key in enumerate(sim_dict.keys()):Lines 1142-1150 1142
1143 # Compute VJP contribution
1144 task_user_vjp = user_vjp.get(task_name)
1145 if isinstance(task_user_vjp, UserVJPConfig):
! 1146 task_user_vjp = (task_user_vjp,)
1147
1148 vjp_results[adj_task_name] = postprocess_adj(
1149 sim_data_adj=sim_data_adj,
1150 sim_data_orig=sim_data_orig,tidy3d/web/api/autograd/backward.pyLines 259-270 259 ) -> ScalarFieldDataArray:
260 # Return the simulation permittivity for eps_box after replacing the geometry
261 # for this structure with a new geometry. This is helpful for carrying out finite
262 # difference permittivity computations
! 263 sim_orig = sim_data_orig.simulation
! 264 sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid)
265
! 266 update_sim = sim_orig.updated_copy(
267 structures=[
268 sim_orig.structures[idx].updated_copy(geometry=replacement_geometry)
269 if idx == structure_index
270 else sim_orig.structures[idx]Lines 272-285 272 ],
273 grid_spec=sim_orig_grid_spec,
274 )
275
! 276 eps_by_f = [
277 update_sim.epsilon(box=eps_box, coord_key="centers", freq=f)
278 for f in adjoint_frequencies
279 ]
280
! 281 return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies)
282
283 # get chunk size - if None, process all frequencies as one chunk
284 freq_chunk_size = config.adjoint.solver_freq_chunk_size
285 n_freqs = len(adjoint_frequencies)Lines 347-355 347 select_adjoint_freqs: typing.Optional[FreqDataArray] = select_adjoint_freqs,
348 updated_epsilon_full: typing.Optional[typing.Callable] = updated_epsilon_full,
349 ) -> ScalarFieldDataArray:
350 # Get permittivity function for a subset of frequencies
! 351 return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs)
352
353 common_kwargs = {
354 "E_der_map": E_der_map_chunk,
355 "D_der_map": D_der_map_chunk,Lines 384-398 384 vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns)
385
386 for path, value in vjp_chunk.items():
387 if path in vjp_value_map:
! 388 existing = vjp_value_map[path]
! 389 if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)):
! 390 vjp_value_map[path] = type(existing)(
391 x + y for x, y in zip(existing, value)
392 )
393 else:
! 394 vjp_value_map[path] = existing + value
395 else:
396 vjp_value_map[path] = value
397
398 # store vjps in output map |
tidy3d/components/structure.py
Outdated
| def _compute_derivatives( | ||
| self, | ||
| derivative_info: DerivativeInfo, | ||
| vjp_fns: typing.Optional[dict[tuple[str, str], typing.Callable[..., typing.Any]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it tuple[str, ...]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, thank you!
tidy3d/web/api/autograd/types.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class UserVJPConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should initialize instances of this class from the simulation. Also it would be not bad to have more validators on the structure index and the path_key in this context.
So f. e. using it like
sim.get_user_vjp(structure_index=0, compute_derivatives=func, path_key=("key",)) # -> raises if index out of bounds, func signature wrong or or path key invalid
And I would extend the validation in run (which is yet on the compute_derivatives func?) to the index/path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the suggestion! I went with adding more validation to the custom_vjp before it is used. We now check that the structure index and path key exist in the traced structures for anything specified by custom_vjp (unless None is used for the path in which case we apply to all relevant paths). I also added checking of the function signature for the compute_derivatives function. Hopefully this adds enough protection and error checking to catch most cases!
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @groberts-flex this is very nice to have! As discussed, I think we should change the name to "custom" instead of "user" VJP. Left a couple of other comments but overall looks good!
| >>> b = Sphere(center=(1,2,3), radius=2) | ||
| """ | ||
|
|
||
| radius: TracedSize1D = pydantic.Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traced type handling is pretty loose. It's handled a bit better on the pydantic v2 branch so I wouldn't worry about this for now because it's a bit involved.
tidy3d/web/api/autograd/autograd.py
Outdated
| ) | ||
| from .types import ( | ||
| SetupRunResult, | ||
| UserVJPConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove the line break here
tidy3d/web/api/autograd/types.py
Outdated
| compute_derivatives: typing.Callable | ||
| """Function that computes the vjp for the structure given the same arguments | ||
| that the internal _compute_derivatives function gets.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be expanded a bit, specifically in terms of the required function signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! expanded this!
tidy3d/web/api/autograd/types.py
Outdated
| """Function that computes the vjp for the structure given the same arguments | ||
| that the internal _compute_derivatives function gets.""" | ||
|
|
||
| path_key: typing.Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't structure look these up as tuple?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, good catch!
8c5546d to
5cc6cda
Compare
…p arguments to provide hook into gradient computation for custom vjp calculation.
5cc6cda to
116571c
Compare
@marcorudolphflex @yaugenst-flex
After the interface reconfiguration, I split the PR for these custom autograd hooks into two so that hopefully it's easier to review! This one is for the user_vjp which allows someone to override the internal vjp calculation for a structure geometry or medium. The other hook is done as well, but I'll save it for after this one is done with review!
Based on the other review, I updated the interface to ideally be a little more straightforward to use and less cumbersome. The specification of paths in the user_vjp is not required unless you want it to only apply to a specific path in the structure. It can also be specified as just a single user_vjp value in run_async_custom if you want the same one to apply to all of the simulations (instead of having to manually broadcast it). I think there are other helper functions that could be added in the future that might make things even easier like applying a certain user_vjp for all structures with a specific geometry type, but I'll leave those for a future upgrade.
Greptile Summary
user_vjpparameter to autograd run functions enabling custom gradient calculations for specific structuresDerivativeInfowithupdated_epsilonhelper for finite difference gradient computations in custom VJPsConfidence Score: 4/5
tidy3d/components/structure.pyandtidy3d/web/api/autograd/backward.pyfor the docstring formatting issuesImportant Files Changed
UserVJPConfigdataclass for custom gradient computation andSetupRunResultfor run preparationuser_vjpparameter throughout run functions with validation and broadcasting logic for single/batch simulationsupdated_epsilonhelper function for finite difference gradient computations_compute_derivativesto accept optionalvjp_fnsdict for custom gradient paths per geometry/medium fieldSequence Diagram
sequenceDiagram participant User participant run_custom participant _run_primitive participant setup_fwd participant _run_tidy3d participant _run_bwd participant postprocess_adj participant Structure participant UserVJP User->>run_custom: "call with simulation and user_vjp" run_custom->>_run_primitive: "pass user_vjp to primitive" _run_primitive->>setup_fwd: "setup forward simulation" setup_fwd-->>_run_primitive: "combined simulation" _run_primitive->>_run_tidy3d: "run forward simulation" _run_tidy3d-->>_run_primitive: "simulation data" Note over _run_bwd: Backward pass triggered _run_bwd->>postprocess_adj: "compute gradients with user_vjp" postprocess_adj->>postprocess_adj: "build user_vjp_lookup dict" postprocess_adj->>Structure: "_compute_derivatives with vjp_fns" alt user VJP exists for path Structure->>UserVJP: "call user-defined vjp function" UserVJP-->>Structure: "custom gradients" else default path Structure->>Structure: "call internal gradient method" Structure-->>Structure: "standard gradients" end Structure-->>postprocess_adj: "gradient values" postprocess_adj-->>_run_bwd: "VJP field map" _run_bwd-->>User: "gradients for optimization"