Skip to content

Conversation

@groberts-flex
Copy link
Contributor

@groberts-flex groberts-flex commented Nov 19, 2025

@marcorudolphflex @yaugenst-flex

After the interface reconfiguration, I split the PR for these custom autograd hooks into two so that hopefully it's easier to review! This one is for the user_vjp which allows someone to override the internal vjp calculation for a structure geometry or medium. The other hook is done as well, but I'll save it for after this one is done with review!

Based on the other review, I updated the interface to ideally be a little more straightforward to use and less cumbersome. The specification of paths in the user_vjp is not required unless you want it to only apply to a specific path in the structure. It can also be specified as just a single user_vjp value in run_async_custom if you want the same one to apply to all of the simulations (instead of having to manually broadcast it). I think there are other helper functions that could be added in the future that might make things even easier like applying a certain user_vjp for all structures with a specific geometry type, but I'll leave those for a future upgrade.

Greptile Summary

  • Adds user_vjp parameter to autograd run functions enabling custom gradient calculations for specific structures
  • Implements VJP lookup mechanism in backward pass to route computation through user-defined functions when specified
  • Extends DerivativeInfo with updated_epsilon helper for finite difference gradient computations in custom VJPs

Confidence Score: 4/5

  • This PR is safe to merge with minor documentation improvements recommended
  • Score reflects solid implementation with comprehensive test coverage and proper error handling. The core logic correctly routes custom VJP functions through the gradient computation pipeline. Minor documentation issues with inline docstrings and incomplete comments don't impact functionality but should be addressed for maintainability.
  • Pay attention to tidy3d/components/structure.py and tidy3d/web/api/autograd/backward.py for the docstring formatting issues

Important Files Changed

Filename Overview
tidy3d/web/api/autograd/types.py New file introducing UserVJPConfig dataclass for custom gradient computation and SetupRunResult for run preparation
tidy3d/web/api/autograd/autograd.py Adds user_vjp parameter throughout run functions with validation and broadcasting logic for single/batch simulations
tidy3d/web/api/autograd/backward.py Implements user VJP lookup mechanism and updated_epsilon helper function for finite difference gradient computations
tidy3d/components/structure.py Extends _compute_derivatives to accept optional vjp_fns dict for custom gradient paths per geometry/medium field

Sequence Diagram

sequenceDiagram
    participant User
    participant run_custom
    participant _run_primitive
    participant setup_fwd
    participant _run_tidy3d
    participant _run_bwd
    participant postprocess_adj
    participant Structure
    participant UserVJP

    User->>run_custom: "call with simulation and user_vjp"
    run_custom->>_run_primitive: "pass user_vjp to primitive"
    _run_primitive->>setup_fwd: "setup forward simulation"
    setup_fwd-->>_run_primitive: "combined simulation"
    _run_primitive->>_run_tidy3d: "run forward simulation"
    _run_tidy3d-->>_run_primitive: "simulation data"
    
    Note over _run_bwd: Backward pass triggered
    _run_bwd->>postprocess_adj: "compute gradients with user_vjp"
    postprocess_adj->>postprocess_adj: "build user_vjp_lookup dict"
    postprocess_adj->>Structure: "_compute_derivatives with vjp_fns"
    
    alt user VJP exists for path
        Structure->>UserVJP: "call user-defined vjp function"
        UserVJP-->>Structure: "custom gradients"
    else default path
        Structure->>Structure: "call internal gradient method"
        Structure-->>Structure: "standard gradients"
    end
    
    Structure-->>postprocess_adj: "gradient values"
    postprocess_adj-->>_run_bwd: "VJP field map"
    _run_bwd-->>User: "gradients for optimization"
Loading

>>> b = Sphere(center=(1,2,3), radius=2)
"""

radius: TracedSize1D = pydantic.Field(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcorudolphflex - good point about losing the non-negativity constraint here. I noticed we also have this case in Cylinder. One thing is that if I try and specify a traced negative value, then it gets set to 0 and I am not sure where that is happening but figured @yaugenst-flex may know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just saw that TracedSize1D is also nonNegativeFloat. But shouldn't it be TracedPositiveFloat? But we also allow 0 radius currently for Circular...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The traced type handling is pretty loose. It's handled a bit better on the pydantic v2 branch so I wouldn't worry about this for now because it's a bit involved.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

Comment on lines 19 to 20
path_key: typing.Optional[str] = None
"""Path key this is relevant for. If not specified, assume the supplied function applies for all keys."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: docstring says "Path key this is relevant for" but doesn't explain what a path key is or provide examples. For user-facing API, add documentation explaining that path_key should be a tuple like ('geometry', 'radius') or ('medium', 'permittivity') to target specific fields, or None to apply to all paths for this structure.

Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/types.py
Line: 19:20

