Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
9fc00df
Added LBFGSB
eduardolneto Aug 13, 2025
844db8d
Added other options
eduardolneto Aug 13, 2025
36a6d4d
Refactored alm.py and alm_convex.py into augmented_lagrangian.py
eduardolneto Aug 13, 2025
068b6f4
Refactored alm.py and alm_convex.py into augmented_lagrangian.py
eduardolneto Aug 13, 2025
720de82
Some changes to ALM, added Coil-Coil and Coil-Surface losses
eduardolneto Aug 15, 2025
8ee6630
Adding mu_0 as constant, BiotSavart_from_gamma to get field given gam…
eduardolneto Aug 16, 2025
eaccc12
Adding coil_perturbation.py
eduardolneto Aug 19, 2025
cd9a4a7
Adding systematic and statistic errors to coils
eduardolneto Aug 19, 2025
d18d8a4
Changing curves to coils in objective functions in case one needs to …
eduardolneto Aug 19, 2025
b3148f2
Adding examples for creating pertubed coils and for stochastic optimi…
eduardolneto Aug 20, 2025
b4dc32f
Solved updated main merge into branch
eduardolneto Aug 20, 2025
19bb461
re-add example files
eduardolneto Aug 20, 2025
c3afb17
Making changes on new examples
eduardolneto Aug 20, 2025
9ba6c42
Removing bug on optional function in coil_perturbation.py
eduardolneto Aug 20, 2025
994b980
Removing bug on optional function in coil_perturbation.py
eduardolneto Aug 20, 2025
f577b18
Removing bug on optional function in coil_perturbation.py
eduardolneto Aug 20, 2025
463a7a6
modifieng test for biot_savart initialization to comply wih changes
eduardolneto Aug 20, 2025
12671f5
modifieng test for biot_savart initialization to comply wih changes
eduardolneto Aug 20, 2025
4d77101
modifieng test for biot_savart initialization to comply wih changes
eduardolneto Aug 20, 2025
66226e9
Adjusting example stochastic optimization
eduardolneto Aug 20, 2025
16f0c5b
Adding tests for augmented_lagrangian.py
eduardolneto Aug 27, 2025
d417e58
Updating requirements
eduardolneto Aug 27, 2025
a720d59
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
70ec6e1
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
f9e1da6
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
a82f484
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
0a1ba28
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
b3e6787
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
c5dd639
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
0493d06
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
4faa30e
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
c4986ae
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
541d74b
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
023723f
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
c8bc9f7
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
4c31f16
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
6a57667
Updating augmented_lagrangian tests
eduardolneto Aug 27, 2025
97c7497
Updating augmented_lagrangian tests and adding test_coil_perturbation.py
eduardolneto Aug 27, 2025
fd44b1c
Updating augmented_lagrangian tests and adding test_coil_perturbation.py
eduardolneto Aug 27, 2025
16afaef
Updating augmented_lagrangian tests and adding test_coil_perturbation.py
eduardolneto Aug 27, 2025
5e5ac6f
Updating augmented_lagrangian tests and adding test_coil_perturbation.py
eduardolneto Aug 27, 2025
497af01
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
e478022
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
47777de
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
429c83d
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
bc5ddd8
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
e0cb957
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
84c9314
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
a71015b
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
63ba8a6
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
7c11681
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
e9603f3
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
57ee101
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
c115864
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
c2e9f58
Updating augmented_lagrangian tests and adding test_objective_functio…
eduardolneto Aug 27, 2025
a921c56
Updating test_coil_perturbation.py
eduardolneto Aug 27, 2025
67e567c
Updating test_coil_perturbation.py
eduardolneto Aug 27, 2025
f35c651
Updating test_coil_perturbation.py
eduardolneto Aug 27, 2025
1f826c8
Updating test_coil_perturbation.py
eduardolneto Aug 27, 2025
a51b7c8
Merge remote-tracking branch 'origin/main' into en/ALM_pull
eduardolneto Aug 27, 2025
d04f0d0
Adding lost fracion objective function and example optimize_coils_par…
eduardolneto Aug 28, 2025
d253dc9
Adding lost fracion objective function and example optimize_coils_par…
eduardolneto Aug 28, 2025
60eddb7
Adding lost fracion objective function and example optimize_coils_par…
eduardolneto Aug 28, 2025
c5c37ff
Adding correction to an if in dynamics.py
eduardolneto Aug 29, 2025
d0433b7
Clearing the comments/description in the coil_perturbation.py module
eduardolneto Aug 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
721 changes: 721 additions & 0 deletions essos/augmented_lagrangian.py

