diff --git a/essos/augmented_lagrangian.py b/essos/augmented_lagrangian.py new file mode 100644 index 0000000..222c9c5 --- /dev/null +++ b/essos/augmented_lagrangian.py @@ -0,0 +1,721 @@ + +"""ALM (Augmented Lagrangian Method) using JAX and optimizers from OPTAX/JAXOPT/OPTIMISTIX inspired by mdmm_jax github repository""" + +from typing import Any, Callable, NamedTuple +import os +import jax +from jax import jit +import jax.numpy as jnp +from functools import partial +import optax +import jaxopt +import optimistix + +class LagrangeMultiplier(NamedTuple): + """A class containing constrain parameters for Augmented Lagrangian Method""" + value: Any + penalty: Any + sq_grad: Any #For updating squared gradient in case of adaptative penalty and multiplier evolution + + + + +#This is used for the usual augmented lagrangian form +def update_method(params,updates,eta,omega,model_mu='Constant',beta=2.0,mu_max=1.e4,alpha=0.99,gamma=1.e-2,epsilon=1.e-8,eta_tol=1.e-4,omega_tol=1.e-6): + """Different methods for updating multipliers and penalties + """ + + + pred = lambda x: isinstance(x, LagrangeMultiplier) + if model_mu=='Constant': + #jax.debug.print('{m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Monotonic': + #jax.debug.print('{m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Conditional_True': + #jax.debug.print('True {m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Conditional_False': + #jax.debug.print('False {m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Tolerance_True': + #jax.debug.print('Standard True {m}', m=model_mu) + mu_average=penalty_average(params) + #eta=eta/mu_average**(0.1) + #omega=omega/mu_average + eta=jnp.maximum(eta/mu_average**(0.1),eta_tol) + omega=jnp.maximum(omega/mu_average,omega_tol) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred),eta,omega + elif model_mu=='Mu_Tolerance_False': + #jax.debug.print('Standard False {m}', m=model_mu) + mu_average=penalty_average(params) + #eta=1./mu_average**(0.1) + #omega=1./mu_average + eta=jnp.maximum(1./mu_average**(0.1),eta_tol) + #jax.debug.print('HMMMMMM mu_av {m}', m=mu_average) + #jax.debug.print('HMMMMMM eta {m}', m=eta) + omega=jnp.maximum(1./mu_average,omega_tol) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + #return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + elif model_mu=='Mu_Adaptative': + #jax.debug.print('True {m}', m=model_mu) + #Note that y.penalty is the derivative with respect to mu and so it is 0.5*C(x)**2, like the derivative with respect to lambda is C(x) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon)*y.value,-x.penalty+gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon),-x.sq_grad+alpha*x.sq_grad+(1.-alpha)*y.penalty*2.),params,updates,is_leaf=pred) + + + +#This is used for the squared form of the augmented Lagrangioan +def update_method_squared(params,updates,eta,omega,model_mu='Constant',beta=2.0,mu_max=1.e4,alpha=0.99,gamma=1.e-2,epsilon=1.e-8,eta_tol=1.e-4,omega_tol=1.e-6): + """Different methods for updating multipliers and penalties) + """ + + + pred = lambda x: isinstance(x, LagrangeMultiplier) + if model_mu=='Constant': + #jax.debug.print('{m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier((y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Monotonic': + #jax.debug.print('{m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Conditional_True': + #jax.debug.print('True {m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Conditional_False': + #jax.debug.print('False {m}', m=model_mu) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + elif model_mu=='Mu_Tolerance_True': + #jax.debug.print('Squared True {m}', m=model_mu) + mu_average=penalty_average(params) + #eta=eta/mu_average**(0.1) + #omega=omega/mu_average + eta=jnp.maximum(eta/mu_average**(0.1),eta_tol) + omega=jnp.maximum(omega/mu_average,omega_tol) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred),eta,omega + elif model_mu=='Mu_Tolerance_False': + #jax.debug.print('Squared False {m}', m=model_mu) + mu_average=penalty_average(params) + #eta=1./mu_average**(0.1) + #omega=1./mu_average + eta=jnp.maximum(1./mu_average**(0.1),eta_tol) + omega=jnp.maximum(1./mu_average,omega_tol) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + #return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + elif model_mu=='Mu_Adaptative': + #jax.debug.print('True {m}', m=model_mu) + #Note that y.penalty is the derivative with respect to mu and so it is 0.5*C(x)**2, like the derivative with respect to lambda is C(x) + return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon)*(y.value-x.value/x.penalty),-x.penalty+gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon),-x.sq_grad+alpha*x.sq_grad+(1.-alpha)*(y.penalty*2.+(x.value/x.penalty)**2)),params,updates,is_leaf=pred) + + + + +def lagrange_update(model_lagrangian='Standard'): + """A gradient transformation for Optax that prepares an MDMM gradient + descent ascent update from a normal gradient descent update. + + It should be used like this with a base optimizer: + optimizer = optax.chain( + optax.sgd(1e-3), + mdmm_jax.optax_prepare_update(), + ) + + Returns: + An Optax gradient transformation that converts a gradient descent update + into a gradient descent ascent update. + """ + def init_fn(params): + del params + return optax.EmptyState() + + def update_fn(lagrange_params,updates, state,eta,omega, params=None,model_mu='Constant',beta=2.,mu_max=1.e4,alpha=0.99,gamma=1.e-2,epsilon=1.e-8,eta_tol=1.e-4,omega_tol=1.e-6): + del params + if model_lagrangian=='Standard' : + return update_method(lagrange_params,updates,eta,omega,model_mu=model_mu,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol), state + elif model_lagrangian=='Squared' : + return update_method_squared(lagrange_params,updates,eta,omega,model_mu=model_mu,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol), state + else: + print('Lagrangian model not available please select Standard or Squared ') + os._exit(0) + + return optax.GradientTransformation(init_fn, update_fn) + + + + + +class Constraint(NamedTuple): + """A pair of pure functions implementing a constraint. + + Attributes: + init: A pure function which, when called with an example instance of + the arguments to the constraint functions, returns a pytree + containing the constraint's learnable parameters. + loss: A pure function which, when called with the the learnable + parameters returned by init() followed by the arguments to the + constraint functions, returns the loss value for the constraint. + """ + init: Callable + loss: Callable + + +def eq(fun,model_lagrangian='Standard', multiplier=0.0,penalty=1.,sq_grad=0., weight=1., reduction=jnp.sum): + """Represents an equality constraint, g(x) = 0. + + Args: + fun: The constraint function, a differentiable function of your + parameters which should output zero when satisfied and smoothly + increasingly far from zero values for increasing levels of + constraint violation. + damping: Sets the damping (oscillation reduction) strength. + weight: Weights the loss from the constraint relative to the primary + loss function's value. + reduction: The function that is used to aggregate the constraints + if the constraint function outputs more than one element. + + Returns: + An (init_fn, loss_fn) constraint tuple for the equality constraint. + """ + + def init_fn(*args, **kwargs): + return {'lambda': LagrangeMultiplier(multiplier+jnp.zeros_like(fun(*args, **kwargs)),penalty+jnp.zeros_like(fun(*args, **kwargs)),sq_grad+jnp.zeros_like(fun(*args, **kwargs)))} + + if model_lagrangian=='Standard': + def loss_fn(params, *args, **kwargs): + inf = fun(*args, **kwargs) + return weight * reduction(-params['lambda'].value * inf + params['lambda'].penalty* inf ** 2 / 2), inf + elif model_lagrangian=='Squared': + def loss_fn(params, *args, **kwargs): + inf = fun(*args, **kwargs) + return weight * reduction(-params['lambda'].value * inf + params['lambda'].penalty* inf ** 2 / 2+ params['lambda'].value**2 /(2.*params['lambda'].penalty)), inf + + return Constraint(init_fn, loss_fn) + + +def ineq(fun, model_lagrangian='Standard', multiplier=0.,penalty=1., sq_grad=0.,weight=1., reduction=jnp.sum): + """Represents an inequality constraint, h(x) >= 0, which uses a slack + variable internally to convert it to an equality constraint. + + Args: + fun: The constraint function, a differentiable function of your + parameters which should output greater than or equal to zero when + satisfied and smoothly increasingly negative values for increasing + levels of constraint violation. + damping: Sets the damping (oscillation reduction) strength. + weight: Weights the loss from the constraint relative to the primary + loss function's value. + reduction: The function that is used to aggregate the constraints + if the constraint function outputs more than one element. + + Returns: + An (init_fn, loss_fn) constraint tuple for the inequality constraint. + """ + + def init_fn(*args, **kwargs): + out = fun(*args, **kwargs) + return {'lambda': LagrangeMultiplier(multiplier+jnp.zeros_like(fun(*args, **kwargs)),penalty+jnp.zeros_like(fun(*args, **kwargs)),sq_grad+jnp.zeros_like(fun(*args, **kwargs))), + 'slack': jax.nn.relu(out) ** 0.5} + + if model_lagrangian=='Standard': + def loss_fn(params, *args, **kwargs): + inf = fun(*args, **kwargs) - params['slack'] ** 2 + return weight * reduction(-params['lambda'].value * inf + params['lambda'].penalty * inf ** 2 / 2), inf + elif model_lagrangian=='Squared': + def loss_fn(params, *args, **kwargs): + inf = fun(*args, **kwargs) - params['slack'] ** 2 + return weight * reduction(-params['lambda'].value * inf + params['lambda'].penalty * inf ** 2 / 2+ params['lambda'].value**2 /(2.*params['lambda'].penalty)), inf + + return Constraint(init_fn, loss_fn) + + +def combine(*args): + """Combines multiple constraint tuples into a single constraint tuple. + + Args: + *args: A series of constraint (init_fn, loss_fn) tuples. + + Returns: + A single (init_fn, loss_fn) tuple that wraps the input constraints. + """ + init_fns, loss_fns = zip(*args) + + def init_fn(*args, **kwargs): + return tuple(fn(*args, **kwargs) for fn in init_fns) + + def loss_fn(params, *args, **kwargs): + outs = [fn(p, *args, **kwargs) for p, fn in zip(params, loss_fns)] + return sum(x[0] for x in outs), tuple(x[1] for x in outs) + + return Constraint(init_fn, loss_fn) + + + +####These are auxilair functions to do operations on the lagrange multiplier parameters and on auxiliar loss information +def total_infeasibility(tree): + return jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), tree, jnp.array(0.)) + +#def norm_constraints(tree): +# return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2), tree, jnp.array(0.))) + +def norm_constraints(tree): + flat=jax.flatten_util.ravel_pytree(tree)[0] + return jnp.linalg.norm(flat) + +def infty_norm_constraints(tree): + flat=jax.flatten_util.ravel_pytree(tree)[0] + return jnp.max(flat) + +def penalty_average(tree): + pred = lambda x: isinstance(x, LagrangeMultiplier) + penalty=jax.tree_util.tree_map(lambda x: x.penalty,tree,is_leaf=pred) + penalty=jax.flatten_util.ravel_pytree(penalty) + return jnp.average(penalty[0]) + + + + + + + + +#Augmented lagrangian method classes +class ALM(NamedTuple): + init: Callable + update: Callable + + +#This can use optax gradient descent optimizers with different mu updating methods +def ALM_model_optax(optimizer: optax.GradientTransformation, #an optimizer from OPTAX + constraints: Constraint, #List of constraints + loss= lambda x: 0., #function which represents the loss (Callable, default 0.) + model_lagrangian='Standard' , #Model to use for updating lagrange multipliers + model_mu='Constant' , #Model to use for updating lagrange multipliers + beta=2.0, + mu_max=1.e4, + alpha=0.99, + gamma=1.e-2, + epsilon=1.e-8, + eta_tol=1.e-4, + omega_tol=1.e-6, + **kargs, #Extra key arguments for loss +): + + + if model_mu=='Mu_Tolerance_LBFGS': + @jax.jit + def init_fn(params,**kargs): + main_params,lagrange_params=params + main_state = optimizer.init(main_params) + lag_state=lagrange_update(model_lagrangian=model_lagrangian).init(lagrange_params) + opt_state=main_state,lag_state + value,grad=jax.value_and_grad(lagrangian,has_aux=True,argnums=(0,1))(main_params,lagrange_params,**kargs) + return opt_state,grad,value[0],value[1] + else: + @jax.jit + def init_fn(params,**kargs): + main_params,lagrange_params=params + main_state = optimizer.init(main_params) + lag_state=lagrange_update(model_lagrangian=model_lagrangian).init(lagrange_params) + opt_state=main_state,lag_state + grad,info=jax.grad(lagrangian,has_aux=True,argnums=(0,1))(main_params,lagrange_params,**kargs) + return opt_state,grad,info + + # Define the Augmented lagrangian + if model_lagrangian=='Standard': + def lagrangian(main_params,lagrange_params,**kargs): + main_loss = jnp.linalg.norm(loss(main_params,**kargs)) #The norm here is to ensure we have a scalr from the loss which should be a vector + mdmm_loss, inf = constraints.loss(lagrange_params, main_params) + return main_loss+mdmm_loss, (main_loss,main_loss+mdmm_loss, inf) + + # Augmented Lagrangian + def lagrangian_lbfgs(main_params,lagrange_params,**kargs): + main_loss = jnp.linalg.norm(loss(main_params,**kargs)) + mdmm_loss, _ = constraints.loss(lagrange_params, main_params) + return main_loss+mdmm_loss + + elif model_lagrangian=='Squared': + def lagrangian(main_params,lagrange_params,**kargs): + main_loss = jnp.square(jnp.linalg.norm(loss(main_params,**kargs))) + #Here we take the square because the term appearing in this Lagrangian + mdmm_loss, inf = constraints.loss(lagrange_params, main_params) + return main_loss+mdmm_loss, (main_loss,main_loss+mdmm_loss, inf) + + # Augmented Lagrangian + def lagrangian_lbfgs(main_params,lagrange_params,**kargs): + #Here we take the square because the term appearing in this Lagrangian + main_loss = jnp.square(jnp.linalg.norm(loss(main_params,**kargs))) + mdmm_loss, _ = constraints.loss(lagrange_params, main_params) + return main_loss+mdmm_loss + + if model_mu=='Mu_Conditional': + # Do the optimization step + @partial(jit, static_argnums=(6,7,8,9,10,11,12,13)) + def update_fn(params, opt_state,grad,info,eta,omega,model_lagrangian=model_lagrangian,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol,**kargs): + main_state,lag_state=opt_state + main_params,lagrange_params=params + main_updates, main_state = optimizer.update(grad[0], main_state) + main_params = optax.apply_updates(main_params, main_updates) + params=main_params,lagrange_params + grad,info = jax.grad(lagrangian,has_aux=True,argnums=(0,1))(main_params,lagrange_params,**kargs) + true_func=partial(lagrange_update(model_lagrangian=model_lagrangian).update,model_mu='Mu_Conditional_True',beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + false_func=partial(lagrange_update(model_lagrangian=model_lagrangian).update,model_mu='Mu_Conditional_False',beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + lag_updates, lag_state = jax.lax.cond(norm_constraints(info[2]) omega + + def minimization_loop(state): + params,main_state,grad,info=state + main_params,lagrange_params=params + #jax.debug.print('Loop omega: {omega}', omega=omega) + #jax.debug.print('Loop grad: {grad}', grad=jnp.linalg.norm(grad[0])) + main_updates, main_state = optimizer.update(grad[0], main_state) + main_params = optax.apply_updates(main_params, main_updates) + params=main_params,lagrange_params + grad,info = jax.grad(lagrangian,has_aux=True,argnums=(0,1))(main_params,lagrange_params,**kargs) + state=params,main_state,grad,info + return state + + params,main_state,grad,info=jax.lax.while_loop(condition,minimization_loop,state) + main_params,lagrange_params=params + true_func=partial(lagrange_update(model_lagrangian=model_lagrangian).update,model_mu='Mu_Tolerance_True',beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + false_func=partial(lagrange_update(model_lagrangian=model_lagrangian).update,model='Mu_Tolerance_False',beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + lag_updates, lag_state = jax.lax.cond(norm_constraints(info[2]) omega + + def minimization_loop(state): + params,main_state,grad,value,info=state + main_params,lagrange_params=params + #jax.debug.print('Loop omega: {omega}', omega=omega) + #jax.debug.print('Loop grad: {grad}', grad=jnp.linalg.norm(grad[0])) + main_updates, main_state = optimizer.update(grad[0], main_state,params=main_params,value=value,grad=grad[0],value_fn=lagrangian_lbfgs,lagrange_params=lagrange_params) + main_params = optax.apply_updates(main_params, main_updates) + params=main_params,lagrange_params + value,grad = jax.value_and_grad(lagrangian,has_aux=True,argnums=(0,1))(main_params,lagrange_params,**kargs) + #Here info is in value[1] + state=params,main_state,grad,value[0],value[1] + return state + + params,main_state,grad,value,info=jax.lax.while_loop(condition,minimization_loop,state) + main_params,lagrange_params=params + true_func=partial(lagrange_update(model_lagrangian=model_lagrangian).update,model_mu='Mu_Tolerance_True',beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + false_func=partial(lagrange_update(model_lagrangian=model_lagrangian).update,model_mu='Mu_Tolerance_False',beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + lag_updates, lag_state = jax.lax.cond(norm_constraints(info[2])= len(self._sample): + raise ValueError("""The sample on has {len(self._sample)-1} derivatives. + Adjust the `n_derivs` parameter of the sampler to access higher derivatives.""") + return self._sample[deriv] + + + +def perturb_curves_systematic(curves: Curves,sampler:GaussianSampler, key=None): + """ + Apply a systematic perturbation to all the coils. + This means taht an independent perturbation is applied to the each unique coil + Then, the required symmetries are applied to the perturbed unique set of coils + + Args: + curves: curves to be perturbed. + sampler: the gaussian sampler used to get the perturbations + key: the seed which will be splited to geenerate random + but reproducible pertubations + + Returns: + The curves given as an input are modified and thus no return is done + """ + new_seeds=jax.random.split(key, num=curves.n_base_curves) + if sampler.n_derivs == 0: + perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds) + gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym) + curves.gamma=curves.gamma + gamma_perturbations + elif sampler.n_derivs == 1: + perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds) + gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym) + gamma_perturbations_dash = apply_symmetries_to_gammas(perturbation[:,1,:,:], curves.nfp, curves.stellsym) + curves.gamma=curves.gamma + gamma_perturbations + curves.gamma_dash=curves.gamma_dash + gamma_perturbations_dash + elif sampler.n_derivs == 2: + perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds) + gamma_perturbations = apply_symmetries_to_gammas(perturbation[:,0,:,:], curves.nfp, curves.stellsym) + gamma_perturbations_dash = apply_symmetries_to_gammas(perturbation[:,1,:,:], curves.nfp, curves.stellsym) + gamma_perturbations_dashdash = apply_symmetries_to_gammas(perturbation[:,2,:,:], curves.nfp, curves.stellsym) + curves.gamma=curves.gamma + gamma_perturbations + curves.gamma_dash=curves.gamma_dash + gamma_perturbations_dash + curves.gamma_dashdash=curves.gamma_dashdash + gamma_perturbations_dashdash + #return curves + + +def perturb_curves_statistic(curves: Curves,sampler:GaussianSampler, key=None): + """ + Apply a statistic perturbation to all the coils. + This means taht an independent perturbation is applied every coil + including repeated coils + + Args: + curves: curves to be perturbed. + sampler: the gaussian sampler used to get the perturbations + key: the seed which will be splited to geenerate random + but reproducible pertubations + + Returns: + The curves given as an input are modified and thus no return is done + """ + new_seeds=jax.random.split(key, num=curves.gamma.shape[0]) + if sampler.n_derivs == 0: + perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds) + curves.gamma=curves.gamma + perturbation[:,0,:,:] + elif sampler.n_derivs == 1: + perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds) + curves.gamma=curves.gamma + perturbation[:,0,:,:] + curves.gamma_dash=curves.gamma_dash + perturbation[:,1,:,:] + elif sampler.n_derivs == 2: + perturbation = jax.vmap(sampler.draw_sample, in_axes=(0))(new_seeds) + curves.gamma=curves.gamma + perturbation[:,0,:,:] + curves.gamma_dash=curves.gamma_dash + perturbation[:,1,:,:] + curves.gamma_dashdash=curves.gamma_dashdash + perturbation[:,2,:,:] + #return curves + diff --git a/essos/coils.py b/essos/coils.py index cc1e715..abe58e5 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -45,6 +45,7 @@ def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stell self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) self.quadpoints = jnp.linspace(0, 1, self.n_segments, endpoint=False) self._set_gamma() + self.n_base_curves=dofs.shape[0] def __str__(self): return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ @@ -152,13 +153,27 @@ def stellsym(self, new_stellsym): def gamma(self): return self._gamma + @gamma.setter + def gamma(self, new_gamma): + self._gamma = new_gamma + @property def gamma_dash(self): return self._gamma_dash + + @gamma_dash.setter + def gamma_dash(self, new_gamma_dash): + self._gamma_dash = new_gamma_dash + + @property def gamma_dashdash(self): return self._gamma_dashdash + + @gamma_dashdash.setter + def gamma_dashdash(self, new_gamma_dashdash): + self._gamma_dashdash = new_gamma_dashdash @property def length(self): @@ -238,7 +253,7 @@ def to_simsopt(self): coils = coils_via_symmetries(cuves_simsopt, currents_simsopt, self.nfp, self.stellsym) return [c.curve for c in coils] - def plot(self, ax=None, show=True, plot_derivative=False, close=False, axis_equal=True, **kwargs): + def plot(self, ax=None, show=True, plot_derivative=False, close=False, axis_equal=True,color="brown", linewidth=3,label=None,**kwargs): def rep(data): if close: return jnp.concatenate((data, [data[0]])) @@ -248,6 +263,7 @@ def rep(data): if ax is None or ax.name != "3d": fig = plt.figure() ax = fig.add_subplot(projection='3d') + label_count=0 for gamma, gammadash in zip(self.gamma, self.gamma_dash): x = rep(gamma[:, 0]) y = rep(gamma[:, 1]) @@ -256,9 +272,13 @@ def rep(data): xt = rep(gammadash[:, 0]) yt = rep(gammadash[:, 1]) zt = rep(gammadash[:, 2]) - ax.plot(x, y, z, **kwargs, color='brown', linewidth=3) + if label_count == 0: + ax.plot(x, y, z, **kwargs, color=color, linewidth=linewidth,label=label) + label_count += 1 + else: + ax.plot(x, y, z, **kwargs, color=color, linewidth=linewidth) if plot_derivative: - ax.quiver(x, y, z, 0.1 * xt, 0.1 * yt, 0.1 * zt, arrow_length_ratio=0.1, color="r") + ax.quiver(x, y, z, 0.1 * xt, 0.1 * yt, 0.1 * zt, arrow_length_ratio=0.1, color='r') if axis_equal: fix_matplotlib_3d(ax) if show: @@ -388,6 +408,10 @@ def __add__(self, other): else: raise TypeError(f"Invalid argument type. Got {type(other)}, expected Coils.") + def __exclude_coil__(self, index): + return Coils(Curves(jnp.concatenate((self.curves[:index], self.curves[index+1:])), self.n_segments, 1, False), jnp.concatenate((self.currents[:index], self.currents[index+1:]))) + + def __contains__(self, other): if isinstance(other, Coils): return jnp.all(jnp.isin(other.dofs, self.dofs)) and jnp.all(jnp.isin(other.dofs_currents, self.dofs_currents)) @@ -494,8 +518,8 @@ def RotatedCurve(curve, phi, flip): if flip: rotmat = rotmat @ jnp.array( [[1, 0, 0], - [0, -1, 0], - [0, 0, -1]]) + [0, -1, 0], + [0, 0, -1]]) return curve @ rotmat @partial(jit, static_argnames=['nfp', 'stellsym']) @@ -512,6 +536,20 @@ def apply_symmetries_to_curves(base_curves, nfp, stellsym): curves.append(rotcurve.T) return jnp.array(curves) +@partial(jit, static_argnames=['nfp', 'stellsym']) +def apply_symmetries_to_gammas(base_gammas, nfp, stellsym): + flip_list = [False, True] if stellsym else [False] + gammas = [] + for k in range(0, nfp): + for flip in flip_list: + for i in range(len(base_gammas)): + if k == 0 and not flip: + gammas.append(base_gammas[i]) + else: + rotcurve = RotatedCurve(base_gammas[i], 2*jnp.pi*k/nfp, flip) + gammas.append(rotcurve) + return jnp.array(gammas) + @partial(jit, static_argnames=['nfp', 'stellsym']) def apply_symmetries_to_currents(base_currents, nfp, stellsym): flip_list = [False, True] if stellsym else [False] diff --git a/essos/constants.py b/essos/constants.py index 8aa8b40..ef55118 100644 --- a/essos/constants.py +++ b/essos/constants.py @@ -9,4 +9,5 @@ BOLTZMANN=1.380649e-23 HBAR=1.0545718176461565e-34 ELECTRON_MASS=9.1093837139e-31 -SPEED_OF_LIGHT=2.99792458e8 \ No newline at end of file +SPEED_OF_LIGHT=2.99792458e8 +mu_0= 1.2566370614359173e-06 #N A^-2 \ No newline at end of file diff --git a/essos/dynamics.py b/essos/dynamics.py index 77900cf..11ad64b 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -517,7 +517,7 @@ def condition_Vmec(t, y, args, **kwargs): s, _, _, _ = y return s-1 self.condition = condition_Vmec - elif isinstance(field,BiotSavart) and isinstance(boundary,SurfaceClassifier): + elif (isinstance(field, Coils) or isinstance(self.field, BiotSavart)) and isinstance(boundary,SurfaceClassifier): if model == 'GuidingCenterCollisionsMuIto' or model == 'GuidingCenterCollisionsMuFixed' or model == 'GuidingCenterCollisionsMuAdaptative' or model=='GuidingCenterCollisions': def condition_BioSavart(t, y, args, **kwargs): xx, yy, zz, _,_ = y diff --git a/essos/fields.py b/essos/fields.py index 0789d2a..d9e28ee 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -1,5 +1,7 @@ import jax jax.config.update("jax_enable_x64", True) +from jax import vmap +from essos.coils import compute_curvature import jax.numpy as jnp from functools import partial from jax import jit, jacfwd, grad, vmap, tree_util, lax @@ -13,6 +15,9 @@ def __init__(self, coils): self.currents = coils.currents self.gamma = coils.gamma self.gamma_dash = coils.gamma_dash + #self.gamma_dashdash = coils.gamma_dashdash + self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) + self.coils_curvature= vmap(compute_curvature)(self.gamma_dash, coils.gamma_dashdash) self.r_axis=jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(self.coils.dofs_curves))) self.z_axis=jnp.mean(vmap(lambda dofs: dofs[2, 0])(self.coils.dofs_curves)) @@ -74,6 +79,74 @@ def to_xyz(self, points): +class BiotSavart_from_gamma(): + def __init__(self, gamma,gamma_dash,gamma_dashdash, currents): + self.currents = currents + self.gamma = gamma + self.gamma_dash = gamma_dash + #self.gamma_dashdash = gamma_dashdash + self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in gamma_dash]) + self.coils_curvature= vmap(compute_curvature)(gamma_dash, gamma_dashdash) + self.r_axis=jnp.average(jnp.linalg.norm(jnp.average(gamma,axis=1)[:,0:2],axis=1)) + self.z_axis=jnp.average(jnp.average(gamma,axis=1)[:,2]) + + @partial(jit, static_argnames=['self']) + def sqrtg(self, points): + return 1. + + @partial(jit, static_argnames=['self']) + def B(self, points): + dif_R = (jnp.array(points)-self.gamma).T + dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0)/jnp.linalg.norm(dif_R, axis=0)**3 + dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") + return jnp.mean(dB_sum, axis=0) + + @partial(jit, static_argnames=['self']) + def B_covariant(self, points): + return self.B(points) + + @partial(jit, static_argnames=['self']) + def B_contravariant(self, points): + return self.B(points) + + @partial(jit, static_argnames=['self']) + def AbsB(self, points): + return jnp.linalg.norm(self.B(points)) + + @partial(jit, static_argnames=['self']) + def dB_by_dX(self, points): + return jacfwd(self.B)(points) + + + @partial(jit, static_argnames=['self']) + def dAbsB_by_dX(self, points): + return grad(self.AbsB)(points) + + @partial(jit, static_argnames=['self']) + def grad_B_covariant(self, points): + return jacfwd(self.B_covariant)(points) + + @partial(jit, static_argnames=['self']) + def curl_B(self, points): + grad_B_cov=self.grad_B_covariant(points) + return jnp.array([grad_B_cov[2][1] -grad_B_cov[1][2], + grad_B_cov[0][2] -grad_B_cov[2][0], + grad_B_cov[1][0] -grad_B_cov[0][1]])/self.sqrtg(points) + + @partial(jit, static_argnames=['self']) + def curl_b(self, points): + return self.curl_B(points)/self.AbsB(points)+jnp.cross(self.B_covariant(points),jnp.array(self.dAbsB_by_dX(points)))/self.AbsB(points)**2/self.sqrtg(points) + + @partial(jit, static_argnames=['self']) + def kappa(self, points): + return -jnp.cross(self.B_contravariant(points),self.curl_b(points))*self.sqrtg(points)/self.AbsB(points) + + @partial(jit, static_argnames=['self']) + def to_xyz(self, points): + return points + + + class Vmec(): def __init__(self, wout_filename, ntheta=50, nphi=50, close=True, range_torus='full torus'): self.wout_filename = wout_filename diff --git a/essos/objective_functions.py b/essos/objective_functions.py index a74f7a4..a9d040f 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -4,23 +4,54 @@ from jax import jit, vmap from functools import partial from essos.dynamics import Tracing -from essos.fields import BiotSavart +from essos.fields import BiotSavart,BiotSavart_from_gamma from essos.surfaces import BdotN_over_B, BdotN -from essos.coils import Curves, Coils +from essos.coils import Curves, Coils,compute_curvature from essos.optimization import new_nearaxis_from_x_and_old_nearaxis -import optax +from essos.constants import mu_0 +from essos.coil_perturbation import perturb_curves_systematic, perturb_curves_statistic -def field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments=60, stellsym=True): + +def pertubred_field_from_dofs(x,key,sampler,dofs_curves,currents_scale,nfp,n_segments=60, stellsym=True): + coils = perturbed_coils_from_dofs(x,key,sampler,dofs_curves,currents_scale,nfp=nfp,n_segments=n_segments, stellsym=stellsym) + field = BiotSavart(coils) + return field + +def perturbed_coils_from_dofs(x,key,sampler,dofs_curves,currents_scale,nfp,n_segments=60, stellsym=True): len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], dofs_curves.shape) dofs_currents = x[len_dofs_curves_ravelled:] - curves = Curves(dofs_curves, n_segments, nfp, stellsym) coils = Coils(curves=curves, currents=dofs_currents*currents_scale) + #Split once the key/seed given for one pertubred stellarator + split_keys = jax.random.split(jax.random.key(key), 2) + #Internally the following functions will then further split the two keys avoiding repeating keys + perturb_curves_systematic(coils, sampler, key=split_keys[0]) + perturb_curves_statistic(coils, sampler, key=split_keys[1]) + return coils + +def field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments=60, stellsym=True): + coils = coils_from_dofs(x,dofs_curves,currents_scale,nfp=nfp,n_segments=n_segments, stellsym=stellsym) field = BiotSavart(coils) return field +def coils_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments=60, stellsym=True): + len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) + dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], dofs_curves.shape) + dofs_currents = x[len_dofs_curves_ravelled:] + curves = Curves(dofs_curves, n_segments, nfp, stellsym) + coils = Coils(curves=curves, currents=dofs_currents*currents_scale) + return coils + +def curves_from_dofs(x,dofs_curves,nfp,n_segments=60, stellsym=True): + len_dofs_curves_ravelled = len(jnp.ravel(dofs_curves)) + dofs_curves = jnp.reshape(x[:len_dofs_curves_ravelled], dofs_curves.shape) + dofs_currents = x[len_dofs_curves_ravelled:] + + curves = Curves(dofs_curves, n_segments, nfp, stellsym) + return curves + @partial(jit, static_argnums=(1, 4, 5, 6, 7, 8)) @@ -156,7 +187,7 @@ def loss_particle_gamma_c(x,particles,dofs_curves, currents_scale, nfp,n_segment #return jnp.sum(jnp.square((2./jnp.pi*jnp.absolute(jnp.arctan2(jnp.average(v_r_cross,axis=1),jnp.average(v_theta,axis=1)))))) return jnp.max(2./jnp.pi*jnp.absolute(jnp.arctan2(jnp.average(v_r_cross,axis=1),jnp.average(v_theta,axis=1)))) -def loss_particle_r_cross_final_new(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): +def loss_particle_r_cross_final(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) particles.to_full_orbit(field) tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime, @@ -165,24 +196,72 @@ def loss_particle_r_cross_final_new(x,particles,dofs_curves, currents_scale, nfp R_axis=tracing.field.r_axis Z_axis=tracing.field.z_axis r_cross=jnp.sqrt(jnp.square(jnp.sqrt(jnp.square(xyz[:,:,0])+jnp.square(xyz[:,:,1]))-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)) - return jnp.linlag.norm((jnp.average(r_cross,axis=1))) + return jnp.linalg.norm((jnp.average(r_cross,axis=1))) -def loss_particle_r_cross_max(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): +def loss_particle_r_cross_max_constraint(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_r=0.4,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) - particles.to_full_orbit(field) + #particles.to_full_orbit(field) tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime, timestep=1.e-8,times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance,boundary=boundary) xyz = tracing.trajectories[:,:, :3] R_axis=tracing.field.r_axis Z_axis=tracing.field.z_axis r_cross=jnp.sqrt(jnp.square(jnp.sqrt(jnp.square(xyz[:,:,0])+jnp.square(xyz[:,:,1]))-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)) - return jnp.ravel(jnp.max(r_cross,axis=1)) + return jnp.maximum(r_cross-target_r,0.0) -def loss_lost_fraction(field, particles, maxtime=1e-5, num_steps=100, trace_tolerance=1e-5, model='GuidingCenterAdaptative',timestep=1.e-8,boundary=None): - particles.to_full_orbit(field) + +def loss_Br(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): + field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + #particles.to_full_orbit(field) + tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime, + timestep=1.e-8,times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance,boundary=boundary) + xyz = tracing.trajectories[:,:, :3] + R_axis=tracing.field.r_axis + Z_axis=tracing.field.z_axis + fac_xy=jnp.sqrt(jnp.square(xyz[:,:,0])+jnp.square(xyz[:,:,1])) + r_cross=jnp.sqrt(jnp.square(fac_xy-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)) + dr_cross_dx=(fac_xy-R_axis+1.e-12)*xyz[:,:,0]/(r_cross*fac_xy+1.e-12) + dr_cross_dy=(fac_xy-R_axis+1.e-12)*xyz[:,:,1]/(r_cross*fac_xy+1.e-12) + dr_cross_dz=(xyz[:,:,2]-Z_axis+1.e-12)/(r_cross+1.e-12) + B_particle=jax.vmap(jax.vmap(field.B_covariant,in_axes=0),in_axes=0)(xyz) + B_r=jnp.multiply(B_particle[:,:,0],dr_cross_dx)+jnp.multiply(B_particle[:,:,1],dr_cross_dy)+jnp.multiply(B_particle[:,:,2],dr_cross_dz) + return jnp.sum(jnp.abs(B_r)) + + +def loss_iota(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_iota=0.5,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): + field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + #particles.to_full_orbit(field) tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime, - timestep=timestep,times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance,boundary=boundary) - lost_fraction = tracing.loss_fraction + timestep=1.e-8,times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance,boundary=boundary) + xyz = tracing.trajectories[:,:, :3] + R_axis=tracing.field.r_axis + Z_axis=tracing.field.z_axis + #theta=jnp.arctan2(xyz[:,:,2]-Z_axis+1.e-12, jnp.sqrt(xyz[:,:,0]**2+xyz[:,:,1]**2)-R_axis+1.e-12) + fac_xy=jnp.sqrt(jnp.square(xyz[:,:,0])+jnp.square(xyz[:,:,1])) + dtheta_dx=-(xyz[:,:,2]-Z_axis+1.e-12)*xyz[:,:,0]/(jnp.square(fac_xy-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)+1.e-12) + dtheta_dy=-(xyz[:,:,2]-Z_axis+1.e-12)*xyz[:,:,1]/(jnp.square(fac_xy-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)+1.e-12) + dtheta_dz=(fac_xy-R_axis+1.e-12)/(jnp.square(fac_xy-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)+1.e-12) + dphi_dx=-(xyz[:,:,1])/(fac_xy**2+1.e-12) + dphi_dy=xyz[:,:,0]/(fac_xy**2+1.e-12) + B_particle=jax.vmap(jax.vmap(tracing.field.B_covariant,in_axes=0),in_axes=0)(xyz) + B_theta=jnp.multiply(B_particle[:,:,0],dtheta_dx)+jnp.multiply(B_particle[:,:,1],dtheta_dy)+jnp.multiply(B_particle[:,:,2],dtheta_dz) + B_phi=jnp.multiply(B_particle[:,:,0],dphi_dx)+jnp.multiply(B_particle[:,:,1],dphi_dy) + return jnp.sum(jnp.maximum(target_iota-B_theta/B_phi,0.0)) + +#final lost fraction +def loss_lost_fraction(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True, maxtime=1e-5, num_steps=300, trace_tolerance=1e-5,timestep=1.e-7, model='GuidingCenterAdaptative',boundary=None): + field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + particles.to_full_orbit(field) + tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime,timestep=timestep,times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance,boundary=boundary) + lost_fraction = tracing.loss_fractions[-1] + return lost_fraction + +#lost fraction at every saved time snapshot (which is given by num_steps) +def loss_lost_fraction_times(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True, maxtime=1e-5, num_steps=300, trace_tolerance=1e-5,timestep=1.e-7, model='GuidingCenterAdaptative',boundary=None): + field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + particles.to_full_orbit(field) + tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime,timestep=timestep,times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance,boundary=boundary) + lost_fraction = tracing.loss_fractions return lost_fraction # @partial(jit, static_argnums=(0, 1)) @@ -195,24 +274,24 @@ def normB_axis(field, npoints=15,target_B_on_axis=5.7): # @partial(jit, static_argnums=(0)) #def loss_coil_length(field,max_coil_length=31): -# coil_length=jnp.ravel(field.coils.length) +# coil_length=jnp.ravel(field.coils_length) # return jnp.array([jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])]))]) # @partial(jit, static_argnums=(0)) #def loss_coil_curvature(field,max_coil_curvature=0.4): -# coil_curvature=jnp.mean(field.coils.curvature, axis=1) +# coil_curvature=jnp.mean(field.coils_curvature, axis=1) # return jnp.array([jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])]))]) # @partial(jit, static_argnums=(0)) def loss_coil_length(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_length=31): field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) - coil_length=jnp.ravel(field.coils.length) + coil_length=jnp.ravel(field.coils_length) return jnp.ravel(jnp.array([jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])]))])) # @partial(jit, static_argnums=(0)) def loss_coil_curvature(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_curvature=0.4): field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) - coil_curvature=jnp.mean(field.coils.curvature, axis=1) + coil_curvature=jnp.mean(field.coils_curvature, axis=1) return jnp.ravel(jnp.array([jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])]))])) # @partial(jit, static_argnums=(0, 1)) @@ -229,7 +308,27 @@ def loss_normB_axis_average(x,dofs_curves,currents_scale,nfp,n_segments=60,stell R_axis=field.r_axis phi_array = jnp.linspace(0, 2 * jnp.pi, npoints) B_axis = vmap(lambda phi: field.AbsB(jnp.array([R_axis * jnp.cos(phi), R_axis * jnp.sin(phi), 0])))(phi_array) - return jnp.absolute(jnp.average(B_axis)-target_B_on_axis) + return jnp.array([jnp.absolute(jnp.average(B_axis)-target_B_on_axis)]) + + + +# @partial(jit, static_argnums=(0)) +def loss_coil_curvature_new(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_curvature=0.4): + field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) + coil_curvature=jnp.mean(field.coils_curvature, axis=1) + return jnp.maximum(coil_curvature-max_coil_curvature,0.0) + +# @partial(jit, static_argnums=(0)) +def loss_coil_length_new(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_length=31): + field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) + coil_length=jnp.ravel(field.coils_length) + return jnp.maximum(coil_length-max_coil_length,0.0) + + + + + + @partial(jit, static_argnums=(1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,14)) def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves, currents_scale, nfp, max_coil_curvature=0.5, @@ -271,4 +370,297 @@ def loss_BdotN(x, vmec, dofs_curves, currents_scale, nfp, max_coil_length=42, coil_length_loss = jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) coil_curvature_loss = jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) - return bdotn_over_b_loss+coil_length_loss+coil_curvature_loss \ No newline at end of file + return bdotn_over_b_loss+coil_length_loss+coil_curvature_loss + +@partial(jit, static_argnums=(1, 4, 5, 6)) +def loss_BdotN_only(x, vmec, dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True): + field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + + bdotn_over_b = BdotN_over_B(vmec.surface, field) + + bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) + + return bdotn_over_b_loss + +@partial(jit, static_argnums=(1, 4, 5, 6,7)) +def loss_BdotN_only_constraint(x, vmec, dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_tol=1.e-6): + field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + + bdotn_over_b = BdotN_over_B(vmec.surface, field) + + bdotn_over_b_loss = jnp.sqrt(jnp.sum(jnp.maximum(jnp.square(bdotn_over_b)-target_tol,0.0))) + #bdotn_over_b_loss = jnp.sqrt(0.5*jnp.maximum(jnp.square(bdotn_over_b)-target_tol,0.0)) + return bdotn_over_b_loss + + +@partial(jit, static_argnums=(1,2,3, 6, 7, 8)) +def loss_BdotN_only_stochastic(x,sampler,N_samples, vmec, dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True): + keys= jnp.arange(N_samples) + def perturbed_bdotn_over_b(x,key,sampler,dofs_curves, currents_scale, nfp, n_segments, stellsym): + perturbed_field = pertubred_field_from_dofs(x,key,sampler, dofs_curves, currents_scale, nfp, n_segments, stellsym) + bdotn_over_b = BdotN_over_B(vmec.surface, perturbed_field) + return jnp.sum(jnp.abs(bdotn_over_b)) + #Average over the N_samples + expected_loss=jnp.average(jax.vmap(perturbed_bdotn_over_b, in_axes=(None,0,None,None,None,None,None,None))(x, keys,sampler, dofs_curves, currents_scale, nfp, n_segments, stellsym),axis=0) + return expected_loss + + +@partial(jit, static_argnums=(1,2,3, 6, 7, 8,9)) +def loss_BdotN_only_constraint_stochastic(x,sampler,N_samples, vmec, dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_tol=1.e-6): + keys= jnp.arange(N_samples) + def perturbed_bdotn_over_b(x,key,sampler,dofs_curves, currents_scale, nfp, n_segments, stellsym): + perturbed_field = pertubred_field_from_dofs(x,key,sampler, dofs_curves, currents_scale, nfp, n_segments, stellsym) + bdotn_over_b = BdotN_over_B(vmec.surface, perturbed_field) + return jnp.square(bdotn_over_b) + #Average over the N_samples + expected_loss=jnp.average(jax.vmap(perturbed_bdotn_over_b, in_axes=(None,0,None,None,None,None,None,None))(x, keys,sampler, dofs_curves, currents_scale, nfp, n_segments, stellsym),axis=0) + + constrained_expected_loss = jnp.sqrt(jnp.sum(jnp.maximum(expected_loss-target_tol,0.0))) + #bdotn_over_b_loss = jnp.sqrt(0.5*jnp.maximum(jnp.square(bdotn_over_b)-target_tol,0.0)) + return constrained_expected_loss + + + +#This is thr quickest way to get coil-surface distance (but I guess not the most efficient way for large sizes). +# In that case we would do the candidates method from simsopt entirely +def loss_cs_distance(x,surface,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,min_distance_cs=1.3): + coils=coils_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + result=jnp.sum(jax.vmap(cs_distance_pure,in_axes=(0,0,None,None,None))(coils.gamma,coils.gamma_dash,surface.gamma,surface.unitnormal,min_distance_cs)) + return result + +#Same as above but for individual constraints (useful in case one wants to target the several pairs individually) +def loss_cs_distance_array(x,surface,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,min_distance_cs=1.3): + coils=coils_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + result=jax.vmap(cs_distance_pure,in_axes=(0,0,None,None,None))(coils.gamma,coils.gamma_dash,surface.gamma,surface.unitnormal,min_distance_cs) + return result.flatten() + +#This is thr quickest way to get coil-coil distance (but I guess not the most efficient way for large sizes). +# In that case we would do the candidates method from simsopt entirely +def loss_cc_distance(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,min_distance_cc=0.7,downsample=1): + coils=coils_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + result=jnp.sum(jnp.triu(jax.vmap(jax.vmap(cc_distance_pure,in_axes=(0,0,None,None,None,None)),in_axes=(None,None,0,0,None,None))(coils.gamma,coils.gamma_dash,coils.gamma,coils.gamma_dash,min_distance_cc,downsample),k=1)) + return result + +#Same as above but for individual constraints (useful in case one wants to target the several pairs individually) +def loss_cc_distance_array(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,min_distance_cc=0.7,downsample=1): + coils=coils_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + result=jnp.triu(jax.vmap(jax.vmap(cc_distance_pure,in_axes=(0,0,None,None,None,None)),in_axes=(None,None,0,0,None,None))(coils.gamma,coils.gamma_dash,coils.gamma,coils.gamma_dash,min_distance_cc,downsample),k=1) + return result[result != 0.0].flatten() + + + +#One curve to curve distance ( +#reused from Simsopt, no changes were necessary) +def cc_distance_pure(gamma1, l1, gamma2, l2, minimum_distance, downsample=1): + """ + Compute the curve-curve distance penalty between two curves. + + Args: + gamma1 (array-like): Points along the first curve. + l1 (array-like): Tangent vectors along the first curve. + gamma2 (array-like): Points along the second curve. + l2 (array-like): Tangent vectors along the second curve. + minimum_distance (float): The minimum allowed distance between curves. + downsample (int, default=1): + Factor by which to downsample the quadrature points + by skipping through the array by a factor of ``downsample``, + e.g. curve.gamma()[::downsample, :]. + Setting this parameter to a value larger than 1 will speed up the calculation, + which may be useful if the set of coils is large, though it may introduce + inaccuracy if ``downsample`` is set too large, or not a multiple of the + total number of quadrature points (since this will produce a nonuniform set of points). + This parameter is used to speed up expensive calculations during optimization, + while retaining higher accuracy for the other objectives. + + Returns: + float: The curve-curve distance penalty value. + """ + gamma1 = gamma1[::downsample, :] + gamma2 = gamma2[::downsample, :] + l1 = l1[::downsample, :] + l2 = l2[::downsample, :] + dists = jnp.sqrt(jnp.sum((gamma1[:, None, :] - gamma2[None, :, :])**2, axis=2)) + alen = jnp.linalg.norm(l1, axis=1)[:, None] * jnp.linalg.norm(l2, axis=1)[None, :] + return jnp.sum(alen * jnp.maximum(minimum_distance-dists, 0)**2)/(gamma1.shape[0]*gamma2.shape[0]) + + + +#One coil to surface distance (reused from Simsopt, no changes were necessary) +def cs_distance_pure(gammac, lc, gammas, ns, minimum_distance): + """ + Compute the curve-surface distance penalty between a curve and a surface. + + Args: + gammac (array-like): Points along the curve. + lc (array-like): Tangent vectors along the curve. + gammas (array-like): Points on the surface. + ns (array-like): Surface normal vectors. + minimum_distance (float): The minimum allowed distance between curve and surface. + + Returns: + float: The curve-surface distance penalty value. + """ + dists = jnp.sqrt(jnp.sum( + (gammac[:, None, :] - gammas[None, :, :])**2, axis=2)) + integralweight = jnp.linalg.norm(lc, axis=1)[:, None] \ + * jnp.linalg.norm(ns, axis=1)[None, :] + return jnp.mean(integralweight * jnp.maximum(minimum_distance-dists, 0)**2) + + + +#This is thr quickest way to get coil-coil distance (but I guess not the most efficient way for large sizes). +# In that case we would do the candidates method from simsopt entirely +def loss_linking_mnumber(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,downsample=1): + coils=coils_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + #Since the quadpoints are the same for every curve then we can calculate the increment is constant for every curve + # (needs change if quadpoints are allowed to be different) + dphi=coils.quadpoints[1]-coils.quadpoints[0] + result=jnp.sum(jnp.triu(jax.vmap(jax.vmap(linking_number_pure,in_axes=(0,0,None,None,None)), + in_axes=(None,None,0,0,None))(coils.gamma[:,0:-1:downsample,:], + coils.gamma_dash[:,0:-1:downsample,:], + coils.gamma[:,0:-1:downsample,:], + coils.gamma_dash[:,0:-1:downsample,:], + dphi),k=1)) + return result + + +#This is thr quickest way to get coil-coil distance (but I guess not the most efficient way for large sizes). +# In that case we would do the candidates method from simsopt entirely +def loss_linking_mnumber_constarint(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,downsample=1): + coils=coils_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) + #Since the quadpoints are the same for every curve then we can calculate the increment is constant for every curve + # (needs change if quadpoints are allowed to be different) + dphi=coils.quadpoints[1]-coils.quadpoints[0] + result=jnp.triu(jax.vmap(jax.vmap(linking_number_pure,in_axes=(0,0,None,None,None)), + in_axes=(None,None,0,0,None))(coils.gamma[:,0:-1:downsample,:], + coils.gamma_dash[:,0:-1:downsample,:], + coils.gamma[:,0:-1:downsample,:], + coils.gamma_dash[:,0:-1:downsample,:], + dphi)+1.e-18,k=1) + #The 1.e-18 above is just to get all the correct values in the following mask + return result[result != 0.0].flatten() + +def linking_number_pure(gamma1, lc1, gamma2, lc2,dphi): + linking_number_ij=jnp.sum(jnp.abs(jax.vmap(integrand_linking_number, in_axes=(0, 0, 0, 0,None,None))(gamma1, lc1, gamma2, lc2,dphi,dphi)/ (4*jnp.pi))) + return linking_number_ij + +def integrand_linking_number(r1,dr1,r2,dr2,dphi1,dphi2): + """ + Compute the integrand for the linking number between two curves. + + Args: + r1 (array-like): Points along the first curve. + dr1 (array-like): Tangent vectors along the first curve. + r2 (array-like): Points along the second curve. + dr2 (array-like): Tangent vectors along the second curve. + dphi1 (array-like): increments of quadpoints 1 + dphi2 (array-like): increments of quadpoints 2 + + Returns: + float: The integrand value for the linking number. + """ + return jnp.dot((r1-r2), jnp.cross(dr1, dr2)) / jnp.linalg.norm(r1-r2)**3*dphi1*dphi2 + + + +#Loss function penalizing force on coils using Landremann-Hurwitz method +def loss_lorentz_force_coils(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,p=1,threshold=0.5e+6): + coils=coils_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments, stellsym) + curves_indeces=jnp.arange(coils.gamma.shape[0]) + #We want to calculate tangeng cross [B_self + B_mutual] for each coil + #B_self is the self-field of the coil, B_mutual is the field from the other coils + force_penalty=jax.vmap(lp_force_pure,in_axes=(0,None,None,None,None,None,None,None))(curves_indeces,coils.gamma, + coils.gamma_dash,coils.gamma_dashdash,coils.currents,coils.quadpoints,p, threshold) + return force_penalty + + + + + + +def lp_force_pure(index,gamma, gamma_dash,gamma_dashdash,currents,quadpoints,p, threshold): + """Pure function for minimizing the Lorentz force on a coil. + """ + regularization = regularization_circ(1./jnp.average(compute_curvature( gamma_dash.at[index].get(), gamma_dashdash.at[index].get()))) + B_mutual=jax.vmap(BiotSavart_from_gamma(jnp.roll(gamma, -index, axis=0)[1:], + jnp.roll(gamma_dash, -index, axis=0)[1:], + jnp.roll(gamma_dashdash, -index, axis=0)[1:], + jnp.roll(currents, -index, axis=0)[1:]).B,in_axes=0)(gamma[index]) + B_self = B_regularized_pure(gamma.at[index].get(),gamma_dash.at[index].get(), gamma_dashdash.at[index].get(), quadpoints, currents[index], regularization) + gammadash_norm = jnp.linalg.norm(gamma_dash.at[index].get(), axis=1)[:, None] + tangent = gamma_dash.at[index].get() / gammadash_norm + force = jnp.cross(currents.at[index].get() * tangent, B_self + B_mutual) + force_norm = jnp.linalg.norm(force, axis=1)[:, None] + return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm))*(1./p) + + + +def B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization): + """The term in the regularized Biot-Savart law in which the near-singularity + has been integrated analytically. + + regularization corresponds to delta * a * b for rectangular x-section, or to + a²/√e for circular x-section. + + A prefactor of μ₀ I / (4π) is not included. + + The derivatives rc_prime, rc_prime_prime refer to an angle that goes up to + 2π, not up to 1. + """ + norm_rc_prime = jnp.linalg.norm(rc_prime, axis=1) + return jnp.cross(rc_prime, rc_prime_prime) * (0.5 * (-2 + jnp.log(64 * norm_rc_prime * norm_rc_prime / regularization)) / (norm_rc_prime**3))[:, None] + + +def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization): + # The factors of 2π in the next few lines come from the fact that simsopt + # uses a curve parameter that goes up to 1 rather than 2π. + phi = quadpoints * 2 * jnp.pi + rc = gamma + rc_prime = gammadash / 2 / jnp.pi + rc_prime_prime = gammadashdash / 4 / jnp.pi**2 + n_quad = phi.shape[0] + dphi = 2 * jnp.pi / n_quad + analytic_term = B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization) + dr = rc[:, None] - rc[None, :] + first_term = jnp.cross(rc_prime[None, :], dr) / ((jnp.sum(dr * dr, axis=2) + regularization) ** 1.5)[:, :, None] + cos_fac = 2 - 2 * jnp.cos(phi[None, :] - phi[:, None]) + denominator2 = cos_fac * jnp.sum(rc_prime * rc_prime, axis=1)[:, None] + regularization + factor2 = 0.5 * cos_fac / denominator2**1.5 + second_term = jnp.cross(rc_prime_prime, rc_prime)[:, None, :] * factor2[:, :, None] + integral_term = dphi * jnp.sum(first_term + second_term, 1) + return current * mu_0 / (4 * jnp.pi) * (analytic_term + integral_term) + + + +def regularization_circ(a): + """Regularization for a circular conductor""" + return a**2 / jnp.sqrt(jnp.e) + + +def regularization_rect(a, b): + """Regularization for a rectangular conductor""" + return a * b * rectangular_xsection_delta(a, b) + +def rectangular_xsection_k(a, b): + """Auxiliary function for field in rectangular conductor""" + return (4 * b) / (3 * a) * jnp.arctan(a/b) + (4*a)/(3*b)*jnp.arctan(b/a)+ (b**2)/(6*a**2)*jnp.log(b/a) + (a**2)/(6*b**2)*jnp.log(a/b) - (a**4 - 6*a**2*b**2 + b**4)/(6*a**2*b**2)*jnp.log(a/b+b/a) + + +def rectangular_xsection_delta(a, b): + """Auxiliary function for field in rectangular conductor""" + return jnp.exp(-25/6 + rectangular_xsection_k(a, b)) + + +#def loss_BdotN_only_with_perturbation(x, vmec, dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True, N_stells=10): +# """ +# Compute the loss function for BdotN with a perturbation applied to the BdotN value.): +# field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) +# +# bdotn_over_b = BdotN_over_B(vmec.surface, field) +# +# # Apply perturbation to the BdotN value +# bdotn_over_b += perturbation +# +# bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) + +# return bdotn_over_b_loss \ No newline at end of file diff --git a/examples/create_perturbed_coils.py b/examples/create_perturbed_coils.py new file mode 100644 index 0000000..b5109ad --- /dev/null +++ b/examples/create_perturbed_coils.py @@ -0,0 +1,76 @@ + +import os +number_of_processors_to_use = 8 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax +print(jax.devices()) +jax.config.update("jax_enable_x64", True) +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from functools import partial +from essos.coil_perturbation import GaussianSampler +from essos.coil_perturbation import perturb_curves_statistic,perturb_curves_systematic + + + + +# Coils parameters +order_Fourier_series_coils = 4 +number_coil_points = 80 +number_coils_per_half_field_period = 3 +number_of_field_periods = 2 + +# Initialize coils +current_on_each_coil = 1.84e7 +major_radius_coils = 7.75 +minor_radius_coils = 4.45 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + + + +g=GaussianSampler(coils_initial.quadpoints,sigma=0.2,length_scale=0.1,n_derivs=2) + +#Split the key for reproducibility +key=0 +split_keys=jax.random.split(jax.random.key(key), num=2) +#Add systematic error +coils_sys = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) +perturb_curves_systematic(coils_sys, g, key=split_keys[0]) +# Add statistical error +coils_stat = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) +perturb_curves_statistic(coils_stat, g, key=split_keys[1]) +# Add both systematic and statistical errors +coils_perturbed = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) +perturb_curves_systematic(coils_perturbed, g, key=split_keys[0]) +perturb_curves_statistic(coils_perturbed, g, key=split_keys[1]) + + +fig = plt.figure(figsize=(9, 8)) +ax1 = fig.add_subplot(111, projection='3d') +coils_initial.plot(ax=ax1, show=False,color='brown',linewidth=1,label='Initial coils') +coils_sys.plot(ax=ax1, show=False,color='blue',linewidth=1,label='Systematic perturbation') +coils_stat.plot(ax=ax1, show=False,color='green',linewidth=1,label='Statistical perturbation') +coils_perturbed.plot(ax=ax1, show=False,color='magenta',linewidth=1,label='Perturbed coils') +plt.legend() +plt.show() + + + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# tracing_initial.to_vtk('trajectories_initial') +#tracing_optimized.to_vtk('trajectories_final') +#coils_initial.to_vtk('coils_initial') +#new_coils.to_vtk('coils_optimized') diff --git a/examples/input_files/input.toroidal_surface b/examples/input_files/input.toroidal_surface new file mode 100644 index 0000000..3a133b2 --- /dev/null +++ b/examples/input_files/input.toroidal_surface @@ -0,0 +1,14 @@ +!----- Runtime Parameters ----- +&INDATA + LASYM = F + NFP = 0001 + MPOL = 002 + NTOR = 002 +!----- Boundary Parameters (n,m) ----- + RBC( 000,000) = 7.75 ZBS( 000,000) = 0 + RBC( 001,000) = 0.000001 ZBS( 001,000) = -0.000001 + RBC(-001,001) = 0.000001 ZBS(-001,001) = 0.000001 + RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5 + RBC( 001,001) = 0.000001 ZBS( 001,001) = 0.000001 + RBC(-002,002) = 1E-7 ZBS(-002,002) = 1E-7 +/ diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_adam_constrained.py b/examples/optimize_coils_particle_confinement_guidingcenter_adam_constrained.py new file mode 100644 index 0000000..b878eb5 --- /dev/null +++ b/examples/optimize_coils_particle_confinement_guidingcenter_adam_constrained.py @@ -0,0 +1,156 @@ + +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax +print(jax.devices()) +jax.config.update("jax_enable_x64", True) +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.dynamics import Particles, Tracing +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from essos.optimization import optimize_loss_function +from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c +from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis,loss_normB_axis_average +from functools import partial +import essos.alm_convex as alm +import optax + + +# Optimization parameters +target_B_on_axis = 5.7 +max_coil_length = 31 +max_coil_curvature = 0.4 +nparticles = number_of_processors_to_use*10 +order_Fourier_series_coils = 4 +number_coil_points = 80 +maximum_function_evaluations = 30 +maxtimes = [1.e-5] +num_steps=100 +number_coils_per_half_field_period = 3 +number_of_field_periods = 2 +model = 'GuidingCenterAdaptative' + +# Initialize coils +current_on_each_coil = 1.84e7 +major_radius_coils = 7.75 +minor_radius_coils = 4.45 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + +len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) +nfp = coils_initial.nfp +stellsym = coils_initial.stellsym +n_segments = coils_initial.n_segments +dofs_curves_shape = coils_initial.dofs_curves.shape +currents_scale = coils_initial.currents_scale + +# Initialize particles +phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) +initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T +particles = Particles(initial_xyz=initial_xyz) + +t=maxtimes[0] +loss_partial = partial(loss_particle_gamma_c,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) +curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) +r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) + + +# Create the constraints +penalty = 1.05 #Intial penalty values +multiplier=1.0 #Initial lagrange multiplier values +sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative +constraints = alm.combine( +alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +#alm.eq(r_max_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +) + + + +model_lagrange='Mu_Tolerance' #Options: Mu_Constant, Mu_Monotonic, Mu_Conditional,Mu_Adaptative +beta=2. #penalty update parameter +mu_max=1.e4 #Maximum penalty parameter allowed +alpha=0.99 # +gamma=1.e-2 +epsilon=1.e-8 +omega_tol=1. #grad_tolerance, associated with grad of lagrangian to main parameters +eta_tol=1.e-6 #contrained tolerances, associated with variation of contraints +optimizer=optax.adabelief(learning_rate=0.003,nesterov=True) + + +ALM=alm.ALM_model(optimizer,constraints,model_lagrange=model_lagrange,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + +lagrange_params=constraints.init(coils_initial.x) +params = coils_initial.x, lagrange_params +opt_state,grad,info=ALM.init(params) +mu_average=alm.penalty_average(lagrange_params) +#omega=1.#1./mu_average +#eta=1000.#1./mu_average**0.1 +omega=1./mu_average +eta=1./mu_average**0.1 + +i=0 +while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): + params, opt_state,grad,info,eta,omega = ALM.update(params,opt_state,grad,info,eta,omega) #One step of ALM optimization + #if i % 5 == 0: + #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') + print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') + print('lagrange',params[1]) + i=i+1 + + +dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) +dofs_currents = params[0][len_dofs_curves:] +curves = Curves(dofs_curves, n_segments, nfp, stellsym) +new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) +params=new_coils.x +tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model + ,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) +tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) + +#print('Final params',params) +#print(info[1]) +# Plot trajectories, before and after optimization +fig = plt.figure(figsize=(9, 8)) +ax1 = fig.add_subplot(221, projection='3d') +ax2 = fig.add_subplot(222, projection='3d') +ax3 = fig.add_subplot(223) +ax4 = fig.add_subplot(224) + +coils_initial.plot(ax=ax1, show=False) +tracing_initial.plot(ax=ax1, show=False) +for i, trajectory in enumerate(tracing_initial.trajectories): + ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') + +ax3.set_xlabel('R (m)') +ax3.set_ylabel('Z (m)') +#ax3.legend() +new_coils.plot(ax=ax2, show=False) +tracing_optimized.plot(ax=ax2, show=False) +for i, trajectory in enumerate(tracing_optimized.trajectories): + ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') +ax4.set_xlabel('R (m)') +ax4.set_ylabel('Z (m)')#ax4.legend() +plt.tight_layout() +plt.savefig(f'opt_constrained.pdf') + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# tracing_initial.to_vtk('trajectories_initial') +#tracing_optimized.to_vtk('trajectories_final') +#coils_initial.to_vtk('coils_initial') +#new_coils.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py b/examples/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py new file mode 100644 index 0000000..d764f8b --- /dev/null +++ b/examples/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py @@ -0,0 +1,174 @@ + +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax +print(jax.devices()) +jax.config.update("jax_enable_x64", True) +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.surfaces import SurfaceRZFourier +from essos.dynamics import Particles, Tracing +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from essos.objective_functions import loss_particle_r_cross_max_constraint,loss_particle_gamma_c +from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis_average,loss_Br,loss_iota +from functools import partial +import essos.augmented_lagrangian as alm + + + + + +# Optimization parameters +target_B_on_axis = 5.7 +max_coil_length = 31 +max_coil_curvature = 0.4 +nparticles = number_of_processors_to_use*10 +order_Fourier_series_coils = 4 +number_coil_points = 80 +maximum_function_evaluations = 1 +maxtimes = [1.e-5] +num_steps=100 +number_coils_per_half_field_period = 3 +number_of_field_periods = 2 +model = 'GuidingCenter' + +# Initialize coils +current_on_each_coil = 1.84e7 +major_radius_coils = 7.75 +minor_radius_coils = 4.45 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + + +len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) +nfp = coils_initial.nfp +stellsym = coils_initial.stellsym +n_segments = coils_initial.n_segments +dofs_curves_shape = coils_initial.dofs_curves.shape +currents_scale = coils_initial.currents_scale + + +# Initialize particles +phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) +initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T +particles = Particles(initial_xyz=initial_xyz) + +t=maxtimes[0] +loss_partial = partial(loss_particle_gamma_c,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) +curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) +r_max_partial = partial(loss_particle_r_cross_max_constraint,target_r=0.4, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) +iota_partial = partial(loss_iota,target_iota=0.5, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) +Br_partial = partial(loss_Br, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) + + +# Create the constraints +penalty = 1. #Intial penalty values +multiplier=0.5 #Initial lagrange multiplier values +sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative +model_lagrangian='Standard' #Use standard augmented lagragian suitable for bounded optimizers +#Since we are using LBFGS-B from jaxopt, model_mu will be updated with tolerances so we do not need to difinte the model + +#Construct constraints +constraints = alm.combine( +alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(r_max_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +#alm.eq(Br_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +#alm.eq(iota_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +) + + + +beta=2. #penalty update parameter +mu_max=1.e4 #Maximum penalty parameter allowed +alpha=0.99 #These are parameters only used if gradient descent and adaaptative mu +gamma=1.e-2 +epsilon=1.e-8 +omega_tol=0.0001 #desired grad_tolerance, associated with grad of lagrangian to main parameters +eta_tol=0.001 #desired contraint tolerance, associated with variation of contraints + + + +#If loss=cost_function(x) is not prescribed, f(x)=0 is considered +ALM=alm.ALM_model_jaxopt_lbfgsb(constraints,loss=loss_partial,model_lagrangian=model_lagrangian,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + +#Initializing lagrange multipliers +lagrange_params=constraints.init(coils_initial.x) +#parameters are a tuple of the primal/main optimisation parameters and the lagrange multipliers +params = coils_initial.x, lagrange_params +#This is just to initialize an empty state for the lagrange multiplier update and get some information +lag_state,grad,info=ALM.init(params) + +#Initializing first tolerances for the inner minimisation loop iteration +mu_average=alm.penalty_average(lagrange_params) +#omega=1.#1./mu_average +#eta=1000.#1./mu_average**0.1 +omega=1./mu_average +eta=1./mu_average**0.1 + +i=0 +while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): + #One step of ALM optimization + params, lag_state,grad,info,eta,omega = ALM.update(params,lag_state,grad,info,eta,omega) + #if i % 5 == 0: + #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') + print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') + #print('lagrange',params[1]) + i=i+1 + + +dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) +dofs_currents = params[0][len_dofs_curves:] +curves = Curves(dofs_curves, n_segments, nfp, stellsym) +new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) +params=new_coils.x +tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model + ,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) +tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) + +#print('Final params',params) +#print(info[1]) +# Plot trajectories, before and after optimization +fig = plt.figure(figsize=(9, 8)) +ax1 = fig.add_subplot(221, projection='3d') +ax2 = fig.add_subplot(222, projection='3d') +ax3 = fig.add_subplot(223) +ax4 = fig.add_subplot(224) + +coils_initial.plot(ax=ax1, show=False) +tracing_initial.plot(ax=ax1, show=False) +for i, trajectory in enumerate(tracing_initial.trajectories): + ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') + +ax3.set_xlabel('R (m)') +ax3.set_ylabel('Z (m)') +#ax3.legend() +new_coils.plot(ax=ax2, show=False) +tracing_optimized.plot(ax=ax2, show=False) +for i, trajectory in enumerate(tracing_optimized.trajectories): + ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') +ax4.set_xlabel('R (m)') +ax4.set_ylabel('Z (m)')#ax4.legend() +plt.tight_layout() +plt.savefig(f'opt_constrained.pdf') + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# tracing_initial.to_vtk('trajectories_initial') +#tracing_optimized.to_vtk('trajectories_final') +#coils_initial.to_vtk('coils_initial') +#new_coils.to_vtk('coils_optimized') diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_jaxopt_constrained.py b/examples/optimize_coils_particle_confinement_guidingcenter_jaxopt_constrained.py new file mode 100644 index 0000000..a9504d5 --- /dev/null +++ b/examples/optimize_coils_particle_confinement_guidingcenter_jaxopt_constrained.py @@ -0,0 +1,156 @@ + +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax +print(jax.devices()) +jax.config.update("jax_enable_x64", True) +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.dynamics import Particles, Tracing +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from essos.optimization import optimize_loss_function +from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c +from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis,loss_normB_axis_average +from functools import partial +import essos.alm_convex as alm +import optax + + +# Optimization parameters +target_B_on_axis = 5.7 +max_coil_length = 31 +max_coil_curvature = 0.4 +nparticles = number_of_processors_to_use*10 +order_Fourier_series_coils = 4 +number_coil_points = 80 +maximum_function_evaluations = 10 +maxtimes = [2.e-5] +num_steps=100 +number_coils_per_half_field_period = 3 +number_of_field_periods = 2 +model = 'GuidingCenterAdaptative' + +# Initialize coils +current_on_each_coil = 1.84e7 +major_radius_coils = 7.75 +minor_radius_coils = 4.45 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + +len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) +nfp = coils_initial.nfp +stellsym = coils_initial.stellsym +n_segments = coils_initial.n_segments +dofs_curves_shape = coils_initial.dofs_curves.shape +currents_scale = coils_initial.currents_scale + +# Initialize particles +phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) +initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T +particles = Particles(initial_xyz=initial_xyz) + +t=maxtimes[0] +loss_partial = partial(loss_particle_gamma_c,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) +curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) +r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) + + +# Create the constraints +penalty = 1.05 #Intial penalty values +multiplier=0.0 #Initial lagrange multiplier values +sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative +constraints = alm.combine( +alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(r_max_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +) + + + +model_lagrange='Mu_Tolerance' #Options: Mu_Constant, Mu_Monotonic, Mu_Conditional,Mu_Adaptative +beta=2. #penalty update parameter +mu_max=1.e4 #Maximum penalty parameter allowed +alpha=0.99 # +gamma=1.e-2 +epsilon=1.e-8 +omega_tol=1.e-5 #grad_tolerance, associated with grad of lagrangian to main parameters +eta_tol=1.e-6 #contrained tolerances, associated with variation of contraints +optimizer='SLSQP' + + +ALM=alm.ALM_model_jaxopt(constraints,optimizer=optimizer,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + +lagrange_params=constraints.init(coils_initial.x) +params = coils_initial.x, lagrange_params +lag_state,grad,info=ALM.init(params) +mu_average=alm.penalty_average(lagrange_params) +#omega=1.#1./mu_average +#eta=1000.#1./mu_average**0.1 +omega=1./mu_average +eta=1./mu_average**0.1 + +i=0 +while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): + params, lag_state,grad,info,eta,omega = ALM.update(params,lag_state,grad,info,eta,omega) #One step of ALM optimization + #if i % 5 == 0: + #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') + print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') + print('lagrange',params[1]) + i=i+1 + + +dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) +dofs_currents = params[0][len_dofs_curves:] +curves = Curves(dofs_curves, n_segments, nfp, stellsym) +new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) +params=new_coils.x +tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model + ,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) +tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) + +#print('Final params',params) +#print(info[1]) +# Plot trajectories, before and after optimization +fig = plt.figure(figsize=(9, 8)) +ax1 = fig.add_subplot(221, projection='3d') +ax2 = fig.add_subplot(222, projection='3d') +ax3 = fig.add_subplot(223) +ax4 = fig.add_subplot(224) + +coils_initial.plot(ax=ax1, show=False) +tracing_initial.plot(ax=ax1, show=False) +for i, trajectory in enumerate(tracing_initial.trajectories): + ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') + +ax3.set_xlabel('R (m)') +ax3.set_ylabel('Z (m)') +#ax3.legend() +new_coils.plot(ax=ax2, show=False) +tracing_optimized.plot(ax=ax2, show=False) +for i, trajectory in enumerate(tracing_optimized.trajectories): + ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') +ax4.set_xlabel('R (m)') +ax4.set_ylabel('Z (m)')#ax4.legend() +plt.tight_layout() +plt.savefig(f'opt_constrained.pdf') + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# tracing_initial.to_vtk('trajectories_initial') +#tracing_optimized.to_vtk('trajectories_final') +#coils_initial.to_vtk('coils_initial') +#new_coils.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py b/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py index a63242d..26f3aa2 100644 --- a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py +++ b/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py @@ -7,8 +7,9 @@ import matplotlib.pyplot as plt from essos.dynamics import Particles, Tracing from essos.coils import Coils, CreateEquallySpacedCurves,Curves -from essos.objective_functions import loss_particle_r_cross_max -from essos.objective_functions import loss_coil_curvature,loss_coil_length, loss_normB_axis_average +from essos.optimization import optimize_loss_function +from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c +from essos.objective_functions import loss_coil_curvature_new,loss_coil_length_new,loss_normB_axis,loss_normB_axis_average from functools import partial import optax @@ -17,7 +18,7 @@ target_B_on_axis = 5.7 max_coil_length = 31 max_coil_curvature = 0.4 -nparticles = number_of_processors_to_use*1 +nparticles = number_of_processors_to_use*10 order_Fourier_series_coils = 4 number_coil_points = 80 maximum_function_evaluations = 3 @@ -52,12 +53,12 @@ t=maxtimes[0] -curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) -length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +curvature_partial=partial(loss_coil_curvature_new, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length_new, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model = model,num_steps=num_steps) def total_loss(params): - return jnp.linalg.norm(curvature_partial(params)+length_partial(params)+Baxis_average_partial(params))**2 + return jnp.linalg.norm(jnp.concatenate((r_max_partial(params),length_partial(params),curvature_partial(params),Baxis_average_partial(params))))**2 params=coils_initial.x optimizer=optax.lbfgs() diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs_constrained.py b/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs_constrained.py new file mode 100644 index 0000000..3bfef5e --- /dev/null +++ b/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs_constrained.py @@ -0,0 +1,155 @@ + +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax +print(jax.devices()) +jax.config.update("jax_enable_x64", True) +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.dynamics import Particles, Tracing +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from essos.optimization import optimize_loss_function +from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c +from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis,loss_normB_axis_average +from functools import partial +import essos.alm_convex as alm +import optax + + +# Optimization parameters +target_B_on_axis = 5.7 +max_coil_length = 31 +max_coil_curvature = 0.4 +nparticles = number_of_processors_to_use*10 +order_Fourier_series_coils = 4 +number_coil_points = 80 +maximum_function_evaluations = 30 +maxtimes = [1.e-5] +num_steps=100 +number_coils_per_half_field_period = 3 +number_of_field_periods = 2 +model = 'GuidingCenterAdaptative' + +# Initialize coils +current_on_each_coil = 1.84e7 +major_radius_coils = 7.75 +minor_radius_coils = 4.45 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + +len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) +nfp = coils_initial.nfp +stellsym = coils_initial.stellsym +n_segments = coils_initial.n_segments +dofs_curves_shape = coils_initial.dofs_curves.shape +currents_scale = coils_initial.currents_scale + +# Initialize particles +phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) +initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T +particles = Particles(initial_xyz=initial_xyz) + +t=maxtimes[0] +loss_partial = partial(loss_particle_gamma_c,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) +curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) +r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) + + +# Create the constraints +penalty = 1.05 #Intial penalty values +multiplier=1.0 #Initial lagrange multiplier values +sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative +constraints = alm.combine( +alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +#alm.eq(r_max_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +) + + + +model_lagrange='Mu_Tolerance_LBFGS' #Options: Mu_Constant, Mu_Monotonic, Mu_Conditional,Mu_Adaptative +beta=2. #penalty update parameter +mu_max=1.e4 #Maximum penalty parameter allowed +alpha=0.99 # +gamma=1.e-2 +epsilon=1.e-8 +omega_tol=0.01 #grad_tolerance, associated with grad of lagrangian to main parameters +eta_tol=1.e-6 #contrained tolerances, associated with variation of contraints +optimizer=optax.lbfgs(linesearch=optax.scale_by_zoom_linesearch(max_linesearch_steps=15)) + +ALM=alm.ALM_model(optimizer,constraints,model_lagrange=model_lagrange,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + +lagrange_params=constraints.init(coils_initial.x) +params = coils_initial.x, lagrange_params +opt_state,grad,value,info=ALM.init(params) +mu_average=alm.penalty_average(lagrange_params) +#omega=1.#1./mu_average +#eta=1000.#1./mu_average**0.1 +omega=1./mu_average +eta=1./mu_average**0.1 + +i=0 +while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): + params, opt_state, grad,value,info,eta,omega = ALM.update(params,opt_state,grad,value,info,eta,omega) #One step of ALM optimization + #if i % 5 == 0: + #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') + print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') + print('lagrange',params[1]) + i=i+1 + + +dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) +dofs_currents = params[0][len_dofs_curves:] +curves = Curves(dofs_curves, n_segments, nfp, stellsym) +new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) +params=new_coils.x +tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model + ,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) +tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) + +#print('Final params',params) +#print(info[1]) +# Plot trajectories, before and after optimization +fig = plt.figure(figsize=(9, 8)) +ax1 = fig.add_subplot(221, projection='3d') +ax2 = fig.add_subplot(222, projection='3d') +ax3 = fig.add_subplot(223) +ax4 = fig.add_subplot(224) + +coils_initial.plot(ax=ax1, show=False) +tracing_initial.plot(ax=ax1, show=False) +for i, trajectory in enumerate(tracing_initial.trajectories): + ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') + +ax3.set_xlabel('R (m)') +ax3.set_ylabel('Z (m)') +#ax3.legend() +new_coils.plot(ax=ax2, show=False) +tracing_optimized.plot(ax=ax2, show=False) +for i, trajectory in enumerate(tracing_optimized.trajectories): + ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') +ax4.set_xlabel('R (m)') +ax4.set_ylabel('Z (m)')#ax4.legend() +plt.tight_layout() +plt.savefig(f'opt_constrained.pdf') + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# tracing_initial.to_vtk('trajectories_initial') +#tracing_optimized.to_vtk('trajectories_final') +#coils_initial.to_vtk('coils_initial') +#new_coils.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py b/examples/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py new file mode 100644 index 0000000..6d35c88 --- /dev/null +++ b/examples/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py @@ -0,0 +1,180 @@ + +import os +number_of_processors_to_use = 8 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax +print(jax.devices()) +jax.config.update("jax_enable_x64", True) +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.surfaces import SurfaceRZFourier, SurfaceClassifier +from essos.dynamics import Particles, Tracing +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from essos.objective_functions import loss_lost_fraction,loss_lost_fraction_times +from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis_average,loss_Br,loss_iota +from functools import partial +import essos.augmented_lagrangian as alm + + + + + +# Optimization parameters +target_B_on_axis = 5.7 +max_coil_length = 31 +max_coil_curvature = 0.4 +nparticles = number_of_processors_to_use*1 +order_Fourier_series_coils = 4 +number_coil_points = 80 +maximum_function_evaluations = 10 +maxtimes = [1.e-2] +timestep=1.e-8 +num_steps=100 +number_coils_per_half_field_period = 3 +number_of_field_periods = 2 +model = 'GuidingCenterAdaptative' + +# Initialize coils +current_on_each_coil = 1.84e7 +major_radius_coils = 7.75 +minor_radius_coils = 4.45 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + + +len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) +nfp = coils_initial.nfp +stellsym = coils_initial.stellsym +n_segments = coils_initial.n_segments +dofs_curves_shape = coils_initial.dofs_curves.shape +currents_scale = coils_initial.currents_scale + +ntheta=30 +nphi=30 +input = os.path.join(os.path.dirname(__name__),'input_files','input.toroidal_surface') +surface= SurfaceRZFourier(input, ntheta=ntheta, nphi=nphi, range_torus='full torus') +timeI=time() +boundary=SurfaceClassifier(surface,h=0.1) +print(f"ESSOS boundary took {time()-timeI:.2f} seconds") +#print('Final params',params) +#print(info[1]) +# Plot trajectories, before and after optimization + + +# Initialize particles +phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) +initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T +particles = Particles(initial_xyz=initial_xyz) + +t=maxtimes[0] +loss_partial = partial(loss_lost_fraction,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,timestep=timestep,model=model,num_steps=num_steps,boundary=boundary) +jax.grad(loss_partial)(coils_initial.x) + + +curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) + +# Create the constraints +penalty = 1. #Intial penalty values +multiplier=0.5 #Initial lagrange multiplier values +sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative +model_lagrangian='Standard' #Use standard augmented lagragian suitable for bounded optimizers +#Since we are using LBFGS-B from jaxopt, model_mu will be updated with tolerances so we do not need to difinte the model + +#Construct constraints +constraints = alm.combine( +alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(loss_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +) + + + +beta=2. #penalty update parameter +mu_max=1.e4 #Maximum penalty parameter allowed +alpha=0.99 #These are parameters only used if gradient descent and adaaptative mu +gamma=1.e-2 +epsilon=1.e-8 +omega_tol=0.0001 #desired grad_tolerance, associated with grad of lagrangian to main parameters +eta_tol=0.001 #desired contraint tolerance, associated with variation of contraints + + + +#If loss=cost_function(x) is not prescribed, f(x)=0 is considered +ALM=alm.ALM_model_jaxopt_lbfgsb(constraints,model_lagrangian=model_lagrangian,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + +#Initializing lagrange multipliers +lagrange_params=constraints.init(coils_initial.x) +#parameters are a tuple of the primal/main optimisation parameters and the lagrange multipliers +params = coils_initial.x, lagrange_params +#This is just to initialize an empty state for the lagrange multiplier update and get some information +lag_state,grad,info=ALM.init(params) + +#Initializing first tolerances for the inner minimisation loop iteration +mu_average=alm.penalty_average(lagrange_params) +#omega=1.#1./mu_average +#eta=1000.#1./mu_average**0.1 +omega=1./mu_average +eta=1./mu_average**0.1 + +i=0 +while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): + #One step of ALM optimization + params, lag_state,grad,info,eta,omega = ALM.update(params,lag_state,grad,info,eta,omega) + print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') + #print('lagrange',params[1]) + i=i+1 + + +dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) +dofs_currents = params[0][len_dofs_curves:] +curves = Curves(dofs_curves, n_segments, nfp, stellsym) +new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) +params=new_coils.x +tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model,times_to_trace=num_steps,timestep=timestep,boundary=boundary) +tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=num_steps,timestep=timestep,boundary=boundary) + +#print('Final params',params) +#print(info[1]) +# Plot trajectories, before and after optimization +fig = plt.figure(figsize=(9, 8)) +ax1 = fig.add_subplot(221, projection='3d') +ax2 = fig.add_subplot(222, projection='3d') +ax3 = fig.add_subplot(223) +ax4 = fig.add_subplot(224) + +coils_initial.plot(ax=ax1, show=False) +tracing_initial.plot(ax=ax1, show=False) +for i, trajectory in enumerate(tracing_initial.trajectories): + ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') + +ax3.set_xlabel('R (m)') +ax3.set_ylabel('Z (m)') +#ax3.legend() +new_coils.plot(ax=ax2, show=False) +tracing_optimized.plot(ax=ax2, show=False) +for i, trajectory in enumerate(tracing_optimized.trajectories): + ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') +ax4.set_xlabel('R (m)') +ax4.set_ylabel('Z (m)')#ax4.legend() +plt.tight_layout() +plt.show() + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# tracing_initial.to_vtk('trajectories_initial') +#tracing_optimized.to_vtk('trajectories_final') +#coils_initial.to_vtk('coils_initial') +#new_coils.to_vtk('coils_optimized') diff --git a/examples/optimize_coils_vmec_surface.py b/examples/optimize_coils_vmec_surface.py index 2ded4be..57324b2 100644 --- a/examples/optimize_coils_vmec_surface.py +++ b/examples/optimize_coils_vmec_surface.py @@ -46,8 +46,13 @@ max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,) print(f"Optimization took {time()-time0:.2f} seconds") + BdotN_over_B_initial = BdotN_over_B(vmec.surface, BiotSavart(coils_initial)) BdotN_over_B_optimized = BdotN_over_B(vmec.surface, BiotSavart(coils_optimized)) +curvature=jnp.mean(BiotSavart(coils_optimized).coils.curvature, axis=1) +length=jnp.max(jnp.ravel(BiotSavart(coils_optimized).coils.length)) +print(f"Mean curvature: ",curvature) +print(f"Length:", length) print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") diff --git a/examples/optimize_coils_vmec_surface_augmented_lagrangian.py b/examples/optimize_coils_vmec_surface_augmented_lagrangian.py new file mode 100644 index 0000000..23f08b3 --- /dev/null +++ b/examples/optimize_coils_vmec_surface_augmented_lagrangian.py @@ -0,0 +1,189 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide ntheta*nphi +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.surfaces import BdotN_over_B +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from essos.fields import Vmec, BiotSavart +from essos.objective_functions import loss_BdotN_only_constraint,loss_coil_curvature_new,loss_coil_length_new,loss_BdotN_only +from essos.objective_functions import loss_coil_curvature,loss_coil_length +from essos.objective_functions import loss_BdotN +from essos.optimization import optimize_loss_function + +import essos.augmented_lagrangian as alm +from functools import partial + +# Optimization parameters +maximum_function_evaluations=10 +max_coil_length = 40 +max_coil_curvature = 0.5 +bdotn_tol=1.e-6 +order_Fourier_series_coils = 6 +number_coil_points = order_Fourier_series_coils*10 +number_coils_per_half_field_period = 4 +ntheta=32 +nphi=32 +#Tolerance for no normal (no ALM) optimization +tolerance_optimization = 1e-5 + +# Initialize VMEC field +vmec = Vmec(os.path.join(os.path.dirname(__name__), 'input_files', + 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), + ntheta=ntheta, nphi=nphi, range_torus='half period') + +# Initialize coils +current_on_each_coil = 1 +number_of_field_periods = vmec.nfp +major_radius_coils = vmec.r_axis +minor_radius_coils = vmec.r_axis/1.5 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + +len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) +nfp = coils_initial.nfp +stellsym = coils_initial.stellsym +n_segments = coils_initial.n_segments +dofs_curves = coils_initial.dofs_curves +currents_scale = coils_initial.currents_scale +dofs_curves_shape = coils_initial.dofs_curves.shape + + + + +# Create the constraints +penalty = 0.1 #Intial penalty values +multiplier=0.5 #Initial lagrange multiplier values +sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative +model_lagrangian='Standard' #Use standard augmented lagragian suitable for bounded optimizers +#Since we are using LBFGS-B from jaxopt, model_mu will be updated with tolerances so we do not need to difinte the model + + +curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +bdotn_partial=partial(loss_BdotN_only_constraint, vmec=vmec, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym,target_tol=bdotn_tol) +bdotn_only_partial=partial(loss_BdotN_only, vmec=vmec, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) + +#Construct constraints +constraints = alm.combine( +alm.eq(curvature_partial,model_lagrangian=model_lagrangian, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(length_partial,model_lagrangian=model_lagrangian, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(bdotn_partial,model_lagrangian=model_lagrangian, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad) +) + + + +beta=2. #penalty update parameter +mu_max=1.e4 #Maximum penalty parameter allowed +alpha=0.99 #These are parameters only used if gradient descent and adaaptative mu +gamma=1.e-2 +epsilon=1.e-8 +omega_tol=1.e-7 #desired grad_tolerance, associated with grad of lagrangian to main parameters +eta_tol=1.e-7 #desired contraint tolerance, associated with variation of contraints + + + +#If loss=cost_function(x) is not prescribed, f(x)=0 is considered +ALM=alm.ALM_model_jaxopt_lbfgsb(constraints,model_lagrangian=model_lagrangian,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + +#Initializing lagrange multipliers +lagrange_params=constraints.init(coils_initial.x) +#parameters are a tuple of the primal/main optimisation parameters and the lagrange multipliers +params = coils_initial.x, lagrange_params +#This is just to initialize an empty state for the lagrange multiplier update and get some information +lag_state,grad,info=ALM.init(params) + +#Initializing first tolerances for the inner minimisation loop iteration +mu_average=alm.penalty_average(lagrange_params) +omega=1./mu_average +eta=1./mu_average**0.1 + + + + + +# Optimize coils +print(f'Optimizing coils with {maximum_function_evaluations} function evaluations no ALM.') +time0 = time() +coils_optimized = optimize_loss_function(loss_BdotN, initial_dofs=coils_initial.x, coils=coils_initial, tolerance_optimization=tolerance_optimization, + maximum_function_evaluations=maximum_function_evaluations, vmec=vmec, + max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,) +print(f"Optimization took {time()-time0:.2f} seconds") + + + + + +# Optimize coils +print(f'Optimizing coils with {maximum_function_evaluations} function evaluations using ALM.') +time0 = time() + + +i=0 +while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): + #One step of ALM optimization + params, lag_state,grad,info,eta,omega = ALM.update(params,lag_state,grad,info,eta,omega) + #if i % 5 == 0: + #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') + print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') + #print('lagrange',params[1]) + i=i+1 + + + +dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) +dofs_currents = params[0][len_dofs_curves:] +curves = Curves(dofs_curves, n_segments, nfp, stellsym) +coils_optimized_alm = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) + +print(f"Optimization took {time()-time0:.2f} seconds") + + +BdotN_over_B_initial = BdotN_over_B(vmec.surface, BiotSavart(coils_initial)) +BdotN_over_B_optimized = BdotN_over_B(vmec.surface, BiotSavart(coils_optimized)) +curvature=jnp.mean(BiotSavart(coils_optimized).coils.curvature, axis=1) +length=jnp.max(jnp.ravel(BiotSavart(coils_optimized).coils.length)) +BdotN_over_B_optimized_alm = BdotN_over_B(vmec.surface, BiotSavart(coils_optimized_alm)) +curvature_alm=jnp.mean(BiotSavart(coils_optimized_alm).coils.curvature, axis=1) +length_alm=jnp.max(jnp.ravel(BiotSavart(coils_optimized_alm).coils.length)) + + +print(f"Maximum allowed curvature target: ",max_coil_curvature) +print(f"Maximum allowed length target: ",max_coil_length) +print(f"Mean curvature without ALM: ",curvature) +print(f"Length withou ALM:", length) +print(f"Mean curvature with ALM: ",curvature_alm) +print(f"Length with ALM:", length_alm) +print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") +print(f"Maximum BdotN/B after optimization without ALM: {jnp.max(BdotN_over_B_optimized):.2e}") +print(f"Maximum BdotN/B after optimization with ALM: {jnp.max(BdotN_over_B_optimized_alm):.2e}") +# Plot coils, before and after optimization +fig = plt.figure(figsize=(8, 4)) +ax1 = fig.add_subplot(121, projection='3d') +ax2 = fig.add_subplot(122, projection='3d') +coils_initial.plot(ax=ax1, show=False) +vmec.surface.plot(ax=ax1, show=False) +coils_optimized.plot(ax=ax2, show=False, label='Optimized no ALM') +coils_optimized_alm.plot(ax=ax2, show=False,color='orange', label='Optimized with ALM') +vmec.surface.plot(ax=ax2, show=False) +plt.legend() +plt.tight_layout() +plt.show() + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# from essos.fields import BiotSavart +# vmec.surface.to_vtk('surface_initial', field=BiotSavart(coils_initial)) +# vmec.surface.to_vtk('surface_final', field=BiotSavart(coils_optimized)) +# coils_initial.to_vtk('coils_initial') +# coils_optimized.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py b/examples/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py new file mode 100644 index 0000000..fdf423a --- /dev/null +++ b/examples/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py @@ -0,0 +1,178 @@ +import os +number_of_processors_to_use = 8 # Parallelization, this should divide ntheta*nphi +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +from essos.surfaces import BdotN_over_B +from essos.coils import Coils, CreateEquallySpacedCurves,Curves +from essos.fields import Vmec, BiotSavart +from essos.objective_functions import loss_BdotN_only_constraint_stochastic,loss_coil_curvature_new,loss_coil_length_new,loss_BdotN_only_stochastic +from essos.objective_functions import loss_coil_curvature,loss_coil_length +from essos.coil_perturbation import GaussianSampler + +import essos.augmented_lagrangian as alm +from functools import partial + +# Optimization parameters +maximum_function_evaluations=10 +max_coil_length = 40 +max_coil_curvature = 0.5 +bdotn_tol=1.e-6 +order_Fourier_series_coils = 6 +number_coil_points = order_Fourier_series_coils*10 +number_coils_per_half_field_period = 4 +ntheta=32 +nphi=32 + + + + + +# Initialize VMEC field +vmec = Vmec(os.path.join(os.path.dirname(__name__), 'input_files', + 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), + ntheta=ntheta, nphi=nphi, range_torus='full torus') + +# Initialize coils +current_on_each_coil = 1 +number_of_field_periods = vmec.nfp +major_radius_coils = vmec.r_axis +minor_radius_coils = vmec.r_axis/1.5 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + +len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) +nfp = coils_initial.nfp +stellsym = coils_initial.stellsym +n_segments = coils_initial.n_segments +dofs_curves = coils_initial.dofs_curves +currents_scale = coils_initial.currents_scale +dofs_curves_shape = coils_initial.dofs_curves.shape + + + +#Sampling parameters +sigma=0.01 +length_scale=0.4*jnp.pi +n_derivs=2 +N_samples=10 #Number of samples for the stochastic perturbation +#Create a Gaussian sampler for perturbation +#This sampler will be used to perturb the coils +sampler=GaussianSampler(coils_initial.quadpoints,sigma=sigma,length_scale=length_scale,n_derivs=n_derivs) + + + + +# Create the constraints +penalty = 0.1 #Intial penalty values +multiplier=0.5 #Initial lagrange multiplier values +sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative +model_lagrangian='Standard' #Use standard augmented lagragian suitable for bounded optimizers +#Since we are using LBFGS-B from jaxopt, model_mu will be updated with tolerances so we do not need to difinte the model + + +curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) +length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) +bdotn_partial=partial(loss_BdotN_only_constraint_stochastic,sampler=sampler,N_samples=N_samples, vmec=vmec, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym,target_tol=bdotn_tol) +bdotn_only_partial=partial(loss_BdotN_only_stochastic,sampler=sampler,N_samples=N_samples, vmec=vmec, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) + +#Construct constraints +constraints = alm.combine( +alm.eq(curvature_partial,model_lagrangian=model_lagrangian, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(length_partial,model_lagrangian=model_lagrangian, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), +alm.eq(bdotn_partial,model_lagrangian=model_lagrangian, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad) +) + + + +beta=2. #penalty update parameter +mu_max=1.e4 #Maximum penalty parameter allowed +alpha=0.99 #These are parameters only used if gradient descent and adaaptative mu +gamma=1.e-2 +epsilon=1.e-8 +omega_tol=1.e-7 #desired grad_tolerance, associated with grad of lagrangian to main parameters +eta_tol=1.e-7 #desired contraint tolerance, associated with variation of contraints + + + +#If loss=cost_function(x) is not prescribed, f(x)=0 is considered +ALM=alm.ALM_model_jaxopt_lbfgsb(constraints,model_lagrangian=model_lagrangian,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) + +#Initializing lagrange multipliers +lagrange_params=constraints.init(coils_initial.x) +#parameters are a tuple of the primal/main optimisation parameters and the lagrange multipliers +params = coils_initial.x, lagrange_params +#This is just to initialize an empty state for the lagrange multiplier update and get some information +lag_state,grad,info=ALM.init(params) + +#Initializing first tolerances for the inner minimisation loop iteration +mu_average=alm.penalty_average(lagrange_params) +omega=1./mu_average +eta=1./mu_average**0.1 + + + + +# Optimize coils +print(f'Optimizing coils with {maximum_function_evaluations} function evaluations using stochastic and ALM.') +time0 = time() + + +i=0 +while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): + #One step of ALM optimization + params, lag_state,grad,info,eta,omega = ALM.update(params,lag_state,grad,info,eta,omega) + #if i % 5 == 0: + #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') + print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') + #print('lagrange',params[1]) + i=i+1 + + + +dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) +dofs_currents = params[0][len_dofs_curves:] +curves = Curves(dofs_curves, n_segments, nfp, stellsym) +coils_optimized = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) + +print(f"Stochastic optimization with ALM took {time()-time0:.2f} seconds") + + +BdotN_over_B_initial = BdotN_over_B(vmec.surface, BiotSavart(coils_initial)) +BdotN_over_B_optimized = BdotN_over_B(vmec.surface, BiotSavart(coils_optimized)) +curvature=jnp.mean(BiotSavart(coils_optimized).coils.curvature, axis=1) +length=jnp.max(jnp.ravel(BiotSavart(coils_optimized).coils.length)) +print(f"Mean curvature: ",curvature) +print(f"Length:", length) +print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") +print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") +print(f"Average BdotN/B before optimization: {jnp.average(jnp.absolute(BdotN_over_B_initial)):.2e}") +print(f"Average BdotN/B after optimization: {jnp.average(jnp.absolute(BdotN_over_B_optimized)):.2e}") +# Plot coils, before and after optimization +fig = plt.figure(figsize=(8, 4)) +ax1 = fig.add_subplot(121, projection='3d') +ax2 = fig.add_subplot(122, projection='3d') +coils_initial.plot(ax=ax1, show=False) +vmec.surface.plot(ax=ax1, show=False) +coils_optimized.plot(ax=ax2, show=False) +vmec.surface.plot(ax=ax2, show=False) +plt.tight_layout() +plt.show() + +# # Save the coils to a json file +# coils_optimized.to_json("stellarator_coils.json") +# # Load the coils from a json file +# from essos.coils import Coils_from_json +# coils = Coils_from_json("stellarator_coils.json") + +# # Save results in vtk format to analyze in Paraview +# from essos.fields import BiotSavart +# vmec.surface.to_vtk('surface_initial', field=BiotSavart(coils_initial)) +# vmec.surface.to_vtk('surface_final', field=BiotSavart(coils_optimized)) +# coils_initial.to_vtk('coils_initial') +# coils_optimized.to_vtk('coils_optimized') \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 49907c7..1f770eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ keywords = ["Plasma", "Simulation", "JAX"] -dependencies = [ "jax", "jaxlib", "tqdm", "matplotlib", "diffrax", "optax", "scipy", "jaxkd", "netcdf4"] +dependencies = [ "jax", "jaxlib", "tqdm", "matplotlib", "diffrax", "optax", "jaxopt", "optimistix", "scipy", "jaxkd", "netcdf4"] requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index f7c6b4a..aea8487 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,8 @@ tqdm matplotlib diffrax optax +jaxopt +optimistix scipy jaxkd netcdf4 diff --git a/tests/test_augmented_lagrangian.py b/tests/test_augmented_lagrangian.py new file mode 100644 index 0000000..6be5c41 --- /dev/null +++ b/tests/test_augmented_lagrangian.py @@ -0,0 +1,247 @@ +import unittest +import pytest +import jax +import jax.numpy as jnp +import optax + +from essos.augmented_lagrangian import ( + LagrangeMultiplier, + update_method, + update_method_squared, + eq, + ineq, + combine, + total_infeasibility, + norm_constraints, + infty_norm_constraints, + penalty_average, + Constraint, + ALM, + lagrange_update, + ALM_model_optax, + ALM_model_jaxopt_lbfgsb, + ALM_model_jaxopt_LevenbergMarquardt, + ALM_model_jaxopt_lbfgs, + ALM_model_optimistix_LevenbergMarquardt, +) + +class TestAugmentedLagrangian(unittest.TestCase): + + def test_lagrange_multiplier(self): + lm = LagrangeMultiplier(value=1.0, penalty=2.0, sq_grad=3.0) + self.assertEqual(lm.value, 1.0) + self.assertEqual(lm.penalty, 2.0) + self.assertEqual(lm.sq_grad, 3.0) + + def test_update_method_all_modes(self): + params = LagrangeMultiplier(jnp.array([1.]), jnp.array([2.]), jnp.array([0.])) + updates = LagrangeMultiplier(jnp.array([0.5]), jnp.array([0.]), jnp.array([0.])) + for mode in [ + 'Constant', 'Mu_Monotonic', 'Mu_Conditional_True', 'Mu_Conditional_False', + 'Mu_Tolerance_True', 'Mu_Tolerance_False', 'Mu_Adaptative' + ]: + if 'Tolerance' in mode: + result, eta, omega = update_method(params, updates, 1.0, 1.0, model_mu=mode) + self.assertIsInstance(result, LagrangeMultiplier) + self.assertIsInstance(eta, jnp.ndarray) + self.assertIsInstance(omega, jnp.ndarray) + else: + result = update_method(params, updates, 1.0, 1.0, model_mu=mode) + self.assertIsInstance(result, LagrangeMultiplier) + + def test_update_method_squared_all_modes(self): + params = LagrangeMultiplier(jnp.array([1.]), jnp.array([2.]), jnp.array([0.])) + updates = LagrangeMultiplier(jnp.array([0.5]), jnp.array([0.]), jnp.array([0.])) + for mode in [ + 'Constant', 'Mu_Monotonic', 'Mu_Conditional_True', 'Mu_Conditional_False', + 'Mu_Tolerance_True', 'Mu_Tolerance_False', 'Mu_Adaptative' + ]: + if 'Tolerance' in mode: + result, eta, omega = update_method_squared(params, updates, 1.0, 1.0, model_mu=mode) + self.assertIsInstance(result, LagrangeMultiplier) + self.assertIsInstance(eta, jnp.ndarray) + self.assertIsInstance(omega, jnp.ndarray) + else: + result = update_method_squared(params, updates, 1.0, 1.0, model_mu=mode) + self.assertIsInstance(result, LagrangeMultiplier) + + def test_eq_and_ineq_constraint(self): + def fun(x): return x - 2 + eq_constraint = eq(fun) + ineq_constraint = ineq(fun) + params_eq = eq_constraint.init(jnp.array([3.])) + params_ineq = ineq_constraint.init(jnp.array([3.])) + eq_constraint.loss(params_eq, jnp.array([3.])) + ineq_constraint.loss(params_ineq, jnp.array([3.])) + + def test_eq_and_ineq_constraint_squared(self): + def fun(x): return x - 2 + eq_constraint = eq(fun, model_lagrangian='Squared') + ineq_constraint = ineq(fun, model_lagrangian='Squared') + params_eq = eq_constraint.init(jnp.array([3.])) + params_ineq = ineq_constraint.init(jnp.array([3.])) + eq_constraint.loss(params_eq, jnp.array([3.])) + ineq_constraint.loss(params_ineq, jnp.array([3.])) + + def test_combine_constraints(self): + def fun1(x): return x - 1 + def fun2(x): return x + 1 + c1 = eq(fun1) + c2 = eq(fun2) + combined = combine(c1, c2) + params = combined.init(jnp.array([2.])) + combined.loss(params, jnp.array([2.])) + + def test_combine_multiple_constraints(self): + def fun1(x): return x - 1 + def fun2(x): return x + 1 + def fun3(x): return x * 2 + c1 = eq(fun1) + c2 = eq(fun2) + c3 = eq(fun3) + combined = combine(c1, c2, c3) + params = combined.init(jnp.array([2.])) + combined.loss(params, jnp.array([2.])) + + def test_total_infeasibility(self): + tree = {'a': jnp.array([1.0, -2.0]), 'b': jnp.array([3.0])} + result = total_infeasibility(tree) + self.assertAlmostEqual(float(result), 6.0) + + def test_norm_constraints(self): + tree = {'a': jnp.array([3.0, 4.0])} + result = norm_constraints(tree) + self.assertAlmostEqual(float(result), 5.0) + + def test_infty_norm_constraints(self): + tree = {'a': jnp.array([1.0, -5.0, 3.0])} + result = infty_norm_constraints(tree) + self.assertAlmostEqual(float(result), 3.0) + + def test_penalty_average(self): + tree = {'a': LagrangeMultiplier(jnp.array([1.0]), jnp.array([2.0]), jnp.array([0.0]))} + result = penalty_average(tree) + self.assertAlmostEqual(float(result), 2.0) + + def test_constraint_namedtuple(self): + def fun(x): return x - 1 + c = eq(fun) + self.assertIsInstance(c, Constraint) + params = c.init(jnp.array([2.])) + c.loss(params, jnp.array([2.])) + + def test_alm_namedtuple(self): + def dummy_init(*args, **kwargs): return None + def dummy_update(*args, **kwargs): return None + alm = ALM(dummy_init, dummy_update) + self.assertIsInstance(alm, ALM) + self.assertTrue(callable(alm.init)) + self.assertTrue(callable(alm.update)) + + def test_lagrange_update_gradient_transformation_and_update(self): + gt = lagrange_update('Standard') + self.assertTrue(hasattr(gt, 'init')) + self.assertTrue(hasattr(gt, 'update')) + # Call init and update with dummy data + params = {'x': jnp.array([1.0])} + lagrange_params = LagrangeMultiplier(jnp.array([0.0]), jnp.array([1.0]), jnp.array([0.0])) + updates = LagrangeMultiplier(jnp.array([-0.5]), jnp.array([1.0]), jnp.array([1.0])) + state = gt.init(params) + # eta, omega, etc. are required by update_fn signature + eta = {'x': jnp.array([0.0])} + omega = {'x': jnp.array([0.0])} + gt.update(lagrange_params, updates, state, eta, omega, params=params) + gt2 = lagrange_update('Squared') + state2 = gt2.init(params) + gt2.update(lagrange_params, updates, state2, eta, omega, params=params) + + def test_eq_constraint_init_kwargs(self): + def fun(x, y=0): return x + y - 2 + constraint = eq(fun) + params = constraint.init(jnp.array([3.]), y=1) + self.assertIn('lambda', params) + + def test_ineq_constraint_init_kwargs(self): + def fun(x, y=0): return x + y - 2 + constraint = ineq(fun) + params = constraint.init(jnp.array([3.]), y=1) + self.assertIn('lambda', params) + self.assertIn('slack', params) + + # ---- ALM model tests ---- + + def test_ALM_model_optax_init_and_update(self): + optimizer = optax.sgd(1e-3) + def fun(x): return x - 1 + constraint = eq(fun) + main_params = jnp.array([6.0,2.0]) + lagrange_params = constraint.init(main_params) + params = main_params,lagrange_params + alm = ALM_model_optax(optimizer, constraint,model_mu='Mu_Conditional') + self.assertIsInstance(alm, ALM) + # Call init and update + state,grad,info = alm.init(params) + # Simulate a gradient step + eta = jnp.array(1.0) + omega = jnp.array(1.0) + alm.update(params, state,grad,info,eta,omega) + + def test_ALM_model_jaxopt_lbfgsb_init_and_update(self): + def fun(x): return x - 1 + constraint = eq(fun) + main_params = jnp.array([6.0,2.0]) + lagrange_params = constraint.init(main_params) + params = main_params,lagrange_params + alm = ALM_model_jaxopt_lbfgsb(constraint) + self.assertIsInstance(alm, ALM) + state,grad,info = alm.init(params) + eta = jnp.array(1.0) + omega = jnp.array(1.0) + alm.update(params, state,grad,info,eta,omega) + + + def test_ALM_model_jaxopt_LevenbergMarquardt_init_and_update(self): + def fun(x): return x - 1 + constraint = eq(fun) + main_params = jnp.array([6.0,2.0]) + lagrange_params = constraint.init(main_params) + params = main_params,lagrange_params + alm = ALM_model_jaxopt_LevenbergMarquardt(constraint) + self.assertIsInstance(alm, ALM) + state,grad,info = alm.init(params) + eta = jnp.array(1.0) + omega = jnp.array(1.0) + alm.update(params, state,grad,info,eta,omega) + + + + def test_ALM_model_jaxopt_lbfgs_init_and_update(self): + def fun(x): return x - 1 + constraint = eq(fun) + main_params = jnp.array([6.0,2.0]) + lagrange_params = constraint.init(main_params) + params = main_params,lagrange_params + alm = ALM_model_jaxopt_lbfgs(constraint) + self.assertIsInstance(alm, ALM) + state,grad,info = alm.init(params) + eta = jnp.array(1.0) + omega = jnp.array(1.0) + alm.update(params, state,grad,info,eta,omega) + + +# def test_ALM_model_optimistix_LevenbergMarquardt_init_and_update(self): +# def fun(x): return x - 1 +# constraint = eq(fun) +# main_params = jnp.array([6.0,2.0]) +# lagrange_params = constraint.init(main_params) +# params = main_params,lagrange_params +# alm = ALM_model_optimistix_LevenbergMarquardt(constraint) +# self.assertIsInstance(alm, ALM) +# state,grad,info = alm.init(params) +# eta = jnp.array(1.0) +# omega = jnp.array(1.0) +# alm.update(params, state,grad,info,eta,omega) + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/test_coil_perturbation.py b/tests/test_coil_perturbation.py new file mode 100644 index 0000000..a33821e --- /dev/null +++ b/tests/test_coil_perturbation.py @@ -0,0 +1,127 @@ +import unittest +import jax +import jax.numpy as jnp +import numpy as np + +from essos.coil_perturbation import ( + #ldl_decomposition, + matrix_sqrt_via_spectral, + GaussianSampler, + PerturbationSample, + perturb_curves_systematic, + perturb_curves_statistic, +) + +# Dummy Curves and apply_symmetries_to_gammas for testing +class DummyCurves: + def __init__(self, n_base_curves=2, nfp=1, stellsym=True, n_points=5, n_derivs=2): + self.n_base_curves = n_base_curves + self.nfp = nfp + self.stellsym = stellsym + self.gamma = jnp.zeros((n_base_curves, n_points, 3)) + self.gamma_dash = jnp.zeros((n_base_curves, n_points, 3)) + self.gamma_dashdash = jnp.zeros((n_base_curves, n_points, 3)) + +def dummy_apply_symmetries_to_gammas(gamma, nfp, stellsym): + # Just return the input for testing + return gamma + +# Patch apply_symmetries_to_gammas in the tested module +import essos.coil_perturbation +essos.coil_perturbation.apply_symmetries_to_gammas = dummy_apply_symmetries_to_gammas + +class TestCoilPerturbation(unittest.TestCase): + #def test_ldl_decomposition(self): + # A = jnp.array([[4.0, 2.0], [2.0, 3.0]]) + # L, D = ldl_decomposition(A) + # # Check shapes + # self.assertEqual(L.shape, (2, 2)) + # self.assertEqual(D.shape, (2,)) + # # Check that A ≈ L @ jnp.diag(D) @ L.T + # A_recon = L @ jnp.diag(D) @ L.T + # np.testing.assert_allclose(A, A_recon, atol=1e-6) + + def test_matrix_sqrt_via_spectral(self): + A = jnp.array([[4.0, 2.0], [2.0, 3.0]]) + sqrt_A = matrix_sqrt_via_spectral(A) + # sqrt_A @ sqrt_A ≈ A + A_recon = sqrt_A @ sqrt_A + np.testing.assert_allclose(A, A_recon, atol=1e-6) + + def test_gaussian_sampler_covariances_and_draw(self): + points = jnp.linspace(0, 1, 5) + sampler0 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=0) + sampler1 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=1) + sampler2 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=2) + # Covariance matrices + cov0 = sampler0.get_covariance_matrix() + cov1 = sampler1.get_covariance_matrix() + cov2 = sampler2.get_covariance_matrix() + self.assertEqual(cov0.shape[0], 5) + self.assertEqual(cov1.shape[0], 10) + self.assertEqual(cov2.shape[0], 15) + # Draw samples + key = jax.random.PRNGKey(0) + sample0 = sampler0.draw_sample(key) + sample1 = sampler1.draw_sample(key) + sample2 = sampler2.draw_sample(key) + self.assertEqual(sample0.shape, (1, 5, 3)) + self.assertEqual(sample1.shape, (2, 5, 3)) + self.assertEqual(sample2.shape, (3, 5, 3)) + + def test_gaussian_sampler_kernels(self): + points = jnp.linspace(0, 1, 3) + sampler = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=2) + # Test kernel and derivatives + val = sampler.kernel_periodicity(0.1, 0.2) + dval = sampler.d_kernel_periodicity_dx(0.1, 0.2) + ddval = sampler.d_kernel_periodicity_dxdx(0.1, 0.2) + dddval = sampler.d_kernel_periodicity_dxdxdx(0.1, 0.2) + ddddval = sampler.d_kernel_periodicity_dxdxdxdx(0.1, 0.2) + self.assertIsInstance(val, jnp.ndarray) + self.assertIsInstance(dval, jnp.ndarray) + self.assertIsInstance(ddval, jnp.ndarray) + self.assertIsInstance(dddval, jnp.ndarray) + self.assertIsInstance(ddddval, jnp.ndarray) + + def test_perturbation_sample(self): + points = jnp.linspace(0, 1, 5) + sampler = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=1) + key = jax.random.PRNGKey(0) + ps = PerturbationSample(sampler, key) + # get_sample for deriv=0 and deriv=1 + s0 = ps.get_sample(0) + s1 = ps.get_sample(1) + self.assertEqual(s0.shape, (5, 3)) + self.assertEqual(s1.shape, (5, 3)) + # resample + ps.resample() + # get_sample with too high deriv should raise + with self.assertRaises(ValueError): + ps.get_sample(2) + + def test_perturb_curves_systematic(self): + points = jnp.linspace(0, 1, 5) + sampler0 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=0) + sampler1 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=1) + sampler2 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=2) + key = jax.random.PRNGKey(0) + for sampler in [sampler0, sampler1, sampler2]: + curves = DummyCurves(n_base_curves=2, nfp=1, stellsym=True, n_points=5) + perturb_curves_systematic(curves, sampler, key) + # Just check that gamma arrays are still the right shape + self.assertEqual(curves.gamma.shape, (2, 5, 3)) + + def test_perturb_curves_statistic(self): + points = jnp.linspace(0, 1, 5) + sampler0 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=0) + sampler1 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=1) + sampler2 = GaussianSampler(points, sigma=1.0, length_scale=0.5, n_derivs=2) + key = jax.random.PRNGKey(0) + for sampler in [sampler0, sampler1, sampler2]: + curves = DummyCurves(n_base_curves=2, nfp=1, stellsym=True, n_points=5) + perturb_curves_statistic(curves, sampler, key) + self.assertEqual(curves.gamma.shape, (2, 5, 3)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_fields.py b/tests/test_fields.py index bccb672..74ea977 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -8,6 +8,7 @@ def __init__(self): self.currents = jnp.array([1.0, 2.0, 3.0]) self.gamma = random.uniform(random.PRNGKey(0), (3, 3, 3)) self.gamma_dash = random.uniform(random.PRNGKey(0), (3, 3, 3)) + self.gamma_dashdash = random.uniform(random.PRNGKey(0), (3, 3, 3)) self.dofs_curves = random.uniform(random.PRNGKey(0), (3, 3, 3)) def test_biot_savart_initialization(): diff --git a/tests/test_objective_functions.py b/tests/test_objective_functions.py new file mode 100644 index 0000000..c2d6eed --- /dev/null +++ b/tests/test_objective_functions.py @@ -0,0 +1,261 @@ +import unittest +from unittest.mock import MagicMock, patch +import jax.numpy as jnp + +import essos.objective_functions as objf + +class DummyField: + def __init__(self): + self.R0 = jnp.array([1.]) + self.Z0 = jnp.array([0.]) + self.phi = jnp.array([0.]) + self.B_axis = jnp.array([[1., 0., 0.]]) + self.grad_B_axis = jnp.array([[0., 0., 0.]]) + self.r_axis = 1.0 + self.z_axis = 0.0 + self.AbsB = MagicMock(return_value=5.7) + self.B = MagicMock(return_value=jnp.array([1., 0., 0.])) + self.dB_by_dX = MagicMock(return_value=jnp.array([0., 0., 0.])) + self.B_covariant = MagicMock(return_value=jnp.array([1., 0., 0.])) + self.coils_length = jnp.array([30.]) + self.coils_curvature = jnp.ones((2, 10)) + self.gamma = jnp.zeros((2, 10, 3)) + self.gamma_dash = jnp.ones((2, 10, 3)) + self.gamma_dashdash = jnp.ones((2, 10, 3)) + self.currents = jnp.ones(2) + self.quadpoints = jnp.linspace(0, 1, 10) + self.x = jnp.zeros((10,1)) + +class DummyCoils(DummyField): + def __init__(self): + super().__init__() + +class DummyCurves: + def __init__(self, *args, **kwargs): + pass + +class DummyParticles: + def __init__(self): + self.to_full_orbit = MagicMock() + self.trajectories = jnp.zeros((2, 10, 3)) + +class DummyTracing: + def __init__(self, *args, **kwargs): + self.trajectories = jnp.zeros((2, 10, 3)) + self.field = DummyField() + self.loss_fractions = jnp.array([0.1,0.2,1.]) + self.times_to_trace = 10 + self.maxtime = 1e-5 + +class DummyVmec: + def __init__(self): + self.surface = MagicMock() + +class DummySurface: + def __init__(self): + self.gamma = jnp.zeros((10, 3)) + self.unitnormal = jnp.ones((10, 3)) + +def dummy_sampler(*args, **kwargs): + return 0 + +def dummy_new_nearaxis_from_x_and_old_nearaxis(x, field_nearaxis): + class DummyNearAxis: + elongation = jnp.array([1.]) + iota = 1.0 + x = jnp.array([1.]) + R0 = jnp.array([1.]) + Z0 = jnp.array([0.]) + phi = jnp.array([0.]) + B_axis = jnp.array([[1., 0., 0.]]) + grad_B_axis = jnp.array([[0., 0., 0.]]) + return DummyNearAxis() + +class TestObjectiveFunctions(unittest.TestCase): + def setUp(self): + self.x = jnp.ones(12) + self.dofs_curves = jnp.ones((2, 3)) + self.currents_scale = 1.0 + self.nfp = 1 + self.n_segments = 10 + self.stellsym = True + self.key = 0 + self.sampler = dummy_sampler + self.field = DummyField() + self.coils = DummyCoils() + self.curves = DummyCurves() + self.particles = DummyParticles() + self.tracing = DummyTracing() + self.vmec = DummyVmec() + self.surface = DummySurface() + + @patch('essos.objective_functions.Curves', return_value=DummyCurves()) + @patch('essos.objective_functions.Coils', return_value=DummyCoils()) + @patch('essos.objective_functions.BiotSavart', return_value=DummyField()) + @patch('essos.objective_functions.perturb_curves_systematic') + @patch('essos.objective_functions.perturb_curves_statistic') + def test_perturbed_field_and_coils_from_dofs(self, pcs, pcss, bs, coils, curves): + objf.pertubred_field_from_dofs(self.x, self.key, self.sampler, self.dofs_curves, self.currents_scale, self.nfp) + objf.perturbed_coils_from_dofs(self.x, self.key, self.sampler, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.Curves', return_value=DummyCurves()) + @patch('essos.objective_functions.Coils', return_value=DummyCoils()) + @patch('essos.objective_functions.BiotSavart', return_value=DummyField()) + def test_field_and_coils_from_dofs(self, bs, coils, curves): + objf.field_from_dofs(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.coils_from_dofs(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.curves_from_dofs(self.x, self.dofs_curves, self.nfp) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + def test_loss_coil_length_and_curvature(self, ffd): + objf.loss_coil_length(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_coil_curvature(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_coil_length_new(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_coil_curvature_new(self.x, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + def test_loss_normB_axis(self, ffd): + objf.loss_normB_axis(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_normB_axis_average(self.x, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + def test_loss_particle_functions(self, ffd): + with patch('essos.objective_functions.Tracing', return_value=self.tracing): + objf.loss_particle_radial_drift(self.x, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_particle_alpha_drift(self.x, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_particle_gamma_c(self.x, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_particle_r_cross_final(self.x, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_particle_r_cross_max_constraint(self.x, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_Br(self.x, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_iota(self.x, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + def test_loss_lost_fraction(self, ffd): + with patch('essos.objective_functions.Tracing', return_value=self.tracing): + objf.loss_lost_fraction(self.field, self.particles, self.dofs_curves, self.currents_scale, self.nfp) + + def test_normB_axis(self): + objf.normB_axis(self.field) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + @patch('essos.objective_functions.new_nearaxis_from_x_and_old_nearaxis', side_effect=dummy_new_nearaxis_from_x_and_old_nearaxis) + def test_loss_coils_for_nearaxis_and_loss_coils_and_nearaxis(self, nna, ffd): + objf.loss_coils_for_nearaxis(self.x, self.field, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_coils_and_nearaxis(jnp.ones(13), self.field, self.dofs_curves, self.currents_scale, self.nfp) + + def test_difference_B_gradB_onaxis(self): + objf.difference_B_gradB_onaxis(self.field, self.field) + + @patch('essos.objective_functions.Curves', return_value=DummyCurves()) + @patch('essos.objective_functions.Coils', return_value=DummyCoils()) + @patch('essos.objective_functions.BiotSavart', return_value=DummyField()) + @patch('essos.objective_functions.BdotN_over_B', return_value=jnp.ones(10)) + def test_loss_bdotn_over_b(self, bdotn, bs, coils, curves): + objf.loss_bdotn_over_b(self.x, self.vmec, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + @patch('essos.objective_functions.BdotN_over_B', return_value=jnp.ones(10)) + def test_loss_BdotN(self, bdotn, ffd): + objf.loss_BdotN(self.x, self.vmec, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + @patch('essos.objective_functions.BdotN_over_B', return_value=jnp.ones(10)) + def test_loss_BdotN_only(self, bdotn, ffd): + objf.loss_BdotN_only(self.x, self.vmec, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.field_from_dofs', return_value=DummyField()) + @patch('essos.objective_functions.BdotN_over_B', return_value=jnp.ones(10)) + def test_loss_BdotN_only_constraint(self, bdotn, ffd): + objf.loss_BdotN_only_constraint(self.x, self.vmec, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.BdotN_over_B', return_value=jnp.ones(10)) + @patch('essos.objective_functions.pertubred_field_from_dofs', return_value=DummyField()) + def test_loss_BdotN_only_stochastic(self, perturbed, bdotn): + objf.loss_BdotN_only_stochastic(self.x, self.sampler, 2, self.vmec, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.BdotN_over_B', return_value=jnp.ones(10)) + @patch('essos.objective_functions.pertubred_field_from_dofs', return_value=DummyField()) + def test_loss_BdotN_only_constraint_stochastic(self, perturbed, bdotn): + objf.loss_BdotN_only_constraint_stochastic(self.x, self.sampler, 2, self.vmec, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.coils_from_dofs', return_value=DummyCoils()) + def test_loss_cs_distance_and_array(self, cfd): + objf.loss_cs_distance(self.x, self.surface, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_cs_distance_array(self.x, self.surface, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.coils_from_dofs', return_value=DummyCoils()) + def test_loss_cc_distance_and_array(self, cfd): + objf.loss_cc_distance(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_cc_distance_array(self.x, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.coils_from_dofs', return_value=DummyCoils()) + def test_loss_linking_mnumber_and_constraint(self, cfd): + objf.loss_linking_mnumber(self.x, self.dofs_curves, self.currents_scale, self.nfp) + objf.loss_linking_mnumber_constarint(self.x, self.dofs_curves, self.currents_scale, self.nfp) + + def test_cc_distance_pure(self): + gamma1 = jnp.ones((10, 3))*3. + l1 = jnp.ones((10, 3)) + gamma2 = jnp.ones((10, 3))*4. + l2 = jnp.ones((10, 3))*6. + objf.cc_distance_pure(gamma1, l1, gamma2, l2, 1.0) + + def test_cs_distance_pure(self): + gammac = jnp.ones((10, 3))*7. + lc = jnp.ones((10, 3)) + gammas = jnp.ones((10, 3))*9. + ns = jnp.ones((10, 3))*10. + objf.cs_distance_pure(gammac, lc, gammas, ns, 1.0) + + @patch('essos.objective_functions.coils_from_dofs', return_value=DummyCoils()) + def test_loss_lorentz_force_coils(self, cfd): + objf.loss_lorentz_force_coils(self.x, self.dofs_curves, self.currents_scale, self.nfp) + + @patch('essos.objective_functions.compute_curvature', return_value=1.0) + @patch('essos.objective_functions.BiotSavart_from_gamma', return_value=MagicMock(B=MagicMock(return_value=jnp.array([1., 0., 0.])))) + def test_lp_force_pure(self, bsg, cc): + gamma = jnp.ones((2, 10, 3))*2. + gamma_dash = jnp.ones((2, 10, 3))*3. + gamma_dashdash = jnp.ones((2, 10, 3)) + currents = jnp.ones(2) + quadpoints = jnp.linspace(0, 1, 10) + objf.lp_force_pure(0, gamma, gamma_dash, gamma_dashdash, currents, quadpoints, 1, 1e6) + + def test_B_regularized_singularity_term(self): + rc_prime = jnp.ones((10, 3)) + rc_prime_prime = jnp.ones((10, 3)) + objf.B_regularized_singularity_term(rc_prime, rc_prime_prime, 1.0) + + def test_B_regularized_pure(self): + gamma = jnp.ones((10, 3))*4. + gammadash = jnp.ones((10, 3)) + gammadashdash = jnp.ones((10, 3)) + quadpoints = jnp.linspace(0, 1, 10) + current = 1.0 + regularization = 1.0 + objf.B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) + + def test_regularization_circ(self): + self.assertTrue(objf.regularization_circ(2.0) > 0) + + def test_regularization_rect_and_k_and_delta(self): + a, b = 2.0, 1.0 + objf.regularization_rect(a, b) + objf.rectangular_xsection_k(a, b) + objf.rectangular_xsection_delta(a, b) + + def test_linking_number_pure_and_integrand(self): + gamma1 = jnp.ones((10, 3))*4. + lc1 = jnp.ones((10, 3))*2. + gamma2 = jnp.ones((10, 3))*6. + lc2 = jnp.ones((10, 3))*5. + dphi = 0.1 + objf.linking_number_pure(gamma1, lc1, gamma2, lc2, dphi) + r1 = jnp.zeros(3) + dr1 = jnp.ones(3) + r2 = jnp.zeros(3) + dr2 = jnp.ones(3) + objf.integrand_linking_number(r1, dr1, r2, dr2, dphi, dphi) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file