Comment:
**style:** docstring says "Path key this is relevant for" but doesn't explain what a path key is or provide examples. For user-facing API, add documentation explaining that `path_key` should be a tuple like `('geometry', 'radius')` or `('medium', 'permittivity')` to target specific fields, or `None` to apply to all paths for this structure.

How can I resolve this? If you propose a fix, please make it concise.

Comment on lines 136 to 140
if path is None:
for match_path in get_all_paths(structure_index):
user_vjp_lookup.setdefault(structure_index, {})[match_path[0:2]] = vjp_fn
else:
user_vjp_lookup.setdefault(structure_index, {})[path] = vjp_fn
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: when path is None, this loops through get_all_paths(structure_index) and sets the same vjp_fn for match_path[0:2] for each path. However, match_path[0:2] would be something like ('geometry', 'radius'), and this could result in the same key being set multiple times if there are multiple paths with the same first two elements. Consider using a set or checking if the key already exists to avoid redundant assignments.

Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/backward.py
Line: 136:140

Comment:
**style:** when `path` is `None`, this loops through `get_all_paths(structure_index)` and sets the same `vjp_fn` for `match_path[0:2]` for each path. However, `match_path[0:2]` would be something like `('geometry', 'radius')`, and this could result in the same key being set multiple times if there are multiple paths with the same first two elements. Consider using a set or checking if the key already exists to avoid redundant assignments.

How can I resolve this? If you propose a fix, please make it concise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is ok because if we have the same path for the same structure, we would want the same vjp function to apply

Comment on lines 365 to 367
) -> AutogradFieldMap:
"""Compute adjoint gradients given the forward and adjoint fields provided in derivative_info."""
"""vjp_fns provide alternate derivative computation paths for the geometry or medium derivatives."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: docstrings use triple-quoted strings on the same line as code, which won't be recognized as docstrings by documentation tools. Move these to proper docstring format or convert to inline comments with #.

Suggested change
) -> AutogradFieldMap:
"""Compute adjoint gradients given the forward and adjoint fields provided in derivative_info."""
"""vjp_fns provide alternate derivative computation paths for the geometry or medium derivatives."""
def _compute_derivatives(
self,
derivative_info: DerivativeInfo,
vjp_fns: typing.Optional[dict[tuple[str, str], typing.Callable[..., typing.Any]]] = None,
) -> AutogradFieldMap:
"""Compute adjoint gradients given the forward and adjoint fields provided in derivative_info.
Parameters
----------
derivative_info : DerivativeInfo
Contains field data and parameters for gradient computation.
vjp_fns : dict[tuple[str, str], Callable], optional
Alternate derivative computation paths for geometry or medium derivatives.
Keys are tuples like ('geometry', 'radius') or ('medium', 'permittivity').
"""
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/components/structure.py
Line: 365:367

Comment:
**syntax:** docstrings use triple-quoted strings on the same line as code, which won't be recognized as docstrings by documentation tools. Move these to proper docstring format or convert to inline comments with `#`.

```suggestion
    def _compute_derivatives(
        self,
        derivative_info: DerivativeInfo,
        vjp_fns: typing.Optional[dict[tuple[str, str], typing.Callable[..., typing.Any]]] = None,
    ) -> AutogradFieldMap:
        """Compute adjoint gradients given the forward and adjoint fields provided in derivative_info.
        
        Parameters
        ----------
        derivative_info : DerivativeInfo
            Contains field data and parameters for gradient computation.
        vjp_fns : dict[tuple[str, str], Callable], optional
            Alternate derivative computation paths for geometry or medium derivatives.
            Keys are tuples like ('geometry', 'radius') or ('medium', 'permittivity').
        """
```

How can I resolve this? If you propose a fix, please make it concise.

structure_fields_map[med_or_geo].append(field_path)

# loop through sub fields, compute VJPs, and store in the derivative map {path -> vjp_value}
derivative_map = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: comment says "the first level of integration would be to" but is incomplete. Either complete the thought or remove the incomplete comment.

Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/components/structure.py
Line: 385:385

Comment:
**style:** comment says "the first level of integration would be to" but is incomplete. Either complete the thought or remove the incomplete comment.

How can I resolve this? If you propose a fix, please make it concise.

Comment on lines 118 to 120
def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]:
"""Get all the paths that may appear in autograd for this structure index. This allows a"""
"""user_vjp to be called for all autograd paths for the structure."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: docstrings here use triple quotes on same line as code, similar issue. These won't be recognized as proper docstrings.

