Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 67 additions & 45 deletions essos/coil_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp
from jax import jit, vmap
from jaxtyping import Array, Float # https://github.com/google/jaxtyping
from essos.coils import Curves,apply_symmetries_to_gammas
from essos.coils import Curves, Coils, CoilsFromGamma, fit_dofs_from_coils
from functools import partial


Expand Down Expand Up @@ -205,70 +205,92 @@ def get_sample(self, deriv):



def perturb_curves_systematic(curves: Curves,sampler:GaussianSampler, key=None):
def perturb_curves_systematic(curves, sampler:GaussianSampler, key=None):
"""
Apply a systematic perturbation to all the coils.
This means taht an independent perturbation is applied to the each unique coil
Then, the required symmetries are applied to the perturbed unique set of coils
Perturbations are applied to base curves and symmetries are reapplied.

Args:
curves: curves to be perturbed.
curves: Curves or CoilsFromGamma to be perturbed.
sampler: the gaussian sampler used to get the perturbations
key: the seed which will be splited to geenerate random
but reproducible pertubations
key: the seed which will be split to generate random
but reproducible perturbations

Returns:
The curves given as an input are modified and thus no return is done
"""
new_seeds=jax.random.split(key, num=curves.n_base_curves)
if sampler.n_derivs == 0:
if isinstance(curves, CoilsFromGamma):
# Systematic perturbation on base dofs only. Symmetry is applied by the class properties.
n_base_curves = curves.n_base_curves
new_seeds = jax.random.split(key, num=n_base_curves)
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym)
curves.gamma=curves.gamma + gamma_perturbations
elif sampler.n_derivs == 1:
curves.dofs_gamma = curves.dofs_gamma + perturbation[:, 0, :, :]
return

if isinstance(curves, Coils):
n_base_curves = curves.curves.n_base_curves
new_seeds = jax.random.split(key, num=n_base_curves)
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym)
gamma_perturbations_dash = apply_symmetries_to_gammas(perturbation[:,1,:,:], curves.nfp, curves.stellsym)
curves.gamma=curves.gamma + gamma_perturbations
curves.gamma_dash=curves.gamma_dash + gamma_perturbations_dash
elif sampler.n_derivs == 2:
base_gamma = Curves(curves.dofs_curves, curves.n_segments, nfp=1, stellsym=False).gamma
perturbed_base_gamma = base_gamma + perturbation[:, 0, :, :]
dofs_new, _ = fit_dofs_from_coils(perturbed_base_gamma, curves.order, curves.n_segments,assume_uniform=True)
curves.dofs_curves = dofs_new
Comment on lines +234 to +237
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In both systematic/statistical paths for Curves/Coils, fit_dofs_from_coils is called with the default assume_uniform=False, which triggers arclength resampling. Here the data comes from uniform quadpoints (and you’re perturbing those samples), so this can likely use assume_uniform=True to avoid an expensive resampling step and speed up perturbations significantly.

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems important

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

return

if isinstance(curves, Curves):
n_base_curves = curves.n_base_curves
new_seeds = jax.random.split(key, num=n_base_curves)
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym)
gamma_perturbations_dash = apply_symmetries_to_gammas(perturbation[:,1,:,:], curves.nfp, curves.stellsym)
gamma_perturbations_dashdash = apply_symmetries_to_gammas(perturbation[:,2,:,:], curves.nfp, curves.stellsym)
curves.gamma=curves.gamma + gamma_perturbations
curves.gamma_dash=curves.gamma_dash + gamma_perturbations_dash
curves.gamma_dashdash=curves.gamma_dashdash + gamma_perturbations_dashdash
base_gamma = Curves(curves.dofs, curves.n_segments, nfp=1, stellsym=False).gamma
perturbed_base_gamma = base_gamma + perturbation[:, 0, :, :]
dofs_new, _ = fit_dofs_from_coils(perturbed_base_gamma, curves.order, curves.n_segments,assume_uniform=True)
curves.dofs = dofs_new
return

raise TypeError(f"Unsupported type {type(curves)}. Expected Curves, Coils, or CoilsFromGamma.")
#return curves


def perturb_curves_statistic(curves: Curves,sampler:GaussianSampler, key=None):
def perturb_curves_statistic(curves, sampler:GaussianSampler, key=None):
"""
Apply a statistic perturbation to all the coils.
This means taht an independent perturbation is applied every coil
including repeated coils
Apply a statistical perturbation to all the coils.
This means that an independent perturbation is applied to every coil
including repeated coils.

Args:
curves: curves to be perturbed.
curves: curves to be perturbed (not modified).
sampler: the gaussian sampler used to get the perturbations
key: the seed which will be splited to geenerate random
but reproducible pertubations
key: the seed which will be split to generate random
but reproducible perturbations

Returns:
The curves given as an input are modified and thus no return is done
A new perturbed curves object of the same type as the input.
The original input object is not modified.

Note:
Statistical perturbations require disabling symmetry (nfp=1, stellsym=False)
since each coil gets an independent perturbation.
"""
new_seeds=jax.random.split(key, num=curves.gamma.shape[0])
if sampler.n_derivs == 0:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
curves.gamma=curves.gamma + perturbation[:,0,:,:]
elif sampler.n_derivs == 1:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
curves.gamma=curves.gamma + perturbation[:,0,:,:]
curves.gamma_dash=curves.gamma_dash + perturbation[:,1,:,:]
elif sampler.n_derivs == 2:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
curves.gamma=curves.gamma + perturbation[:,0,:,:]
curves.gamma_dash=curves.gamma_dash + perturbation[:,1,:,:]
curves.gamma_dashdash=curves.gamma_dashdash + perturbation[:,2,:,:]
#return curves
n_curves = curves.gamma.shape[0]
new_seeds = jax.random.split(key, num=n_curves)
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
gamma_perturbed = curves.gamma + perturbation[:, 0, :, :]

if isinstance(curves, CoilsFromGamma):
# Statistical perturbation is independent for all coils, so we return a new object with no symmetry.
expanded_currents = curves.currents
return CoilsFromGamma(gamma_perturbed, currents=expanded_currents, nfp=1, stellsym=False)

if isinstance(curves, Coils):
# Capture the expanded currents and create new curves object with no symmetry.
expanded_currents = curves.currents
dofs_new, _ = fit_dofs_from_coils(gamma_perturbed, curves.order, curves.n_segments,assume_uniform=True)
new_curves = Curves(dofs_new, curves.n_segments, nfp=1, stellsym=False)
return Coils(curves=new_curves, currents=expanded_currents)

if isinstance(curves, Curves):
dofs_new, _ = fit_dofs_from_coils(gamma_perturbed, curves.order, curves.n_segments,assume_uniform=True)
return Curves(dofs_new, curves.n_segments, nfp=1, stellsym=False)

raise TypeError(f"Unsupported type {type(curves)}. Expected Curves, Coils, or CoilsFromGamma.")

Loading