Large diffs are not rendered by default.

274 changes: 274 additions & 0 deletions essos/coil_perturbation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
from jaxtyping import Array, Float # https://github.com/google/jaxtyping
from essos.coils import Curves,apply_symmetries_to_gammas
from functools import partial


#def ldl_decomposition(A):
# """
# Performs LDLᵀ decomposition on a symmetric positive-definite matrix A.
# A = L D Lᵀ where:
# - L is lower triangular with unit diagonal
# - D is diagonal
#
# Args:
# A: (n, n) symmetric matrix
#
# Returns:
# L: (n, n) lower-triangular matrix with unit diagonal
# D: (n,) diagonal elements of D
# """
# n = A.shape[0]
# L = jnp.eye(n)
# D = jnp.zeros(n)

# def body_fun(k, val):
# L, D = val

# Compute D[k]
# D_k = A[k, k] - jnp.sum((L[k, :k] ** 2) * D[:k])
# D = D.at[k].set(D_k)

# def inner_body(i, L):
# L_ik = (A[i, k] - jnp.sum(L[i, :k] * L[k, :k] * D[:k])) / D_k
# return L.at[i, k].set(L_ik)

# Update column k of L below diagonal
# L = jax.lax.fori_loop(k + 1, n, inner_body, L)

# return (L, D)

# L, D = jax.lax.fori_loop(0, n, body_fun, (L, D))

# return L, D


@jit
def matrix_sqrt_via_spectral(A):
"""Compute matrix square root of SPD matrix A via spectral decomposition."""
eigvals, Q = jnp.linalg.eigh(A) # A = Q Λ Q^T

# Ensure numerical stability (clip small negatives to 0)
eigvals = jnp.clip(eigvals, a_min=0)

sqrt_eigvals = jnp.sqrt(eigvals)
sqrt_A = Q @ jnp.diag(sqrt_eigvals) @ Q.T

return sqrt_A

#This is based on SIMSOPT's GaussianSampler, but with some modifications to make it work with JAX.
#Note: I am not sure this should be kept as a class, but it is for now to keep the interface similar to SIMSOPT.
class GaussianSampler():
r"""
Generate a periodic gaussian process on the interval [0, 1] on a given list of quadrature points.
The process has standard deviation ``sigma`` a correlation length scale ``length_scale``.
Large values of ``length_scale`` correspond to smooth processes, small values result in highly oscillatory
functions.
Also has the ability to sample the derivatives of the function.

We consider the kernel

.. math::

\kappa(d) = \sigma^2 \exp(-d^2/l^2)

and then consider a Gaussian process with covariance

.. math::

Cov(X(s), X(t)) = \sum_{i=-\infty}^\infty \sigma^2 \exp(-(s-t+i)^2/l^2)

the sum is used to make the kernel periodic and in practice the infinite sum is truncated.

Args:
points: the quadrature points along which the perturbation should be computed.
sigma: standard deviation of the underlying gaussian process
(measure for the magnitude of the perturbation).
length_scale: length scale of the underlying gaussian process
(measure for the smoothness of the perturbation).
n_derivs: number of derivatives to calculate, right now maximum is up to 2
"""

points: Array
sigma: Float
length_scale: Float
n_derivs: int

def __init__(self,points: Array, sigma: Float, length_scale: Float, n_derivs: int = 0):
self.points=points
self.sigma=sigma
self.length_scale=length_scale
self.n_derivs=n_derivs


@partial(jit, static_argnames=['self'])
def kernel_periodicity(self,x, y):
aux_periodicity=jnp.arange(-5, 6)
def kernel(x, y,i):
return self.sigma**2*jnp.exp(-(x-y+i)**2/(2.*self.length_scale**2))

