diff --git a/PyPIC3D/J.py b/PyPIC3D/J.py deleted file mode 100644 index 2a8620d..0000000 --- a/PyPIC3D/J.py +++ /dev/null @@ -1,632 +0,0 @@ -import jax -from jax import jit -import jax.numpy as jnp -from functools import partial -from jax import lax -# import external libraries - -from PyPIC3D.utils import digital_filter, wrap_around, bilinear_filter -from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights - -@partial(jit, static_argnames=("filter",)) -def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): - """ - Compute the current density from the charge density and particle velocities. - - Args: - particles (list): List of particle species, each with methods to get charge, subcell position, resolution, and index. - rho (ndarray): Charge density array. - J (tuple): Current density arrays (Jx, Jy, Jz) for the x, y, and z directions respectively. - constants (dict): Dictionary containing physical constants. - - Returns: - tuple: Updated current density arrays (Jx, Jy, Jz) for the x, y, and z directions respectively. - """ - - if grid is None: - grid = world['grids']['center'] - - dx = world['dx'] - dy = world['dy'] - dz = world['dz'] - Nx = world['Nx'] - Ny = world['Ny'] - Nz = world['Nz'] - # get the world parameters - - Jx, Jy, Jz = J - x_active = Jx.shape[0] != 1 - y_active = Jx.shape[1] != 1 - z_active = Jx.shape[2] != 1 - # infer effective dimensionality from the current-grid shape - - # unpack the values of J - Jx = Jx.at[:, :, :].set(0) - Jy = Jy.at[:, :, :].set(0) - Jz = Jz.at[:, :, :].set(0) - # initialize the current arrays as 0 - J = (Jx, Jy, Jz) - # initialize the current density as a tuple - - for species in particles: - shape_factor = species.get_shape() - charge = species.get_charge() - dq = charge / (dx * dy * dz) - # calculate the charge density contribution per particle - x, y, z = species.get_forward_position() - vx, vy, vz = species.get_velocity() - # get the particles positions and velocities - - x = x - vx * world['dt'] / 2 - y = y - vy * world['dt'] / 2 - z = z - vz * world['dt'] / 2 - # step back to half time step positions for proper time staggering - - x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (x - grid[0][0]) / dx).astype(int), - lambda _: jnp.round( (x - grid[0][0]) / dx).astype(int), - operand=None - ) - y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (y - grid[1][0]) / dy).astype(int), - lambda _: jnp.round( (y - grid[1][0]) / dy).astype(int), - operand=None - ) - z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (z - grid[2][0]) / dz).astype(int), - lambda _: jnp.round( (z - grid[2][0]) / dz).astype(int), - operand=None - ) - # calculate the nearest grid point based on shape factor - - deltax_node = (x - grid[0][0]) - (x0 * dx) - deltay_node = (y - grid[1][0]) - (y0 * dy) - deltaz_node = (z - grid[2][0]) - (z0 * dz) - # Calculate the difference between the particle position and the nearest grid point - - deltax_face = (x - grid[0][0]) - (x0 + 0.5) * dx - deltay_face = (y - grid[1][0]) - (y0 + 0.5) * dy - deltaz_face = (z - grid[2][0]) - (z0 + 0.5) * dz - # Calculate the difference between the particle position and the nearest staggered cell face - - x0 = wrap_around(x0, Nx) - y0 = wrap_around(y0, Ny) - z0 = wrap_around(z0, Nz) - # wrap around the grid points for periodic boundary conditions - x1 = wrap_around(x0+1, Nx) - y1 = wrap_around(y0+1, Ny) - z1 = wrap_around(z0+1, Nz) - # calculate the right grid point - x_minus1 = x0 - 1 - y_minus1 = y0 - 1 - z_minus1 = z0 - 1 - # calculate the left grid point - xpts = [x_minus1, x0, x1] - ypts = [y_minus1, y0, y1] - zpts = [z_minus1, z0, z1] - # place all the points in a list - - x_weights_node, y_weights_node, z_weights_node = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights( deltax_node, deltay_node, deltaz_node, dx, dy, dz), - lambda _: get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz), - operand=None - ) - - x_weights_face, y_weights_face, z_weights_face = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights( deltax_face, deltay_face, deltaz_face, dx, dy, dz), - lambda _: get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz), - operand=None - ) - # get the weights for node and face positions - - xpts = jnp.asarray(xpts) # (Sx, Np) - ypts = jnp.asarray(ypts) # (Sy, Np) - zpts = jnp.asarray(zpts) # (Sz, Np) - - x_weights_face = jnp.asarray(x_weights_face) # (Sx, Np) - y_weights_face = jnp.asarray(y_weights_face) # (Sy, Np) - z_weights_face = jnp.asarray(z_weights_face) # (Sz, Np) - - x_weights_node = jnp.asarray(x_weights_node) # (Sx, Np) - y_weights_node = jnp.asarray(y_weights_node) # (Sy, Np) - z_weights_node = jnp.asarray(z_weights_node) # (Sz, Np) - - # Keep full shape-factor computation but collapse inactive axes to an - # effective stencil of size 1 to avoid redundant deposition work. - if x_active: - xpts_eff = xpts - x_weights_node_eff = x_weights_node - x_weights_face_eff = x_weights_face - else: - xpts_eff = jnp.zeros((1, xpts.shape[1]), dtype=xpts.dtype) - x_weights_node_eff = jnp.sum(x_weights_node, axis=0, keepdims=True) - x_weights_face_eff = jnp.sum(x_weights_face, axis=0, keepdims=True) - - if y_active: - ypts_eff = ypts - y_weights_node_eff = y_weights_node - y_weights_face_eff = y_weights_face - else: - ypts_eff = jnp.zeros((1, ypts.shape[1]), dtype=ypts.dtype) - y_weights_node_eff = jnp.sum(y_weights_node, axis=0, keepdims=True) - y_weights_face_eff = jnp.sum(y_weights_face, axis=0, keepdims=True) - - if z_active: - zpts_eff = zpts - z_weights_node_eff = z_weights_node - z_weights_face_eff = z_weights_face - else: - zpts_eff = jnp.zeros((1, zpts.shape[1]), dtype=zpts.dtype) - z_weights_node_eff = jnp.sum(z_weights_node, axis=0, keepdims=True) - z_weights_face_eff = jnp.sum(z_weights_face, axis=0, keepdims=True) - - ii, jj, kk = jnp.meshgrid( - jnp.arange(xpts_eff.shape[0]), - jnp.arange(ypts_eff.shape[0]), - jnp.arange(zpts_eff.shape[0]), - indexing="ij", - ) - combos = jnp.stack([ii.ravel(), jj.ravel(), kk.ravel()], axis=1) # (Sx*Sy*Sz, 3) - - def idx_and_dJ_values(idx): - i, j, k = idx - # unpack the stencil indices - ix = xpts_eff[i, ...] - iy = ypts_eff[j, ...] - iz = zpts_eff[k, ...] - # get the grid indices for this stencil point - valx = (dq * vx) * x_weights_face_eff[i, ...] * y_weights_node_eff[j, ...] * z_weights_node_eff[k, ...] - valy = (dq * vy) * x_weights_node_eff[i, ...] * y_weights_face_eff[j, ...] * z_weights_node_eff[k, ...] - valz = (dq * vz) * x_weights_node_eff[i, ...] * y_weights_node_eff[j, ...] * z_weights_face_eff[k, ...] - # calculate the current contributions for this stencil point - return ix, iy, iz, valx, valy, valz - - ix, iy, iz, valx, valy, valz = jax.vmap(idx_and_dJ_values)(combos) # each: (M, Np) - # vectorized computation of indices and current contributions - - Jx = Jx.at[(ix, iy, iz)].add(valx, mode="drop") - Jy = Jy.at[(ix, iy, iz)].add(valy, mode="drop") - Jz = Jz.at[(ix, iy, iz)].add(valz, mode="drop") - # deposit the current contributions into the global J arrays - - def filter_func(J_, filter): - J_ = jax.lax.cond( - filter == 'bilinear', - lambda J_: bilinear_filter(J_), - lambda J_: J_, - operand=J_ - ) - - # alpha = constants['alpha'] - # J_ = jax.lax.cond( - # filter == 'digital', - # lambda J_: digital_filter(J_, alpha), - # lambda J_: J_, - # operand=J_ - # ) - return J_ - # define a filtering function - - Jx = filter_func(Jx, filter) - Jy = filter_func(Jy, filter) - Jz = filter_func(Jz, filter) - # apply the selected filter to each component of J - J = (Jx, Jy, Jz) - - return J - -def _roll_old_weights_to_new_frame(old_w_list, shift): - """ - old_w_list: list of 5 arrays, each (Np,) - shift: (Np,) integer = old_i0 - new_i0 (expected in {-1,0,1} for Esirkepov) - Returns a list of 5 arrays rolled per particle so old weights align with new-cell frame. - """ - old_w = jnp.stack(old_w_list, axis=0) # (5, Np) - - def roll_one_particle(w5, s): - return jnp.roll(w5, -s, axis=0) - - rolled = jax.vmap(roll_one_particle, in_axes=(1, 0), out_axes=1)(old_w, shift) # (5,Np) - return [rolled[i, :] for i in range(5)] - - -def Esirkepov_current(particles, J, constants, world, grid=None, filter=None): - """ - Local per-particle Esirkepov deposition that works for 1D/2D/3D by setting inactive dims to size 1. - J is a tuple (Jx,Jy,Jz) arrays shaped (Nx,Ny,Nz). - """ - if grid is None: - grid = world['grids']['center'] - - Jx, Jy, Jz = J - Nx, Ny, Nz = Jx.shape - dx, dy, dz, dt = world["dx"], world["dy"], world["dz"], world["dt"] - xmin, ymin, zmin = grid[0][0], grid[1][0], grid[2][0] - - # zero current arrays - Jx = Jx.at[:, :, :].set(0) - Jy = Jy.at[:, :, :].set(0) - Jz = Jz.at[:, :, :].set(0) - - x_active = (Nx != 1) - y_active = (Ny != 1) - z_active = (Nz != 1) - # determine which axis are null - - for species in particles: - q = species.get_charge() - x, y, z = species.get_forward_position() - vx, vy, vz = species.get_velocity() - shape_factor = species.get_shape() - N_particles = species.get_number_of_particles() - # get the particle properties - - old_x = x - vx * dt - old_y = y - vy * dt - old_z = z - vz * dt - # calculate old positions from new positions and velocities - - x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (x - xmin) / dx).astype(int), - lambda _: jnp.round( (x - xmin) / dx).astype(int), - operand=None - ) - y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (y - ymin) / dy).astype(int), - lambda _: jnp.round( (y - ymin) / dy).astype(int), - operand=None - ) - z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (z - zmin) / dz).astype(int), - lambda _: jnp.round( (z - zmin) / dz).astype(int), - operand=None - ) # calculate the nearest grid point based on shape factor for new positions - - old_x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (old_x - xmin) / dx).astype(int), - lambda _: jnp.round( (old_x - xmin) / dx).astype(int), - operand=None - ) - old_y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (old_y - ymin) / dy).astype(int), - lambda _: jnp.round( (old_y - ymin) / dy).astype(int), - operand=None - ) - old_z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (old_z - zmin) / dz).astype(int), - lambda _: jnp.round( (old_z - zmin) / dz).astype(int), - operand=None - ) # calculate the nearest grid point based on shape factor for old positions - - deltax = (x - xmin) - x0 * dx - deltay = (y - ymin) - y0 * dy - deltaz = (z - zmin) - z0 * dz - # get the difference between the particle position and the nearest grid point - old_deltax = (old_x - xmin) - old_x0 * dx - old_deltay = (old_y - ymin) - old_y0 * dy - old_deltaz = (old_z - zmin) - old_z0 * dz - # get the difference between the particle position and the nearest grid point - - shift_x = x0 - old_x0 - shift_y = y0 - old_y0 - shift_z = z0 - old_z0 - # calculate the shift between old and new grid points - - x0 = wrap_around(x0, Nx) - y0 = wrap_around(y0, Ny) - z0 = wrap_around(z0, Nz) - # wrap around the grid points for periodic boundary conditions - x1 = wrap_around(x0+1, Nx) - y1 = wrap_around(y0+1, Ny) - z1 = wrap_around(z0+1, Nz) - # calculate the right grid point - x2 = wrap_around(x0+2, Nx) - y2 = wrap_around(y0+2, Ny) - z2 = wrap_around(z0+2, Nz) - # calculate the second right grid point - x_minus1 = wrap_around(x0 - 1, Nx) - y_minus1 = wrap_around(y0 - 1, Ny) - z_minus1 = wrap_around(z0 - 1, Nz) - # calculate the left grid point - x_minus2 = wrap_around(x0 - 2, Nx) - y_minus2 = wrap_around(y0 - 2, Ny) - z_minus2 = wrap_around(z0 - 2, Nz) - # calculate the second left grid point - - xpts = [x_minus2, x_minus1, x0, x1, x2] - ypts = [y_minus2, y_minus1, y0, y1, y2] - zpts = [z_minus2, z_minus1, z0, z1, z2] - # place all the points in a list - - xw, yw, zw = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), - lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), - operand=None, - ) - # get the weights for the new positions - oxw, oyw, ozw = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(old_deltax, old_deltay, old_deltaz, dx, dy, dz), - lambda _: get_second_order_weights(old_deltax, old_deltay, old_deltaz, dx, dy, dz), - operand=None, - ) # get the weights for the old positions - - tmp = jnp.zeros_like(xw[0]) - # build the temporary zero array for padding - - xw = [tmp, xw[0], xw[1], xw[2], tmp] - yw = [tmp, yw[0], yw[1], yw[2], tmp] - zw = [tmp, zw[0], zw[1], zw[2], tmp] - # pad the weights to 5 points for consistency - - oxw = [tmp, oxw[0], oxw[1], oxw[2], tmp] - oyw = [tmp, oyw[0], oyw[1], oyw[2], tmp] - ozw = [tmp, ozw[0], ozw[1], ozw[2], tmp] - # pad the old weights to 5 points for consistency - - oxw = _roll_old_weights_to_new_frame(oxw, shift_x) - oyw = _roll_old_weights_to_new_frame(oyw, shift_y) - ozw = _roll_old_weights_to_new_frame(ozw, shift_z) - - # --- build Esirkepov W on compact stencil --- - if x_active and y_active and z_active: - Wx_, Wy_, Wz_ = get_3D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles) - elif (x_active and y_active and (not z_active)) or (x_active and z_active and (not y_active)) or (y_active and z_active and (not x_active)): - null_dim = lax.cond( - not x_active, - lambda _: 0, - lambda _: lax.cond( - not y_active, - lambda _: 1, - lambda _: 2, - operand=None, - ), - operand=None, - ) - # determine which dimension is inactive - - Wx_, Wy_, Wz_ = get_2D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles, null_dim=null_dim) - elif x_active and (not y_active) and (not z_active): - # 1D in x: Esirkepov reduces to 1D continuity; - Wx_, Wy_, Wz_ = get_1D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles, dim=0) - elif y_active and (not x_active) and (not z_active): - Wx_, Wy_, Wz_ = get_1D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles, dim=1) - elif z_active and (not x_active) and (not y_active): - Wx_, Wy_, Wz_ = get_1D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles, dim=2) - - dJx = jax.lax.cond( - x_active, - lambda _: -(q / (dy * dz)) / dt * jnp.ones(N_particles), - lambda _: q * vx / (dx * dy * dz) * jnp.ones(N_particles), - operand=None, - ) - - dJy = jax.lax.cond( - y_active, - lambda _: -(q / (dx * dz)) / dt * jnp.ones(N_particles), - lambda _: q * vy / (dx * dy * dz) * jnp.ones(N_particles), - operand=None, - ) - - dJz = jax.lax.cond( - z_active, - lambda _: -(q / (dx * dy)) / dt * jnp.ones(N_particles), - lambda _: q * vz / (dx * dy * dz) * jnp.ones(N_particles), - operand=None, - ) - # calculate prefactors for current deposition - - # local “difference RHS” - Fx = dJx * Wx_ # (Sx,Sy,Sz,Np) - Fy = dJy * Wy_ - Fz = dJz * Wz_ - - Jx_loc = jnp.zeros_like(Fx) - Jy_loc = jnp.zeros_like(Fy) - Jz_loc = jnp.zeros_like(Fz) - - # Using Backward Finite Difference approach for prefix sum ################################# - # Jx currents - Jx_loc = jnp.cumsum(Fx, axis=0) - # Jy currents - Jy_loc = jnp.cumsum(Fy, axis=1) - # Jz currents - Jz_loc = jnp.cumsum(Fz, axis=2) - # This assumes 5 cells in each dimension for the stencil, but 6 faces (so 5 differences). - # This should give periodic wrap around J(1) = J(6) = 0 as required. - ################################################################################################ - if x_active: - for i in range(5): - for j in range(5): - for k in range(5): - Jx = Jx.at[xpts[i], ypts[j], zpts[k]].add(Jx_loc[i, j, k, :], mode="drop") - # deposit Jx using Esirkepov weights - else: - for i in range(5): - for j in range(5): - for k in range(5): - Jx = Jx.at[xpts[i], ypts[j], zpts[k]].add(Fx[i, j, k, :], mode="drop") - # deposit Jx using midpoint weights for inactive dimension - - if y_active: - for i in range(5): - for j in range(5): - for k in range(5): - Jy = Jy.at[xpts[i], ypts[j], zpts[k]].add(Jy_loc[i, j, k, :], mode="drop") - # deposit Jy using Esirkepov weights - else: - for i in range(5): - for j in range(5): - for k in range(5): - Jy = Jy.at[xpts[i], ypts[j], zpts[k]].add(Fy[i, j, k, :], mode="drop") - # deposit Jy using midpoint weights for inactive dimension - - if z_active: - for i in range(5): - for j in range(5): - for k in range(5): - Jz = Jz.at[xpts[i], ypts[j], zpts[k]].add(Jz_loc[i, j, k, :], mode="drop") - # deposit Jz using Esirkepov weights - - else: - for i in range(5): - for j in range(5): - for k in range(5): - Jz = Jz.at[xpts[i], ypts[j], zpts[k]].add(Fz[i, j, k, :], mode="drop") - # deposit Jz using midpoint weights for inactive dimension - - - return (Jx, Jy, Jz) - - -def get_3D_esirkepov_weights(x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights, N_particles, null_dim=None): - - Wx_ = jnp.zeros( (len(x_weights),len(y_weights),len(z_weights), N_particles) ) - Wy_ = jnp.zeros_like( Wx_) - Wz_ = jnp.zeros_like( Wx_) - - - for i in range(len(x_weights)): - for j in range(len(y_weights)): - for k in range(len(z_weights)): - Wx_ = Wx_.at[i,j,k,:].set( (x_weights[i] - old_x_weights[i]) * ( 1/3 * (y_weights[j] * z_weights[k] + old_y_weights[j] * old_z_weights[k]) \ - + 1/6 * (y_weights[j] * old_z_weights[k] + old_y_weights[j] * z_weights[k]) ) ) - - Wy_ = Wy_.at[i,j,k,:].set( (y_weights[j] - old_y_weights[j]) * ( 1/3 * (x_weights[i] * z_weights[k] + old_x_weights[i] * old_z_weights[k]) \ - + 1/6 * (x_weights[i] * old_z_weights[k] + old_x_weights[i] * z_weights[k]) ) ) - - Wz_ = Wz_.at[i,j,k,:].set( (z_weights[k] - old_z_weights[k]) * ( 1/3 * (x_weights[i] * y_weights[j] + old_x_weights[i] * old_y_weights[j]) \ - + 1/6 * (x_weights[i] * old_y_weights[j] + old_x_weights[i] * y_weights[j]) ) ) - - return Wx_, Wy_, Wz_ - -def get_2D_esirkepov_weights(x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights, N_particles, null_dim=2): - d_Sx = [] - d_Sy = [] - d_Sz = [] - - d_Sx = [ x_weights[i] - old_x_weights[i] for i in range(len(x_weights)) ] - d_Sy = [ y_weights[i] - old_y_weights[i] for i in range(len(y_weights)) ] - d_Sz = [ z_weights[i] - old_z_weights[i] for i in range(len(z_weights)) ] - - Wx_ = jnp.zeros( (len(x_weights),len(y_weights),len(z_weights), N_particles) ) - Wy_ = jnp.zeros_like( Wx_) - Wz_ = jnp.zeros_like( Wx_) - # initialize the weight arrays - - # XY Plane - def xy_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): - for i in range(len(x_weights)): - for j in range(len(y_weights)): - Wx_ = Wx_.at[i,j,2,:].set( 1/2 * d_Sx[i] * ( y_weights[j] + old_y_weights[j] ) ) - Wy_ = Wy_.at[i,j,2,:].set( 1/2 * d_Sy[j] * ( x_weights[i] + old_x_weights[i] ) ) - Wz_ = Wz_.at[i,j,2,:].set( 1/3 * ( x_weights[i] * y_weights[j] + old_x_weights[i] * old_y_weights[j] ) \ - + 1/6 * ( x_weights[i] * old_y_weights[j] + old_x_weights[i] * y_weights[j] ) ) - # Weights if the 2D plane is in the XY plane - - return Wx_, Wy_, Wz_ - - - # XZ Plane - def xz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): - for i in range(len(x_weights)): - for k in range(len(z_weights)): - Wx_ = Wx_.at[i,2,k,:].set( 1/2 * d_Sx[i] * ( z_weights[k] + old_z_weights[k] ) ) - Wy_ = Wy_.at[i,2,k,:].set( 1/3 * ( x_weights[i] * z_weights[k] + old_x_weights[i] * old_z_weights[k] ) \ - + 1/6 * ( x_weights[i] * old_z_weights[k] + old_x_weights[i] * z_weights[k] ) ) - Wz_ = Wz_.at[i,2,k,:].set( 1/2 * d_Sz[k] * ( x_weights[i] + old_x_weights[i] ) ) - # Weights if the 2D plane is in the XZ plane - return Wx_, Wy_, Wz_ - - - # YZ Plane - def yz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): - for j in range(len(y_weights)): - for k in range(len(z_weights)): - Wx_ = Wx_.at[2,j,k,:].set( 1/3 * ( y_weights[j] * z_weights[k] + old_y_weights[j] * old_z_weights[k] ) \ - + 1/6 * ( y_weights[j] * old_z_weights[k] + old_y_weights[j] * z_weights[k] ) ) - Wy_ = Wy_.at[2,j,k,:].set( 1/2 * d_Sy[j] * ( z_weights[k] + old_z_weights[k] ) ) - Wz_ = Wz_.at[2,j,k,:].set( 1/2 * d_Sz[k] * ( y_weights[j] + old_y_weights[j] ) ) - # Weights if the 2D plane is in the YZ plane - return Wx_, Wy_, Wz_ - - - Wx_, Wy_, Wz_ = lax.cond( - null_dim == 0, - lambda _: yz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), - lambda _: lax.cond( - null_dim == 1, - lambda _: xz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), - lambda _: xy_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), - operand=None - ), - operand=None - ) - - return Wx_, Wy_, Wz_ - - -def get_1D_esirkepov_weights(x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights, N_particles, dim=0): - - Wx_ = jnp.zeros( (len(x_weights),len(y_weights),len(z_weights), N_particles) ) - Wy_ = jnp.zeros_like( Wx_) - Wz_ = jnp.zeros_like( Wx_) - # initialize the weight arrays - - def x_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): - for i in range(len(x_weights)): - Wx_ = Wx_.at[i, 2, 2, :].set( (x_weights[i] - old_x_weights[i]) ) - # get the weights for x direction - Wy_ = Wy_.at[i, 2, 2, :].set( (x_weights[i] + old_x_weights[i]) / 2 ) - Wz_ = Wz_.at[i, 2, 2, :].set( (x_weights[i] + old_x_weights[i]) / 2 ) - # use a midpoint average for inactive directions - # weights if x direction is active - return Wx_, Wy_, Wz_ - - def y_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): - for j in range(len(y_weights)): - Wy_ = Wy_.at[2, j, 2, :].set( (y_weights[j] - old_y_weights[j]) ) - # weights for y direction - Wx_ = Wx_.at[2, j, 2, :].set( (y_weights[j] + old_y_weights[j]) / 2 ) - Wz_ = Wz_.at[2, j, 2, :].set( (y_weights[j] + old_y_weights[j]) / 2 ) - # use a midpoint average for inactive directions - # weights if y direction is active - return Wx_, Wy_, Wz_ - - def z_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): - for k in range(len(z_weights)): - Wz_ = Wz_.at[2, 2, k, :].set( (z_weights[k] - old_z_weights[k]) ) - # weights for z direction - Wx_ = Wx_.at[2, 2, k, :].set( (z_weights[k] + old_z_weights[k]) / 2 ) - Wy_ = Wy_.at[2, 2, k, :].set( (z_weights[k] + old_z_weights[k]) / 2 ) - # use a midpoint average for inactive directions - # weights if z direction is active - return Wx_, Wy_, Wz_ - - Wx_, Wy_, Wz_ = lax.cond( - dim == 0, - lambda _: x_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), - lambda _: lax.cond( - dim == 1, - lambda _: y_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), - lambda _: z_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), - operand=None, - ), - operand=None, - ) - # determine which dimension is active and calculate weights accordingly - - - return Wx_, Wy_, Wz_ diff --git a/PyPIC3D/__init__.py b/PyPIC3D/__init__.py index bea02e0..5f18f46 100644 --- a/PyPIC3D/__init__.py +++ b/PyPIC3D/__init__.py @@ -6,17 +6,28 @@ import matplotlib.pyplot as plt # import external libraries -jax.config.update('jax_platform_name', 'cpu') - from . import errors from . import boundaryconditions from . import initialization -from . import particle -from .diagnostics import plotting from . import utils -from .solvers import pstd -from .solvers import fdtd from . import boris -from . import rho from . import evolve -from .solvers import vector_potential \ No newline at end of file + +from .solvers import vector_potential +from .solvers import pstd +from .solvers import fdtd +from .solvers import first_order_yee +from .solvers import electrostatic_yee + +from .particles import particle_initialization +from .particles import species_class + +from .deposition import shapes +from .deposition import Esirkepov +from .deposition import J_from_rhov +from .deposition import rho + +from .diagnostics import plotting +from .diagnostics import fluid_quantities +from .diagnostics import openPMD +from .diagnostics import vtk \ No newline at end of file diff --git a/PyPIC3D/__main__.py b/PyPIC3D/__main__.py index 775213c..cdf5461 100644 --- a/PyPIC3D/__main__.py +++ b/PyPIC3D/__main__.py @@ -40,7 +40,7 @@ compute_mass_density ) -from PyPIC3D.rho import compute_rho +from PyPIC3D.deposition.rho import compute_rho # Importing functions from the PyPIC3D package diff --git a/PyPIC3D/boris.py b/PyPIC3D/boris.py index 7510f90..98fc7b4 100644 --- a/PyPIC3D/boris.py +++ b/PyPIC3D/boris.py @@ -2,7 +2,7 @@ from jax import jit import jax.numpy as jnp -from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights +from PyPIC3D.deposition.shapes import get_first_order_weights, get_second_order_weights from PyPIC3D.utils import wrap_around @jit diff --git a/PyPIC3D/deposition/Esirkepov.py b/PyPIC3D/deposition/Esirkepov.py new file mode 100644 index 0000000..f7b2ebe --- /dev/null +++ b/PyPIC3D/deposition/Esirkepov.py @@ -0,0 +1,379 @@ +import jax +from jax import jit +import jax.numpy as jnp +from functools import partial +from jax import lax + +from PyPIC3D.utils import digital_filter, wrap_around, bilinear_filter +from PyPIC3D.deposition.shapes import get_first_order_weights, get_second_order_weights + + +def _roll_old_weights_to_new_frame(old_w_list, shift): + """Roll old weights into the new-cell frame for Esirkepov deposition.""" + old_w = jnp.stack(old_w_list, axis=0) + + def roll_one_particle(w5, s): + return jnp.roll(w5, -s, axis=0) + + rolled = jax.vmap(roll_one_particle, in_axes=(1, 0), out_axes=1)(old_w, shift) + return [rolled[i, :] for i in range(5)] + + +def Esirkepov_current(particles, J, constants, world, grid=None, filter=None): + """Esirkepov current deposition supporting 1D/2D/3D via inactive dims.""" + if grid is None: + grid = world["grids"]["center"] + + Jx, Jy, Jz = J + Nx, Ny, Nz = Jx.shape + dx, dy, dz, dt = world["dx"], world["dy"], world["dz"], world["dt"] + xmin, ymin, zmin = grid[0][0], grid[1][0], grid[2][0] + + Jx = Jx.at[:, :, :].set(0) + Jy = Jy.at[:, :, :].set(0) + Jz = Jz.at[:, :, :].set(0) + + x_active = Nx != 1 + y_active = Ny != 1 + z_active = Nz != 1 + + for species in particles: + q = species.get_charge() + x, y, z = species.get_forward_position() + vx, vy, vz = species.get_velocity() + shape_factor = species.get_shape() + N_particles = species.get_number_of_particles() + + old_x = x - vx * dt + old_y = y - vy * dt + old_z = z - vz * dt + + x0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((x - xmin) / dx).astype(int), + lambda _: jnp.round((x - xmin) / dx).astype(int), + operand=None, + ) + y0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((y - ymin) / dy).astype(int), + lambda _: jnp.round((y - ymin) / dy).astype(int), + operand=None, + ) + z0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((z - zmin) / dz).astype(int), + lambda _: jnp.round((z - zmin) / dz).astype(int), + operand=None, + ) + + old_x0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((old_x - xmin) / dx).astype(int), + lambda _: jnp.round((old_x - xmin) / dx).astype(int), + operand=None, + ) + old_y0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((old_y - ymin) / dy).astype(int), + lambda _: jnp.round((old_y - ymin) / dy).astype(int), + operand=None, + ) + old_z0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((old_z - zmin) / dz).astype(int), + lambda _: jnp.round((old_z - zmin) / dz).astype(int), + operand=None, + ) + + deltax = (x - xmin) - x0 * dx + deltay = (y - ymin) - y0 * dy + deltaz = (z - zmin) - z0 * dz + old_deltax = (old_x - xmin) - old_x0 * dx + old_deltay = (old_y - ymin) - old_y0 * dy + old_deltaz = (old_z - zmin) - old_z0 * dz + + shift_x = x0 - old_x0 + shift_y = y0 - old_y0 + shift_z = z0 - old_z0 + + x0 = wrap_around(x0, Nx) + y0 = wrap_around(y0, Ny) + z0 = wrap_around(z0, Nz) + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) + x2 = wrap_around(x0 + 2, Nx) + y2 = wrap_around(y0 + 2, Ny) + z2 = wrap_around(z0 + 2, Nz) + x_minus1 = wrap_around(x0 - 1, Nx) + y_minus1 = wrap_around(y0 - 1, Ny) + z_minus1 = wrap_around(z0 - 1, Nz) + x_minus2 = wrap_around(x0 - 2, Nx) + y_minus2 = wrap_around(y0 - 2, Ny) + z_minus2 = wrap_around(z0 - 2, Nz) + + xpts = [x_minus2, x_minus1, x0, x1, x2] + ypts = [y_minus2, y_minus1, y0, y1, y2] + zpts = [z_minus2, z_minus1, z0, z1, z2] + + xw, yw, zw = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), + lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), + operand=None, + ) + oxw, oyw, ozw = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(old_deltax, old_deltay, old_deltaz, dx, dy, dz), + lambda _: get_second_order_weights(old_deltax, old_deltay, old_deltaz, dx, dy, dz), + operand=None, + ) + + tmp = jnp.zeros_like(xw[0]) + + xw = [tmp, xw[0], xw[1], xw[2], tmp] + yw = [tmp, yw[0], yw[1], yw[2], tmp] + zw = [tmp, zw[0], zw[1], zw[2], tmp] + + oxw = [tmp, oxw[0], oxw[1], oxw[2], tmp] + oyw = [tmp, oyw[0], oyw[1], oyw[2], tmp] + ozw = [tmp, ozw[0], ozw[1], ozw[2], tmp] + + oxw = _roll_old_weights_to_new_frame(oxw, shift_x) + oyw = _roll_old_weights_to_new_frame(oyw, shift_y) + ozw = _roll_old_weights_to_new_frame(ozw, shift_z) + + if x_active and y_active and z_active: + Wx_, Wy_, Wz_ = get_3D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles) + elif (x_active and y_active and (not z_active)) or (x_active and z_active and (not y_active)) or ( + y_active and z_active and (not x_active) + ): + null_dim = lax.cond( + not x_active, + lambda _: 0, + lambda _: lax.cond( + not y_active, + lambda _: 1, + lambda _: 2, + operand=None, + ), + operand=None, + ) + + Wx_, Wy_, Wz_ = get_2D_esirkepov_weights( + xw, yw, zw, oxw, oyw, ozw, N_particles, null_dim=null_dim + ) + elif x_active and (not y_active) and (not z_active): + Wx_, Wy_, Wz_ = get_1D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles, dim=0) + elif y_active and (not x_active) and (not z_active): + Wx_, Wy_, Wz_ = get_1D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles, dim=1) + elif z_active and (not x_active) and (not y_active): + Wx_, Wy_, Wz_ = get_1D_esirkepov_weights(xw, yw, zw, oxw, oyw, ozw, N_particles, dim=2) + + dJx = jax.lax.cond( + x_active, + lambda _: -(q / (dy * dz)) / dt * jnp.ones(N_particles), + lambda _: q * vx / (dx * dy * dz) * jnp.ones(N_particles), + operand=None, + ) + + dJy = jax.lax.cond( + y_active, + lambda _: -(q / (dx * dz)) / dt * jnp.ones(N_particles), + lambda _: q * vy / (dx * dy * dz) * jnp.ones(N_particles), + operand=None, + ) + + dJz = jax.lax.cond( + z_active, + lambda _: -(q / (dx * dy)) / dt * jnp.ones(N_particles), + lambda _: q * vz / (dx * dy * dz) * jnp.ones(N_particles), + operand=None, + ) + + Fx = dJx * Wx_ + Fy = dJy * Wy_ + Fz = dJz * Wz_ + + Jx_loc = jnp.zeros_like(Fx) + Jy_loc = jnp.zeros_like(Fy) + Jz_loc = jnp.zeros_like(Fz) + + Jx_loc = jnp.cumsum(Fx, axis=0) + Jy_loc = jnp.cumsum(Fy, axis=1) + Jz_loc = jnp.cumsum(Fz, axis=2) + + if x_active: + for i in range(5): + for j in range(5): + for k in range(5): + Jx = Jx.at[xpts[i], ypts[j], zpts[k]].add(Jx_loc[i, j, k, :], mode="drop") + else: + for i in range(5): + for j in range(5): + for k in range(5): + Jx = Jx.at[xpts[i], ypts[j], zpts[k]].add(Fx[i, j, k, :], mode="drop") + + if y_active: + for i in range(5): + for j in range(5): + for k in range(5): + Jy = Jy.at[xpts[i], ypts[j], zpts[k]].add(Jy_loc[i, j, k, :], mode="drop") + else: + for i in range(5): + for j in range(5): + for k in range(5): + Jy = Jy.at[xpts[i], ypts[j], zpts[k]].add(Fy[i, j, k, :], mode="drop") + + if z_active: + for i in range(5): + for j in range(5): + for k in range(5): + Jz = Jz.at[xpts[i], ypts[j], zpts[k]].add(Jz_loc[i, j, k, :], mode="drop") + else: + for i in range(5): + for j in range(5): + for k in range(5): + Jz = Jz.at[xpts[i], ypts[j], zpts[k]].add(Fz[i, j, k, :], mode="drop") + + return (Jx, Jy, Jz) + + +def get_3D_esirkepov_weights( + x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights, N_particles, null_dim=None +): + Wx_ = jnp.zeros((len(x_weights), len(y_weights), len(z_weights), N_particles)) + Wy_ = jnp.zeros_like(Wx_) + Wz_ = jnp.zeros_like(Wx_) + + for i in range(len(x_weights)): + for j in range(len(y_weights)): + for k in range(len(z_weights)): + Wx_ = Wx_.at[i, j, k, :].set( + (x_weights[i] - old_x_weights[i]) + * ( + 1 / 3 * (y_weights[j] * z_weights[k] + old_y_weights[j] * old_z_weights[k]) + + 1 / 6 * (y_weights[j] * old_z_weights[k] + old_y_weights[j] * z_weights[k]) + ) + ) + + Wy_ = Wy_.at[i, j, k, :].set( + (y_weights[j] - old_y_weights[j]) + * ( + 1 / 3 * (x_weights[i] * z_weights[k] + old_x_weights[i] * old_z_weights[k]) + + 1 / 6 * (x_weights[i] * old_z_weights[k] + old_x_weights[i] * z_weights[k]) + ) + ) + + Wz_ = Wz_.at[i, j, k, :].set( + (z_weights[k] - old_z_weights[k]) + * ( + 1 / 3 * (x_weights[i] * y_weights[j] + old_x_weights[i] * old_y_weights[j]) + + 1 / 6 * (x_weights[i] * old_y_weights[j] + old_x_weights[i] * y_weights[j]) + ) + ) + + return Wx_, Wy_, Wz_ + + +def get_2D_esirkepov_weights( + x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights, N_particles, null_dim=2 +): + d_Sx = [x_weights[i] - old_x_weights[i] for i in range(len(x_weights))] + d_Sy = [y_weights[i] - old_y_weights[i] for i in range(len(y_weights))] + d_Sz = [z_weights[i] - old_z_weights[i] for i in range(len(z_weights))] + + Wx_ = jnp.zeros((len(x_weights), len(y_weights), len(z_weights), N_particles)) + Wy_ = jnp.zeros_like(Wx_) + Wz_ = jnp.zeros_like(Wx_) + + def xy_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): + for i in range(len(x_weights)): + for j in range(len(y_weights)): + Wx_ = Wx_.at[i, j, 2, :].set(1 / 2 * d_Sx[i] * (y_weights[j] + old_y_weights[j])) + Wy_ = Wy_.at[i, j, 2, :].set(1 / 2 * d_Sy[j] * (x_weights[i] + old_x_weights[i])) + Wz_ = Wz_.at[i, j, 2, :].set( + 1 / 3 * (x_weights[i] * y_weights[j] + old_x_weights[i] * old_y_weights[j]) + + 1 / 6 * (x_weights[i] * old_y_weights[j] + old_x_weights[i] * y_weights[j]) + ) + return Wx_, Wy_, Wz_ + + def xz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): + for i in range(len(x_weights)): + for k in range(len(z_weights)): + Wx_ = Wx_.at[i, 2, k, :].set(1 / 2 * d_Sx[i] * (z_weights[k] + old_z_weights[k])) + Wy_ = Wy_.at[i, 2, k, :].set( + 1 / 3 * (x_weights[i] * z_weights[k] + old_x_weights[i] * old_z_weights[k]) + + 1 / 6 * (x_weights[i] * old_z_weights[k] + old_x_weights[i] * z_weights[k]) + ) + Wz_ = Wz_.at[i, 2, k, :].set(1 / 2 * d_Sz[k] * (x_weights[i] + old_x_weights[i])) + return Wx_, Wy_, Wz_ + + def yz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): + for j in range(len(y_weights)): + for k in range(len(z_weights)): + Wx_ = Wx_.at[2, j, k, :].set( + 1 / 3 * (y_weights[j] * z_weights[k] + old_y_weights[j] * old_z_weights[k]) + + 1 / 6 * (y_weights[j] * old_z_weights[k] + old_y_weights[j] * z_weights[k]) + ) + Wy_ = Wy_.at[2, j, k, :].set(1 / 2 * d_Sy[j] * (z_weights[k] + old_z_weights[k])) + Wz_ = Wz_.at[2, j, k, :].set(1 / 2 * d_Sz[k] * (y_weights[j] + old_y_weights[j])) + return Wx_, Wy_, Wz_ + + Wx_, Wy_, Wz_ = lax.cond( + null_dim == 0, + lambda _: yz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), + lambda _: lax.cond( + null_dim == 1, + lambda _: xz_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), + lambda _: xy_plane(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), + operand=None, + ), + operand=None, + ) + + return Wx_, Wy_, Wz_ + + +def get_1D_esirkepov_weights( + x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights, N_particles, dim=0 +): + Wx_ = jnp.zeros((len(x_weights), len(y_weights), len(z_weights), N_particles)) + Wy_ = jnp.zeros_like(Wx_) + Wz_ = jnp.zeros_like(Wx_) + + def x_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): + for i in range(len(x_weights)): + Wx_ = Wx_.at[i, 2, 2, :].set((x_weights[i] - old_x_weights[i])) + Wy_ = Wy_.at[i, 2, 2, :].set((x_weights[i] + old_x_weights[i]) / 2) + Wz_ = Wz_.at[i, 2, 2, :].set((x_weights[i] + old_x_weights[i]) / 2) + return Wx_, Wy_, Wz_ + + def y_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): + for j in range(len(y_weights)): + Wy_ = Wy_.at[2, j, 2, :].set((y_weights[j] - old_y_weights[j])) + Wx_ = Wx_.at[2, j, 2, :].set((y_weights[j] + old_y_weights[j]) / 2) + Wz_ = Wz_.at[2, j, 2, :].set((y_weights[j] + old_y_weights[j]) / 2) + return Wx_, Wy_, Wz_ + + def z_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights): + for k in range(len(z_weights)): + Wz_ = Wz_.at[2, 2, k, :].set((z_weights[k] - old_z_weights[k])) + Wx_ = Wx_.at[2, 2, k, :].set((z_weights[k] + old_z_weights[k]) / 2) + Wy_ = Wy_.at[2, 2, k, :].set((z_weights[k] + old_z_weights[k]) / 2) + return Wx_, Wy_, Wz_ + + Wx_, Wy_, Wz_ = lax.cond( + dim == 0, + lambda _: x_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), + lambda _: lax.cond( + dim == 1, + lambda _: y_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), + lambda _: z_active(Wx_, Wy_, Wz_, x_weights, y_weights, z_weights, old_x_weights, old_y_weights, old_z_weights), + operand=None, + ), + operand=None, + ) + + return Wx_, Wy_, Wz_ diff --git a/PyPIC3D/deposition/J_from_rhov.py b/PyPIC3D/deposition/J_from_rhov.py new file mode 100644 index 0000000..c1d977b --- /dev/null +++ b/PyPIC3D/deposition/J_from_rhov.py @@ -0,0 +1,192 @@ +import jax +from jax import jit +import jax.numpy as jnp +from functools import partial +from jax import lax + +from PyPIC3D.utils import digital_filter, wrap_around, bilinear_filter +from PyPIC3D.deposition.shapes import get_first_order_weights, get_second_order_weights + + +@partial(jit, static_argnames=("filter",)) +def J_from_rhov(particles, J, constants, world, grid=None, filter="bilinear"): + """Compute current density (Jx,Jy,Jz) by depositing particle velocities.""" + + if grid is None: + grid = world["grids"]["center"] + + dx = world["dx"] + dy = world["dy"] + dz = world["dz"] + Nx = world["Nx"] + Ny = world["Ny"] + Nz = world["Nz"] + + Jx, Jy, Jz = J + x_active = Jx.shape[0] != 1 + y_active = Jx.shape[1] != 1 + z_active = Jx.shape[2] != 1 + + Jx = Jx.at[:, :, :].set(0) + Jy = Jy.at[:, :, :].set(0) + Jz = Jz.at[:, :, :].set(0) + + for species in particles: + shape_factor = species.get_shape() + charge = species.get_charge() + dq = charge / (dx * dy * dz) + + x, y, z = species.get_forward_position() + vx, vy, vz = species.get_velocity() + + x = x - vx * world["dt"] / 2 + y = y - vy * world["dt"] / 2 + z = z - vz * world["dt"] / 2 + + x0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((x - grid[0][0]) / dx).astype(int), + lambda _: jnp.round((x - grid[0][0]) / dx).astype(int), + operand=None, + ) + y0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((y - grid[1][0]) / dy).astype(int), + lambda _: jnp.round((y - grid[1][0]) / dy).astype(int), + operand=None, + ) + z0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((z - grid[2][0]) / dz).astype(int), + lambda _: jnp.round((z - grid[2][0]) / dz).astype(int), + operand=None, + ) + + deltax_node = (x - grid[0][0]) - (x0 * dx) + deltay_node = (y - grid[1][0]) - (y0 * dy) + deltaz_node = (z - grid[2][0]) - (z0 * dz) + + deltax_face = (x - grid[0][0]) - (x0 + 0.5) * dx + deltay_face = (y - grid[1][0]) - (y0 + 0.5) * dy + deltaz_face = (z - grid[2][0]) - (z0 + 0.5) * dz + + x0 = wrap_around(x0, Nx) + y0 = wrap_around(y0, Ny) + z0 = wrap_around(z0, Nz) + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) + x_minus1 = x0 - 1 + y_minus1 = y0 - 1 + z_minus1 = z0 - 1 + + xpts = [x_minus1, x0, x1] + ypts = [y_minus1, y0, y1] + zpts = [z_minus1, z0, z1] + + x_weights_node, y_weights_node, z_weights_node = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz), + lambda _: get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz), + operand=None, + ) + + x_weights_face, y_weights_face, z_weights_face = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz), + lambda _: get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz), + operand=None, + ) + + xpts = jnp.asarray(xpts) + ypts = jnp.asarray(ypts) + zpts = jnp.asarray(zpts) + + x_weights_face = jnp.asarray(x_weights_face) + y_weights_face = jnp.asarray(y_weights_face) + z_weights_face = jnp.asarray(z_weights_face) + + x_weights_node = jnp.asarray(x_weights_node) + y_weights_node = jnp.asarray(y_weights_node) + z_weights_node = jnp.asarray(z_weights_node) + + if x_active: + xpts_eff = xpts + x_weights_node_eff = x_weights_node + x_weights_face_eff = x_weights_face + else: + xpts_eff = jnp.zeros((1, xpts.shape[1]), dtype=xpts.dtype) + x_weights_node_eff = jnp.sum(x_weights_node, axis=0, keepdims=True) + x_weights_face_eff = jnp.sum(x_weights_face, axis=0, keepdims=True) + + if y_active: + ypts_eff = ypts + y_weights_node_eff = y_weights_node + y_weights_face_eff = y_weights_face + else: + ypts_eff = jnp.zeros((1, ypts.shape[1]), dtype=ypts.dtype) + y_weights_node_eff = jnp.sum(y_weights_node, axis=0, keepdims=True) + y_weights_face_eff = jnp.sum(y_weights_face, axis=0, keepdims=True) + + if z_active: + zpts_eff = zpts + z_weights_node_eff = z_weights_node + z_weights_face_eff = z_weights_face + else: + zpts_eff = jnp.zeros((1, zpts.shape[1]), dtype=zpts.dtype) + z_weights_node_eff = jnp.sum(z_weights_node, axis=0, keepdims=True) + z_weights_face_eff = jnp.sum(z_weights_face, axis=0, keepdims=True) + + ii, jj, kk = jnp.meshgrid( + jnp.arange(xpts_eff.shape[0]), + jnp.arange(ypts_eff.shape[0]), + jnp.arange(zpts_eff.shape[0]), + indexing="ij", + ) + combos = jnp.stack([ii.ravel(), jj.ravel(), kk.ravel()], axis=1) + + def idx_and_dJ_values(idx): + i, j, k = idx + ix = xpts_eff[i, ...] + iy = ypts_eff[j, ...] + iz = zpts_eff[k, ...] + valx = ( + (dq * vx) + * x_weights_face_eff[i, ...] + * y_weights_node_eff[j, ...] + * z_weights_node_eff[k, ...] + ) + valy = ( + (dq * vy) + * x_weights_node_eff[i, ...] + * y_weights_face_eff[j, ...] + * z_weights_node_eff[k, ...] + ) + valz = ( + (dq * vz) + * x_weights_node_eff[i, ...] + * y_weights_node_eff[j, ...] + * z_weights_face_eff[k, ...] + ) + return ix, iy, iz, valx, valy, valz + + ix, iy, iz, valx, valy, valz = jax.vmap(idx_and_dJ_values)(combos) + + Jx = Jx.at[(ix, iy, iz)].add(valx, mode="drop") + Jy = Jy.at[(ix, iy, iz)].add(valy, mode="drop") + Jz = Jz.at[(ix, iy, iz)].add(valz, mode="drop") + + def filter_func(J_, filter): + J_ = jax.lax.cond( + filter == "bilinear", + lambda J_: bilinear_filter(J_), + lambda J_: J_, + operand=J_, + ) + return J_ + + Jx = filter_func(Jx, filter) + Jy = filter_func(Jy, filter) + Jz = filter_func(Jz, filter) + + return (Jx, Jy, Jz) \ No newline at end of file diff --git a/PyPIC3D/deposition/rho.py b/PyPIC3D/deposition/rho.py new file mode 100644 index 0000000..cb364b7 --- /dev/null +++ b/PyPIC3D/deposition/rho.py @@ -0,0 +1,84 @@ +import jax +from jax import jit +import jax.numpy as jnp + +from PyPIC3D.utils import digital_filter, wrap_around +from PyPIC3D.deposition.shapes import get_first_order_weights, get_second_order_weights + + +@jit +def compute_rho(particles, rho, world, constants): + """Compute the charge density (rho) on the vertex grid.""" + dx = world["dx"] + dy = world["dy"] + dz = world["dz"] + grid = world["grids"]["vertex"] + Nx, Ny, Nz = rho.shape + + rho = jnp.zeros_like(rho) + + for species in particles: + shape_factor = species.get_shape() + q = species.get_charge() + dq = q / dx / dy / dz + x, y, z = species.get_position() + + x0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((x - grid[0][0]) / dx).astype(int), + lambda _: jnp.round((x - grid[0][0]) / dx).astype(int), + operand=None, + ) + + y0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((y - grid[1][0]) / dy).astype(int), + lambda _: jnp.round((y - grid[1][0]) / dy).astype(int), + operand=None, + ) + + z0 = jax.lax.cond( + shape_factor == 1, + lambda _: jnp.floor((z - grid[2][0]) / dz).astype(int), + lambda _: jnp.round((z - grid[2][0]) / dz).astype(int), + operand=None, + ) + + deltax = x - (x0 * dx + grid[0][0]) + deltay = y - (y0 * dy + grid[1][0]) + deltaz = z - (z0 * dz + grid[2][0]) + + x0 = wrap_around(x0, Nx) + y0 = wrap_around(y0, Ny) + z0 = wrap_around(z0, Nz) + + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) + + x_minus1 = x0 - 1 + y_minus1 = y0 - 1 + z_minus1 = z0 - 1 + + xpts = [x_minus1, x0, x1] + ypts = [y_minus1, y0, y1] + zpts = [z_minus1, z0, z1] + + x_weights, y_weights, z_weights = jax.lax.cond( + shape_factor == 1, + lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), + lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), + operand=None, + ) + + for i in range(3): + for j in range(3): + for k in range(3): + rho = rho.at[xpts[i], ypts[j], zpts[k]].add( + dq * x_weights[i] * y_weights[j] * z_weights[k], mode="drop" + ) + + alpha = constants["alpha"] + rho = digital_filter(rho, alpha) + + return rho \ No newline at end of file diff --git a/PyPIC3D/deposition/shapes.py b/PyPIC3D/deposition/shapes.py new file mode 100644 index 0000000..5f10353 --- /dev/null +++ b/PyPIC3D/deposition/shapes.py @@ -0,0 +1,54 @@ +from jax import jit +import jax.numpy as jnp + + +@jit +def get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz): + """Calculate the second-order weights for particle deposition. + + Args: + deltax, deltay, deltaz: Particle offsets from the nearest grid point. + dx, dy, dz: Grid spacings. + + Returns: + (x_weights, y_weights, z_weights): each a length-3 list of weights. + """ + Sx0 = (3 / 4) - (deltax / dx) ** 2 + Sy0 = (3 / 4) - (deltay / dy) ** 2 + Sz0 = (3 / 4) - (deltaz / dz) ** 2 + + Sx1 = (1 / 2) * ((1 / 2) + (deltax / dx)) ** 2 + Sy1 = (1 / 2) * ((1 / 2) + (deltay / dy)) ** 2 + Sz1 = (1 / 2) * ((1 / 2) + (deltaz / dz)) ** 2 + + Sx_minus1 = (1 / 2) * ((1 / 2) - (deltax / dx)) ** 2 + Sy_minus1 = (1 / 2) * ((1 / 2) - (deltay / dy)) ** 2 + Sz_minus1 = (1 / 2) * ((1 / 2) - (deltaz / dz)) ** 2 + + x_weights = [Sx_minus1, Sx0, Sx1] + y_weights = [Sy_minus1, Sy0, Sy1] + z_weights = [Sz_minus1, Sz0, Sz1] + + return x_weights, y_weights, z_weights + + +@jit +def get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz): + """Calculate the first-order (CIC) weights for particle deposition.""" + Sx0 = jnp.asarray(1 - deltax / dx) + Sy0 = jnp.asarray(1 - deltay / dy) + Sz0 = jnp.asarray(1 - deltaz / dz) + + Sx1 = jnp.asarray(deltax / dx) + Sy1 = jnp.asarray(deltay / dy) + Sz1 = jnp.asarray(deltaz / dz) + + Sx_minus1 = jnp.zeros_like(Sx0) + Sy_minus1 = jnp.zeros_like(Sy0) + Sz_minus1 = jnp.zeros_like(Sz0) + + x_weights = [Sx_minus1, Sx0, Sx1] + y_weights = [Sy_minus1, Sy0, Sy1] + z_weights = [Sz_minus1, Sz0, Sz1] + + return x_weights, y_weights, z_weights diff --git a/PyPIC3D/diagnostics/fluid_quantities.py b/PyPIC3D/diagnostics/fluid_quantities.py index 781f275..1d25a26 100644 --- a/PyPIC3D/diagnostics/fluid_quantities.py +++ b/PyPIC3D/diagnostics/fluid_quantities.py @@ -1,5 +1,5 @@ -from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights +from PyPIC3D.deposition.shapes import get_first_order_weights, get_second_order_weights from PyPIC3D.utils import wrap_around import jax diff --git a/PyPIC3D/initialization.py b/PyPIC3D/initialization.py index c07e6f1..4833d30 100644 --- a/PyPIC3D/initialization.py +++ b/PyPIC3D/initialization.py @@ -9,7 +9,8 @@ import jax.numpy as jnp #from memory_profiler import profile -from PyPIC3D.particle import ( + +from PyPIC3D.particles.particle_initialization import ( load_particles_from_toml ) @@ -38,7 +39,7 @@ write_openpmd_initial_particles, write_openpmd_initial_fields ) -from PyPIC3D.flat_particles import ( +from PyPIC3D.particles.flat_particles import ( to_flat_particles, check_flat_compat ) @@ -47,9 +48,8 @@ time_loop_electrodynamic, time_loop_electrostatic, time_loop_vector_potential ) -from PyPIC3D.J import ( - J_from_rhov, Esirkepov_current -) +from PyPIC3D.deposition.Esirkepov import Esirkepov_current +from PyPIC3D.deposition.J_from_rhov import J_from_rhov from PyPIC3D.solvers.vector_potential import initialize_vector_potential diff --git a/PyPIC3D/particle.py b/PyPIC3D/particle.py deleted file mode 100644 index 6c43860..0000000 --- a/PyPIC3D/particle.py +++ /dev/null @@ -1,750 +0,0 @@ -import numpy as np -import jax -from jax import jit -from functools import partial -import jax.numpy as jnp -from jax.tree_util import register_pytree_node_class - -from PyPIC3D.utils import vth_to_T, plasma_frequency, debye_length, T_to_vth - -def grab_particle_keys(config): - """ - Extracts and returns a list of keys from the given configuration dictionary - that start with the prefix 'particle'. - - Args: - config (dict): A dictionary containing configuration keys and values. - - Returns: - list: A list of keys from the configuration dictionary that start with 'particle'. - """ - particle_keys = [] - for key in config.keys(): - if key[:8] == 'particle': - particle_keys.append(key) - return particle_keys - -def load_particles_from_toml(config, simulation_parameters, world, constants): - """ - Load particle data from a TOML file and initialize particle species. - Args: - config (dict): Dictionary containing configuration keys and values. - simulation_parameters (dict): Dictionary containing simulation parameters. - world (dict): Dictionary containing world parameters such as 'x_wind', 'y_wind', 'z_wind', 'dx', 'dy', 'dz'. - constants (dict): Dictionary containing constants such as 'kb'. - Returns: - list: A list of particle_species objects initialized with the data from the TOML file. - - The function reads particle configuration from the provided TOML file, initializes particle properties such as - position, velocity, charge, mass, and temperature. It also handles loading initial positions and velocities from - external sources if specified in the TOML file. The particles are then appended to a list and returned. - """ - - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - # get the world dimensions - Nx = world['Nx'] - Ny = world['Ny'] - Nz = world['Nz'] - # get the number of grid points in each dimension - dx = world['dx'] - dy = world['dy'] - dz = world['dz'] - dt = world['dt'] - # get spatial and temporal resolution - kb = constants['kb'] - eps = constants['eps'] - C = constants['C'] - # get the constants - - i = 0 - # initialize the random number generator key - # this is used to generate random numbers for the initial positions and velocities of the particles - # it is incremented by 3 for each particle species to ensure different random numbers for each species - particles = [] - particle_keys = grab_particle_keys(config) - # get the particle keys from the config dictionary - - weight = compute_macroparticle_weight(config, particle_keys, simulation_parameters, world, constants) - # scale the particle weight by the debye length to prevent numerical heating - # this is done by computing the total debye length of the plasma and scaling the particle weight accordingly - - - for toml_key in particle_keys: - key1, key2, key3 = jax.random.key(i), jax.random.key(i+1), jax.random.key(i+2) - i += 3 - # build the particle random number generator keys - particle_name = config[toml_key]['name'] - print(f"\nInitializing particle species: {particle_name}") - charge=config[toml_key]['charge'] - mass=config[toml_key]['mass'] - - if 'N_particles' in config[toml_key]: - N_particles=config[toml_key]['N_particles'] - N_per_cell = N_particles / (world['Nx'] * world['Ny'] * world['Nz']) - elif "N_per_cell" in config[toml_key]: - N_per_cell = config[toml_key]["N_per_cell"] - N_particles = int(N_per_cell * world['Nx'] * world['Ny'] * world['Nz']) - # set the number of particles in the species - - if 'temperature' in config[toml_key]: - T=config[toml_key]['temperature'] - vth = T_to_vth(T, mass, kb) - elif 'vth' in config[toml_key]: - vth = config[toml_key]['vth'] - T = vth_to_T(vth, mass, kb) - else: - T = 1.0 - vth = T_to_vth(T, mass, kb) - # set the temperature of the particle species - - Tx = read_value('Tx', toml_key, config, T) - Ty = read_value('Ty', toml_key, config, T) - Tz = read_value('Tz', toml_key, config, T) - # set the temperature of the particle species in each dimension - - xmin = read_value('xmin', toml_key, config, -x_wind / 2) - xmax = read_value('xmax', toml_key, config, x_wind / 2) - ymin = read_value('ymin', toml_key, config, -y_wind / 2) - ymax = read_value('ymax', toml_key, config, y_wind / 2) - zmin = read_value('zmin', toml_key, config, -z_wind / 2) - zmax = read_value('zmax', toml_key, config, z_wind / 2) - # set the bounds for the particle species - x, y, z, vx, vy, vz = initial_particles(N_per_cell, N_particles, xmin, xmax, ymin, ymax, zmin, zmax, mass, Tx, Ty, Tz, kb, key1, key2, key3) - # initialize the positions and velocities of the particles - - x_bc = 'periodic' - if 'x_bc' in config[toml_key]: - assert config[toml_key]['x_bc'] in ['periodic', 'reflecting'], f"Invalid x boundary condition: {config[toml_key]['x_bc']}" - x_bc = config[toml_key]['x_bc'] - y_bc = 'periodic' - if 'y_bc' in config[toml_key]: - assert config[toml_key]['y_bc'] in ['periodic', 'reflecting'], f"Invalid y boundary condition: {config[toml_key]['y_bc']}" - y_bc = config[toml_key]['y_bc'] - z_bc = 'periodic' - if 'z_bc' in config[toml_key]: - assert config[toml_key]['z_bc'] in ['periodic', 'reflecting'], f"Invalid z boundary condition: {config[toml_key]['z_bc']}" - z_bc = config[toml_key]['z_bc'] - # set the boundary conditions for the particle species - - x = load_initial_positions('initial_x', config, toml_key, x, N_particles, dx, Nx, key1) - y = load_initial_positions('initial_y', config, toml_key, y, N_particles, dy, Ny, key2) - z = load_initial_positions('initial_z', config, toml_key, z, N_particles, dz, Nz, key3) - # load the initial positions of the particles from the toml file, if specified - # otherwise, use the initialized positions - vx = load_initial_velocities('initial_vx', config, toml_key, vx, N_particles) - vy = load_initial_velocities('initial_vy', config, toml_key, vy, N_particles) - vz = load_initial_velocities('initial_vz', config, toml_key, vz, N_particles) - # load the initial velocities of the particles from the toml file, if specified - # otherwise, use the initialized velocities - - # Calculate the temperature from the velocities if not explicitly set - if 'temperature' not in config[toml_key]: - T = (mass / (3 * kb * N_particles)) * ( - jnp.sum(vx ** 2) + jnp.sum(vy ** 2) + jnp.sum(vz ** 2) - ) - - if "weight" in config[toml_key]: - weight = config[toml_key]['weight'] - # set the weight of the particles, if specified in the toml file - elif 'ds_per_debye' in config[toml_key]: # account for anisotropic grid spacings via ds2 - ds_per_debye = config[toml_key]['ds_per_debye'] - - ds2 = 0 - for d in [dx, dy, dz]: - if d != 1: - ds2 += d**2 - - if ds2 == 0: - raise ValueError( - "Invalid configuration for 'ds_per_debye': at least one of dx, dy, dz must differ from 1." - ) - weight = (x_wind*y_wind*z_wind * eps * kb * T) / (N_particles * charge**2 * ds_per_debye**2 * ds2) - # weight the particles by the debye length and the number of particles - - update_pos = read_value('update_pos', toml_key, config, True) - update_v = read_value('update_v', toml_key, config, True) - update_vx = read_value('update_vx', toml_key, config, True) - update_vy = read_value('update_vy', toml_key, config, True) - update_vz = read_value('update_vz', toml_key, config, True) - update_x = read_value('update_x', toml_key, config, True) - update_y = read_value('update_y', toml_key, config, True) - update_z = read_value('update_z', toml_key, config, True) - - particle = particle_species( - name=particle_name, - N_particles=N_particles, - charge=charge, - mass=mass, - T=T, - x1=x, - x2=y, - x3=z, - v1=vx, - v2=vy, - v3=vz, - xwind=x_wind, - ywind=y_wind, - zwind=z_wind, - dx=dx, - dy=dy, - dz=dz, - weight=weight, - x_bc=x_bc, - y_bc=y_bc, - z_bc=z_bc, - update_vx=update_vx, - update_vy=update_vy, - update_vz=update_vz, - update_x=update_x, - update_y=update_y, - update_z=update_z, - update_pos=update_pos, - update_v=update_v, - shape=simulation_parameters['shape_factor'], - dt=dt - ) - particles.append(particle) - - pf = plasma_frequency(particle, world, constants) - dl = debye_length(particle, world, constants) - print(f"Number of particles: {N_particles}") - print(f"Number of particles per cell: {N_per_cell}") - print(f"x, y, z boundary conditions: {x_bc}, {y_bc}, {z_bc}") - print(f"Charge: {charge}") - print(f"Mass: {mass}") - print(f"Temperature: {T}") - print(f"Thermal Velocity: {vth}") - print(f"Particle Kinetic Energy: {particle.kinetic_energy()}") - print(f"Particle Species Plasma Frequency: {pf}") - print(f"Time Steps Per Plasma Period: {(1 / (dt * pf) )}") - print(f"Particle Species Debye Length: {dl}") - print(f"Particle Weight: {weight}") - print(f"Particle Species Scaled Charge: {particle.get_charge()}") - print(f"Particle Species Scaled Mass: {particle.get_mass()}") - - return particles - - -def read_value(param, key, config, default_value): - """ - Reads a value from a nested dictionary structure and returns it if it exists; - otherwise, returns a default value. - - Args: - param (str): The parameter name to look for in the nested dictionary. - key (str): The key in the outer dictionary where the nested dictionary is located. - config (dict): The configuration dictionary containing nested dictionaries. - default_value (Any): The value to return if the parameter is not found. - - Returns: - Any: The value associated with `param` in `config[key]` if it exists, - otherwise `default_value`. - """ - if param in config[key]: - print(f'Reading user defined {param}') - return config[key][param] - else: - return default_value - - -def load_initial_positions(param, config, key, default, N_particles, ds, ns, key1): - """ - Load initial positions for particles based on the provided configuration. - - This function checks if a specific parameter exists in the configuration - under the given key. If the parameter exists and is a string, it loads - the data from an external source. If the parameter exists and is a number, - it creates an array filled with that value. If the parameter does not - exist, it returns the default value. - - Args: - param (str): The name of the parameter to look for in the configuration. - config (dict): The configuration dictionary containing parameters and values. - key (str): The key in the configuration dictionary under which the parameter is stored. - default (Any): The default value to return if the parameter is not found. - N_particles (int): The number of particles, used to determine the size of the array. - ds (float): The spatial resolution, used to add noise to the particle positions. - key1 (jax.random.PRNGKey): The random key for generating random numbers. - - Returns: - jax.numpy.ndarray or Any: An array of particle positions if the parameter is found, - or the default value if the parameter is not found. - """ - if param in config[key]: - if isinstance(config[key][param], str): - print(f"Loading {param} from external source: {config[key][param]}") - return jnp.load(config[key][param]) - # if the value is a string, load it from an external source - else: - #return jnp.full(N_particles, config[key][param]) - val = config[key][param] - if ns == 1: - return val * jnp.ones(N_particles) - else: - return jax.random.uniform(key1, shape=(N_particles,), minval=val-(ds/2), maxval=val+(ds/2)) - # if the value is a number, fill the array with that value with some noise in the subcell position - else: - return default - # return the default value if the parameter is not found - -def load_initial_velocities(param, config, key, default, N_particles): - """ - Load initial velocities for particles based on the provided configuration. - - This function checks if a specific parameter exists in the configuration - dictionary under the given key. Depending on the type of the parameter's - value, it either loads data from an external source or initializes an array - with a specified value. If the parameter is not found, a default value is returned. - - Args: - param (str): The name of the parameter to look for in the configuration. - config (dict): A dictionary containing configuration data. - key (str): The key in the configuration dictionary where the parameter is located. - default (float or jnp.ndarray): The default value to return if the parameter is not found. - N_particles (int): The number of particles, used to determine the size of the array. - - Returns: - jnp.ndarray: An array of initial velocities for the particles. If the parameter - is a string, the array is loaded from an external source. If the parameter is - a number, the array is filled with that value plus the default. If the parameter - is not found, the default value is returned. - """ - if param in config[key]: - if isinstance(config[key][param], str): - print(f"Loading {param} from external source: {config[key][param]}") - return jnp.load(config[key][param]) - # if the value is a string, load it from an external source - else: - return jnp.full(N_particles, config[key][param]) + default - # if the value is a number, fill the array with that value - else: - return default - # return the default value if the parameter is not found - -def compute_macroparticle_weight(config, particle_keys, simulation_parameters, world, constants): - - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - # get the world dimensions - dx = world['dx'] - dy = world['dy'] - dz = world['dz'] - # get the world resolution - kb = constants['kb'] - eps = constants['eps'] - # get the constants - - if simulation_parameters['ds_per_debye']: # scale the particle weight by the debye length to prevent numerical heating - ds_per_debye = simulation_parameters['ds_per_debye'] - # get the number of grid points per debye length - inverse_total_debye = 0 - - for toml_key in particle_keys: - N_particles = config[toml_key]['N_particles'] - charge = config[toml_key]['charge'] - mass = config[toml_key]['mass'] - # get the charge and mass of the particle species - if 'temperature' in config[toml_key]: - T=config[toml_key]['temperature'] - elif 'vth' in config[toml_key]: - T = vth_to_T(config[toml_key]['vth'], mass, kb) - # get the temperature of the particle species - - inverse_total_debye += jnp.sqrt( N_particles / (x_wind * y_wind * z_wind) / (eps * kb * T) ) * jnp.abs(charge) - # get the inverse debye length before macroparticle weighting - - ds2 = 0 - for d in [dx, dy, dz]: - if d != 1: - ds2 += d**2 - - - weight = 1 / (ds2) / (ds_per_debye**2) / inverse_total_debye - # weight the particles by the total debye length of the plasma - - else: - weight = 1.0 # default to single particle weight - - return weight - -def initial_particles(N_per_cell, N_particles, minx, maxx, miny, maxy, minz, maxz, mass, Tx, Ty, Tz, kb, key1, key2, key3): - """ - Initializes the velocities and positions of the particles. - - Args: - N_particles (int): The number of particles. - minx (float): The minimum value for the x-coordinate of the particles' positions. - maxx (float): The maximum value for the x-coordinate of the particles' positions. - miny (float): The minimum value for the y-coordinate of the particles' positions. - maxy (float): The maximum value for the y-coordinate of the particles' positions. - minz (float): The minimum value for the z-coordinate of the particles' positions. - maxz (float): The maximum value for the z-coordinate of the particles' positions. - mass (float): The mass of the particles. - T (float): The temperature of the system. - kb (float): The Boltzmann constant. - key (jax.random.PRNGKey): The random key for generating random numbers. - - Returns: - x (jax.numpy.ndarray): The x-coordinates of the particles' positions. - y (jax.numpy.ndarray): The y-coordinates of the particles' positions. - z (jax.numpy.ndarray): The z-coordinates of the particles' positions. - v_x (numpy.ndarray): The x-component of the particles' velocities. - v_y (numpy.ndarray): The y-component of the particles' velocities. - v_z (numpy.ndarray): The z-component of the particles' velocities. - """ - - # if N_per_cell < 1: - x = jax.random.uniform(key1, shape = (N_particles,), minval=minx, maxval=maxx) - y = jax.random.uniform(key2, shape = (N_particles,), minval=miny, maxval=maxy) - z = jax.random.uniform(key3, shape = (N_particles,), minval=minz, maxval=maxz) - # initialize the positions of the particles - # else: - # x = jnp.repeat(jax.random.uniform(key1, shape=(N_particles // N_per_cell,), minval=minx, maxval=maxx), N_per_cell) - # y = jnp.repeat(jax.random.uniform(key2, shape=(N_particles // N_per_cell,), minval=miny, maxval=maxy), N_per_cell) - # z = jnp.repeat(jax.random.uniform(key3, shape=(N_particles // N_per_cell,), minval=minz, maxval=maxz), N_per_cell) - # initialize the positions of the particles, giving every N_per_cell particles the same position - #std = jnp.sqrt( kb * T / mass ) - std_x = T_to_vth( Tx, mass, kb ) - std_y = T_to_vth( Ty, mass, kb ) - std_z = T_to_vth( Tz, mass, kb ) - v_x = np.random.normal(0, std_x, N_particles) - v_y = np.random.normal(0, std_y, N_particles) - v_z = np.random.normal(0, std_z, N_particles) - # initialize the particles with a maxwell boltzmann distribution. - return x, y, z, v_x, v_y, v_z - -@jit -def compute_index(x, dx, window): - """ - Compute the index of a position in a discretized space. - - Args: - x (float or ndarray): The position(s) to compute the index for. - dx (float): The discretization step size. - - Returns: - int or ndarray: The computed index/indices as integer(s). - """ - scaled_x = x + window/2 - return jnp.floor( scaled_x / dx).astype(int) - - -@partial(jit, static_argnames=("periodic", "reflecting")) -def apply_axis_boundary_condition(x, v, wind, half_wind, periodic, reflecting): - """ - Apply boundary conditions to particle positions and velocities along a single axis. - This function handles three types of boundary conditions: periodic, reflecting, and open. - Particles that cross the boundary are treated according to the specified condition type. - Args: - x (jnp.ndarray): Particle positions along the axis. - v (jnp.ndarray): Particle velocities along the axis. - wind (float): The width of the domain (full extent from -half_wind to half_wind). - half_wind (float): Half the domain width, used to determine boundary positions. - periodic (bool): If True, apply periodic boundary conditions (particles wrap around). - reflecting (bool): If True, apply reflecting boundary conditions (particle velocities reverse). - Only evaluated if periodic is False. - Returns: - tuple: A tuple of (x_out, v_out) where: - - x_out (jnp.ndarray): Updated particle positions after applying boundary conditions. - - v_out (jnp.ndarray): Updated particle velocities after applying boundary conditions. - Notes: - - Periodic BC: Particles that exceed ±half_wind are wrapped to the opposite side. - - Reflecting BC: Particles at boundaries have their velocities reversed (x position unchanged). - - Open BC (default): Particles and velocities pass through unchanged. - """ - - def periodic_bc(state): - x_in, v_in = state - x_out = x_in + wind * (x_in < -half_wind) - wind * (x_in > half_wind) - return x_out, v_in - # if periodic is True, apply periodic boundary conditions by wrapping positions around the domain - - def reflecting_bc(state): - x_in, v_in = state - v_out = jnp.where((x_in >= half_wind) | (x_in <= -half_wind), -v_in, v_in) - return x_in, v_out - # if reflecting is True, apply reflecting boundary conditions by reversing velocities at the boundaries - - def identity_bc(state): - return state - # if neither periodic nor reflecting, return the positions and velocities unchanged - - return jax.lax.cond( - periodic, - periodic_bc, - lambda state: jax.lax.cond(reflecting, reflecting_bc, identity_bc, state), - (x, v), - ) - -@register_pytree_node_class -class particle_species: - """ - Class representing a species of particles in a simulation. - - Attributes: - name (str): Name of the particle species. - N_particles (int): Number of particles in the species. - charge (float): Charge of each particle. - mass (float): Mass of each particle. - weight (float): Weighting factor for the particles. - T (float): Temperature of the particle species. - v1, v2, v3 (array-like): Velocity components of the particles. - x1, x2, x3 (array-like): Position components of the particles. - dx, dy, dz (float): Spatial resolution in each dimension. - x_wind, y_wind, z_wind (float): Domain size in each dimension. - zeta1, zeta2, eta1, eta2, xi1, xi2 (float): Subcell positions for charge conservation. - bc (str): Boundary condition type ('periodic' or 'reflecting'). - update_x, update_y, update_z (bool): Flags to update position in respective dimensions. - update_vx, update_vy, update_vz (bool): Flags to update velocity in respective dimensions. - update_pos (bool): Flag to update particle positions. - update_v (bool): Flag to update particle velocities. - shape (int): Shape factor for the particles (1 for first order, 2 for second order) - - Methods: - get_name(): Returns the name of the particle species. - get_charge(): Returns the total charge of the particles. - get_number_of_particles(): Returns the number of particles in the species. - get_temperature(): Returns the temperature of the particle species. - get_velocity(): Returns the velocity components of the particles. - get_position(): Returns the position components of the particles. - get_mass(): Returns the total mass of the particles. - get_subcell_position(): Returns the subcell positions for charge conservation. - get_resolution(): Returns the spatial resolution in each dimension. - get_shape(): Returns the shape factor of the particles. - get_index(): Computes and returns the particle indices in the grid. - set_velocity(v1, v2, v3): Sets the velocity components of the particles. - set_position(x1, x2, x3): Sets the position components of the particles. - set_mass(mass): Sets the mass of the particles. - set_weight(weight): Sets the weight of the particles. - calc_subcell_position(): Calculates and returns the subcell positions. - kinetic_energy(): Computes and returns the kinetic energy of the particles. - momentum(): Computes and returns the momentum of the particles. - periodic_boundary_condition(x_wind, y_wind, z_wind): Applies periodic boundary conditions. - reflecting_boundary_condition(x_wind, y_wind, z_wind): Applies reflecting boundary conditions. - update_position(dt): Updates the positions of the particles based on their velocities and boundary conditions. - tree_flatten(): Flattens the object for serialization. - tree_unflatten(aux_data, children): Reconstructs the object from flattened data. - """ - - - - def __init__(self, name, N_particles, charge, mass, T, v1, v2, v3, x1, x2, x3, \ - xwind, ywind, zwind, dx, dy, dz, weight=1, x_bc="periodic", y_bc="periodic", \ - z_bc="periodic", update_x=True, update_y=True, update_z=True, \ - update_vx=True, update_vy=True, update_vz=True, update_pos=True, update_v=True, shape=1, dt = 0): - self.name = name - self.N_particles = N_particles - self.charge = charge - self.mass = mass - self.weight = weight - self.T = T - self.v1 = v1 - self.v2 = v2 - self.v3 = v3 - self.dx = dx - self.dy = dy - self.dz = dz - self.x_wind = xwind - self.y_wind = ywind - self.z_wind = zwind - self.half_x_wind = 0.5 * xwind - self.half_y_wind = 0.5 * ywind - self.half_z_wind = 0.5 * zwind - self.x_bc = x_bc - self.y_bc = y_bc - self.z_bc = z_bc - self.x_periodic = x_bc == 'periodic' - self.x_reflecting = x_bc == 'reflecting' - self.y_periodic = y_bc == 'periodic' - self.y_reflecting = y_bc == 'reflecting' - self.z_periodic = z_bc == 'periodic' - self.z_reflecting = z_bc == 'reflecting' - # boundary conditions for each dimension - self.update_x = update_x - self.update_y = update_y - self.update_z = update_z - self.update_vx = update_vx - self.update_vy = update_vy - self.update_vz = update_vz - self.update_pos = update_pos - self.update_v = update_v - self.shape = shape - self.dt = dt - - self.x1 = x1 - self.x2 = x2 - self.x3 = x3 - - def get_name(self): - return self.name - - def get_charge(self): - return self.charge*self.weight - - def get_number_of_particles(self): - return self.N_particles - - def get_temperature(self): - return self.T - - def get_velocity(self): - return self.v1, self.v2, self.v3 - - def get_forward_position(self): - return self.x1, self.x2, self.x3 - - def get_position(self): - x1_back = self.x1 - self.v1 * self.dt / 2 - x2_back = self.x2 - self.v2 * self.dt / 2 - x3_back = self.x3 - self.v3 * self.dt / 2 - # moving the forward positions back half a time step to get x_t positions - - # boundary conditions - if self.x_bc == 'periodic': - x1_back = jnp.where(x1_back > self.x_wind/2, x1_back - self.x_wind, \ - jnp.where(x1_back < -self.x_wind/2, x1_back + self.x_wind, x1_back)) - # apply boundary conditions to the x position of the particles - - if self.y_bc == 'periodic': - x2_back = jnp.where(x2_back > self.y_wind/2, x2_back - self.y_wind, \ - jnp.where(x2_back < -self.y_wind/2, x2_back + self.y_wind, x2_back)) - # apply boundary conditions to the y position of the particles - - if self.z_bc == 'periodic': - x3_back = jnp.where(x3_back > self.z_wind/2, x3_back - self.z_wind, \ - jnp.where(x3_back < -self.z_wind/2, x3_back + self.z_wind, x3_back)) - # apply boundary conditions to the z position of the particles - - return x1_back, x2_back, x3_back - - def get_mass(self): - return self.mass*self.weight - - def get_resolution(self): - return self.dx, self.dy, self.dz - - def get_shape(self): - return self.shape - - def get_index(self): - return compute_index(self.x1, self.dx, self.x_wind), compute_index(self.x2, self.dy, self.y_wind), compute_index(self.x3, self.dz, self.z_wind) - - def set_velocity(self, v1, v2, v3): - if self.update_v: - if self.update_vx: - self.v1 = v1 - if self.update_vy: - self.v2 = v2 - if self.update_vz: - self.v3 = v3 - - def set_position(self, x1, x2, x3): - self.x1 = x1 - self.x2 = x2 - self.x3 = x3 - - def set_mass(self, mass): - self.mass = mass - - def set_weight(self, weight): - self.weight = weight - - def get_weight(self): - return self.weight - - def kinetic_energy(self): - v2 = jnp.square(self.v1) + jnp.square(self.v2) + jnp.square(self.v3) - # compute the square of the velocity - return 0.5 * self.weight * self.mass * jnp.sum(v2) - - def momentum(self): - return self.mass * self.weight * jnp.sum(jnp.sqrt(self.v1**2 + self.v2**2 + self.v3**2)) - - def boundary_conditions(self): - x1, x2, x3 = self.x1, self.x2, self.x3 - v1, v2, v3 = self.v1, self.v2, self.v3 - - x1, v1 = apply_axis_boundary_condition(x1, v1, self.x_wind, self.half_x_wind, self.x_periodic, self.x_reflecting) - x2, v2 = apply_axis_boundary_condition(x2, v2, self.y_wind, self.half_y_wind, self.y_periodic, self.y_reflecting) - x3, v3 = apply_axis_boundary_condition(x3, v3, self.z_wind, self.half_z_wind, self.z_periodic, self.z_reflecting) - - self.x1, self.x2, self.x3 = x1, x2, x3 - self.v1, self.v2, self.v3 = v1, v2, v3 - - def update_position(self): - if self.update_pos: - if self.update_x: - self.x1 = self.x1 + self.v1 * self.dt - # update the x position of the particles - - if self.update_y: - self.x2 = self.x2 + self.v2 * self.dt - # update the y position of the particles - - if self.update_z: - self.x3 = self.x3 + self.v3 * self.dt - # update the z position of the particles - - def tree_flatten(self): - children = ( - self.v1, self.v2, self.v3, \ - self.x1, self.x2, self.x3, \ - ) - - aux_data = ( - self.name, self.N_particles, self.charge, self.mass, self.T, \ - self.x_wind, self.y_wind, self.z_wind, self.dx, self.dy, self.dz, \ - self.weight, self.x_bc, self.y_bc, self.z_bc, self.update_pos, self.update_v, \ - self.update_x, self.update_y, self.update_z, self.update_vx, self.update_vy, \ - self.update_vz, self.shape, self.dt - ) - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - v1, v2, v3, x1, x2, x3 = children - - - name, N_particles, charge, mass, T, x_wind, y_wind, z_wind, dx, dy, \ - dz, weight, x_bc, y_bc, z_bc, update_pos, update_v, update_x, update_y, update_z, \ - update_vx, update_vy, update_vz, shape, dt = aux_data - - - obj = cls( - name=name, - N_particles=N_particles, - charge=charge, - mass=mass, - T=T, - x1=x1, - x2=x2, - x3=x3, - v1=v1, - v2=v2, - v3=v3, - xwind=x_wind, - ywind=y_wind, - zwind=z_wind, - dx=dx, - dy=dy, - dz=dz, - weight=weight, - x_bc=x_bc, - y_bc=y_bc, - z_bc=z_bc, - update_x=update_x, - update_y=update_y, - update_z=update_z, - update_vx=update_vx, - update_vy=update_vy, - update_vz=update_vz, - update_pos=update_pos, - update_v=update_v, - shape=shape, - dt=dt - ) - - return obj diff --git a/PyPIC3D/flat_particles.py b/PyPIC3D/particles/flat_particles.py similarity index 89% rename from PyPIC3D/flat_particles.py rename to PyPIC3D/particles/flat_particles.py index e16b04f..d0ff9c7 100644 --- a/PyPIC3D/flat_particles.py +++ b/PyPIC3D/particles/flat_particles.py @@ -99,15 +99,27 @@ def get_position(self): half_y = self.y_wind / 2 half_z = self.z_wind / 2 - x1_back = jnp.where(x1_back > half_x, x1_back - self.x_wind, jnp.where(x1_back < -half_x, x1_back + self.x_wind, x1_back)) - x2_back = jnp.where(x2_back > half_y, x2_back - self.y_wind, jnp.where(x2_back < -half_y, x2_back + self.y_wind, x2_back)) - x3_back = jnp.where(x3_back > half_z, x3_back - self.z_wind, jnp.where(x3_back < -half_z, x3_back + self.z_wind, x3_back)) + x1_back = jnp.where( + x1_back > half_x, + x1_back - self.x_wind, + jnp.where(x1_back < -half_x, x1_back + self.x_wind, x1_back), + ) + x2_back = jnp.where( + x2_back > half_y, + x2_back - self.y_wind, + jnp.where(x2_back < -half_y, x2_back + self.y_wind, x2_back), + ) + x3_back = jnp.where( + x3_back > half_z, + x3_back - self.z_wind, + jnp.where(x3_back < -half_z, x3_back + self.z_wind, x3_back), + ) return x1_back, x2_back, x3_back def get_mass(self): return self.mass * self.weight - + def get_weight(self): return self.weight @@ -141,9 +153,21 @@ def boundary_conditions(self): half_y = self.y_wind / 2 half_z = self.z_wind / 2 - self.x1 = jnp.where(self.x1 > half_x, self.x1 - self.x_wind, jnp.where(self.x1 < -half_x, self.x1 + self.x_wind, self.x1)) - self.x2 = jnp.where(self.x2 > half_y, self.x2 - self.y_wind, jnp.where(self.x2 < -half_y, self.x2 + self.y_wind, self.x2)) - self.x3 = jnp.where(self.x3 > half_z, self.x3 - self.z_wind, jnp.where(self.x3 < -half_z, self.x3 + self.z_wind, self.x3)) + self.x1 = jnp.where( + self.x1 > half_x, + self.x1 - self.x_wind, + jnp.where(self.x1 < -half_x, self.x1 + self.x_wind, self.x1), + ) + self.x2 = jnp.where( + self.x2 > half_y, + self.x2 - self.y_wind, + jnp.where(self.x2 < -half_y, self.x2 + self.y_wind, self.x2), + ) + self.x3 = jnp.where( + self.x3 > half_z, + self.x3 - self.z_wind, + jnp.where(self.x3 < -half_z, self.x3 + self.z_wind, self.x3), + ) def tree_flatten(self): children = (self.x1, self.x2, self.x3, self.v1, self.v2, self.v3) diff --git a/PyPIC3D/particles/particle_initialization.py b/PyPIC3D/particles/particle_initialization.py new file mode 100644 index 0000000..cfc54be --- /dev/null +++ b/PyPIC3D/particles/particle_initialization.py @@ -0,0 +1,315 @@ +import numpy as np +import jax +import jax.numpy as jnp + +from PyPIC3D.utils import vth_to_T, plasma_frequency, debye_length, T_to_vth +from PyPIC3D.particles.species_class import particle_species + + +def grab_particle_keys(config): + """Return keys in a TOML config that start with 'particle'.""" + particle_keys = [] + for key in config.keys(): + if key[:8] == "particle": + particle_keys.append(key) + return particle_keys + + +def read_value(param, key, config, default_value): + if param in config[key]: + print(f"Reading user defined {param}") + return config[key][param] + return default_value + + +def load_initial_positions(param, config, key, default, N_particles, ds, ns, key1): + if param in config[key]: + if isinstance(config[key][param], str): + print(f"Loading {param} from external source: {config[key][param]}") + return jnp.load(config[key][param]) + val = config[key][param] + if ns == 1: + return val * jnp.ones(N_particles) + return jax.random.uniform( + key1, shape=(N_particles,), minval=val - (ds / 2), maxval=val + (ds / 2) + ) + return default + + +def load_initial_velocities(param, config, key, default, N_particles): + if param in config[key]: + if isinstance(config[key][param], str): + print(f"Loading {param} from external source: {config[key][param]}") + return jnp.load(config[key][param]) + return jnp.full(N_particles, config[key][param]) + default + return default + + +def compute_macroparticle_weight(config, particle_keys, simulation_parameters, world, constants): + x_wind = world["x_wind"] + y_wind = world["y_wind"] + z_wind = world["z_wind"] + dx = world["dx"] + dy = world["dy"] + dz = world["dz"] + kb = constants["kb"] + eps = constants["eps"] + + if simulation_parameters["ds_per_debye"]: + ds_per_debye = simulation_parameters["ds_per_debye"] + inverse_total_debye = 0 + + for toml_key in particle_keys: + N_particles = config[toml_key]["N_particles"] + charge = config[toml_key]["charge"] + mass = config[toml_key]["mass"] + if "temperature" in config[toml_key]: + T = config[toml_key]["temperature"] + elif "vth" in config[toml_key]: + T = vth_to_T(config[toml_key]["vth"], mass, kb) + + inverse_total_debye += ( + jnp.sqrt(N_particles / (x_wind * y_wind * z_wind) / (eps * kb * T)) + * jnp.abs(charge) + ) + + ds2 = 0 + for d in [dx, dy, dz]: + if d != 1: + ds2 += d**2 + + if ds2 == 0: + raise ValueError( + "Invalid configuration for 'ds_per_debye': at least one of dx, dy, dz must differ from 1." + ) + weight = 1 / (ds2) / (ds_per_debye**2) / inverse_total_debye + else: + weight = 1.0 + + return weight + + +def initial_particles( + N_per_cell, + N_particles, + minx, + maxx, + miny, + maxy, + minz, + maxz, + mass, + Tx, + Ty, + Tz, + kb, + key1, + key2, + key3, +): + x = jax.random.uniform(key1, shape=(N_particles,), minval=minx, maxval=maxx) + y = jax.random.uniform(key2, shape=(N_particles,), minval=miny, maxval=maxy) + z = jax.random.uniform(key3, shape=(N_particles,), minval=minz, maxval=maxz) + + std_x = T_to_vth(Tx, mass, kb) + std_y = T_to_vth(Ty, mass, kb) + std_z = T_to_vth(Tz, mass, kb) + vx = np.random.normal(0, std_x, N_particles) + vy = np.random.normal(0, std_y, N_particles) + vz = np.random.normal(0, std_z, N_particles) + + return x, y, z, vx, vy, vz + + +def load_particles_from_toml(config, simulation_parameters, world, constants): + x_wind = world["x_wind"] + y_wind = world["y_wind"] + z_wind = world["z_wind"] + Nx = world["Nx"] + Ny = world["Ny"] + Nz = world["Nz"] + dx = world["dx"] + dy = world["dy"] + dz = world["dz"] + dt = world["dt"] + kb = constants["kb"] + eps = constants["eps"] + + i = 0 + particles = [] + particle_keys = grab_particle_keys(config) + + weight = compute_macroparticle_weight( + config, particle_keys, simulation_parameters, world, constants + ) + + for toml_key in particle_keys: + key1, key2, key3 = jax.random.key(i), jax.random.key(i + 1), jax.random.key(i + 2) + i += 3 + + particle_name = config[toml_key]["name"] + print(f"\nInitializing particle species: {particle_name}") + charge = config[toml_key]["charge"] + mass = config[toml_key]["mass"] + + if "N_particles" in config[toml_key]: + N_particles = config[toml_key]["N_particles"] + N_per_cell = N_particles / (world["Nx"] * world["Ny"] * world["Nz"]) + elif "N_per_cell" in config[toml_key]: + N_per_cell = config[toml_key]["N_per_cell"] + N_particles = int(N_per_cell * world["Nx"] * world["Ny"] * world["Nz"]) + + if "temperature" in config[toml_key]: + T = config[toml_key]["temperature"] + vth = T_to_vth(T, mass, kb) + elif "vth" in config[toml_key]: + vth = config[toml_key]["vth"] + T = vth_to_T(vth, mass, kb) + else: + T = 1.0 + vth = T_to_vth(T, mass, kb) + + Tx = read_value("Tx", toml_key, config, T) + Ty = read_value("Ty", toml_key, config, T) + Tz = read_value("Tz", toml_key, config, T) + + xmin = read_value("xmin", toml_key, config, -x_wind / 2) + xmax = read_value("xmax", toml_key, config, x_wind / 2) + ymin = read_value("ymin", toml_key, config, -y_wind / 2) + ymax = read_value("ymax", toml_key, config, y_wind / 2) + zmin = read_value("zmin", toml_key, config, -z_wind / 2) + zmax = read_value("zmax", toml_key, config, z_wind / 2) + + x, y, z, vx, vy, vz = initial_particles( + N_per_cell, + N_particles, + xmin, + xmax, + ymin, + ymax, + zmin, + zmax, + mass, + Tx, + Ty, + Tz, + kb, + key1, + key2, + key3, + ) + + x_bc = "periodic" + if "x_bc" in config[toml_key]: + assert config[toml_key]["x_bc"] in ["periodic", "reflecting"], ( + f"Invalid x boundary condition: {config[toml_key]['x_bc']}" + ) + x_bc = config[toml_key]["x_bc"] + + y_bc = "periodic" + if "y_bc" in config[toml_key]: + assert config[toml_key]["y_bc"] in ["periodic", "reflecting"], ( + f"Invalid y boundary condition: {config[toml_key]['y_bc']}" + ) + y_bc = config[toml_key]["y_bc"] + + z_bc = "periodic" + if "z_bc" in config[toml_key]: + assert config[toml_key]["z_bc"] in ["periodic", "reflecting"], ( + f"Invalid z boundary condition: {config[toml_key]['z_bc']}" + ) + z_bc = config[toml_key]["z_bc"] + + x = load_initial_positions("initial_x", config, toml_key, x, N_particles, dx, Nx, key1) + y = load_initial_positions("initial_y", config, toml_key, y, N_particles, dy, Ny, key2) + z = load_initial_positions("initial_z", config, toml_key, z, N_particles, dz, Nz, key3) + + vx = load_initial_velocities("initial_vx", config, toml_key, vx, N_particles) + vy = load_initial_velocities("initial_vy", config, toml_key, vy, N_particles) + vz = load_initial_velocities("initial_vz", config, toml_key, vz, N_particles) + + if "temperature" not in config[toml_key]: + T = (mass / (3 * kb * N_particles)) * ( + jnp.sum(vx**2) + jnp.sum(vy**2) + jnp.sum(vz**2) + ) + + if "weight" in config[toml_key]: + weight = config[toml_key]["weight"] + elif "ds_per_debye" in config[toml_key]: + ds_per_debye = config[toml_key]["ds_per_debye"] + + ds2 = 0 + for d in [dx, dy, dz]: + if d != 1: + ds2 += d**2 + + if ds2 == 0: + raise ValueError( + "Invalid configuration for 'ds_per_debye': at least one of dx, dy, dz must differ from 1." + ) + weight = ( + x_wind * y_wind * z_wind * eps * kb * T + ) / (N_particles * charge**2 * ds_per_debye**2 * ds2) + + update_pos = read_value("update_pos", toml_key, config, True) + update_v = read_value("update_v", toml_key, config, True) + update_vx = read_value("update_vx", toml_key, config, True) + update_vy = read_value("update_vy", toml_key, config, True) + update_vz = read_value("update_vz", toml_key, config, True) + update_x = read_value("update_x", toml_key, config, True) + update_y = read_value("update_y", toml_key, config, True) + update_z = read_value("update_z", toml_key, config, True) + + particle = particle_species( + name=particle_name, + N_particles=N_particles, + charge=charge, + mass=mass, + T=T, + x1=x, + x2=y, + x3=z, + v1=vx, + v2=vy, + v3=vz, + xwind=x_wind, + ywind=y_wind, + zwind=z_wind, + dx=dx, + dy=dy, + dz=dz, + weight=weight, + x_bc=x_bc, + y_bc=y_bc, + z_bc=z_bc, + update_vx=update_vx, + update_vy=update_vy, + update_vz=update_vz, + update_x=update_x, + update_y=update_y, + update_z=update_z, + update_pos=update_pos, + update_v=update_v, + shape=simulation_parameters["shape_factor"], + dt=dt, + ) + particles.append(particle) + + pf = plasma_frequency(particle, world, constants) + dl = debye_length(particle, world, constants) + print(f"Number of particles: {N_particles}") + print(f"Number of particles per cell: {N_per_cell}") + print(f"x, y, z boundary conditions: {x_bc}, {y_bc}, {z_bc}") + print(f"Charge: {charge}") + print(f"Mass: {mass}") + print(f"Temperature: {T}") + print(f"Thermal Velocity: {vth}") + print(f"Particle Kinetic Energy: {particle.kinetic_energy()}") + print(f"Particle Species Plasma Frequency: {pf}") + print(f"Time Steps Per Plasma Period: {(1 / (dt * pf))}") + print(f"Particle Species Debye Length: {dl}") + print(f"Particle Weight: {weight}") + print(f"Particle Species Scaled Charge: {particle.get_charge()}") + print(f"Particle Species Scaled Mass: {particle.get_mass()}") + + return particles diff --git a/PyPIC3D/particles/species_class.py b/PyPIC3D/particles/species_class.py new file mode 100644 index 0000000..901199e --- /dev/null +++ b/PyPIC3D/particles/species_class.py @@ -0,0 +1,336 @@ +import jax +from jax import jit +from functools import partial +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + + +@jit +def compute_index(x, dx, window): + """Compute grid-cell index for a position.""" + scaled_x = x + window / 2 + return jnp.floor(scaled_x / dx).astype(int) + + +@partial(jit, static_argnames=("periodic", "reflecting")) +def apply_axis_boundary_condition(x, v, wind, half_wind, periodic, reflecting): + """Apply boundary conditions to particle positions/velocities along one axis.""" + + def periodic_bc(state): + x_in, v_in = state + x_out = x_in + wind * (x_in < -half_wind) - wind * (x_in > half_wind) + return x_out, v_in + + def reflecting_bc(state): + x_in, v_in = state + v_out = jnp.where((x_in >= half_wind) | (x_in <= -half_wind), -v_in, v_in) + return x_in, v_out + + def identity_bc(state): + return state + + return jax.lax.cond( + periodic, + periodic_bc, + lambda state: jax.lax.cond(reflecting, reflecting_bc, identity_bc, state), + (x, v), + ) + + +@register_pytree_node_class +class particle_species: + """A particle species (positions/velocities + metadata) stored as a JAX pytree.""" + + def __init__( + self, + name, + N_particles, + charge, + mass, + T, + v1, + v2, + v3, + x1, + x2, + x3, + xwind, + ywind, + zwind, + dx, + dy, + dz, + weight=1, + x_bc="periodic", + y_bc="periodic", + z_bc="periodic", + update_x=True, + update_y=True, + update_z=True, + update_vx=True, + update_vy=True, + update_vz=True, + update_pos=True, + update_v=True, + shape=1, + dt=0, + ): + self.name = name + self.N_particles = N_particles + self.charge = charge + self.mass = mass + self.weight = weight + self.T = T + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.dx = dx + self.dy = dy + self.dz = dz + self.x_wind = xwind + self.y_wind = ywind + self.z_wind = zwind + self.half_x_wind = 0.5 * xwind + self.half_y_wind = 0.5 * ywind + self.half_z_wind = 0.5 * zwind + self.x_bc = x_bc + self.y_bc = y_bc + self.z_bc = z_bc + self.x_periodic = x_bc == "periodic" + self.x_reflecting = x_bc == "reflecting" + self.y_periodic = y_bc == "periodic" + self.y_reflecting = y_bc == "reflecting" + self.z_periodic = z_bc == "periodic" + self.z_reflecting = z_bc == "reflecting" + self.update_x = update_x + self.update_y = update_y + self.update_z = update_z + self.update_vx = update_vx + self.update_vy = update_vy + self.update_vz = update_vz + self.update_pos = update_pos + self.update_v = update_v + self.shape = shape + self.dt = dt + + self.x1 = x1 + self.x2 = x2 + self.x3 = x3 + + def get_name(self): + return self.name + + def get_charge(self): + return self.charge * self.weight + + def get_number_of_particles(self): + return self.N_particles + + def get_temperature(self): + return self.T + + def get_velocity(self): + return self.v1, self.v2, self.v3 + + def get_forward_position(self): + return self.x1, self.x2, self.x3 + + def get_position(self): + x1_back = self.x1 - self.v1 * self.dt / 2 + x2_back = self.x2 - self.v2 * self.dt / 2 + x3_back = self.x3 - self.v3 * self.dt / 2 + + if self.x_bc == "periodic": + x1_back = jnp.where( + x1_back > self.x_wind / 2, + x1_back - self.x_wind, + jnp.where(x1_back < -self.x_wind / 2, x1_back + self.x_wind, x1_back), + ) + + if self.y_bc == "periodic": + x2_back = jnp.where( + x2_back > self.y_wind / 2, + x2_back - self.y_wind, + jnp.where(x2_back < -self.y_wind / 2, x2_back + self.y_wind, x2_back), + ) + + if self.z_bc == "periodic": + x3_back = jnp.where( + x3_back > self.z_wind / 2, + x3_back - self.z_wind, + jnp.where(x3_back < -self.z_wind / 2, x3_back + self.z_wind, x3_back), + ) + + return x1_back, x2_back, x3_back + + def get_mass(self): + return self.mass * self.weight + + def get_weight(self): + return self.weight + + def get_resolution(self): + return self.dx, self.dy, self.dz + + def get_shape(self): + return self.shape + + def get_index(self): + return ( + compute_index(self.x1, self.dx, self.x_wind), + compute_index(self.x2, self.dy, self.y_wind), + compute_index(self.x3, self.dz, self.z_wind), + ) + + def set_velocity(self, v1, v2, v3): + if self.update_v: + if self.update_vx: + self.v1 = v1 + if self.update_vy: + self.v2 = v2 + if self.update_vz: + self.v3 = v3 + + def set_position(self, x1, x2, x3): + self.x1 = x1 + self.x2 = x2 + self.x3 = x3 + + def set_mass(self, mass): + self.mass = mass + + def set_weight(self, weight): + self.weight = weight + + def kinetic_energy(self): + v2 = jnp.square(self.v1) + jnp.square(self.v2) + jnp.square(self.v3) + return 0.5 * self.weight * self.mass * jnp.sum(v2) + + def momentum(self): + return self.mass * self.weight * jnp.sum(jnp.sqrt(self.v1**2 + self.v2**2 + self.v3**2)) + + def boundary_conditions(self): + x1, x2, x3 = self.x1, self.x2, self.x3 + v1, v2, v3 = self.v1, self.v2, self.v3 + + x1, v1 = apply_axis_boundary_condition( + x1, v1, self.x_wind, self.half_x_wind, self.x_periodic, self.x_reflecting + ) + x2, v2 = apply_axis_boundary_condition( + x2, v2, self.y_wind, self.half_y_wind, self.y_periodic, self.y_reflecting + ) + x3, v3 = apply_axis_boundary_condition( + x3, v3, self.z_wind, self.half_z_wind, self.z_periodic, self.z_reflecting + ) + + self.x1, self.x2, self.x3 = x1, x2, x3 + self.v1, self.v2, self.v3 = v1, v2, v3 + + def update_position(self): + if self.update_pos: + if self.update_x: + self.x1 = self.x1 + self.v1 * self.dt + if self.update_y: + self.x2 = self.x2 + self.v2 * self.dt + if self.update_z: + self.x3 = self.x3 + self.v3 * self.dt + + def tree_flatten(self): + children = (self.v1, self.v2, self.v3, self.x1, self.x2, self.x3) + + aux_data = ( + self.name, + self.N_particles, + self.charge, + self.mass, + self.T, + self.x_wind, + self.y_wind, + self.z_wind, + self.dx, + self.dy, + self.dz, + self.weight, + self.x_bc, + self.y_bc, + self.z_bc, + self.update_pos, + self.update_v, + self.update_x, + self.update_y, + self.update_z, + self.update_vx, + self.update_vy, + self.update_vz, + self.shape, + self.dt, + ) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + v1, v2, v3, x1, x2, x3 = children + + ( + name, + N_particles, + charge, + mass, + T, + x_wind, + y_wind, + z_wind, + dx, + dy, + dz, + weight, + x_bc, + y_bc, + z_bc, + update_pos, + update_v, + update_x, + update_y, + update_z, + update_vx, + update_vy, + update_vz, + shape, + dt, + ) = aux_data + + obj = cls( + name=name, + N_particles=N_particles, + charge=charge, + mass=mass, + T=T, + x1=x1, + x2=x2, + x3=x3, + v1=v1, + v2=v2, + v3=v3, + xwind=x_wind, + ywind=y_wind, + zwind=z_wind, + dx=dx, + dy=dy, + dz=dz, + weight=weight, + x_bc=x_bc, + y_bc=y_bc, + z_bc=z_bc, + update_x=update_x, + update_y=update_y, + update_z=update_z, + update_vx=update_vx, + update_vy=update_vy, + update_vz=update_vz, + update_pos=update_pos, + update_v=update_v, + shape=shape, + dt=dt, + ) + + return obj diff --git a/PyPIC3D/rho.py b/PyPIC3D/rho.py deleted file mode 100644 index f4e404a..0000000 --- a/PyPIC3D/rho.py +++ /dev/null @@ -1,120 +0,0 @@ -import jax -from jax import jit -import jax.numpy as jnp -from jax import lax -# import external libraries - -from PyPIC3D.utils import digital_filter, wrap_around -from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights -# import internal libraries - -@jit -def compute_rho(particles, rho, world, constants): - """ - Compute the charge density (rho) for a given set of particles in a simulation world. - Parameters: - particles (list): A list of particle species, each containing methods to get the number of particles, - their positions, and their charge. - rho (ndarray): The initial charge density array to be updated. - world (dict): A dictionary containing the simulation world parameters, including: - - 'dx': Grid spacing in the x-direction. - - 'dy': Grid spacing in the y-direction. - - 'dz': Grid spacing in the z-direction. - - 'x_wind': Window size in the x-direction. - - 'y_wind': Window size in the y-direction. - - 'z_wind': Window size in the z-direction. - Returns: - ndarray: The updated charge density array. - """ - dx = world['dx'] - dy = world['dy'] - dz = world['dz'] - x_wind = world['x_wind'] - y_wind = world['y_wind'] - z_wind = world['z_wind'] - grid = world['grids']['vertex'] - Nx, Ny, Nz = rho.shape - # get the shape of the charge density array - - rho = jnp.zeros_like(rho) - # reset rho to zero - - for species in particles: - shape_factor = species.get_shape() - # get the shape factor of the species, which determines the weighting function - N_particles = species.get_number_of_particles() - q = species.get_charge() - # get the number of particles and their charge - dq = q / dx / dy / dz - # calculate the charge per unit volume - x, y, z = species.get_position() - # get the position of the particles in the species - - - x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (x - grid[0][0]) / dx).astype(int), - lambda _: jnp.round( (x - grid[0][0]) / dx).astype(int), - operand=None - ) - - y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (y - grid[1][0]) / dy).astype(int), - lambda _: jnp.round( (y - grid[1][0]) / dy).astype(int), - operand=None - ) - - z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (z - grid[2][0]) / dz).astype(int), - lambda _: jnp.round( (z - grid[2][0]) / dz).astype(int), - operand=None - ) - # calculate the nearest grid point based on shape factor - - deltax = x - (x0 * dx + grid[0][0]) - deltay = y - (y0 * dy + grid[1][0]) - deltaz = z - (z0 * dz + grid[2][0]) - # calculate the difference based on shape factor - - x0 = wrap_around(x0, Nx) - y0 = wrap_around(y0, Ny) - z0 = wrap_around(z0, Nz) - # ensure indices are within bounds - - x1 = wrap_around(x0 + 1, Nx) - y1 = wrap_around(y0 + 1, Ny) - z1 = wrap_around(z0 + 1, Nz) - # Calculate the index of the next grid point - - x_minus1 = x0 - 1 - y_minus1 = y0 - 1 - z_minus1 = z0 - 1 - # Calculate the index of the previous grid point - - xpts = [x_minus1, x0, x1] - ypts = [y_minus1, y0, y1] - zpts = [z_minus1, z0, z1] - # place all the points in a list - - x_weights, y_weights, z_weights = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), - lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), - operand=None - ) - # get the weighting factors based on the shape factor - - - for i in range(3): - for j in range(3): - for k in range(3): - rho = rho.at[xpts[i], ypts[j], zpts[k]].add( dq * x_weights[i] * y_weights[j] * z_weights[k], mode='drop') - # distribute the charge of the particles to the grid points using the weighting factors - - alpha = constants['alpha'] - rho = digital_filter(rho, alpha) - # apply a digital filter to the charge density array - - return rho diff --git a/PyPIC3D/shapes.py b/PyPIC3D/shapes.py deleted file mode 100644 index 2738aef..0000000 --- a/PyPIC3D/shapes.py +++ /dev/null @@ -1,65 +0,0 @@ -from jax import jit -import jax.numpy as jnp - - -@jit -def get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz): - """ - Calculate the second-order weights for particle current distribution. - - Args: - deltax, deltay, deltaz (float): Particle position offsets from grid points. - dx, dy, dz (float): Grid spacings in x, y, and z directions. - - Returns: - tuple: Weights for x, y, and z directions. - """ - Sx0 = (3/4) - (deltax/dx)**2 - Sy0 = (3/4) - (deltay/dy)**2 - Sz0 = (3/4) - (deltaz/dz)**2 - - Sx1 = (1/2) * ((1/2) + (deltax/dx))**2 - Sy1 = (1/2) * ((1/2) + (deltay/dy))**2 - Sz1 = (1/2) * ((1/2) + (deltaz/dz))**2 - - Sx_minus1 = (1/2) * ((1/2) - (deltax/dx))**2 - Sy_minus1 = (1/2) * ((1/2) - (deltay/dy))**2 - Sz_minus1 = (1/2) * ((1/2) - (deltaz/dz))**2 - # second order weights - - x_weights = [Sx_minus1, Sx0, Sx1] - y_weights = [Sy_minus1, Sy0, Sy1] - z_weights = [Sz_minus1, Sz0, Sz1] - - return x_weights, y_weights, z_weights - -@jit -def get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz): - """ - Calculate the first-order weights for particle current distribution. - - Args: - deltax, deltay, deltaz (float): Particle position offsets from grid points. - dx, dy, dz (float): Grid spacings in x, y, and z directions. - - Returns: - tuple: Weights for x, y, and z directions. - """ - Sx0 = jnp.asarray(1 - deltax / dx) - Sy0 = jnp.asarray(1 - deltay / dy) - Sz0 = jnp.asarray(1 - deltaz / dz) - - Sx1 = jnp.asarray(deltax / dx) - Sy1 = jnp.asarray(deltay / dy) - Sz1 = jnp.asarray(deltaz / dz) - - Sx_minus1 = jnp.zeros_like(Sx0) - Sy_minus1 = jnp.zeros_like(Sy0) - Sz_minus1 = jnp.zeros_like(Sz0) - # No second-order weights for first-order weighting - - x_weights = [Sx_minus1, Sx0, Sx1] - y_weights = [Sy_minus1, Sy0, Sy1] - z_weights = [Sz_minus1, Sz0, Sz1] - - return x_weights, y_weights, z_weights \ No newline at end of file diff --git a/PyPIC3D/solvers/electrostatic_yee.py b/PyPIC3D/solvers/electrostatic_yee.py index e7446b8..4bf122c 100644 --- a/PyPIC3D/solvers/electrostatic_yee.py +++ b/PyPIC3D/solvers/electrostatic_yee.py @@ -3,7 +3,7 @@ from jax import lax from functools import partial -from PyPIC3D.rho import compute_rho +from PyPIC3D.deposition.rho import compute_rho from PyPIC3D.solvers.fdtd import centered_finite_difference_gradient from PyPIC3D.solvers.pstd import spectral_gradient from PyPIC3D.utils import digital_filter diff --git a/tests/electrostatic_yee_test.py b/tests/electrostatic_yee_test.py index f243e82..7478e92 100644 --- a/tests/electrostatic_yee_test.py +++ b/tests/electrostatic_yee_test.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from PyPIC3D.particle import particle_species +from PyPIC3D.particles.species_class import particle_species from PyPIC3D.solvers.electrostatic_yee import ( solve_poisson_with_fft, solve_poisson_with_conjugate_gradient, diff --git a/tests/particle_test.py b/tests/particle_test.py index fd2a5e5..5cc0d03 100644 --- a/tests/particle_test.py +++ b/tests/particle_test.py @@ -5,13 +5,16 @@ import os -from PyPIC3D.particle import ( - initial_particles, particle_species -) +from PyPIC3D.particles.particle_initialization import ( + initial_particles + ) -from PyPIC3D.J import J_from_rhov, Esirkepov_current +from PyPIC3D.particles.species_class import particle_species + +from PyPIC3D.deposition.J_from_rhov import J_from_rhov +from PyPIC3D.deposition.Esirkepov import Esirkepov_current +from PyPIC3D.deposition.rho import compute_rho -from PyPIC3D.rho import compute_rho jax.config.update("jax_enable_x64", True)