Suggested change
def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]:
"""Get all the paths that may appear in autograd for this structure index. This allows a"""
"""user_vjp to be called for all autograd paths for the structure."""
def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]:
"""Get all the paths that may appear in autograd for this structure index.
This allows a user_vjp to be called for all autograd paths for the structure.
"""
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/backward.py
Line: 118:120

Comment:
**syntax:** docstrings here use triple quotes on same line as code, similar issue. These won't be recognized as proper docstrings.

```suggestion
    def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]:
        """Get all the paths that may appear in autograd for this structure index.
        
        This allows a user_vjp to be called for all autograd paths for the structure.
        """
```

How can I resolve this? If you propose a fix, please make it concise.

Comment on lines 254 to 261
def updated_epsilon_full(
replacement_geometry: GeometryType,
adjoint_frequencies: typing.Optional[FreqDataArray] = adjoint_frequencies,
structure_index: typing.Optional[int] = structure_index,
eps_box: typing.Optional[Box] = eps_fwd.monitor.geometry,
) -> ScalarFieldDataArray:
# Return the simulation permittivity for eps_box after replacing the geometry
# for this structure with a new geometry. This is helpful for carrying out finite
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: inline comment uses # style but spans multiple lines, making it look like a docstring. Since this is a nested function definition, either use a proper docstring or keep it as a single-line comment. Also, the default arguments capture closure variables which could be error-prone if these functions are called outside their intended scope.

Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/backward.py
Line: 254:261

Comment:
**style:** inline comment uses `#` style but spans multiple lines, making it look like a docstring. Since this is a nested function definition, either use a proper docstring or keep it as a single-line comment. Also, the default arguments capture closure variables which could be error-prone if these functions are called outside their intended scope.

How can I resolve this? If you propose a fix, please make it concise.

Copy link
Contributor Author

@groberts-flex groberts-flex Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used functools partial to bind the arguments instead

@github-actions
Copy link
Contributor

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/autograd/derivative_utils.py (100%)
  • tidy3d/components/geometry/primitives.py (100%)
  • tidy3d/components/structure.py (100%)
  • tidy3d/plugins/smatrix/run.py (100%)
  • tidy3d/web/api/autograd/autograd.py (91.9%): Missing lines 476-477,481,489,499,1146
  • tidy3d/web/api/autograd/backward.py (77.8%): Missing lines 263-264,266,276,281,351,388-390,394
  • tidy3d/web/api/autograd/types.py (100%)

Summary

  • Total: 172 lines
  • Missing: 16 lines
  • Coverage: 90%

tidy3d/web/api/autograd/autograd.py

Lines 472-485

  472         if fn_arg is None:
  473             return fn_arg
  474 
  475         if isinstance(fn_arg, base_type):
! 476             expanded = dict.fromkeys(sim_dict.keys(), fn_arg)
! 477             return expanded
  478 
  479         expanded = {}
  480         if not isinstance(fn_arg, type(orig_sim_arg)):
! 481             raise AdjointError(
  482                 f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})"
  483             )
  484 
  485         if isinstance(orig_sim_arg, dict):

Lines 485-493

  485         if isinstance(orig_sim_arg, dict):
  486             check_keys = fn_arg.keys() == sim_dict.keys()
  487 
  488             if not check_keys:
! 489                 raise AdjointError(f"{fn_arg_name} keys do not match simulations keys")
  490 
  491             for key, val in fn_arg.items():
  492                 if isinstance(val, base_type):
  493                     expanded[key] = (val,)

Lines 495-503

  495                     expanded[key] = val
  496 
  497         elif isinstance(orig_sim_arg, (list, tuple)):
  498             if not (len(fn_arg) == len(orig_sim_arg)):
! 499                 raise AdjointError(
  500                     f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})"
  501                 )
  502 
  503             for idx, key in enumerate(sim_dict.keys()):

Lines 1142-1150

  1142 
  1143                 # Compute VJP contribution
  1144                 task_user_vjp = user_vjp.get(task_name)
  1145                 if isinstance(task_user_vjp, UserVJPConfig):