return jnp.sum(jax.vmap(kernel,in_axes=(None,None,0))(x,y,aux_periodicity))

@partial(jit, static_argnames=['self'])
def d_kernel_periodicity_dx(self,x, y):
return jax.grad(self.kernel_periodicity, argnums=0)(x, y)

@partial(jit, static_argnames=['self'])
def d_kernel_periodicity_dxdx(self,x, y):
return jax.grad(self.d_kernel_periodicity_dx, argnums=0)(x, y)

@partial(jit, static_argnames=['self'])
def d_kernel_periodicity_dxdxdx(self,x, y):
return jax.grad(self.d_kernel_periodicity_dxdx, argnums=0)(x, y)

@partial(jit, static_argnames=['self'])
def d_kernel_periodicity_dxdxdxdx(self,x, y):
return jax.grad(self.d_kernel_periodicity_dxdxdx, argnums=0)(x, y)


@partial(jit, static_argnames=['self'])
def compute_covariance_matrix(self):
final_mat= jax.vmap(jax.vmap(self.kernel_periodicity,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
return matrix_sqrt_via_spectral(final_mat)


@partial(jit, static_argnames=['self'])
def compute_covariance_matrix_and_first_derivatives(self):
cov_mat= jax.vmap(jax.vmap(self.kernel_periodicity,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
dcov_mat_dx= jax.vmap(jax.vmap(self.d_kernel_periodicity_dx,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
dcov_mat_dxdx= jax.vmap(jax.vmap(self.d_kernel_periodicity_dxdx,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
final_mat = jnp.concatenate((jnp.concatenate((cov_mat, dcov_mat_dx),axis=0),jnp.concatenate((-dcov_mat_dx,dcov_mat_dxdx),axis=0 )), axis=1)
return matrix_sqrt_via_spectral(final_mat)

@partial(jit, static_argnames=['self'])
def compute_covariance_matrix_and_second_derivatives(self):
cov_mat= jax.vmap(jax.vmap(self.kernel_periodicity,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
dcov_mat_dx= jax.vmap(jax.vmap(self.d_kernel_periodicity_dx,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
dcov_mat_dxdx= jax.vmap(jax.vmap(self.d_kernel_periodicity_dxdx,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
dcov_mat_dxdxdx= jax.vmap(jax.vmap(self.d_kernel_periodicity_dxdxdx,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
dcov_mat_dxdxdxdx= jax.vmap(jax.vmap(self.d_kernel_periodicity_dxdxdxdx,in_axes=(None,0)),in_axes=(0,None))(self.points, self.points)
final_mat= jnp.concatenate((jnp.concatenate((cov_mat, dcov_mat_dx,dcov_mat_dxdx),axis=0),
jnp.concatenate((-dcov_mat_dx,dcov_mat_dxdx,-dcov_mat_dxdxdx),axis=0),
jnp.concatenate((dcov_mat_dxdx,-dcov_mat_dxdxdx,dcov_mat_dxdxdxdx),axis=0 )), axis=1)
return matrix_sqrt_via_spectral(final_mat)

@partial(jit, static_argnames=['self'])
def get_covariance_matrix(self):
if self.n_derivs ==0:
return self.compute_covariance_matrix()
elif self.n_derivs ==1:
return self.compute_covariance_matrix_and_first_derivatives()
elif self.n_derivs ==2:
return self.compute_covariance_matrix_and_second_derivatives()


@partial(jit, static_argnames=['self'])
def draw_sample(self, key=0):

n = len(self.points)
z = jax.random.normal(key=key,shape=(len(self.points)*(self.n_derivs+1), 3))
L=self.get_covariance_matrix()
curve_and_derivs = jnp.matmul(L,z)
if self.n_derivs ==0:
return jnp.reshape(jnp.matmul(L,z),(1,len(self.points),3))
elif self.n_derivs ==1:
return jnp.reshape(jnp.matmul(L,z),(2,len(self.points),3))
elif self.n_derivs ==2:
return jnp.reshape(jnp.matmul(L,z),(3,len(self.points),3))



class PerturbationSample():
def __init__(self, sampler, key=0, sample=None):
self.sampler = sampler
self.key = key # If not None, most likely fail with serialization
if sample:
self._sample = sample
else:
self.resample()

def resample(self):
self._sample = self.sampler.draw_sample(self.key)

def get_sample(self, deriv):
"""
Get the perturbation (if ``deriv=0``) or its ``deriv``-th derivative.
"""
assert isinstance(deriv, int)
if deriv >= len(self._sample):
raise ValueError("""The sample on has {len(self._sample)-1} derivatives.
Adjust the `n_derivs` parameter of the sampler to access higher derivatives.""")
return self._sample[deriv]



def perturb_curves_systematic(curves: Curves,sampler:GaussianSampler, key=None):
"""
Apply a systematic perturbation to all the coils.
This means taht an independent perturbation is applied to the each unique coil
Then, the required symmetries are applied to the perturbed unique set of coils

Args:
curves: curves to be perturbed.
sampler: the gaussian sampler used to get the perturbations
key: the seed which will be splited to geenerate random
but reproducible pertubations

Returns:
The curves given as an input are modified and thus no return is done
"""
new_seeds=jax.random.split(key, num=curves.n_base_curves)
if sampler.n_derivs == 0:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym)
curves.gamma=curves.gamma + gamma_perturbations
elif sampler.n_derivs == 1:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym)
gamma_perturbations_dash = apply_symmetries_to_gammas(perturbation[:,1,:,:], curves.nfp, curves.stellsym)
curves.gamma=curves.gamma + gamma_perturbations
curves.gamma_dash=curves.gamma_dash + gamma_perturbations_dash
elif sampler.n_derivs == 2:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym)
gamma_perturbations_dash = apply_symmetries_to_gammas(perturbation[:,1,:,:], curves.nfp, curves.stellsym)
gamma_perturbations_dashdash = apply_symmetries_to_gammas(perturbation[:,2,:,:], curves.nfp, curves.stellsym)
curves.gamma=curves.gamma + gamma_perturbations
curves.gamma_dash=curves.gamma_dash + gamma_perturbations_dash
curves.gamma_dashdash=curves.gamma_dashdash + gamma_perturbations_dashdash
#return curves


def perturb_curves_statistic(curves: Curves,sampler:GaussianSampler, key=None):
"""
Apply a statistic perturbation to all the coils.
This means taht an independent perturbation is applied every coil
including repeated coils

Args:
curves: curves to be perturbed.
sampler: the gaussian sampler used to get the perturbations
key: the seed which will be splited to geenerate random
but reproducible pertubations

Returns:
The curves given as an input are modified and thus no return is done
"""
new_seeds=jax.random.split(key, num=curves.gamma.shape[0])
if sampler.n_derivs == 0:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
curves.gamma=curves.gamma + perturbation[:,0,:,:]
elif sampler.n_derivs == 1:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
curves.gamma=curves.gamma + perturbation[:,0,:,:]
curves.gamma_dash=curves.gamma_dash + perturbation[:,1,:,:]
elif sampler.n_derivs == 2:
perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds)
curves.gamma=curves.gamma + perturbation[:,0,:,:]
curves.gamma_dash=curves.gamma_dash + perturbation[:,1,:,:]
curves.gamma_dashdash=curves.gamma_dashdash + perturbation[:,2,:,:]
#return curves

48 changes: 43 additions & 5 deletions essos/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stell
self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym)
self.quadpoints = jnp.linspace(0, 1, self.n_segments, endpoint=False)
self._set_gamma()
self.n_base_curves=dofs.shape[0]

def __str__(self):
return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\
Expand Down Expand Up @@ -152,13 +153,27 @@ def stellsym(self, new_stellsym):
def gamma(self):
return self._gamma

@gamma.setter
def gamma(self, new_gamma):
self._gamma = new_gamma

@property
def gamma_dash(self):
return self._gamma_dash

@gamma_dash.setter
def gamma_dash(self, new_gamma_dash):
self._gamma_dash = new_gamma_dash



@property
def gamma_dashdash(self):
return self._gamma_dashdash

@gamma_dashdash.setter
def gamma_dashdash(self, new_gamma_dashdash):
self._gamma_dashdash = new_gamma_dashdash

@property
def length(self):
Expand Down Expand Up @@ -238,7 +253,7 @@ def to_simsopt(self):
coils = coils_via_symmetries(cuves_simsopt, currents_simsopt, self.nfp, self.stellsym)
return [c.curve for c in coils]

def plot(self, ax=None, show=True, plot_derivative=False, close=False, axis_equal=True, **kwargs):
def plot(self, ax=None, show=True, plot_derivative=False, close=False, axis_equal=True,color="brown", linewidth=3,label=None,**kwargs):
def rep(data):
if close:
return jnp.concatenate((data, [data[0]]))
Expand All @@ -248,6 +263,7 @@ def rep(data):
if ax is None or ax.name != "3d":
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
label_count=0
for gamma, gammadash in zip(self.gamma, self.gamma_dash):
x = rep(gamma[:, 0])
y = rep(gamma[:, 1])
Expand All @@ -256,9 +272,13 @@ def rep(data):
xt = rep(gammadash[:, 0])
yt = rep(gammadash[:, 1])
zt = rep(gammadash[:, 2])
ax.plot(x, y, z, **kwargs, color='brown', linewidth=3)
if label_count == 0:
ax.plot(x, y, z, **kwargs, color=color, linewidth=linewidth,label=label)
label_count += 1
else:
ax.plot(x, y, z, **kwargs, color=color, linewidth=linewidth)
if plot_derivative:
ax.quiver(x, y, z, 0.1 * xt, 0.1 * yt, 0.1 * zt, arrow_length_ratio=0.1, color="r")
ax.quiver(x, y, z, 0.1 * xt, 0.1 * yt, 0.1 * zt, arrow_length_ratio=0.1, color='r')
if axis_equal:
fix_matplotlib_3d(ax)
if show:
Expand Down Expand Up @@ -388,6 +408,10 @@ def __add__(self, other):
else:
raise TypeError(f"Invalid argument type. Got {type(other)}, expected Coils.")

def __exclude_coil__(self, index):
return Coils(Curves(jnp.concatenate((self.curves[:index], self.curves[index+1:])), self.n_segments, 1, False), jnp.concatenate((self.currents[:index], self.currents[index+1:])))


def __contains__(self, other):
if isinstance(other, Coils):
return jnp.all(jnp.isin(other.dofs, self.dofs)) and jnp.all(jnp.isin(other.dofs_currents, self.dofs_currents))
Expand Down Expand Up @@ -494,8 +518,8 @@ def RotatedCurve(curve, phi, flip):
if flip:
rotmat = rotmat @ jnp.array(
[[1, 0, 0],
[0, -1, 0],
[0, 0, -1]])
[0, -1, 0],
[0, 0, -1]])
return curve @ rotmat

@partial(jit, static_argnames=['nfp', 'stellsym'])
Expand All @@ -512,6 +536,20 @@ def apply_symmetries_to_curves(base_curves, nfp, stellsym):
curves.append(rotcurve.T)
return jnp.array(curves)

@partial(jit, static_argnames=['nfp', 'stellsym'])
def apply_symmetries_to_gammas(base_gammas, nfp, stellsym):
flip_list = [False, True] if stellsym else [False]
gammas = []
for k in range(0, nfp):
for flip in flip_list:
for i in range(len(base_gammas)):
if k == 0 and not flip:
gammas.append(base_gammas[i])
else:
rotcurve = RotatedCurve(base_gammas[i], 2*jnp.pi*k/nfp, flip)
gammas.append(rotcurve)
return jnp.array(gammas)

@partial(jit, static_argnames=['nfp', 'stellsym'])
def apply_symmetries_to_currents(base_currents, nfp, stellsym):
flip_list = [False, True] if stellsym else [False]
Expand Down
3 changes: 2 additions & 1 deletion essos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
BOLTZMANN=1.380649e-23
HBAR=1.0545718176461565e-34
ELECTRON_MASS=9.1093837139e-31
SPEED_OF_LIGHT=2.99792458e8
SPEED_OF_LIGHT=2.99792458e8
mu_0= 1.2566370614359173e-06 #N A^-2
Loading