diff --git a/essos/coils.py b/essos/coils.py index abe58e5..7782fe8 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -518,8 +518,8 @@ def RotatedCurve(curve, phi, flip): if flip: rotmat = rotmat @ jnp.array( [[1, 0, 0], - [0, -1, 0], - [0, 0, -1]]) + [0, -1, 0], + [0, 0, -1]]) return curve @ rotmat @partial(jit, static_argnames=['nfp', 'stellsym']) @@ -559,4 +559,102 @@ def apply_symmetries_to_currents(base_currents, nfp, stellsym): for i in range(len(base_currents)): current = -base_currents[i] if flip else base_currents[i] currents.append(current) - return jnp.array(currents) \ No newline at end of file + return jnp.array(currents) + +def _resample_closed_curve_uniform_one(g: jnp.ndarray, n_segments: int) -> jnp.ndarray: + """ + One-curve arclength resample to n_segments points on t∈[0,1), piecewise linear. + g: (M,3) closed curve (first≈last not required; we close internally). + Returns: (n_segments,3) + """ + # Close the loop + g0 = g[0:1, :] + g_ext = jnp.concatenate([g, g0], axis=0) # (M+1,3) + seg = g_ext[1:] - g_ext[:-1] # (M,3) + seg_len = jnp.linalg.norm(seg, axis=1) # (M,) + cum = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(seg_len)], axis=0) # (M+1,) + total = cum[-1] + # Uniform targets in arclength (exclude total to avoid duplicate) + s_targets = jnp.linspace(0.0, total, n_segments, endpoint=False) # (n_segments,) + # For each s_t, find i with cum[i] <= s_t < cum[i+1] + idx = jnp.searchsorted(cum, s_targets, side='right') - 1 # (n_segments,) + idx = jnp.clip(idx, 0, seg.shape[0]-1) + s0 = cum[idx] + s1 = cum[idx+1] + w = (s_targets - s0) / jnp.maximum(s1 - s0, 1e-20) # (n_segments,) + p0 = g_ext[idx] + p1 = g_ext[idx+1] + return p0 + w[:, None] * (p1 - p0) # (n_segments,3) + +def _resample_closed_curve_uniform_batch(gammas: jnp.ndarray, n_segments: int) -> jnp.ndarray: + """ + Batch arclength resample. + gammas: (Ncoils, M, 3) (all curves same M; if not, pre-interp in index space). + Returns: (Ncoils, n_segments, 3) + """ + return vmap(_resample_closed_curve_uniform_one, in_axes=(0, None))(gammas, n_segments) + +@partial(jit, static_argnames=('order',)) +def _fit_real_fourier_batch(gamma_uni: jnp.ndarray, order: int) -> jnp.ndarray: + """ + gamma_uni: (Ncoils, Nseg, 3), samples at t_j = j/Nseg, j=0..Nseg-1 + Returns dofs: (Ncoils, 3, 2*order+1) with [a0, sin1, cos1, ..., sinK, cosK]. + """ + Ncoils, Nseg, _ = gamma_uni.shape # Nseg is static if n_segments was static upstream + Kmax = min(order, Nseg // 2) # <-- Python int (static) + + g = jnp.transpose(gamma_uni, (0, 2, 1)) # (Ncoils, 3, Nseg) + F = jnp.fft.rfft(g, axis=-1) / Nseg # (Ncoils, 3, Nseg//2 + 1) + + a0 = F[..., 0].real # (Ncoils, 3) + + # Static slice (OK under jit) + Fk = F[..., 1:1 + Kmax] # (Ncoils, 3, Kmax) + + cos_k = 2.0 * Fk.real # (Ncoils, 3, Kmax) + sin_k = -2.0 * Fk.imag # (Ncoils, 3, Kmax) + + # Pad to 'order' if needed (pad width is also static here) + if Kmax < order: + pad = order - Kmax + zshape = (cos_k.shape[0], cos_k.shape[1], pad) + z = jnp.zeros(zshape, dtype=gamma_uni.dtype) + cos_k = jnp.concatenate([cos_k, z], axis=-1) # (Ncoils, 3, order) + sin_k = jnp.concatenate([sin_k, z], axis=-1) # (Ncoils, 3, order) + + inter = jnp.empty((Ncoils, 3, 2*order), dtype=gamma_uni.dtype) + inter = inter.at[..., 0::2].set(sin_k) # sin₁, sin₂, ... + inter = inter.at[..., 1::2].set(cos_k) # cos₁, cos₂, ... + + dofs = jnp.concatenate([a0[..., None], inter], axis=-1) # (Ncoils, 3, 2*order+1) + return dofs + +@partial(jit, static_argnames=('order','n_segments','assume_uniform')) +def fit_dofs_from_coils( + coils_gamma: jnp.ndarray, + order: int, + n_segments: int, + assume_uniform: bool = False, +) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + Fast path (batched + JIT + rFFT). + coils_gamma: (Ncoils, M, 3) JAX array. If M != n_segments and assume_uniform=True, + curves are uniformly subsampled in index space. If assume_uniform=False, + do arclength resampling (slower but accurate). + Returns: + dofs: (Ncoils, 3, 2*order+1) + gamma_resampled: (Ncoils, n_segments, 3) + """ + Ncoils, M, _ = coils_gamma.shape + if assume_uniform: + if M == n_segments: + gamma_uni = coils_gamma + else: + # uniform subsampling in index space (fast) + idx = jnp.floor(jnp.linspace(0, M, n_segments, endpoint=False)).astype(int) % M + gamma_uni = coils_gamma[:, idx, :] + else: + gamma_uni = _resample_closed_curve_uniform_batch(coils_gamma, n_segments) # arclength (vmapped) + + dofs = _fit_real_fourier_batch(gamma_uni, order) # rFFT-based fit + return dofs, gamma_uni \ No newline at end of file diff --git a/essos/dynamics.py b/essos/dynamics.py index 11ad64b..2de1e98 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -5,7 +5,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax import jit, vmap, tree_util, random, lax, device_put from functools import partial -from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, PIDController, Event, TqdmProgressMeter +from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, PIDController, Event, TqdmProgressMeter, NoProgressMeter from diffrax import ControlTerm,UnsafeBrownianPath,MultiTerm,ItoMilstein,ClipStepSizeController #For collisions we need this to solve stochastic differential equation import diffrax from essos.coils import Coils @@ -501,6 +501,7 @@ def __init__(self, trajectories_input=None, initial_conditions=None, times_to_tr self.particles = particles self.species=species self.tag_gc=tag_gc + self.progress_meter = TqdmProgressMeter() # NoProgressMeter() # TqdmProgressMeter() if condition is None: self.condition = lambda t, y, args, **kwargs: False if isinstance(field, Vmec): @@ -694,7 +695,7 @@ def update_state(state, _): #stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size), max_steps=10000000000, event = Event(self.condition), - progress_meter=TqdmProgressMeter(), + progress_meter=self.progress_meter, ).ys elif self.model == 'GuidingCenterCollisionsMuAdaptative': import warnings @@ -720,7 +721,7 @@ def update_state(state, _): stepsize_controller=ClipStepSizeController(controller=PIDController(pcoeff=0.1, icoeff=0.3, dcoeff=0.0, rtol=self.rtol, atol=self.atol,dtmin=dt0,dtmax=1.e-4,force_dtmin=True),step_ts=self.times,store_rejected_steps=self.rejected_steps), max_steps=10000000000, event = Event(self.condition), - progress_meter=TqdmProgressMeter(), + progress_meter=self.progress_meter, ).ys elif self.model == 'GuidingCenterCollisionsMuFixed': import warnings @@ -744,7 +745,7 @@ def update_state(state, _): # adjoint=DirectAdjoint(), max_steps=10000000000, event = Event(self.condition), - progress_meter=TqdmProgressMeter(), + progress_meter=self.progress_meter, ).ys elif self.model == 'GuidingCenterCollisionsMuIto': import warnings @@ -768,7 +769,7 @@ def update_state(state, _): # adjoint=DirectAdjoint(), max_steps=10000000000, event = Event(self.condition), - progress_meter=TqdmProgressMeter(), + progress_meter=self.progress_meter, ).ys elif self.model == 'FullOrbitCollisions': import warnings @@ -794,7 +795,7 @@ def update_state(state, _): stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size,dtmin=dt0), max_steps=10000000000, event = Event(self.condition), - progress_meter=TqdmProgressMeter() + progress_meter=self.progress_meter, ).ys elif self.model == 'GuidingCenterAdaptative' : import warnings @@ -805,12 +806,12 @@ def update_state(state, _): t1=self.maxtime, dt0=self.timestep,#self.maxtime / self.timesteps, y0=initial_condition, - solver=diffrax.Tsit5(), + solver=diffrax.Dopri8(), args=self.args, saveat=SaveAt(ts=self.times), throw=False, # adjoint=DirectAdjoint(), - progress_meter=TqdmProgressMeter(), + progress_meter=self.progress_meter, stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.rtol, atol=self.atol), max_steps=10000000000, event = Event(self.condition) @@ -824,12 +825,12 @@ def update_state(state, _): t1=self.maxtime, dt0=self.timestep,#self.maxtime / self.timesteps, y0=initial_condition, - solver=diffrax.Tsit5(), + solver=diffrax.Dopri8(), args=self.args, saveat=SaveAt(ts=self.times), throw=False, # adjoint=DirectAdjoint(), - progress_meter=TqdmProgressMeter(), + progress_meter=self.progress_meter, stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.rtol, atol=self.atol), max_steps=10000000000, event = Event(self.condition) @@ -844,19 +845,19 @@ def update_state(state, _): t1=self.maxtime, dt0=self.timestep,#self.maxtime / self.timesteps, y0=initial_condition, - solver=diffrax.Tsit5(), + solver=diffrax.Dopri8(), args=self.args, saveat=SaveAt(ts=self.times), throw=False, # adjoint=DirectAdjoint(), - progress_meter=TqdmProgressMeter(), + progress_meter=self.progress_meter, max_steps=10000000000, event = Event(self.condition) ).ys return trajectory return jit(vmap(compute_trajectory,in_axes=(0,0)), in_shardings=(sharding,sharding_index), out_shardings=sharding)( - device_put(self.initial_conditions, sharding), device_put(self.particles.random_keys if self.particles else None, sharding_index)) + device_put(self.initial_conditions, sharding), device_put(self.particles.random_keys if self.particles else None, sharding_index)) #x=jax.device_put(self.initial_conditions, sharding) #y=jax.device_put(self.particles.random_keys, sharding_index) #sharded_fun = jax.jit(jax.shard_map(jax.vmap(compute_trajectory,in_axes=(0,0)), mesh=mesh, in_specs=(spec,spec_index), out_specs=spec)) diff --git a/essos/fields.py b/essos/fields.py index d9e28ee..4689e76 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -15,7 +15,7 @@ def __init__(self, coils): self.currents = coils.currents self.gamma = coils.gamma self.gamma_dash = coils.gamma_dash - #self.gamma_dashdash = coils.gamma_dashdash + self.gamma_dashdash = coils.gamma_dashdash self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) self.coils_curvature= vmap(compute_curvature)(self.gamma_dash, coils.gamma_dashdash) self.r_axis=jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(self.coils.dofs_curves))) @@ -77,18 +77,58 @@ def kappa(self, points): def to_xyz(self, points): return points - +@jit +def d_dtheta_fft(f_theta): + ntheta = f_theta.shape[-1] + k = jnp.fft.fftfreq(ntheta, d=1.0/ntheta) # integer modes + Fk = jnp.fft.fft(f_theta, axis=-1) + dF = (1j * k) * Fk + return jnp.fft.ifft(dF, axis=-1).real * (2*jnp.pi) + +@jit +def d2_dtheta2_fft(f_theta): + ntheta = f_theta.shape[-1] + k = jnp.fft.fftfreq(ntheta, d=1.0/ntheta) # integer modes + Fk = jnp.fft.fft(f_theta, axis=-1) + d2F = -(k**2) * Fk + return jnp.fft.ifft(d2F, axis=-1).real * (2*jnp.pi)**2 + +@jit +def gamma_dash_from_gamma(gamma): + return jnp.stack([ + d_dtheta_fft(gamma[..., 0]), + d_dtheta_fft(gamma[..., 1]), + d_dtheta_fft(gamma[..., 2]), + ], axis=-1) + +@jit +def gamma_dashdash_from_gamma(gamma): + return jnp.stack([ + d2_dtheta2_fft(gamma[..., 0]), + d2_dtheta2_fft(gamma[..., 1]), + d2_dtheta2_fft(gamma[..., 2]), + ], axis=-1) class BiotSavart_from_gamma(): - def __init__(self, gamma,gamma_dash,gamma_dashdash, currents): + def __init__(self, gamma,gamma_dash=None,gamma_dashdash=None, currents=None): + if currents is None: + currents = jnp.ones(len(gamma)) + else: + currents = currents self.currents = currents self.gamma = gamma - self.gamma_dash = gamma_dash - #self.gamma_dashdash = gamma_dashdash - self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in gamma_dash]) - self.coils_curvature= vmap(compute_curvature)(gamma_dash, gamma_dashdash) self.r_axis=jnp.average(jnp.linalg.norm(jnp.average(gamma,axis=1)[:,0:2],axis=1)) self.z_axis=jnp.average(jnp.average(gamma,axis=1)[:,2]) + if gamma_dash is not None: + self.gamma_dash = gamma_dash + else: + self.gamma_dash = gamma_dash_from_gamma(gamma) + self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) + if gamma_dashdash is not None: + self.gamma_dashdash = gamma_dashdash + else: + self.gamma_dashdash = gamma_dashdash_from_gamma(gamma) + self.coils_curvature= vmap(compute_curvature)(self.gamma_dash, self.gamma_dashdash) @partial(jit, static_argnames=['self']) def sqrtg(self, points): @@ -350,8 +390,8 @@ def __init__(self, rc=jnp.array([1, 0.1]), zs=jnp.array([0, 0.1]), etabar=1.0, (self.R0, self.Z0, self.sigma, self.elongation, self.B_axis, self.grad_B_axis, self.axis_length, self.iota, self.iotaN, self.G0, self.helicity, self.X1c_untwisted, self.X1s_untwisted, self.Y1s_untwisted, self.Y1c_untwisted, self.normal_R, self.normal_phi, self.normal_z, self.binormal_R, self.binormal_phi, self.binormal_z, - self.L_grad_B, self.inv_L_grad_B, self.torsion, self.curvature) = parameters - + self.L_grad_B, self.inv_L_grad_B, self.torsion, self.curvature, self.varphi, self.R0p, self.Z0p) = parameters + @property def dofs(self): return self._dofs @@ -366,8 +406,8 @@ def dofs(self, new_dofs): (self.R0, self.Z0, self.sigma, self.elongation, self.B_axis, self.grad_B_axis, self.axis_length, self.iota, self.iotaN, self.G0, self.helicity, self.X1c_untwisted, self.X1s_untwisted, self.Y1s_untwisted, self.Y1c_untwisted, self.normal_R, self.normal_z, self.normal_phi, self.binormal_R, self.binormal_z, self.binormal_phi, - self.L_grad_B, self.inv_L_grad_B, self.torsion, self.curvature) = parameters - + self.L_grad_B, self.inv_L_grad_B, self.torsion, self.curvature, self.varphi, self.R0p, self.Z0p) = parameters + @property def x(self): return self._dofs @@ -611,7 +651,43 @@ def body_fun(i, x): return (R0, Z0, sigma, elongation, B_axis, grad_B_axis, axis_length, iota, iotaN, G0, helicity, X1c_untwisted, X1s_untwisted, Y1s_untwisted, Y1c_untwisted, normal_R, normal_phi, normal_z, binormal_R, binormal_phi, binormal_z, - L_grad_B, inv_L_grad_B, torsion, curvature) + L_grad_B, inv_L_grad_B, torsion, curvature, varphi, R0p, Z0p) + + @jit + def residual_phi0_of_theta_varphi_func(self, phi_0, r, theta, varphi): + # Residual = phi + nu - varphi = 0 + # Compute phi off axis + X_at_this_theta = r * (self.X1c_untwisted * jnp.cos(theta) + self.X1s_untwisted * jnp.sin(theta)) + Y_at_this_theta = r * (self.Y1c_untwisted * jnp.cos(theta) + self.Y1s_untwisted * jnp.sin(theta)) + _, _, phi = self.Frenet_to_cylindrical_1_point(phi_0, X_at_this_theta, Y_at_this_theta) + # phi = phi + 2 * jnp.pi * (phi < 0) - 2 * jnp.pi * (phi > 2 * jnp.pi) + # Compute nu = nu0 + r (nu1c cos theta + nu1s sin theta) + nu0 = self.interpolated_array_at_point(self.varphi-self.phi, phi_0) + X1c = self.interpolated_array_at_point(self.X1c_untwisted, phi_0) + X1s = self.interpolated_array_at_point(self.X1s_untwisted, phi_0) + Y1c = self.interpolated_array_at_point(self.Y1c_untwisted, phi_0) + Y1s = self.interpolated_array_at_point(self.Y1s_untwisted, phi_0) + bR = self.interpolated_array_at_point(self.binormal_R, phi_0) + bZ = self.interpolated_array_at_point(self.binormal_z, phi_0) + nR = self.interpolated_array_at_point(self.normal_R, phi_0) + nZ = self.interpolated_array_at_point(self.normal_z, phi_0) + R0 = self.interpolated_array_at_point(self.R0, phi_0) + R0p = self.interpolated_array_at_point(self.R0p, phi_0) + Z0p = self.interpolated_array_at_point(self.Z0p, phi_0) + nu1c = X1c * (bR * Z0p - bZ * R0p)/R0 + Y1c * (nZ * R0p - nR * Z0p)/R0 + nu1s = X1s * (bR * Z0p - bZ * R0p)/R0 + Y1s * (nZ * R0p - nR * Z0p)/R0 + nu = nu0 + r * (nu1c * jnp.cos(theta) + nu1s * jnp.sin(theta)) + # Return residual + return phi + nu - varphi + + @jit + def phi_of_theta_varphi(self, r, theta, varphi): + residual = partial(self.residual_phi0_of_theta_varphi_func, theta=theta, r=r, varphi=varphi) + phi_on_axis = lax.custom_root(residual, varphi, newton, lambda g, y: y / g(1.0)) + X_at_this_theta = r * (self.X1c_untwisted * jnp.cos(theta) + self.X1s_untwisted * jnp.sin(theta)) + Y_at_this_theta = r * (self.Y1c_untwisted * jnp.cos(theta) + self.Y1s_untwisted * jnp.sin(theta)) + _, _, phi_off_axis = self.Frenet_to_cylindrical_1_point(phi_on_axis, X_at_this_theta, Y_at_this_theta) + return phi_off_axis# + 2 * jnp.pi * (phi_off_axis < 0) - 2 * jnp.pi * (phi_off_axis > 2 * jnp.pi) @jit def interpolated_array_at_point(self,array,point): @@ -668,7 +744,7 @@ def Frenet_to_cylindrical_1_point(self, phi0, X_at_this_theta, Y_at_this_theta): return total_R, total_z, total_phi @partial(jit, static_argnames=['ntheta']) - def Frenet_to_cylindrical(self, r, ntheta=20): + def Frenet_to_cylindrical(self, r, ntheta=20, phi_is_varphi=False): nphi_conversion = self.nphi theta = jnp.linspace(0, 2 * jnp.pi, ntheta, endpoint=False) phi_conversion = jnp.linspace(0, 2 * jnp.pi / self.nfp, nphi_conversion, endpoint=False) @@ -680,9 +756,28 @@ def compute_for_theta(theta_j): Y_at_this_theta = r * (self.Y1c_untwisted * costheta + self.Y1s_untwisted * sintheta) def compute_for_phi(phi_target): - residual = partial(self.Frenet_to_cylindrical_residual_func, phi_target=phi_target, - X_at_this_theta=X_at_this_theta, Y_at_this_theta=Y_at_this_theta) + + def residual(z): + return jax.lax.cond( + phi_is_varphi, + # Branch A: solve for phi0 so that phi+nu-varphi = 0 + lambda _: self.residual_phi0_of_theta_varphi_func( + z, r=r, theta=theta_j, varphi=phi_target + ), + # Branch B: solve for phi so Frenet_to_cylindrical_residual_func = 0 + lambda _: self.Frenet_to_cylindrical_residual_func( + z, phi_target=phi_target, + X_at_this_theta=X_at_this_theta, + Y_at_this_theta=Y_at_this_theta + ), + operand=None + ) + # residual = partial(self.Frenet_to_cylindrical_residual_func, phi_target=phi_target, + # X_at_this_theta=X_at_this_theta, Y_at_this_theta=Y_at_this_theta) + # residual = partial(self.residual_phi0_of_theta_varphi_func, theta=theta_j, r=r, varphi=phi_target) + phi0_solution = lax.custom_root(residual, phi_target, newton, lambda g, y: y / g(1.0)) + final_R, final_Z, _ = self.Frenet_to_cylindrical_1_point(phi0_solution, X_at_this_theta, Y_at_this_theta) return final_R, final_Z, phi0_solution @@ -727,18 +822,33 @@ def compute_RBC_ZBS(m, n): RBC = RBC.at[:ntor, 0].set(0) return RBC, ZBS - - @partial(jit, static_argnames=['ntheta_fourier', 'mpol', 'ntor', 'ntheta', 'nphi']) - def get_boundary(self, r=0.1, ntheta=30, nphi=120, ntheta_fourier=20, mpol=5, ntor=5): - R_2D, Z_2D, _ = self.Frenet_to_cylindrical(r, ntheta=ntheta_fourier) + @partial(jit, static_argnames=['ntheta_fourier', 'mpol', 'ntor', 'ntheta', 'nphi', 'phi_is_varphi']) + def get_boundary(self, r=0.1, ntheta=30, nphi=120, ntheta_fourier=20, mpol=5, ntor=5, phi_is_varphi=False, phi_offset=0.0): + R_2D, Z_2D, _ = self.Frenet_to_cylindrical(r, ntheta=ntheta_fourier, phi_is_varphi=phi_is_varphi) RBC, ZBS = self.to_Fourier(R_2D, Z_2D, self.nfp, mpol=mpol, ntor=ntor) theta1D = jnp.linspace(0, 2 * jnp.pi, ntheta) - phi1D = jnp.linspace(0, 2 * jnp.pi, nphi) - phi2D, theta2D = jnp.meshgrid(phi1D, theta1D, indexing='ij') - + + # phi1D = jax.lax.cond( + # phi_is_varphi, + # lambda _: jnp.linspace(2*jnp.pi/nphi/2, 2*jnp.pi + 2*jnp.pi/nphi/2, nphi, endpoint=False), + # lambda _: jnp.linspace(0, 2 * jnp.pi, nphi), + # operand=None + # ) + # phi1D += phi_offset + phi1D = jnp.linspace(0, 2 * jnp.pi, nphi) + phi_offset + + phi2D_original, theta2D = jnp.meshgrid(phi1D, theta1D, indexing='ij') + + phi2D = jax.lax.cond( + phi_is_varphi, + lambda _: vmap(lambda theta_row, varphi_row: vmap(lambda theta, varphi: self.phi_of_theta_varphi(r, theta, varphi))(theta_row, varphi_row))(theta2D, phi2D_original), + lambda _: phi2D_original, + operand=None + ) + def compute_RZ(m, n): - angle = m * theta2D - n * self.nfp * phi2D + angle = m * theta2D - n * self.nfp * phi2D_original return RBC[n + ntor, m] * jnp.cos(angle), ZBS[n + ntor, m] * jnp.sin(angle) m_vals = jnp.arange(mpol + 1) @@ -747,8 +857,8 @@ def compute_RZ(m, n): R_2Dnew, Z_2Dnew = vmap(lambda m: vmap(lambda n: compute_RZ(m, n))(n_vals))(m_vals) R_2Dnew, Z_2Dnew = R_2Dnew.sum(axis=(0, 1)), Z_2Dnew.sum(axis=(0, 1)) - x_2D_plot = R_2Dnew.T * jnp.cos(phi1D) - y_2D_plot = R_2Dnew.T * jnp.sin(phi1D) + x_2D_plot = R_2Dnew.T * jnp.cos(phi2D.T) + y_2D_plot = R_2Dnew.T * jnp.sin(phi2D.T) z_2D_plot = Z_2Dnew.T return x_2D_plot, y_2D_plot, z_2D_plot, R_2Dnew.T diff --git a/examples/coils_from_BOOZ_XFORM.py b/examples/coils_from_BOOZ_XFORM.py new file mode 100644 index 0000000..4691848 --- /dev/null +++ b/examples/coils_from_BOOZ_XFORM.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3.11 +import os +number_of_processors_to_use = 6 # Parallelization, this should divide nfieldlines +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +import numpy as np +from time import time +import booz_xform as bx +import plotly.graph_objects as go +from essos.dynamics import Tracing +from essos.fields import BiotSavart +from simsopt.mhd import Vmec, Boozer +from jax import block_until_ready +import jax.numpy as jnp +from essos.coils import fit_dofs_from_coils, Curves, Coils +import matplotlib.pyplot as plt + +file_to_use = 'LandremanPaul2021_QA_reactorScale_lowres' + +ntheta = 41 +ncoils = 6 +tmax = 1100 +nfieldlines_per_core=1 +trace_tolerance = 1e-9 +num_steps = 22000 +order_Fourier_coils = 4 +current_on_each_coil = 2e8 +refine_nphi_for_surface_plot = 4 +radial_extension_of_the_surface = 0.0 +Poincare_plot_phi = jnp.array([0]) +shift_surface_plot_for_phi = jnp.pi +plot_fieldlines_constant_phi = False +show_coils_fitted_to_Fourier = False + +input_dir = os.path.join(os.path.dirname(__file__), 'input_files') +output_dir = os.path.join(os.path.dirname(__file__), 'output_files') +os.makedirs(output_dir, exist_ok=True) + +wout_filename = os.path.join(input_dir, 'wout_'+file_to_use+'.nc') +boozmn_filename = os.path.join(output_dir, 'boozmn_'+file_to_use+'.nc') + +print(f"Computing {boozmn_filename}") +vmec = Vmec(wout_filename, verbose=False) +b = Boozer(vmec, mpol=64, ntor=64, verbose=True) +time0 = time() +b.register([1]) +b.run() +# b.bx.write_boozmn(boozmn_filename) +b = b.bx +print(f"Computing Boozer harmonics took {time()-time0:.2f} seconds") + +current_on_each_coil = current_on_each_coil / ncoils*vmec.wout.Aminor_p**2/1.7**2 +nfieldlines = number_of_processors_to_use*nfieldlines_per_core +nphi = ncoils * 2 * b.nfp + +theta1D = np.linspace(0, 2 * np.pi, ntheta) +phi1D = jnp.linspace(2*jnp.pi/nphi/2, 2*jnp.pi + 2*jnp.pi/nphi/2, nphi, endpoint=False) +phi1D_surface = jnp.linspace(0, 2*jnp.pi, nphi*refine_nphi_for_surface_plot, endpoint=True) +varphi, theta = np.meshgrid(phi1D, theta1D) +varphi_surface, theta_surface = np.meshgrid(phi1D_surface, theta1D) + +R = np.zeros_like(theta) +R_surface = np.zeros_like(theta_surface) +Z = np.zeros_like(theta) +Z_surface = np.zeros_like(theta_surface) +nu = np.zeros_like(theta) +d_R_d_theta = np.zeros_like(theta) +d_R_d_theta_surface = np.zeros_like(theta_surface) +d_Z_d_theta = np.zeros_like(theta) +d_Z_d_theta_surface = np.zeros_like(theta_surface) + +phi1D_Boozerplot = np.linspace(0, 2 * np.pi / b.nfp / 2, nphi*refine_nphi_for_surface_plot) +phi_Boozerplot, theta_Boozerplot = np.meshgrid(phi1D_Boozerplot, theta1D) +modB_Boozerplot = np.zeros_like(theta_Boozerplot) + +js = None +for jmn in range(b.mnboz): + m = b.xm_b[jmn] + n = b.xn_b[jmn] + angle = m * theta - n * varphi + angle_surface = m * theta_surface - n * varphi_surface + sinangle = np.sin(angle) + sinangle_surface = np.sin(angle_surface) + cosangle = np.cos(angle) + cosangle_surface = np.cos(angle_surface) + R += b.rmnc_b[jmn, js] * cosangle + R_surface += b.rmnc_b[jmn, js] * cosangle_surface + Z += b.zmns_b[jmn, js] * sinangle + Z_surface += b.zmns_b[jmn, js] * sinangle_surface + nu += b.numns_b[jmn, js] * sinangle + d_R_d_theta += -m * b.rmnc_b[jmn, js] * sinangle + d_R_d_theta_surface += -m * b.rmnc_b[jmn, js] * sinangle_surface + d_Z_d_theta += m * b.zmns_b[jmn, js] * cosangle + d_Z_d_theta_surface += m * b.zmns_b[jmn, js] * cosangle_surface + cosangle_Boozerplot = np.cos(m * theta_Boozerplot - n * phi_Boozerplot) + modB_Boozerplot += b.bmnc_b[jmn, js] * np.cos(cosangle_Boozerplot) + +denom = np.sqrt(d_R_d_theta * d_R_d_theta + d_Z_d_theta * d_Z_d_theta) +denom_surface = np.sqrt(d_R_d_theta_surface * d_R_d_theta_surface + d_Z_d_theta_surface * d_Z_d_theta_surface) +R = R - radial_extension_of_the_surface * (d_Z_d_theta / denom) +R_surface = R_surface - radial_extension_of_the_surface * (d_Z_d_theta_surface / denom_surface) +Z = Z + radial_extension_of_the_surface * (d_R_d_theta / denom) +Z_surface = Z_surface + radial_extension_of_the_surface * (d_R_d_theta_surface / denom_surface) + +# Following the sign convention in the code, to convert from the +# Boozer toroidal angle to the standard toroidal angle, we +# *subtract* nu: +phi = varphi - nu +X = R * np.cos(phi) +Y = R * np.sin(phi) + +coils_gamma = np.zeros((ncoils, ntheta, 3)) +for i in range(ncoils): + coils_gamma[i, :, 0] = X[:, i] + coils_gamma[i, :, 1] = Y[:, i] + coils_gamma[i, :, 2] = Z[:, i] + +time0 = time() +dofs, gamma_uni = fit_dofs_from_coils(coils_gamma[:ncoils], order=order_Fourier_coils, n_segments=ntheta, assume_uniform=True) +curves = Curves(dofs=dofs, n_segments=ntheta, nfp=b.nfp, stellsym=True) +coils = Coils(curves=curves, currents=[-current_on_each_coil]*(ncoils)) +field_coils_DOFS = BiotSavart(coils) +print(f"Fitting coils took {time()-time0:.2f} seconds") + +data=[] + +color = "#C5B6A7" +# Hack to get a uniform surface color: +colorscale = [[0, color], [1, color]] +Xsurf = R_surface * np.cos(phi1D_surface) +Ysurf = R_surface * np.sin(phi1D_surface) +data.append(go.Surface(x=Xsurf, y=Ysurf, z=Z_surface, + colorscale=colorscale, + opacity=0.3, + showscale=False, # Turns off colorbar + lighting={"specular": 0.3, "diffuse":0.9})) + +line_width = 12 +line_marker = dict(color="#5B2222", width=line_width) +index = 0 +index = 0 +for i, j, k in zip(X.T, Y.T, Z.T): + index += 1 + showlegend = False + data.append(go.Scatter3d(x=i, y=j, z=k, mode='lines', line=line_marker, showlegend=showlegend))#, name=r'Constant $\varphi$ contours')) + +if show_coils_fitted_to_Fourier: + line_marker = dict(color='blue', width=line_width) + gamma_coils = np.transpose(curves.gamma, (1, 0 , 2)) + index = 0 + index = 0 + for i, j, k in zip(gamma_coils[:, :, 0].T, gamma_coils[:, :, 1].T, gamma_coils[:, :, 2].T): + index += 1 + showlegend = False + data.append(go.Scatter3d(x=i, y=j, z=k, mode='lines', line=line_marker, showlegend=showlegend, name='Coils fitted to Fourier')) + +if plot_fieldlines_constant_phi: + js_phi = b.compute_surfs[js] + R_phi = np.zeros_like(theta) + Z_phi = np.zeros_like(theta) + phi1D_phi = jnp.linspace(2*jnp.pi/nphi/2, 2*jnp.pi + 2*jnp.pi/nphi/2, nphi, endpoint=False) + phi_phi, _ = np.meshgrid(phi1D_phi, theta1D) + + for jmn in range(b.mnmax): + angle = b.xm[jmn] * theta - b.xn[jmn] * phi_phi + sinangle = np.sin(angle) + cosangle = np.cos(angle) + R_phi += b.rmnc[jmn, js_phi] * cosangle + Z_phi += b.zmns[jmn, js_phi] * sinangle + + X_phi = R_phi * np.cos(phi_phi) + Y_phi = R_phi * np.sin(phi_phi) + line_marker = dict(color='green', width=line_width) + + index = 0 + for i, j, k in zip(X_phi.T, Y_phi.T, Z_phi.T): + index += 1 + showlegend = False + data.append(go.Scatter3d(x=i, y=j, z=k, mode='lines', line=line_marker, showlegend=showlegend, name=r"Constant $\phi$ contours")) + + coils_gamma_phi = np.zeros((ncoils, ntheta, 3)) + for i in range(ncoils): + coils_gamma_phi[i, :, 0] = X_phi[:, i] + coils_gamma_phi[i, :, 1] = Y_phi[:, i] + coils_gamma_phi[i, :, 2] = Z_phi[:, i] + + time0 = time() + dofs_phi, gamma_uni_phi = fit_dofs_from_coils(coils_gamma_phi[:ncoils], order=order_Fourier_coils, n_segments=ntheta, assume_uniform=True) + curves_phi = Curves(dofs=dofs_phi, n_segments=ntheta, nfp=b.nfp, stellsym=True) + coils_phi = Coils(curves=curves_phi, currents=[-current_on_each_coil]*(ncoils)) + field_coils_phi = BiotSavart(coils_phi) + print(f"Fitting coils took {time()-time0:.2f} seconds") + +R0 = jnp.linspace(sum(vmec.wout.rmnc)[0], sum(vmec.wout.rmnc)[-1], nfieldlines) +Z0 = jnp.zeros(nfieldlines) +phi0 = jnp.zeros(nfieldlines) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T + +time0 = time() +tracing_coils_DOFS = block_until_ready(Tracing(field=field_coils_DOFS, model='FieldLineAdaptative', initial_conditions=initial_xyz, + maxtime=tmax, times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance)) +print(f"ESSOS tracing coils_DOFS took {time()-time0:.2f} seconds") +trajectories_coils_DOFS = tracing_coils_DOFS.trajectories +if plot_fieldlines_constant_phi: + time0 = time() + tracing_coils_phi = block_until_ready(Tracing(field=field_coils_phi, model='FieldLineAdaptative', initial_conditions=initial_xyz, + maxtime=tmax, times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance)) + print(f"ESSOS tracing coils_phi took {time()-time0:.2f} seconds") + trajectories_coils_phi = tracing_coils_phi.trajectories + +# Add fieldline traces from fitted coils +for traj in trajectories_coils_DOFS: + data.append(go.Scatter3d( + x=traj[:, 0], + y=traj[:, 1], + z=traj[:, 2], + mode='lines', + line=dict(color='black', width=0.2), + opacity=1.0, + name='Fieldline constant Boozer coils', + showlegend=False# if traj is not trajectories_coils_DOFS[0] else True + )) + +# Add fieldlines from phi coils +if plot_fieldlines_constant_phi: + for traj in trajectories_coils_phi: + data.append(go.Scatter3d( + x=traj[:, 0], + y=traj[:, 1], + z=traj[:, 2], + mode='lines', + line=dict(color='blue', width=0.2), + opacity=1.0, + name='Fieldline constant phi coils', + showlegend=False# if traj is not trajectories_coils_phi[0] else True + )) + +fig = go.Figure(data=data) + +# Turn off hover contours on the surface: +fig.update_traces(contours_x_highlight=False, + contours_y_highlight=False, + contours_z_highlight=False, + selector={"type":"surface"}) + +# Make x, y, z coordinate scales equal, and turn off more hover stuff +fig.update_layout(scene={"aspectmode": "data", + "xaxis_showspikes": False, + "yaxis_showspikes": False, + "zaxis_showspikes": False, + "xaxis_visible": False, + "yaxis_visible": False, + "zaxis_visible": False}, + hovermode=False, + margin={"l":0, "r":0, "t":25, "b":0}, + # title="Curves of constant poloidal or toroidal angle" + ) + +fig.show() + +# Now plot the 2D Poincare plot with Matplotlib (ax2 only) +fig2 = plt.figure(figsize=(6, 5)) +ax2 = fig2.add_subplot(111) + +tracing_coils_DOFS.poincare_plot(ax=ax2, show=False, shifts=Poincare_plot_phi/b.nfp/2, color='b', s=0.15) +if plot_fieldlines_constant_phi: + tracing_coils_phi.poincare_plot(ax=ax2, show=False, shifts=Poincare_plot_phi/b.nfp/2, color='r', s=0.15) + +Rsurf_phi0 = np.array([0.0]*ntheta) +Zsurf_phi0 = np.array([0.0]*ntheta) +for jmn in range(b.mnboz): + Rsurf_phi0 += (b.rmnc_b[jmn, js] * np.cos(b.xm_b[jmn] * theta1D - b.xn_b[jmn] * shift_surface_plot_for_phi))[0] + Zsurf_phi0 += (b.zmns_b[jmn, js] * np.sin(b.xm_b[jmn] * theta1D - b.xn_b[jmn] * shift_surface_plot_for_phi))[0] +ax2.plot(Rsurf_phi0, Zsurf_phi0, color='black', alpha=1.0, linewidth=2, label='Surface of Constant Boozer Angle') +ax2.set_xlabel('R (m)') +ax2.set_ylabel('Z (m)') +ax2.plot([], [], color='blue', label='Fieldlines') +if plot_fieldlines_constant_phi: + ax2.plot([], [], color='red', label='Fieldlines (constant phi)') + +# Plot VMEC flux surfaces for reference +# Match the surfaces of VMEC closest to the radii of the fieldlines traced +s_fieldlines = (jnp.linspace(sum(vmec.wout.rmnc)[0], sum(vmec.wout.rmnc)[-1], nfieldlines) - sum(vmec.wout.rmnc)[0])/ \ + (sum(vmec.wout.rmnc)[-1] - sum(vmec.wout.rmnc)[0]) +s_vmec = jnp.sqrt(jnp.linspace(0, 1, vmec.wout.ns)) +iradii = np.array([np.abs(s_vmec - s).argmin() for s in s_fieldlines]) +for iradius in range(nfieldlines): + R = [0]*ntheta + Z = [0]*ntheta + for imode, xnn in enumerate(vmec.wout.xn): + angle = vmec.wout.xm[imode]*theta1D - xnn*shift_surface_plot_for_phi + R += vmec.wout.rmnc[imode, iradii[iradius]]*np.cos(angle) + Z += vmec.wout.zmns[imode, iradii[iradius]]*np.sin(angle) + ax2.plot(R, Z, 'r--', linewidth=1.5, label='Surfaces of Constant Cylindrical Angle' if iradius ==0 else '_nolegend_') +ax2.legend() +plt.tight_layout() + +fig = plt.figure() +plt.contourf(phi_Boozerplot, theta_Boozerplot, modB_Boozerplot, levels=6) +plt.xlabel(r'Boozer toroidal angle $\varphi$') +plt.ylabel(r'Boozer poloidal angle $\theta$') +for i in range(ncoils): + plt.axvline(x=phi1D[i], color='black', linewidth=2.5) +plt.colorbar(label='|B| (T)') + +plt.savefig('modB_Boozerplot.png', dpi=300) \ No newline at end of file diff --git a/examples/coils_from_nearaxis.py b/examples/coils_from_nearaxis.py new file mode 100644 index 0000000..8518d23 --- /dev/null +++ b/examples/coils_from_nearaxis.py @@ -0,0 +1,218 @@ +import os +number_of_processors_to_use = 6 # Parallelization, this should divide nfieldlines +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +import jax.numpy as jnp +from jax import block_until_ready, vmap +import matplotlib.pyplot as plt +from essos.fields import near_axis, BiotSavart_from_gamma, BiotSavart +import plotly.graph_objects as go +from essos.dynamics import Tracing +from essos.coils import fit_dofs_from_coils, Curves, Coils +from time import time + +# Initialize Near-Axis field +rc = jnp.array([1, 0.045]) +zs = jnp.array([0, -0.045]) +etabar = -0.9 +nfp = 3 +nphi_internal_pyQSC = 51 +r_coils = 0.4 +r_surface = 0.2 +ntheta = 41 +ncoils = 4 +tmax = 800 +nfieldlines_per_core=1 +trace_tolerance = 1e-8 +num_steps = 22000 +order = 4 +current_on_each_coil = 2e8 +Poincare_plot_phi = jnp.array([0]) +plot_coils_without_Fourier_fit = False +plot_coils_on_2D = False +plot_difference_varphi_phi = False + +field_nearaxis = near_axis(rc=rc, zs=zs, etabar=etabar, nfp=nfp, nphi=nphi_internal_pyQSC) + +nfieldlines = number_of_processors_to_use*nfieldlines_per_core +current_on_each_coil = current_on_each_coil / ncoils*r_surface**2/1.7**2 +r_array = jnp.linspace(1e-5, r_surface, nfieldlines) +n_segments = ntheta + +print(f"Starting to create surfaces and coils for {nfieldlines} fieldlines and {ncoils} coils...") +time0 = time() +r_array = jnp.linspace(1e-5, r_surface, nfieldlines) +results = [field_nearaxis.get_boundary(r=r, ntheta=ntheta, nphi=nphi_internal_pyQSC) for r in r_array] +x_2D_surface_array, y_2D_surface_array, z_2D_surface_array, R_2D_surface_array = map(lambda arr: jnp.stack(arr), zip(*results)) +print(f"Creating surfaces of constant phi took {time()-time0:.2f} seconds") +time0 = time() +nphi = ncoils * 2 * nfp +x_2D_coils, y_2D_coils, z_2D_coils, R_2D_coils = field_nearaxis.get_boundary(r=r_coils, ntheta=ntheta, nphi=nphi, phi_is_varphi=True, phi_offset = 2*jnp.pi/nphi/2) +print(f"Creating surfaces of constant varphi took {time()-time0:.2f} seconds") + +time0 = time() +coils_gamma = jnp.zeros((ncoils * 2 * nfp, ntheta, 3)) +coil_i = 0 +for n in range(2*nfp): + phi_vals = (jnp.arange(ncoils) + 0.5) * (2 * jnp.pi) / ((2) * nfp * ncoils) + 2*jnp.pi/(2*nfp)*n + phi_idx = (phi_vals / (2*jnp.pi) * nphi).astype(int) % nphi + for i in phi_idx: + loop = jnp.stack([x_2D_coils[:, i], y_2D_coils[:, i], z_2D_coils[:, i]], axis=-1) + coils_gamma = coils_gamma.at[coil_i].set(loop) + coil_i += 1 +print(f"Creating coils_gamma took {time()-time0:.2f} seconds for {ncoils*2*nfp} coils") + + +time0 = time() +dofs, gamma_uni = fit_dofs_from_coils(coils_gamma[:ncoils], order=order, n_segments=n_segments, assume_uniform=True) +curves = Curves(dofs=dofs, n_segments=n_segments, nfp=nfp, stellsym=True) +coils = Coils(curves=curves, currents=[-current_on_each_coil]*(ncoils)) +field_coils_DOFS = BiotSavart(coils) +print(f"Fitting coils took {time()-time0:.2f} seconds") + +R0 = R_2D_surface_array[:,0,0] +Z0 = jnp.zeros(nfieldlines) +phi0 = jnp.zeros(nfieldlines) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +if plot_coils_without_Fourier_fit: + time0 = time() + field_coils_gamma = BiotSavart_from_gamma(coils_gamma, currents=current_on_each_coil*jnp.ones(len(coils_gamma))) + tracing_coils_gamma = block_until_ready(Tracing(field=field_coils_gamma, model='FieldLineAdaptative', initial_conditions=initial_xyz, + maxtime=tmax, times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance)) + print(f"ESSOS tracing coils_gamma took {time()-time0:.2f} seconds") + trajectories_coils_gamma = tracing_coils_gamma.trajectories +time0 = time() +tracing_coils_DOFS = block_until_ready(Tracing(field=field_coils_DOFS, model='FieldLineAdaptative', initial_conditions=initial_xyz, + maxtime=tmax, times_to_trace=num_steps, atol=trace_tolerance,rtol=trace_tolerance)) +print(f"ESSOS tracing coils_DOFS took {time()-time0:.2f} seconds") +trajectories_coils_DOFS = tracing_coils_DOFS.trajectories + +fig_plotly = go.Figure() + +color = "#C2AC95" +colorscale = [[0, color], [1, color]] +fig_plotly.add_surface( + x=x_2D_surface_array[-1], + y=y_2D_surface_array[-1], + z=z_2D_surface_array[-1], + opacity=0.3, + colorscale=colorscale, + showscale=False, + name='Surface', + lighting={"specular": 0.3, "diffuse":0.9}, + showlegend=False#True, +) + +if plot_coils_without_Fourier_fit: + for coil in coils_gamma: + fig_plotly.add_trace(go.Scatter3d( + x=coil[:, 0], + y=coil[:, 1], + z=coil[:, 2], + mode='lines', + line=dict(width=10, color='#b87333'), + name='Coil (Near-Axis)', + showlegend=False# if coil is not coils_gamma[0] else True + )) + + +line_width = 12 +line_marker = dict(color="#5B2222", width=line_width) +for i, curve_gamma in enumerate(curves.gamma): + color = "#93785A" + fig_plotly.add_trace(go.Scatter3d( + x=curve_gamma[:, 0], + y=curve_gamma[:, 1], + z=curve_gamma[:, 2], + mode='lines', + line=line_marker, + name='Coils', + showlegend=False#(i==0) + )) + +if plot_coils_without_Fourier_fit: + for traj in trajectories_coils_gamma: + fig_plotly.add_trace(go.Scatter3d( + x=traj[:, 0], + y=traj[:, 1], + z=traj[:, 2], + mode='lines', + line=dict(color='black', width=2), + name='Fieldline', + showlegend=False# if traj is not trajectories_coils_gamma[0] else True + )) + +for traj in trajectories_coils_DOFS: + fig_plotly.add_trace(go.Scatter3d( + x=traj[:, 0], + y=traj[:, 1], + z=traj[:, 2], + mode='lines', + line=dict(color='black', width=0.2), + name='Fieldline (Fitted Coils)', + showlegend=False# if traj is not trajectories_coils_DOFS[0] else True + )) + +# Turn off hover contours on the surface: +fig_plotly.update_traces(contours_x_highlight=False, + contours_y_highlight=False, + contours_z_highlight=False, + selector={"type":"surface"}) + +# Make x, y, z coordinate scales equal, and turn off more hover stuff +fig_plotly.update_layout(scene={"aspectmode": "data", + "xaxis_showspikes": False, + "yaxis_showspikes": False, + "zaxis_showspikes": False, + "xaxis_visible": False, + "yaxis_visible": False, + "zaxis_visible": False}, + hovermode=False, + margin={"l":0, "r":0, "t":25, "b":0}, + ) + +fig_plotly.show() + +# Now plot the 2D Poincare plot with Matplotlib +fig2 = plt.figure(figsize=(6, 5)) +ax = fig2.add_subplot(111) +if plot_coils_without_Fourier_fit: + tracing_coils_gamma.poincare_plot(ax=ax, show=False, shifts=Poincare_plot_phi/nfp/2, color='k', s=0.05) +tracing_coils_DOFS.poincare_plot(ax=ax, show=False, shifts=Poincare_plot_phi/nfp/2, color='b', s=0.05) + +for i in range(nfieldlines): + ax.plot(R_2D_surface_array[i,:,0], z_2D_surface_array[i,:,0], 'r--', linewidth=1.5, label='Surfaces of Constant Cylindrical Angle' if i==0 else '_nolegend_') +_, _, z_2D_at_coils, R_2D_at_coils = field_nearaxis.get_boundary(r=r_coils, ntheta=ntheta, nphi=nphi) +ax.plot(R_2D_at_coils[:,0], z_2D_at_coils[:,0], 'r--', linewidth=1.5, label='_nolegend_') + +x_2D_coil0, y_2D_coil0, z_2D_coil0, R_2D_coil0 = field_nearaxis.get_boundary(r=r_coils, ntheta=ntheta, nphi=nphi, phi_is_varphi=True) +ax.plot(R_2D_coil0[:,0], z_2D_coil0[:,0], color='black', alpha=1.0, linewidth=2, label='Surface of Constant Boozer Angle') +if plot_coils_on_2D: + for coil_number in range(ncoils): + if plot_coils_without_Fourier_fit: + R_coils_gamma = jnp.sqrt(coils_gamma[coil_number,:,0]**2 + coils_gamma[coil_number,:,1]**2) + ax.plot(R_coils_gamma, coils_gamma[coil_number,:,2], color='#b87333', linewidth=2, label='Coils from Near-Axis' if coil_number==0 else '_nolegend_') + R_curve = jnp.sqrt(curves.gamma[coil_number,:,0]**2 + curves.gamma[coil_number,:,1]**2) + ax.plot(R_curve, curves.gamma[coil_number,:,2], '-', color='blue', linewidth=2, label='Coils' if coil_number==0 else '_nolegend_') +if plot_coils_without_Fourier_fit: + ax.plot([], [], color='k', label='Fieldlines from Coils from Near-Axis') +ax.plot([], [], color='b', label='Fieldlines') +ax.set_xlabel('R (m)') +ax.set_ylabel('Z (m)') +ax.legend() +plt.tight_layout() + +if plot_difference_varphi_phi: + itheta = 0 # ntheta // 2 + theta1D = jnp.linspace(0, 2 * jnp.pi, ntheta) + varphi1D = jnp.linspace(0, 2 * jnp.pi / nfp, nphi) + varphi2D, theta2D = jnp.meshgrid(varphi1D, theta1D, indexing='ij') + phi2D = vmap(lambda theta_row, varphi_row: vmap(lambda theta, varphi: field_nearaxis.phi_of_theta_varphi(r_coils, theta, varphi))(theta_row, varphi_row))(theta2D, varphi2D) + plt.figure(figsize=(8,6)) + plt.plot(varphi2D[:,itheta], label='varphi') + plt.plot(phi2D[:,itheta], label='phi') + plt.legend() + plt.title(f'Conversion from varphi to phi at theta={itheta}') + plt.grid() + plt.tight_layout() + +plt.show() \ No newline at end of file diff --git a/examples/optimize_coils_and_nearaxis.py b/examples/optimize_coils_and_nearaxis.py index edb96be..00fa24e 100644 --- a/examples/optimize_coils_and_nearaxis.py +++ b/examples/optimize_coils_and_nearaxis.py @@ -1,3 +1,6 @@ +import os +number_of_processors_to_use = 4 # Parallelization, this should divide nfieldlines +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' from time import time import jax.numpy as jnp import matplotlib.pyplot as plt diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_adam.py b/examples/optimize_coils_particle_confinement_guidingcenter_adam.py index 80c2b8d..f5e66ab 100644 --- a/examples/optimize_coils_particle_confinement_guidingcenter_adam.py +++ b/examples/optimize_coils_particle_confinement_guidingcenter_adam.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt from essos.dynamics import Particles, Tracing from essos.coils import Coils, CreateEquallySpacedCurves,Curves -from essos.objective_functions import loss_particle_r_cross_max +from essos.objective_functions import loss_particle_r_cross_max_constraint from essos.objective_functions import loss_coil_curvature,loss_coil_length, loss_normB_axis_average from functools import partial import optax @@ -55,7 +55,7 @@ curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) -r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model = model,num_steps=num_steps) +r_max_partial = partial(loss_particle_r_cross_max_constraint, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model = model,num_steps=num_steps) params=coils_initial.x optimizer=optax.adabelief(learning_rate=0.003) @@ -111,8 +111,7 @@ def update(params,opt_state): ax4.set_ylabel('Z (m)')#ax4.legend() plt.tight_layout() # plt.savefig(f'opt_adam.pdf') -plt.show() - +plt.savefig('optimize_coils_particle_confinement_guidingcenter_adam.png', dpi=300) # # Save the coils to a json file # coils_optimized.to_json("stellarator_coils.json") # # Load the coils from a json file diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_adam_constrained.py b/examples/optimize_coils_particle_confinement_guidingcenter_adam_constrained.py deleted file mode 100644 index b878eb5..0000000 --- a/examples/optimize_coils_particle_confinement_guidingcenter_adam_constrained.py +++ /dev/null @@ -1,156 +0,0 @@ - -import os -number_of_processors_to_use = 1 # Parallelization, this should divide nparticles -os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' -from time import time -import jax -print(jax.devices()) -jax.config.update("jax_enable_x64", True) -import jax.numpy as jnp -import matplotlib.pyplot as plt -from essos.dynamics import Particles, Tracing -from essos.coils import Coils, CreateEquallySpacedCurves,Curves -from essos.optimization import optimize_loss_function -from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c -from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis,loss_normB_axis_average -from functools import partial -import essos.alm_convex as alm -import optax - - -# Optimization parameters -target_B_on_axis = 5.7 -max_coil_length = 31 -max_coil_curvature = 0.4 -nparticles = number_of_processors_to_use*10 -order_Fourier_series_coils = 4 -number_coil_points = 80 -maximum_function_evaluations = 30 -maxtimes = [1.e-5] -num_steps=100 -number_coils_per_half_field_period = 3 -number_of_field_periods = 2 -model = 'GuidingCenterAdaptative' - -# Initialize coils -current_on_each_coil = 1.84e7 -major_radius_coils = 7.75 -minor_radius_coils = 4.45 -curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, - order=order_Fourier_series_coils, - R=major_radius_coils, r=minor_radius_coils, - n_segments=number_coil_points, - nfp=number_of_field_periods, stellsym=True) -coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) - -len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) -nfp = coils_initial.nfp -stellsym = coils_initial.stellsym -n_segments = coils_initial.n_segments -dofs_curves_shape = coils_initial.dofs_curves.shape -currents_scale = coils_initial.currents_scale - -# Initialize particles -phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) -initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T -particles = Particles(initial_xyz=initial_xyz) - -t=maxtimes[0] -loss_partial = partial(loss_particle_gamma_c,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) -curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) -length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) -Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) -r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) - - -# Create the constraints -penalty = 1.05 #Intial penalty values -multiplier=1.0 #Initial lagrange multiplier values -sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative -constraints = alm.combine( -alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -#alm.eq(r_max_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -) - - - -model_lagrange='Mu_Tolerance' #Options: Mu_Constant, Mu_Monotonic, Mu_Conditional,Mu_Adaptative -beta=2. #penalty update parameter -mu_max=1.e4 #Maximum penalty parameter allowed -alpha=0.99 # -gamma=1.e-2 -epsilon=1.e-8 -omega_tol=1. #grad_tolerance, associated with grad of lagrangian to main parameters -eta_tol=1.e-6 #contrained tolerances, associated with variation of contraints -optimizer=optax.adabelief(learning_rate=0.003,nesterov=True) - - -ALM=alm.ALM_model(optimizer,constraints,model_lagrange=model_lagrange,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) - -lagrange_params=constraints.init(coils_initial.x) -params = coils_initial.x, lagrange_params -opt_state,grad,info=ALM.init(params) -mu_average=alm.penalty_average(lagrange_params) -#omega=1.#1./mu_average -#eta=1000.#1./mu_average**0.1 -omega=1./mu_average -eta=1./mu_average**0.1 - -i=0 -while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): - params, opt_state,grad,info,eta,omega = ALM.update(params,opt_state,grad,info,eta,omega) #One step of ALM optimization - #if i % 5 == 0: - #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') - print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') - print('lagrange',params[1]) - i=i+1 - - -dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) -dofs_currents = params[0][len_dofs_curves:] -curves = Curves(dofs_curves, n_segments, nfp, stellsym) -new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) -params=new_coils.x -tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model - ,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) -tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) - -#print('Final params',params) -#print(info[1]) -# Plot trajectories, before and after optimization -fig = plt.figure(figsize=(9, 8)) -ax1 = fig.add_subplot(221, projection='3d') -ax2 = fig.add_subplot(222, projection='3d') -ax3 = fig.add_subplot(223) -ax4 = fig.add_subplot(224) - -coils_initial.plot(ax=ax1, show=False) -tracing_initial.plot(ax=ax1, show=False) -for i, trajectory in enumerate(tracing_initial.trajectories): - ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') - -ax3.set_xlabel('R (m)') -ax3.set_ylabel('Z (m)') -#ax3.legend() -new_coils.plot(ax=ax2, show=False) -tracing_optimized.plot(ax=ax2, show=False) -for i, trajectory in enumerate(tracing_optimized.trajectories): - ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') -ax4.set_xlabel('R (m)') -ax4.set_ylabel('Z (m)')#ax4.legend() -plt.tight_layout() -plt.savefig(f'opt_constrained.pdf') - -# # Save the coils to a json file -# coils_optimized.to_json("stellarator_coils.json") -# # Load the coils from a json file -# from essos.coils import Coils_from_json -# coils = Coils_from_json("stellarator_coils.json") - -# # Save results in vtk format to analyze in Paraview -# tracing_initial.to_vtk('trajectories_initial') -#tracing_optimized.to_vtk('trajectories_final') -#coils_initial.to_vtk('coils_initial') -#new_coils.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_jaxopt_constrained.py b/examples/optimize_coils_particle_confinement_guidingcenter_jaxopt_constrained.py deleted file mode 100644 index a9504d5..0000000 --- a/examples/optimize_coils_particle_confinement_guidingcenter_jaxopt_constrained.py +++ /dev/null @@ -1,156 +0,0 @@ - -import os -number_of_processors_to_use = 1 # Parallelization, this should divide nparticles -os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' -from time import time -import jax -print(jax.devices()) -jax.config.update("jax_enable_x64", True) -import jax.numpy as jnp -import matplotlib.pyplot as plt -from essos.dynamics import Particles, Tracing -from essos.coils import Coils, CreateEquallySpacedCurves,Curves -from essos.optimization import optimize_loss_function -from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c -from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis,loss_normB_axis_average -from functools import partial -import essos.alm_convex as alm -import optax - - -# Optimization parameters -target_B_on_axis = 5.7 -max_coil_length = 31 -max_coil_curvature = 0.4 -nparticles = number_of_processors_to_use*10 -order_Fourier_series_coils = 4 -number_coil_points = 80 -maximum_function_evaluations = 10 -maxtimes = [2.e-5] -num_steps=100 -number_coils_per_half_field_period = 3 -number_of_field_periods = 2 -model = 'GuidingCenterAdaptative' - -# Initialize coils -current_on_each_coil = 1.84e7 -major_radius_coils = 7.75 -minor_radius_coils = 4.45 -curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, - order=order_Fourier_series_coils, - R=major_radius_coils, r=minor_radius_coils, - n_segments=number_coil_points, - nfp=number_of_field_periods, stellsym=True) -coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) - -len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) -nfp = coils_initial.nfp -stellsym = coils_initial.stellsym -n_segments = coils_initial.n_segments -dofs_curves_shape = coils_initial.dofs_curves.shape -currents_scale = coils_initial.currents_scale - -# Initialize particles -phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) -initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T -particles = Particles(initial_xyz=initial_xyz) - -t=maxtimes[0] -loss_partial = partial(loss_particle_gamma_c,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) -curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) -length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) -Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) -r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) - - -# Create the constraints -penalty = 1.05 #Intial penalty values -multiplier=0.0 #Initial lagrange multiplier values -sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative -constraints = alm.combine( -alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -alm.eq(r_max_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -) - - - -model_lagrange='Mu_Tolerance' #Options: Mu_Constant, Mu_Monotonic, Mu_Conditional,Mu_Adaptative -beta=2. #penalty update parameter -mu_max=1.e4 #Maximum penalty parameter allowed -alpha=0.99 # -gamma=1.e-2 -epsilon=1.e-8 -omega_tol=1.e-5 #grad_tolerance, associated with grad of lagrangian to main parameters -eta_tol=1.e-6 #contrained tolerances, associated with variation of contraints -optimizer='SLSQP' - - -ALM=alm.ALM_model_jaxopt(constraints,optimizer=optimizer,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) - -lagrange_params=constraints.init(coils_initial.x) -params = coils_initial.x, lagrange_params -lag_state,grad,info=ALM.init(params) -mu_average=alm.penalty_average(lagrange_params) -#omega=1.#1./mu_average -#eta=1000.#1./mu_average**0.1 -omega=1./mu_average -eta=1./mu_average**0.1 - -i=0 -while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): - params, lag_state,grad,info,eta,omega = ALM.update(params,lag_state,grad,info,eta,omega) #One step of ALM optimization - #if i % 5 == 0: - #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') - print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') - print('lagrange',params[1]) - i=i+1 - - -dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) -dofs_currents = params[0][len_dofs_curves:] -curves = Curves(dofs_curves, n_segments, nfp, stellsym) -new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) -params=new_coils.x -tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model - ,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) -tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) - -#print('Final params',params) -#print(info[1]) -# Plot trajectories, before and after optimization -fig = plt.figure(figsize=(9, 8)) -ax1 = fig.add_subplot(221, projection='3d') -ax2 = fig.add_subplot(222, projection='3d') -ax3 = fig.add_subplot(223) -ax4 = fig.add_subplot(224) - -coils_initial.plot(ax=ax1, show=False) -tracing_initial.plot(ax=ax1, show=False) -for i, trajectory in enumerate(tracing_initial.trajectories): - ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') - -ax3.set_xlabel('R (m)') -ax3.set_ylabel('Z (m)') -#ax3.legend() -new_coils.plot(ax=ax2, show=False) -tracing_optimized.plot(ax=ax2, show=False) -for i, trajectory in enumerate(tracing_optimized.trajectories): - ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') -ax4.set_xlabel('R (m)') -ax4.set_ylabel('Z (m)')#ax4.legend() -plt.tight_layout() -plt.savefig(f'opt_constrained.pdf') - -# # Save the coils to a json file -# coils_optimized.to_json("stellarator_coils.json") -# # Load the coils from a json file -# from essos.coils import Coils_from_json -# coils = Coils_from_json("stellarator_coils.json") - -# # Save results in vtk format to analyze in Paraview -# tracing_initial.to_vtk('trajectories_initial') -#tracing_optimized.to_vtk('trajectories_final') -#coils_initial.to_vtk('coils_initial') -#new_coils.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py b/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py index 26f3aa2..2ee7770 100644 --- a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py +++ b/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py @@ -8,7 +8,7 @@ from essos.dynamics import Particles, Tracing from essos.coils import Coils, CreateEquallySpacedCurves,Curves from essos.optimization import optimize_loss_function -from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c +from essos.objective_functions import loss_particle_r_cross_final,loss_particle_radial_drift,loss_particle_gamma_c from essos.objective_functions import loss_coil_curvature_new,loss_coil_length_new,loss_normB_axis,loss_normB_axis_average from functools import partial import optax @@ -56,9 +56,9 @@ curvature_partial=partial(loss_coil_curvature_new, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) length_partial=partial(loss_coil_length_new, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) -r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model = model,num_steps=num_steps) +r_max_partial = partial(loss_particle_r_cross_final, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model = model,num_steps=num_steps) def total_loss(params): - return jnp.linalg.norm(jnp.concatenate((r_max_partial(params),length_partial(params),curvature_partial(params),Baxis_average_partial(params))))**2 + return jnp.linalg.norm(jnp.concatenate((jnp.ravel(r_max_partial(params)),length_partial(params),curvature_partial(params),Baxis_average_partial(params))))**2 params=coils_initial.x optimizer=optax.lbfgs() diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs_constrained.py b/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs_constrained.py deleted file mode 100644 index 3bfef5e..0000000 --- a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs_constrained.py +++ /dev/null @@ -1,155 +0,0 @@ - -import os -number_of_processors_to_use = 1 # Parallelization, this should divide nparticles -os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' -from time import time -import jax -print(jax.devices()) -jax.config.update("jax_enable_x64", True) -import jax.numpy as jnp -import matplotlib.pyplot as plt -from essos.dynamics import Particles, Tracing -from essos.coils import Coils, CreateEquallySpacedCurves,Curves -from essos.optimization import optimize_loss_function -from essos.objective_functions import loss_particle_r_cross_final_new,loss_particle_r_cross_max,loss_particle_radial_drift,loss_particle_gamma_c -from essos.objective_functions import loss_coil_curvature,loss_coil_length,loss_normB_axis,loss_normB_axis_average -from functools import partial -import essos.alm_convex as alm -import optax - - -# Optimization parameters -target_B_on_axis = 5.7 -max_coil_length = 31 -max_coil_curvature = 0.4 -nparticles = number_of_processors_to_use*10 -order_Fourier_series_coils = 4 -number_coil_points = 80 -maximum_function_evaluations = 30 -maxtimes = [1.e-5] -num_steps=100 -number_coils_per_half_field_period = 3 -number_of_field_periods = 2 -model = 'GuidingCenterAdaptative' - -# Initialize coils -current_on_each_coil = 1.84e7 -major_radius_coils = 7.75 -minor_radius_coils = 4.45 -curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, - order=order_Fourier_series_coils, - R=major_radius_coils, r=minor_radius_coils, - n_segments=number_coil_points, - nfp=number_of_field_periods, stellsym=True) -coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) - -len_dofs_curves = len(jnp.ravel(coils_initial.dofs_curves)) -nfp = coils_initial.nfp -stellsym = coils_initial.stellsym -n_segments = coils_initial.n_segments -dofs_curves_shape = coils_initial.dofs_curves.shape -currents_scale = coils_initial.currents_scale - -# Initialize particles -phi_array = jnp.linspace(0, 2*jnp.pi, nparticles) -initial_xyz=jnp.array([major_radius_coils*jnp.cos(phi_array), major_radius_coils*jnp.sin(phi_array), 0*phi_array]).T -particles = Particles(initial_xyz=initial_xyz) - -t=maxtimes[0] -loss_partial = partial(loss_particle_gamma_c,particles=particles, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) -curvature_partial=partial(loss_coil_curvature, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_curvature=max_coil_curvature) -length_partial=partial(loss_coil_length, dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,max_coil_length=max_coil_length) -Baxis_average_partial=partial(loss_normB_axis_average,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) -r_max_partial = partial(loss_particle_r_cross_max, particles=particles,dofs_curves=coils_initial.dofs_curves, currents_scale=currents_scale, nfp=nfp, n_segments=n_segments, stellsym=stellsym,maxtime=t,model=model,num_steps=num_steps) - - -# Create the constraints -penalty = 1.05 #Intial penalty values -multiplier=1.0 #Initial lagrange multiplier values -sq_grad=0.0 #Initial square gradient parameter value for Mu adaptative -constraints = alm.combine( -alm.eq(curvature_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -alm.eq(length_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -alm.eq(Baxis_average_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -#alm.eq(r_max_partial, multiplier=multiplier,penalty=penalty,sq_grad=sq_grad), -) - - - -model_lagrange='Mu_Tolerance_LBFGS' #Options: Mu_Constant, Mu_Monotonic, Mu_Conditional,Mu_Adaptative -beta=2. #penalty update parameter -mu_max=1.e4 #Maximum penalty parameter allowed -alpha=0.99 # -gamma=1.e-2 -epsilon=1.e-8 -omega_tol=0.01 #grad_tolerance, associated with grad of lagrangian to main parameters -eta_tol=1.e-6 #contrained tolerances, associated with variation of contraints -optimizer=optax.lbfgs(linesearch=optax.scale_by_zoom_linesearch(max_linesearch_steps=15)) - -ALM=alm.ALM_model(optimizer,constraints,model_lagrange=model_lagrange,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol) - -lagrange_params=constraints.init(coils_initial.x) -params = coils_initial.x, lagrange_params -opt_state,grad,value,info=ALM.init(params) -mu_average=alm.penalty_average(lagrange_params) -#omega=1.#1./mu_average -#eta=1000.#1./mu_average**0.1 -omega=1./mu_average -eta=1./mu_average**0.1 - -i=0 -while i<=maximum_function_evaluations and (jnp.linalg.norm(grad[0])>omega_tol or alm.norm_constraints(info[2])>eta_tol): - params, opt_state, grad,value,info,eta,omega = ALM.update(params,opt_state,grad,value,info,eta,omega) #One step of ALM optimization - #if i % 5 == 0: - #print(f'i: {i}, loss f: {info[0]:g}, infeasibility: {alm.total_infeasibility(info[1]):g}') - print(f'i: {i}, loss f: {info[0]:g},loss L: {info[1]:g}, infeasibility: {alm.total_infeasibility(info[2]):g}') - print('lagrange',params[1]) - i=i+1 - - -dofs_curves = jnp.reshape(params[0][:len_dofs_curves], (dofs_curves_shape)) -dofs_currents = params[0][len_dofs_curves:] -curves = Curves(dofs_curves, n_segments, nfp, stellsym) -new_coils = Coils(curves=curves, currents=dofs_currents*coils_initial.currents_scale) -params=new_coils.x -tracing_initial = Tracing(field=coils_initial, particles=particles, maxtime=t, model=model - ,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) -tracing_optimized = Tracing(field=new_coils, particles=particles, maxtime=t, model=model,times_to_trace=200,timestep=1.e-8,atol=1.e-5,rtol=1.e-5) - -#print('Final params',params) -#print(info[1]) -# Plot trajectories, before and after optimization -fig = plt.figure(figsize=(9, 8)) -ax1 = fig.add_subplot(221, projection='3d') -ax2 = fig.add_subplot(222, projection='3d') -ax3 = fig.add_subplot(223) -ax4 = fig.add_subplot(224) - -coils_initial.plot(ax=ax1, show=False) -tracing_initial.plot(ax=ax1, show=False) -for i, trajectory in enumerate(tracing_initial.trajectories): - ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') - -ax3.set_xlabel('R (m)') -ax3.set_ylabel('Z (m)') -#ax3.legend() -new_coils.plot(ax=ax2, show=False) -tracing_optimized.plot(ax=ax2, show=False) -for i, trajectory in enumerate(tracing_optimized.trajectories): - ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') -ax4.set_xlabel('R (m)') -ax4.set_ylabel('Z (m)')#ax4.legend() -plt.tight_layout() -plt.savefig(f'opt_constrained.pdf') - -# # Save the coils to a json file -# coils_optimized.to_json("stellarator_coils.json") -# # Load the coils from a json file -# from essos.coils import Coils_from_json -# coils = Coils_from_json("stellarator_coils.json") - -# # Save results in vtk format to analyze in Paraview -# tracing_initial.to_vtk('trajectories_initial') -#tracing_optimized.to_vtk('trajectories_final') -#coils_initial.to_vtk('coils_initial') -#new_coils.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/optimize_multiple_objectives.py b/examples/optimize_multiple_objectives.py index f38ed8e..b6cf47b 100644 --- a/examples/optimize_multiple_objectives.py +++ b/examples/optimize_multiple_objectives.py @@ -17,7 +17,7 @@ "max_coil_curvature": 0.0, }, opt_config={ - "n_trials": 2, + "n_trials": 4, "maximum_function_evaluations": 300, "tolerance_optimization": 1e-5, "optimizer_choices": ["adam", "amsgrad", "sgd"],