! 1146                     task_user_vjp = (task_user_vjp,)
  1147 
  1148                 vjp_results[adj_task_name] = postprocess_adj(
  1149                     sim_data_adj=sim_data_adj,
  1150                     sim_data_orig=sim_data_orig,

tidy3d/web/api/autograd/backward.py

Lines 259-270

  259         ) -> ScalarFieldDataArray:
  260             # Return the simulation permittivity for eps_box after replacing the geometry
  261             # for this structure with a new geometry. This is helpful for carrying out finite
  262             # difference permittivity computations
! 263             sim_orig = sim_data_orig.simulation
! 264             sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid)
  265 
! 266             update_sim = sim_orig.updated_copy(
  267                 structures=[
  268                     sim_orig.structures[idx].updated_copy(geometry=replacement_geometry)
  269                     if idx == structure_index
  270                     else sim_orig.structures[idx]

Lines 272-285

  272                 ],
  273                 grid_spec=sim_orig_grid_spec,
  274             )
  275 
! 276             eps_by_f = [
  277                 update_sim.epsilon(box=eps_box, coord_key="centers", freq=f)
  278                 for f in adjoint_frequencies
  279             ]
  280 
! 281             return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies)
  282 
  283         # get chunk size - if None, process all frequencies as one chunk
  284         freq_chunk_size = config.adjoint.solver_freq_chunk_size
  285         n_freqs = len(adjoint_frequencies)

Lines 347-355

  347                 select_adjoint_freqs: typing.Optional[FreqDataArray] = select_adjoint_freqs,
  348                 updated_epsilon_full: typing.Optional[typing.Callable] = updated_epsilon_full,
  349             ) -> ScalarFieldDataArray:
  350                 # Get permittivity function for a subset of frequencies
! 351                 return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs)
  352 
  353             common_kwargs = {
  354                 "E_der_map": E_der_map_chunk,
  355                 "D_der_map": D_der_map_chunk,

Lines 384-398

  384                 vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns)
  385 
  386                 for path, value in vjp_chunk.items():
  387                     if path in vjp_value_map:
! 388                         existing = vjp_value_map[path]
! 389                         if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)):
! 390                             vjp_value_map[path] = type(existing)(
  391                                 x + y for x, y in zip(existing, value)
  392                             )
  393                         else:
! 394                             vjp_value_map[path] = existing + value
  395                     else:
  396                         vjp_value_map[path] = value
  397 
  398         # store vjps in output map

def _compute_derivatives(
self,
derivative_info: DerivativeInfo,
vjp_fns: typing.Optional[dict[tuple[str, str], typing.Callable[..., typing.Any]]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it tuple[str, ...]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thank you!



@dataclass
class UserVJPConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should initialize instances of this class from the simulation. Also it would be not bad to have more validators on the structure index and the path_key in this context.
So f. e. using it like
sim.get_user_vjp(structure_index=0, compute_derivatives=func, path_key=("key",)) # -> raises if index out of bounds, func signature wrong or or path key invalid
And I would extend the validation in run (which is yet on the compute_derivatives func?) to the index/path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion! I went with adding more validation to the custom_vjp before it is used. We now check that the structure index and path key exist in the traced structures for anything specified by custom_vjp (unless None is used for the path in which case we apply to all relevant paths). I also added checking of the function signature for the compute_derivatives function. Hopefully this adds enough protection and error checking to catch most cases!

Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @groberts-flex this is very nice to have! As discussed, I think we should change the name to "custom" instead of "user" VJP. Left a couple of other comments but overall looks good!

>>> b = Sphere(center=(1,2,3), radius=2)
"""

radius: TracedSize1D = pydantic.Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The traced type handling is pretty loose. It's handled a bit better on the pydantic v2 branch so I wouldn't worry about this for now because it's a bit involved.

)
from .types import (
SetupRunResult,
UserVJPConfig,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove the line break here

Comment on lines 15 to 17
compute_derivatives: typing.Callable
"""Function that computes the vjp for the structure given the same arguments
that the internal _compute_derivatives function gets."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be expanded a bit, specifically in terms of the required function signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! expanded this!

"""Function that computes the vjp for the structure given the same arguments
that the internal _compute_derivatives function gets."""

path_key: typing.Optional[str] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't structure look these up as tuple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good catch!

@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch from 8c5546d to 5cc6cda Compare November 26, 2025 21:50
…p arguments to provide hook into gradient computation for custom vjp calculation.
@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch from 5cc6cda to 116571c Compare November 26, 2025 21:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants