diff --git a/PyPIC3D/J.py b/PyPIC3D/J.py index 2a8620d..37a123d 100644 --- a/PyPIC3D/J.py +++ b/PyPIC3D/J.py @@ -7,9 +7,247 @@ from PyPIC3D.utils import digital_filter, wrap_around, bilinear_filter from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights +from PyPIC3D.indexed_particles import _advance_index_and_frac -@partial(jit, static_argnames=("filter",)) -def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): + +def _weights_order1(r): + w0 = 1.0 - r + w1 = r + return jnp.stack((w0, w1), axis=0) + + +def _weights_order2(r): + w0 = 0.5 * (0.5 - r) ** 2 + w1 = 0.75 - r**2 + w2 = 0.5 * (0.5 + r) ** 2 + return jnp.stack((w0, w1, w2), axis=0) + + +def _weights_face(r, shape_factor): + if shape_factor == 1: + return _weights_order1(r - 0.5) + return _weights_order2(r - 0.5) + + +def _deposit_1d(J_stack, dq, vx, vy, vz, x, grid_x0, dx, dt, shape_factor, Nx): + if shape_factor == 1: + x0 = jnp.floor((x - grid_x0) / dx).astype(jnp.int32) + deltax_node = (x - grid_x0) - x0 * dx + deltax_face = (x - grid_x0) - (x0 + 0.5) * dx + + r_node = deltax_node / dx + r_face = deltax_face / dx + + w_node = _weights_order1(r_node) # (2,Np) + w_face = _weights_order1(r_face) # (2,Np) + + ix = jnp.stack((x0, x0 + 1), axis=0) + ix = wrap_around(ix, Nx) + else: + x0 = jnp.round((x - grid_x0) / dx).astype(jnp.int32) + deltax_node = (x - grid_x0) - x0 * dx + deltax_face = (x - grid_x0) - (x0 + 0.5) * dx + + r_node = deltax_node / dx + r_face = deltax_face / dx + + w_node = _weights_order2(r_node) # (3,Np) + w_face = _weights_order2(r_face) # (3,Np) + + ix = jnp.stack((x0 - 1, x0, x0 + 1), axis=0) + ix = wrap_around(ix, Nx) + + # Jx uses face weights; Jy/Jz use node weights. + val = jnp.stack( + ( + (dq * vx)[None, :] * w_face, + (dq * vy)[None, :] * w_node, + (dq * vz)[None, :] * w_node, + ), + axis=-1, + ) # (S,Np,3) + + comp = jnp.arange(3, dtype=ix.dtype)[None, None, :] # (1,1,3) + idx = ix[:, :, None] + comp * jnp.asarray(Nx, dtype=ix.dtype) # (S,Np,3) + + out = jnp.bincount( + idx.reshape(-1), + weights=val.reshape(-1), + length=Nx * 3, + ).reshape(3, Nx) + + Jx, Jy, Jz = out[0], out[1], out[2] + return jnp.stack((Jx, Jy, Jz), axis=-1).reshape((Nx, 1, 1, 3)) + + +def _deposit_2d(J_stack, dq, vx, vy, vz, x, y, xmin, ymin, dx, dy, dt, shape_factor, Nx, Ny): + if shape_factor == 1: + x0 = jnp.floor((x - xmin) / dx).astype(jnp.int32) + y0 = jnp.floor((y - ymin) / dy).astype(jnp.int32) + deltax_node = (x - xmin) - x0 * dx + deltay_node = (y - ymin) - y0 * dy + deltax_face = (x - xmin) - (x0 + 0.5) * dx + deltay_face = (y - ymin) - (y0 + 0.5) * dy + + wx_node = _weights_order1(deltax_node / dx) # (2,Np) + wy_node = _weights_order1(deltay_node / dy) # (2,Np) + wx_face = _weights_order1(deltax_face / dx) # (2,Np) + wy_face = _weights_order1(deltay_face / dy) # (2,Np) + + ix = jnp.stack((x0, x0 + 1), axis=0) + iy = jnp.stack((y0, y0 + 1), axis=0) + ix = wrap_around(ix, Nx) + iy = wrap_around(iy, Ny) + else: + x0 = jnp.round((x - xmin) / dx).astype(jnp.int32) + y0 = jnp.round((y - ymin) / dy).astype(jnp.int32) + deltax_node = (x - xmin) - x0 * dx + deltay_node = (y - ymin) - y0 * dy + deltax_face = (x - xmin) - (x0 + 0.5) * dx + deltay_face = (y - ymin) - (y0 + 0.5) * dy + + wx_node = _weights_order2(deltax_node / dx) # (3,Np) + wy_node = _weights_order2(deltay_node / dy) # (3,Np) + wx_face = _weights_order2(deltax_face / dx) # (3,Np) + wy_face = _weights_order2(deltay_face / dy) # (3,Np) + + ix = jnp.stack((x0 - 1, x0, x0 + 1), axis=0) + iy = jnp.stack((y0 - 1, y0, y0 + 1), axis=0) + ix = wrap_around(ix, Nx) + iy = wrap_around(iy, Ny) + + idx = ix[:, None, :] + Nx * iy[None, :, :] # (Sx,Sy,Np) + idx_flat = idx.reshape(-1) + + # weights for each component + wjx = wx_face[:, None, :] * wy_node[None, :, :] + wjy = wx_node[:, None, :] * wy_face[None, :, :] + wjz = wx_node[:, None, :] * wy_node[None, :, :] + + valx = (dq * vx)[None, None, :] * wjx + valy = (dq * vy)[None, None, :] * wjy + valz = (dq * vz)[None, None, :] * wjz + + vals = jnp.stack((valx, valy, valz), axis=-1).reshape(-1, 3) + J_flat = jax.ops.segment_sum(vals, idx_flat, num_segments=Nx * Ny) # (Nx*Ny,3) + J2 = J_flat.reshape((Nx, Ny, 3))[:, :, None, :] + return J2 + + +def _deposit_1d_indexed(dq, vx, vy, vz, i0, r, shape_factor, Nx): + if shape_factor == 1: + ix = jnp.stack((i0, wrap_around(i0 + 1, Nx)), axis=0) + else: + ix = jnp.stack((wrap_around(i0 - 1, Nx), i0, wrap_around(i0 + 1, Nx)), axis=0) + + w_node = _weights_order1(r) if shape_factor == 1 else _weights_order2(r) + w_face = _weights_face(r, shape_factor) + + val = jnp.stack( + ( + (dq * vx)[None, :] * w_face, + (dq * vy)[None, :] * w_node, + (dq * vz)[None, :] * w_node, + ), + axis=-1, + ) + + comp = jnp.arange(3, dtype=ix.dtype)[None, None, :] + idx = ix[:, :, None] + comp * jnp.asarray(Nx, dtype=ix.dtype) + + out = jnp.bincount(idx.reshape(-1), weights=val.reshape(-1), length=Nx * 3).reshape(3, Nx) + Jx, Jy, Jz = out[0], out[1], out[2] + return jnp.stack((Jx, Jy, Jz), axis=-1).reshape((Nx, 1, 1, 3)) + + +def _deposit_2d_indexed(dq, vx, vy, vz, i0, j0, rx, ry, shape_factor, Nx, Ny): + if shape_factor == 1: + ix = jnp.stack((i0, wrap_around(i0 + 1, Nx)), axis=0) + iy = jnp.stack((j0, wrap_around(j0 + 1, Ny)), axis=0) + else: + ix = jnp.stack((wrap_around(i0 - 1, Nx), i0, wrap_around(i0 + 1, Nx)), axis=0) + iy = jnp.stack((wrap_around(j0 - 1, Ny), j0, wrap_around(j0 + 1, Ny)), axis=0) + + wx_node = _weights_order1(rx) if shape_factor == 1 else _weights_order2(rx) + wy_node = _weights_order1(ry) if shape_factor == 1 else _weights_order2(ry) + wx_face = _weights_face(rx, shape_factor) + wy_face = _weights_face(ry, shape_factor) + + idx = ix[:, None, :] + Nx * iy[None, :, :] + idx_flat = idx.reshape(-1) + + wjx = wx_face[:, None, :] * wy_node[None, :, :] + wjy = wx_node[:, None, :] * wy_face[None, :, :] + wjz = wx_node[:, None, :] * wy_node[None, :, :] + + valx = (dq * vx)[None, None, :] * wjx + valy = (dq * vy)[None, None, :] * wjy + valz = (dq * vz)[None, None, :] * wjz + + vals = jnp.stack((valx, valy, valz), axis=-1).reshape(-1, 3) + J_flat = jax.ops.segment_sum(vals, idx_flat, num_segments=Nx * Ny) + return J_flat.reshape((Nx, Ny, 3))[:, :, None, :] + + +def _deposit_3d_indexed(dq, vx, vy, vz, i0, j0, k0, rx, ry, rz, shape_factor, Nx, Ny, Nz): + if shape_factor == 1: + xpts = jnp.stack((i0, wrap_around(i0 + 1, Nx)), axis=0) + ypts = jnp.stack((j0, wrap_around(j0 + 1, Ny)), axis=0) + zpts = jnp.stack((k0, wrap_around(k0 + 1, Nz)), axis=0) + x_weights_node = _weights_order1(rx) + y_weights_node = _weights_order1(ry) + z_weights_node = _weights_order1(rz) + x_weights_face = _weights_face(rx, shape_factor) + y_weights_face = _weights_face(ry, shape_factor) + z_weights_face = _weights_face(rz, shape_factor) + else: + xpts = jnp.stack((wrap_around(i0 - 1, Nx), i0, wrap_around(i0 + 1, Nx)), axis=0) + ypts = jnp.stack((wrap_around(j0 - 1, Ny), j0, wrap_around(j0 + 1, Ny)), axis=0) + zpts = jnp.stack((wrap_around(k0 - 1, Nz), k0, wrap_around(k0 + 1, Nz)), axis=0) + x_weights_node = _weights_order2(rx) + y_weights_node = _weights_order2(ry) + z_weights_node = _weights_order2(rz) + x_weights_face = _weights_face(rx, shape_factor) + y_weights_face = _weights_face(ry, shape_factor) + z_weights_face = _weights_face(rz, shape_factor) + + if shape_factor == 1: + xpts = xpts + ypts = ypts + zpts = zpts + + ii, jj, kk = jnp.meshgrid( + jnp.arange(xpts.shape[0]), + jnp.arange(ypts.shape[0]), + jnp.arange(zpts.shape[0]), + indexing="ij", + ) + combos = jnp.stack([ii.ravel(), jj.ravel(), kk.ravel()], axis=1) + + def idx_and_dJ_values(idx): + i, j, k = idx + ix = xpts[i, ...] + iy = ypts[j, ...] + iz = zpts[k, ...] + valx = (dq * vx) * x_weights_face[i, ...] * y_weights_node[j, ...] * z_weights_node[k, ...] + valy = (dq * vy) * x_weights_node[i, ...] * y_weights_face[j, ...] * z_weights_node[k, ...] + valz = (dq * vz) * x_weights_node[i, ...] * y_weights_node[j, ...] * z_weights_face[k, ...] + return ix, iy, iz, jnp.stack((valx, valy, valz), axis=-1) + + ix, iy, iz, dJ = jax.vmap(idx_and_dJ_values)(combos) + + ix_flat = ix.reshape(-1) + iy_flat = iy.reshape(-1) + iz_flat = iz.reshape(-1) + dJ_flat = dJ.reshape(-1, 3) + + idx_flat = ix_flat + Nx * (iy_flat + Ny * iz_flat) + J_flat = jax.ops.segment_sum(dJ_flat, idx_flat, num_segments=Nx * Ny * Nz) + return J_flat.reshape((Nx, Ny, Nz, 3)) + + +@partial(jit, static_argnames=("filter", "shape_factor")) +def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear', shape_factor=2): """ Compute the current density from the charge density and particle velocities. @@ -29,27 +267,18 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): dx = world['dx'] dy = world['dy'] dz = world['dz'] - Nx = world['Nx'] - Ny = world['Ny'] - Nz = world['Nz'] - # get the world parameters - Jx, Jy, Jz = J + Nx, Ny, Nz = Jx.shape + # get the world parameters x_active = Jx.shape[0] != 1 y_active = Jx.shape[1] != 1 z_active = Jx.shape[2] != 1 # infer effective dimensionality from the current-grid shape - # unpack the values of J - Jx = Jx.at[:, :, :].set(0) - Jy = Jy.at[:, :, :].set(0) - Jz = Jz.at[:, :, :].set(0) - # initialize the current arrays as 0 - J = (Jx, Jy, Jz) - # initialize the current density as a tuple + J_stack = jnp.zeros((Nx, Ny, Nz, 3), dtype=Jx.dtype) + # keep J together so deposition and filtering can be fused across components for species in particles: - shape_factor = species.get_shape() charge = species.get_charge() dq = charge / (dx * dy * dz) # calculate the charge density contribution per particle @@ -57,29 +286,45 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): vx, vy, vz = species.get_velocity() # get the particles positions and velocities - x = x - vx * world['dt'] / 2 - y = y - vy * world['dt'] / 2 - z = z - vz * world['dt'] / 2 + dt = world["dt"] + x = x - vx * dt / 2 # step back to half time step positions for proper time staggering - x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (x - grid[0][0]) / dx).astype(int), - lambda _: jnp.round( (x - grid[0][0]) / dx).astype(int), - operand=None - ) - y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (y - grid[1][0]) / dy).astype(int), - lambda _: jnp.round( (y - grid[1][0]) / dy).astype(int), - operand=None - ) - z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (z - grid[2][0]) / dz).astype(int), - lambda _: jnp.round( (z - grid[2][0]) / dz).astype(int), - operand=None - ) + if Ny == 1 and Nz == 1: + J_stack = _deposit_1d(J_stack, dq, vx, vy, vz, x, grid[0][0], dx, world["dt"], shape_factor, Nx) + continue + if Nz == 1: + y = y - vy * dt / 2 + J_stack = _deposit_2d( + J_stack, + dq, + vx, + vy, + vz, + x, + y, + grid[0][0], + grid[1][0], + dx, + dy, + world["dt"], + shape_factor, + Nx, + Ny, + ) + continue + + y = y - vy * dt / 2 + z = z - vz * dt / 2 + + if shape_factor == 1: + x0 = jnp.floor((x - grid[0][0]) / dx).astype(jnp.int32) + y0 = jnp.floor((y - grid[1][0]) / dy).astype(jnp.int32) + z0 = jnp.floor((z - grid[2][0]) / dz).astype(jnp.int32) + else: + x0 = jnp.round((x - grid[0][0]) / dx).astype(jnp.int32) + y0 = jnp.round((y - grid[1][0]) / dy).astype(jnp.int32) + z0 = jnp.round((z - grid[2][0]) / dz).astype(jnp.int32) # calculate the nearest grid point based on shape factor deltax_node = (x - grid[0][0]) - (x0 * dx) @@ -96,9 +341,9 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): y0 = wrap_around(y0, Ny) z0 = wrap_around(z0, Nz) # wrap around the grid points for periodic boundary conditions - x1 = wrap_around(x0+1, Nx) - y1 = wrap_around(y0+1, Ny) - z1 = wrap_around(z0+1, Nz) + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) # calculate the right grid point x_minus1 = x0 - 1 y_minus1 = y0 - 1 @@ -109,19 +354,12 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): zpts = [z_minus1, z0, z1] # place all the points in a list - x_weights_node, y_weights_node, z_weights_node = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights( deltax_node, deltay_node, deltaz_node, dx, dy, dz), - lambda _: get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz), - operand=None - ) - - x_weights_face, y_weights_face, z_weights_face = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights( deltax_face, deltay_face, deltaz_face, dx, dy, dz), - lambda _: get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz), - operand=None - ) + if shape_factor == 1: + x_weights_node, y_weights_node, z_weights_node = get_first_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz) + x_weights_face, y_weights_face, z_weights_face = get_first_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz) + else: + x_weights_node, y_weights_node, z_weights_node = get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz) + x_weights_face, y_weights_face, z_weights_face = get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz) # get the weights for node and face positions xpts = jnp.asarray(xpts) # (Sx, Np) @@ -136,6 +374,18 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): y_weights_node = jnp.asarray(y_weights_node) # (Sy, Np) z_weights_node = jnp.asarray(z_weights_node) # (Sz, Np) + if shape_factor == 1: + # drop the redundant (-1) stencil point for first-order (its weights are identically 0) + xpts = xpts[1:, ...] + ypts = ypts[1:, ...] + zpts = zpts[1:, ...] + x_weights_face = x_weights_face[1:, ...] + y_weights_face = y_weights_face[1:, ...] + z_weights_face = z_weights_face[1:, ...] + x_weights_node = x_weights_node[1:, ...] + y_weights_node = y_weights_node[1:, ...] + z_weights_node = z_weights_node[1:, ...] + # Keep full shape-factor computation but collapse inactive axes to an # effective stencil of size 1 to avoid redundant deposition work. if x_active: @@ -184,41 +434,83 @@ def idx_and_dJ_values(idx): valy = (dq * vy) * x_weights_node_eff[i, ...] * y_weights_face_eff[j, ...] * z_weights_node_eff[k, ...] valz = (dq * vz) * x_weights_node_eff[i, ...] * y_weights_node_eff[j, ...] * z_weights_face_eff[k, ...] # calculate the current contributions for this stencil point - return ix, iy, iz, valx, valy, valz + return ix, iy, iz, jnp.stack((valx, valy, valz), axis=-1) - ix, iy, iz, valx, valy, valz = jax.vmap(idx_and_dJ_values)(combos) # each: (M, Np) + ix, iy, iz, dJ = jax.vmap(idx_and_dJ_values)(combos) # (M,Np), (M,Np), (M,Np), (M,Np,3) # vectorized computation of indices and current contributions - Jx = Jx.at[(ix, iy, iz)].add(valx, mode="drop") - Jy = Jy.at[(ix, iy, iz)].add(valy, mode="drop") - Jz = Jz.at[(ix, iy, iz)].add(valz, mode="drop") - # deposit the current contributions into the global J arrays - - def filter_func(J_, filter): - J_ = jax.lax.cond( - filter == 'bilinear', - lambda J_: bilinear_filter(J_), - lambda J_: J_, - operand=J_ + ix_flat = ix.reshape(-1) + iy_flat = iy.reshape(-1) + iz_flat = iz.reshape(-1) + dJ_flat = dJ.reshape(-1, 3) + + in_bounds = ( + (ix_flat >= 0) + & (ix_flat < Nx) + & (iy_flat >= 0) + & (iy_flat < Ny) + & (iz_flat >= 0) + & (iz_flat < Nz) ) - # alpha = constants['alpha'] - # J_ = jax.lax.cond( - # filter == 'digital', - # lambda J_: digital_filter(J_, alpha), - # lambda J_: J_, - # operand=J_ - # ) - return J_ - # define a filtering function - - Jx = filter_func(Jx, filter) - Jy = filter_func(Jy, filter) - Jz = filter_func(Jz, filter) - # apply the selected filter to each component of J - J = (Jx, Jy, Jz) - - return J + ix_flat = jnp.clip(ix_flat, 0, Nx - 1) + iy_flat = jnp.clip(iy_flat, 0, Ny - 1) + iz_flat = jnp.clip(iz_flat, 0, Nz - 1) + + idx_flat = ix_flat + Nx * (iy_flat + Ny * iz_flat) + dJ_flat = jnp.where(in_bounds[:, None], dJ_flat, 0) + + J_flat = jax.ops.segment_sum(dJ_flat, idx_flat, num_segments=Nx * Ny * Nz) + J_stack = J_stack + J_flat.reshape((Nx, Ny, Nz, 3)) + # segment_sum avoids large scatter updates on CPU + + if filter == "bilinear": + J_stack = bilinear_filter(J_stack) + # (optional) digital filter disabled by default + + return (J_stack[..., 0], J_stack[..., 1], J_stack[..., 2]) + + +@partial(jit, static_argnames=("filter",)) +def J_from_rhov_indexed(particles, J, constants, world, grid=None, filter="bilinear"): + if grid is None: + grid = world["grids"]["center"] + + dx = world["dx"] + dy = world["dy"] + dz = world["dz"] + Jx, Jy, Jz = J + Nx, Ny, Nz = Jx.shape + + J_stack = jnp.zeros((Nx, Ny, Nz, 3), dtype=Jx.dtype) + + dt = world["dt"] + for species in particles: + charge = species.get_charge() + dq = charge / (dx * dy * dz) + vx, vy, vz = species.get_velocity() + i0, j0, k0, rx, ry, rz = species.get_indexed_position() + shape_factor = species.get_shape() + + i0b, rxb = _advance_index_and_frac(i0, rx, -vx * dt / (2 * dx), Nx, shape_factor) + + if Ny == 1 and Nz == 1: + J_stack = J_stack + _deposit_1d_indexed(dq, vx, vy, vz, i0b, rxb, shape_factor, Nx) + continue + + j0b, ryb = _advance_index_and_frac(j0, ry, -vy * dt / (2 * dy), Ny, shape_factor) + + if Nz == 1: + J_stack = J_stack + _deposit_2d_indexed(dq, vx, vy, vz, i0b, j0b, rxb, ryb, shape_factor, Nx, Ny) + continue + + k0b, rzb = _advance_index_and_frac(k0, rz, -vz * dt / (2 * dz), Nz, shape_factor) + J_stack = J_stack + _deposit_3d_indexed(dq, vx, vy, vz, i0b, j0b, k0b, rxb, ryb, rzb, shape_factor, Nx, Ny, Nz) + + if filter == "bilinear": + J_stack = bilinear_filter(J_stack) + + return (J_stack[..., 0], J_stack[..., 1], J_stack[..., 2]) def _roll_old_weights_to_new_frame(old_w_list, shift): """ diff --git a/PyPIC3D/__main__.py b/PyPIC3D/__main__.py index 775213c..22b6ec7 100644 --- a/PyPIC3D/__main__.py +++ b/PyPIC3D/__main__.py @@ -8,7 +8,7 @@ import os import time import jax -from jax import block_until_ready +from jax import block_until_ready, lax import jax.numpy as jnp from tqdm import tqdm @@ -23,10 +23,6 @@ write_openpmd_particles, write_openpmd_fields ) -from PyPIC3D.diagnostics.vtk import ( - plot_field_slice_vtk, plot_vectorfield_slice_vtk, plot_vtk_particles -) - from PyPIC3D.utils import ( dump_parameters_to_toml, load_config_file, compute_energy, setup_pmd_files @@ -50,11 +46,41 @@ def run_PyPIC3D(config_file): ##################################### INITIALIZE SIMULATION ################################################ - loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(config_file) + cfg = config_file + if isinstance(cfg, dict): + sim = cfg.setdefault("simulation_parameters", {}) + fast_mode = sim.get("fast_mode", "off") + if fast_mode not in ("off", "fp32", "aggressive", "extreme"): + raise ValueError("simulation_parameters.fast_mode must be one of: off, fp32, aggressive, extreme") + + if fast_mode in ("fp32", "aggressive", "extreme"): + sim["enable_x64"] = False + + if fast_mode in ("aggressive", "extreme"): + sim["shape_factor"] = 1 + sim["filter_j"] = "none" + + plot = cfg.setdefault("plotting", {}) + for key in ( + "plot_phasespace", + "plot_vtk_particles", + "plot_vtk_scalars", + "plot_vtk_vectors", + "plot_openpmd_particles", + "plot_openpmd_fields", + "dump_particles", + "dump_fields", + ): + plot[key] = False + plot["plotting_interval"] = 10**9 + + if fast_mode == "extreme": + # opt-in physics approximation for maximum throughput + sim["relativistic"] = False + + loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(cfg) # initialize the simulation - jit_loop = jax.jit(loop, static_argnames=('curl_func', 'J_func', 'solver', 'relativistic')) - dt = world['dt'] output_dir = simulation_parameters['output_dir'] vertex_grid = world['grids']['vertex'] @@ -77,12 +103,88 @@ def run_PyPIC3D(config_file): ###################################################### SIMULATION LOOP ##################################### - for t in tqdm(range(Nt)): + scan_chunk = int(simulation_parameters.get("scan_chunk", 1)) + plotting_interval = int(plotting_parameters["plotting_interval"]) + fast_mode = str(simulation_parameters.get("fast_mode", "off")) + advance_impl = str(simulation_parameters.get("advance_impl", "fori")) + scan_unroll = int(simulation_parameters.get("scan_unroll", 1)) + + if scan_chunk < 1: + raise ValueError("simulation_parameters.scan_chunk must be >= 1") + + if scan_chunk > 1 and (plotting_interval % scan_chunk) != 0: + raise ValueError( + f"scan_chunk={scan_chunk} requires plotting_interval to be a multiple of scan_chunk " + f"(got plotting_interval={plotting_interval})." + ) + + def make_advance(n_steps): + def advance(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): + if fast_mode == "aggressive" and advance_impl == "scan": + def scan_body(carry, _): + p, f = carry + return loop(p, f, world, constants, curl_func, J_func, solver, relativistic=relativistic), None + + (particles, fields), _ = lax.scan( + scan_body, + (particles, fields), + xs=None, + length=n_steps, + unroll=scan_unroll, + ) + return particles, fields + + def body(_, state): + p, f = state + return loop(p, f, world, constants, curl_func, J_func, solver, relativistic=relativistic) + + return lax.fori_loop(0, n_steps, body, (particles, fields)) + + return jax.jit( + advance, + static_argnames=("curl_func", "J_func", "solver", "relativistic"), + donate_argnums=(0, 1), + ) + + advance_full = make_advance(scan_chunk) if scan_chunk > 1 else None + tail = Nt % scan_chunk + advance_tail = make_advance(tail) if (scan_chunk > 1 and tail) else None + + outputs_enabled = any( + plotting_parameters.get(k, False) + for k in ( + "plot_phasespace", + "plot_vtk_scalars", + "plot_vtk_vectors", + "plot_vtk_particles", + "plot_openpmd_particles", + "plot_openpmd_fields", + "dump_particles", + "dump_fields", + ) + ) + + if (not outputs_enabled) and plotting_interval > Nt: + advance_all = make_advance(Nt) + particles, fields = advance_all( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + return Nt, plotting_parameters, simulation_parameters, plasma_parameters, constants, particles, fields, world + + step_iter = range(0, Nt, scan_chunk) if scan_chunk > 1 else range(Nt) + for t in tqdm(step_iter): # plot the data - if t % plotting_parameters['plotting_interval'] == 0: + if t % plotting_interval == 0: - plot_num = t // plotting_parameters['plotting_interval'] + plot_num = t // plotting_interval # determine the plot number E, B, J, rho, *rest = fields @@ -111,6 +213,11 @@ def run_PyPIC3D(config_file): if plotting_parameters['plot_vtk_scalars']: + try: + from PyPIC3D.diagnostics.vtk import plot_field_slice_vtk + except ModuleNotFoundError as e: + raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e + rho = compute_rho(particles, rho, world, constants) # calculate the charge density based on the particle positions mass_density = compute_mass_density(particles, rho, world) @@ -122,6 +229,11 @@ def run_PyPIC3D(config_file): if plotting_parameters['plot_vtk_vectors']: + try: + from PyPIC3D.diagnostics.vtk import plot_vectorfield_slice_vtk + except ModuleNotFoundError as e: + raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e + vector_field_slices = [ [E[0][:,world['Ny']//2,:], E[1][:,world['Ny']//2,:], E[2][:,world['Ny']//2,:]], [B[0][:,world['Ny']//2,:], B[1][:,world['Ny']//2,:], B[2][:,world['Ny']//2,:]], [J[0][:,world['Ny']//2,:], J[1][:,world['Ny']//2,:], J[2][:,world['Ny']//2,:]]] @@ -129,6 +241,11 @@ def run_PyPIC3D(config_file): # Plot the vector fields in VTK format if plotting_parameters['plot_vtk_particles']: + try: + from PyPIC3D.diagnostics.vtk import plot_vtk_particles + except ModuleNotFoundError as e: + raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e + plot_vtk_particles(particles, plot_num, output_dir) # Plot the particles in VTK format @@ -143,25 +260,58 @@ def run_PyPIC3D(config_file): fields = (E, B, J, rho, *rest) # repack the fields - particles, fields = jit_loop( - particles, - fields, - world, - constants, - curl_func, - J_func, - solver, - relativistic=relativistic, - ) - # time loop to update the particles and fields + if scan_chunk == 1: + particles, fields = loop( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + else: + if (t + scan_chunk) <= Nt: + particles, fields = advance_full( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + else: + particles, fields = advance_tail( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + # advance the particles and fields return Nt, plotting_parameters, simulation_parameters, plasma_parameters, constants, particles, fields, world def main(): ###################### JAX SETTINGS ######################################################################## - jax.config.update("jax_enable_x64", True) - # set Jax to use 64 bit precision + toml_file = load_config_file() + # load the configuration file + + sim = toml_file.get("simulation_parameters", {}) if isinstance(toml_file, dict) else {} + fast_mode = sim.get("fast_mode", "off") + enable_x64 = bool(sim.get("enable_x64", True)) + if fast_mode in ("fp32", "aggressive", "extreme"): + enable_x64 = False + jax.config.update("jax_enable_x64", enable_x64) + # set Jax precision (default preserves legacy behavior) + # jax.config.update("jax_debug_nans", True) # debugging for nans jax.config.update('jax_platform_name', 'cpu') @@ -169,9 +319,6 @@ def main(): #jax.config.update("jax_disable_jit", True) ############################################################################################################ - toml_file = load_config_file() - # load the configuration file - start = time.time() # start the timer diff --git a/PyPIC3D/boris.py b/PyPIC3D/boris.py index 49405c8..763e9ad 100644 --- a/PyPIC3D/boris.py +++ b/PyPIC3D/boris.py @@ -1,9 +1,11 @@ import jax from jax import jit import jax.numpy as jnp +from functools import partial from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights from PyPIC3D.utils import wrap_around +from PyPIC3D.indexed_particles import _advance_index_and_frac @jit def particle_push(particles, E, B, grid, staggered_grid, dt, constants, periodic=True, relativistic=True): @@ -49,31 +51,83 @@ def particle_push(particles, E, B, grid, staggered_grid, dt, constants, periodic ################## INTERPOLATE FIELDS TO PARTICLE POSITIONS ############## Ex, Ey, Ez = E # unpack the electric field components - efield_atx = interpolate_field_to_particles(Ex, x, y, z, Ex_grid, shape_factor) - efield_aty = interpolate_field_to_particles(Ey, x, y, z, Ey_grid, shape_factor) - efield_atz = interpolate_field_to_particles(Ez, x, y, z, Ez_grid, shape_factor) - # calculate the electric field at the particle positions on the Yee-staggered component grids Bx, By, Bz = B # unpack the magnetic field components - bfield_atx = interpolate_field_to_particles(Bx, x, y, z, Bx_grid, shape_factor) - bfield_aty = interpolate_field_to_particles(By, x, y, z, By_grid, shape_factor) - bfield_atz = interpolate_field_to_particles(Bz, x, y, z, Bz_grid, shape_factor) - # calculate the magnetic field at the particle positions on the Yee-staggered component grids + + Ny = len(grid[1]) + Nz = len(grid[2]) + + if Ny == 1 and Nz == 1: + node_stack = jnp.stack((Ey, Ez, Bx), axis=-1) + face_stack = jnp.stack((Ex, By, Bz), axis=-1) + + node_vals = interpolate_field_to_particles(node_stack, x, y, z, (grid[0], grid[1], grid[2]), shape_factor) + face_vals = interpolate_field_to_particles(face_stack, x, y, z, (staggered_grid[0], grid[1], grid[2]), shape_factor) + + efield_aty, efield_atz, bfield_atx = node_vals[:, 0], node_vals[:, 1], node_vals[:, 2] + efield_atx, bfield_aty, bfield_atz = face_vals[:, 0], face_vals[:, 1], face_vals[:, 2] + + elif Nz == 1: + ex_by = jnp.stack((Ex, By), axis=-1) + ey_bx = jnp.stack((Ey, Bx), axis=-1) + + ex_by_vals = interpolate_field_to_particles(ex_by, x, y, z, Ex_grid, shape_factor) + ey_bx_vals = interpolate_field_to_particles(ey_bx, x, y, z, Ey_grid, shape_factor) + ez_vals = interpolate_field_to_particles(Ez, x, y, z, Ez_grid, shape_factor) + bz_vals = interpolate_field_to_particles(Bz, x, y, z, Bz_grid, shape_factor) + + efield_atx, bfield_aty = ex_by_vals[:, 0], ex_by_vals[:, 1] + efield_aty, bfield_atx = ey_bx_vals[:, 0], ey_bx_vals[:, 1] + efield_atz = ez_vals + bfield_atz = bz_vals + + else: + efield_atx = interpolate_field_to_particles(Ex, x, y, z, Ex_grid, shape_factor) + efield_aty = interpolate_field_to_particles(Ey, x, y, z, Ey_grid, shape_factor) + efield_atz = interpolate_field_to_particles(Ez, x, y, z, Ez_grid, shape_factor) + # calculate the electric field at the particle positions on the Yee-staggered component grids + bfield_atx = interpolate_field_to_particles(Bx, x, y, z, Bx_grid, shape_factor) + bfield_aty = interpolate_field_to_particles(By, x, y, z, By_grid, shape_factor) + bfield_atz = interpolate_field_to_particles(Bz, x, y, z, Bz_grid, shape_factor) + # calculate the magnetic field at the particle positions on the Yee-staggered component grids ######################################################################### #################### BORIS ALGORITHM #################################### - boris_vmap = jax.vmap(boris_single_particle, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None)) - relativistic_boris_vmap = jax.vmap(relativistic_boris_single_particle, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None)) - # vectorize the Boris algorithm for batch processing - newvx, newvy, newvz = jax.lax.cond( relativistic == True, - lambda _: relativistic_boris_vmap(vx, vy, vz, efield_atx, efield_aty, efield_atz, bfield_atx, bfield_aty, bfield_atz, q, m, dt, constants), - lambda _: boris_vmap(vx, vy, vz, efield_atx, efield_aty, efield_atz, bfield_atx, bfield_aty, bfield_atz, q, m, dt, constants), - operand=None + lambda _: relativistic_boris_push( + vx, + vy, + vz, + efield_atx, + efield_aty, + efield_atz, + bfield_atx, + bfield_aty, + bfield_atz, + q, + m, + dt, + constants, + ), + lambda _: boris_push( + vx, + vy, + vz, + efield_atx, + efield_aty, + efield_atz, + bfield_atx, + bfield_aty, + bfield_atz, + q, + m, + dt, + ), + operand=None, ) - # apply the Boris algorithm to update the velocities of the particles + # apply the Boris algorithm (vectorized over particles) ######################################################################### @@ -81,6 +135,84 @@ def particle_push(particles, E, B, grid, staggered_grid, dt, constants, periodic # set the new velocities of the particles return particles + +def boris_push(vx, vy, vz, ex, ey, ez, bx, by, bz, q, m, dt): + qmdt2 = q * dt / (2 * m) + + vminus_x = vx + qmdt2 * ex + vminus_y = vy + qmdt2 * ey + vminus_z = vz + qmdt2 * ez + + t_x = qmdt2 * bx + t_y = qmdt2 * by + t_z = qmdt2 * bz + + t2 = t_x * t_x + t_y * t_y + t_z * t_z + inv = 1.0 / (1.0 + t2) + s_x = 2.0 * t_x * inv + s_y = 2.0 * t_y * inv + s_z = 2.0 * t_z * inv + + vprime_x = vminus_x + (vminus_y * t_z - vminus_z * t_y) + vprime_y = vminus_y + (vminus_z * t_x - vminus_x * t_z) + vprime_z = vminus_z + (vminus_x * t_y - vminus_y * t_x) + + vplus_x = vminus_x + (vprime_y * s_z - vprime_z * s_y) + vplus_y = vminus_y + (vprime_z * s_x - vprime_x * s_z) + vplus_z = vminus_z + (vprime_x * s_y - vprime_y * s_x) + + newvx = vplus_x + qmdt2 * ex + newvy = vplus_y + qmdt2 * ey + newvz = vplus_z + qmdt2 * ez + + return newvx, newvy, newvz + + +def relativistic_boris_push(vx, vy, vz, ex, ey, ez, bx, by, bz, q, m, dt, constants): + C = constants["C"] + qmdt2 = q * dt / (2 * m) + + v2_over_c2 = (vx * vx + vy * vy + vz * vz) / (C * C) + gamma = 1.0 / jnp.sqrt(1.0 - v2_over_c2) + + uminus_x = vx * gamma + qmdt2 * ex + uminus_y = vy * gamma + qmdt2 * ey + uminus_z = vz * gamma + qmdt2 * ez + + uminus2_over_c2 = (uminus_x * uminus_x + uminus_y * uminus_y + uminus_z * uminus_z) / (C * C) + gamma_minus = jnp.sqrt(1.0 + uminus2_over_c2) + + t_x = (qmdt2 * bx) / gamma_minus + t_y = (qmdt2 * by) / gamma_minus + t_z = (qmdt2 * bz) / gamma_minus + + t2 = t_x * t_x + t_y * t_y + t_z * t_z + inv = 1.0 / (1.0 + t2) + s_x = 2.0 * t_x * inv + s_y = 2.0 * t_y * inv + s_z = 2.0 * t_z * inv + + uprime_x = uminus_x + (uminus_y * t_z - uminus_z * t_y) + uprime_y = uminus_y + (uminus_z * t_x - uminus_x * t_z) + uprime_z = uminus_z + (uminus_x * t_y - uminus_y * t_x) + + uplus_x = uminus_x + (uprime_y * s_z - uprime_z * s_y) + uplus_y = uminus_y + (uprime_z * s_x - uprime_x * s_z) + uplus_z = uminus_z + (uprime_x * s_y - uprime_y * s_x) + + newu_x = uplus_x + qmdt2 * ex + newu_y = uplus_y + qmdt2 * ey + newu_z = uplus_z + qmdt2 * ez + + newu2_over_c2 = (newu_x * newu_x + newu_y * newu_y + newu_z * newu_z) / (C * C) + new_gamma = jnp.sqrt(1.0 + newu2_over_c2) + + newvx = newu_x / new_gamma + newvy = newu_y / new_gamma + newvz = newu_z / new_gamma + + return newvx, newvy, newvz + @jit def boris_single_particle(vx, vy, vz, efield_atx, efield_aty, efield_atz, bfield_atx, bfield_aty, bfield_atz, q, m, dt, constants): """ @@ -194,7 +326,7 @@ def relativistic_boris_single_particle(vx, vy, vz, efield_atx, efield_aty, efiel return newv[0], newv[1], newv[2] -@jit +@partial(jit, static_argnames=("shape_factor",)) def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): """ Interpolate a Yee-grid field component to particle positions using PIC shape functions. @@ -227,37 +359,50 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): dz = z_grid[1] - z_grid[0] if Nz > 1 else 1.0 # grid spacing in each direction - x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor((x - xmin) / dx).astype(int), - lambda _: jnp.round((x - xmin) / dx).astype(int), - operand=None, - ) - y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor((y - ymin) / dy).astype(int), - lambda _: jnp.round((y - ymin) / dy).astype(int), - operand=None, - ) - z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor((z - zmin) / dz).astype(int), - lambda _: jnp.round((z - zmin) / dz).astype(int), - operand=None, - ) + if shape_factor == 1: + x0 = jnp.floor((x - xmin) / dx).astype(jnp.int32) + else: + x0 = jnp.round((x - xmin) / dx).astype(jnp.int32) # compute the stencil anchor points (cell-left for first order, nearest node for second order) deltax = (x - xmin) - x0 * dx + # determine the distance from the closest grid nodes + + if x_active and (not y_active) and (not z_active): + x0 = wrap_around(x0, Nx) + if shape_factor == 1: + r = deltax / dx + xpts = jnp.stack((x0, wrap_around(x0 + 1, Nx)), axis=0) + xw = jnp.stack((1.0 - r, r), axis=0) + else: + r = deltax / dx + xpts = jnp.stack((wrap_around(x0 - 1, Nx), x0, wrap_around(x0 + 1, Nx)), axis=0) + xw = jnp.stack( + ( + 0.5 * (0.5 - r) ** 2, + 0.75 - r**2, + 0.5 * (0.5 + r) ** 2, + ), + axis=0, + ) + if field.ndim == 4: + return jnp.sum(field[xpts, 0, 0, :] * xw[:, :, None], axis=0) + return jnp.sum(field[xpts, 0, 0] * xw, axis=0) + + if shape_factor == 1: + y0 = jnp.floor((y - ymin) / dy).astype(jnp.int32) + z0 = jnp.floor((z - zmin) / dz).astype(jnp.int32) + else: + y0 = jnp.round((y - ymin) / dy).astype(jnp.int32) + z0 = jnp.round((z - zmin) / dz).astype(jnp.int32) + deltay = (y - ymin) - y0 * dy deltaz = (z - zmin) - z0 * dz - # determine the distance from the closest grid nodes - x_weights, y_weights, z_weights = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), - lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), - operand=None, - ) + if shape_factor == 1: + x_weights, y_weights, z_weights = get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz) + else: + x_weights, y_weights, z_weights = get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz) x_weights = jnp.asarray(x_weights) y_weights = jnp.asarray(y_weights) z_weights = jnp.asarray(z_weights) @@ -280,6 +425,15 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): zpts = jnp.asarray([z_minus1, z0, z1]) # place all the points in a list + if shape_factor == 1: + # drop the redundant (-1) stencil point for first-order (its weights are identically 0) + xpts = xpts[1:, ...] + ypts = ypts[1:, ...] + zpts = zpts[1:, ...] + x_weights = x_weights[1:, ...] + y_weights = y_weights[1:, ...] + z_weights = z_weights[1:, ...] + # Keep full shape-factor computation but collapse inactive axes to an # effective stencil size of 1 to avoid redundant interpolation work. if x_active: @@ -303,26 +457,152 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): zpts_eff = jnp.zeros((1, zpts.shape[1]), dtype=zpts.dtype) z_weights_eff = jnp.sum(z_weights, axis=0, keepdims=True) - def stencil_contribution(stencil_idx): - i, j, k = stencil_idx - return ( - field[xpts_eff[i, ...], ypts_eff[j, ...], zpts_eff[k, ...]] - * x_weights_eff[i, ...] - * y_weights_eff[j, ...] - * z_weights_eff[k, ...] - ) - # define a function to compute the contribution from each point in the effective stencil - - ii, jj, kk = jnp.meshgrid( - jnp.arange(xpts_eff.shape[0]), - jnp.arange(ypts_eff.shape[0]), - jnp.arange(zpts_eff.shape[0]), - indexing="ij", + field_vals = field[ + xpts_eff[:, None, None, :], + ypts_eff[None, :, None, :], + zpts_eff[None, None, :, :], + ] + weights = ( + x_weights_eff[:, None, None, :] + * y_weights_eff[None, :, None, :] + * z_weights_eff[None, None, :, :] + ) + if field.ndim == 4: + weights = weights[..., None] + return jnp.sum(field_vals * weights, axis=(0, 1, 2)) + + +def _weights_from_r(r, shape_factor): + if shape_factor == 1: + return jnp.stack((1.0 - r, r), axis=0) + return jnp.stack( + ( + 0.5 * (0.5 - r) ** 2, + 0.75 - r**2, + 0.5 * (0.5 + r) ** 2, + ), + axis=0, ) - stencil_indicies = jnp.stack([ii.ravel(), jj.ravel(), kk.ravel()], axis=1) - # build effective stencil indices with shape (Sx*Sy*Sz, 3) - interpolated_field = jnp.sum(jax.vmap(stencil_contribution)(stencil_indicies), axis=0) - # sum the contributions from all stencil points to get the final interpolated field value at each particle position - return interpolated_field +@partial(jit, static_argnames=("shape_factor", "Nx", "Ny", "Nz")) +def interpolate_field_indexed(field, i0, j0, k0, rx, ry, rz, shape_factor, Nx, Ny, Nz): + if Ny == 1 and Nz == 1: + xw = _weights_from_r(rx, shape_factor) + if shape_factor == 1: + xpts = jnp.stack((i0, wrap_around(i0 + 1, Nx)), axis=0) + else: + xpts = jnp.stack((wrap_around(i0 - 1, Nx), i0, wrap_around(i0 + 1, Nx)), axis=0) + + if field.ndim == 4: + return jnp.sum(field[xpts, 0, 0, :] * xw[:, :, None], axis=0) + return jnp.sum(field[xpts, 0, 0] * xw, axis=0) + + xw = _weights_from_r(rx, shape_factor) + yw = _weights_from_r(ry, shape_factor) + zw = _weights_from_r(rz, shape_factor) + + if shape_factor == 1: + xpts = jnp.stack((i0, wrap_around(i0 + 1, Nx)), axis=0) + ypts = jnp.stack((j0, wrap_around(j0 + 1, Ny)), axis=0) + zpts = jnp.stack((k0, wrap_around(k0 + 1, Nz)), axis=0) + else: + xpts = jnp.stack((wrap_around(i0 - 1, Nx), i0, wrap_around(i0 + 1, Nx)), axis=0) + ypts = jnp.stack((wrap_around(j0 - 1, Ny), j0, wrap_around(j0 + 1, Ny)), axis=0) + zpts = jnp.stack((wrap_around(k0 - 1, Nz), k0, wrap_around(k0 + 1, Nz)), axis=0) + + field_vals = field[ + xpts[:, None, None, :], + ypts[None, :, None, :], + zpts[None, None, :, :], + ] + weights = xw[:, None, None, :] * yw[None, :, None, :] * zw[None, None, :, :] + if field.ndim == 4: + weights = weights[..., None] + return jnp.sum(field_vals * weights, axis=(0, 1, 2)) + + +@jit +def particle_push_indexed(particles, E, B, world, constants, relativistic=True): + q = particles.get_charge() + m = particles.get_mass() + vx, vy, vz = particles.get_velocity() + shape_factor = particles.get_shape() + + Ex, Ey, Ez = E + Bx, By, Bz = B + + i0, j0, k0, rx, ry, rz = particles.get_indexed_position() + + Nx, Ny, Nz = Ex.shape + + i0s, rxs = _advance_index_and_frac(i0, rx, -0.5, Nx, shape_factor) + j0s, rys = _advance_index_and_frac(j0, ry, -0.5, Ny, shape_factor) + k0s, rzs = _advance_index_and_frac(k0, rz, -0.5, Nz, shape_factor) + + if Ny == 1 and Nz == 1: + node_stack = jnp.stack((Ey, Ez, Bx), axis=-1) + face_stack = jnp.stack((Ex, By, Bz), axis=-1) + + node_vals = interpolate_field_indexed(node_stack, i0, j0, k0, rx, ry, rz, shape_factor, Nx, Ny, Nz) + face_vals = interpolate_field_indexed(face_stack, i0s, j0, k0, rxs, ry, rz, shape_factor, Nx, Ny, Nz) + + efield_aty, efield_atz, bfield_atx = node_vals[:, 0], node_vals[:, 1], node_vals[:, 2] + efield_atx, bfield_aty, bfield_atz = face_vals[:, 0], face_vals[:, 1], face_vals[:, 2] + elif Nz == 1: + ex_by = jnp.stack((Ex, By), axis=-1) + ey_bx = jnp.stack((Ey, Bx), axis=-1) + + ex_by_vals = interpolate_field_indexed(ex_by, i0s, j0, k0, rxs, ry, rz, shape_factor, Nx, Ny, Nz) + ey_bx_vals = interpolate_field_indexed(ey_bx, i0, j0s, k0, rx, rys, rz, shape_factor, Nx, Ny, Nz) + ez_vals = interpolate_field_indexed(Ez, i0, j0, k0s, rx, ry, rzs, shape_factor, Nx, Ny, Nz) + bz_vals = interpolate_field_indexed(Bz, i0s, j0s, k0, rxs, rys, rz, shape_factor, Nx, Ny, Nz) + + efield_atx, bfield_aty = ex_by_vals[:, 0], ex_by_vals[:, 1] + efield_aty, bfield_atx = ey_bx_vals[:, 0], ey_bx_vals[:, 1] + efield_atz = ez_vals + bfield_atz = bz_vals + else: + efield_atx = interpolate_field_indexed(Ex, i0s, j0, k0, rxs, ry, rz, shape_factor, Nx, Ny, Nz) + efield_aty = interpolate_field_indexed(Ey, i0, j0s, k0, rx, rys, rz, shape_factor, Nx, Ny, Nz) + efield_atz = interpolate_field_indexed(Ez, i0, j0, k0s, rx, ry, rzs, shape_factor, Nx, Ny, Nz) + bfield_atx = interpolate_field_indexed(Bx, i0, j0s, k0s, rx, rys, rzs, shape_factor, Nx, Ny, Nz) + bfield_aty = interpolate_field_indexed(By, i0s, j0, k0s, rxs, ry, rzs, shape_factor, Nx, Ny, Nz) + bfield_atz = interpolate_field_indexed(Bz, i0s, j0s, k0, rxs, rys, rz, shape_factor, Nx, Ny, Nz) + + newvx, newvy, newvz = jax.lax.cond( + relativistic == True, + lambda _: relativistic_boris_push( + vx, + vy, + vz, + efield_atx, + efield_aty, + efield_atz, + bfield_atx, + bfield_aty, + bfield_atz, + q, + m, + world["dt"], + constants, + ), + lambda _: boris_push( + vx, + vy, + vz, + efield_atx, + efield_aty, + efield_atz, + bfield_atx, + bfield_aty, + bfield_atz, + q, + m, + world["dt"], + ), + operand=None, + ) + + particles.set_velocity(newvx, newvy, newvz) + return particles diff --git a/PyPIC3D/evolve.py b/PyPIC3D/evolve.py index 5cdacff..f5fc31e 100644 --- a/PyPIC3D/evolve.py +++ b/PyPIC3D/evolve.py @@ -7,7 +7,8 @@ from functools import partial from PyPIC3D.boris import ( - particle_push + particle_push, + particle_push_indexed, ) from PyPIC3D.solvers.first_order_yee import ( @@ -22,7 +23,7 @@ E_from_A, B_from_A, update_vector_potential ) -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_electrostatic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advances the simulation by one time step for an electrostatic Particle-In-Cell (PIC) loop. @@ -74,8 +75,7 @@ def time_loop_electrostatic(particles, fields, world, constants, curl_func, J_fu return particles, fields -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) -def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): +def time_loop_electrodynamic_inline(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advance an electrodynamic Particle-In-Cell (PIC) system by one time step. This routine performs, in order: @@ -154,7 +154,37 @@ def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_f return particles, fields -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) +def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): + return time_loop_electrodynamic_inline( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + + +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) +def time_loop_electrodynamic_indexed(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): + E, B, J, rho, phi = fields + + for i in range(len(particles)): + particles[i] = particle_push_indexed(particles[i], E, B, world, constants, relativistic=relativistic) + particles[i].update_position(world) + + J = J_func(particles, J, constants, world) + E = update_E(E, B, J, world, constants, curl_func) + B = update_B(E, B, world, constants, curl_func) + + fields = (E, B, J, rho, phi) + return particles, fields + + +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_vector_potential(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advance a PIC (Particle-In-Cell) simulation by one time step using a diff --git a/PyPIC3D/flat_particles.py b/PyPIC3D/flat_particles.py new file mode 100644 index 0000000..2413ee4 --- /dev/null +++ b/PyPIC3D/flat_particles.py @@ -0,0 +1,324 @@ +import jax +import jax.numpy as jnp + + +@jax.tree_util.register_pytree_node_class +class flat_particle_species: + def __init__( + self, + name, + N_particles, + charge, + mass, + weight, + T, + x1, + x2, + x3, + v1, + v2, + v3, + x_wind, + y_wind, + z_wind, + dx, + dy, + dz, + x_bc, + y_bc, + z_bc, + update_pos, + update_v, + update_x, + update_y, + update_z, + update_vx, + update_vy, + update_vz, + shape, + dt, + species_meta, + ): + self.name = name + self.N_particles = N_particles + self.charge = charge + self.mass = mass + self.weight = weight + self.T = T + self.x1 = x1 + self.x2 = x2 + self.x3 = x3 + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.x_wind = x_wind + self.y_wind = y_wind + self.z_wind = z_wind + self.dx = dx + self.dy = dy + self.dz = dz + self.x_bc = x_bc + self.y_bc = y_bc + self.z_bc = z_bc + self.update_pos = update_pos + self.update_v = update_v + self.update_x = update_x + self.update_y = update_y + self.update_z = update_z + self.update_vx = update_vx + self.update_vy = update_vy + self.update_vz = update_vz + self.shape = shape + self.dt = dt + self.species_meta = species_meta + + def get_name(self): + return self.name + + def get_charge(self): + return self.charge * self.weight + + def get_number_of_particles(self): + return self.N_particles + + def get_temperature(self): + return self.T + + def get_velocity(self): + return self.v1, self.v2, self.v3 + + def get_forward_position(self): + return self.x1, self.x2, self.x3 + + def get_mass(self): + return self.mass * self.weight + + def get_shape(self): + return self.shape + + def set_velocity(self, v1, v2, v3): + if self.update_v: + if self.update_vx: + self.v1 = v1 + if self.update_vy: + self.v2 = v2 + if self.update_vz: + self.v3 = v3 + + def update_position(self): + if self.update_pos: + if self.update_x: + self.x1 = self.x1 + self.v1 * self.dt + if self.update_y: + self.x2 = self.x2 + self.v2 * self.dt + if self.update_z: + self.x3 = self.x3 + self.v3 * self.dt + + def boundary_conditions(self): + half_x = self.x_wind / 2 + half_y = self.y_wind / 2 + half_z = self.z_wind / 2 + + self.x1 = jnp.where(self.x1 > half_x, self.x1 - self.x_wind, jnp.where(self.x1 < -half_x, self.x1 + self.x_wind, self.x1)) + self.x2 = jnp.where(self.x2 > half_y, self.x2 - self.y_wind, jnp.where(self.x2 < -half_y, self.x2 + self.y_wind, self.x2)) + self.x3 = jnp.where(self.x3 > half_z, self.x3 - self.z_wind, jnp.where(self.x3 < -half_z, self.x3 + self.z_wind, self.x3)) + + def tree_flatten(self): + children = (self.x1, self.x2, self.x3, self.v1, self.v2, self.v3) + aux_data = ( + self.name, + self.N_particles, + self.charge, + self.mass, + self.weight, + self.T, + self.x_wind, + self.y_wind, + self.z_wind, + self.dx, + self.dy, + self.dz, + self.x_bc, + self.y_bc, + self.z_bc, + self.update_pos, + self.update_v, + self.update_x, + self.update_y, + self.update_z, + self.update_vx, + self.update_vy, + self.update_vz, + self.shape, + self.dt, + self.species_meta, + ) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + x1, x2, x3, v1, v2, v3 = children + ( + name, + N_particles, + charge, + mass, + weight, + T, + x_wind, + y_wind, + z_wind, + dx, + dy, + dz, + x_bc, + y_bc, + z_bc, + update_pos, + update_v, + update_x, + update_y, + update_z, + update_vx, + update_vy, + update_vz, + shape, + dt, + species_meta, + ) = aux_data + return cls( + name=name, + N_particles=N_particles, + charge=charge, + mass=mass, + weight=weight, + T=T, + x1=x1, + x2=x2, + x3=x3, + v1=v1, + v2=v2, + v3=v3, + x_wind=x_wind, + y_wind=y_wind, + z_wind=z_wind, + dx=dx, + dy=dy, + dz=dz, + x_bc=x_bc, + y_bc=y_bc, + z_bc=z_bc, + update_pos=update_pos, + update_v=update_v, + update_x=update_x, + update_y=update_y, + update_z=update_z, + update_vx=update_vx, + update_vy=update_vy, + update_vz=update_vz, + shape=shape, + dt=dt, + species_meta=species_meta, + ) + + +def _same(attr_list): + return len(set(attr_list)) == 1 + + +def check_flat_compat(particles): + if not particles: + return False + if not _same([p.get_shape() for p in particles]): + return False + if not _same([p.x_bc for p in particles]) or not _same([p.y_bc for p in particles]) or not _same([p.z_bc for p in particles]): + return False + if particles[0].x_bc != "periodic" or particles[0].y_bc != "periodic" or particles[0].z_bc != "periodic": + return False + if not _same([p.update_pos for p in particles]) or not _same([p.update_v for p in particles]): + return False + return True + + +def to_flat_particles(particles): + species_meta = [] + x_list, y_list, z_list = [], [], [] + vx_list, vy_list, vz_list = [], [], [] + q_list, m_list, w_list, T_list = [], [], [], [] + + for species in particles: + x, y, z = species.get_forward_position() + vx, vy, vz = species.get_velocity() + x_list.append(x) + y_list.append(y) + z_list.append(z) + vx_list.append(vx) + vy_list.append(vy) + vz_list.append(vz) + q_list.append(jnp.full_like(x, species.charge)) + m_list.append(jnp.full_like(x, species.mass)) + w_list.append(jnp.full_like(x, species.weight)) + T_list.append(jnp.full_like(x, species.T)) + species_meta.append( + { + "name": species.name, + "N_particles": float(species.N_particles), + "weight": float(species.weight), + "charge": float(species.charge), + "mass": float(species.mass), + "temperature": float(species.T), + "scaled mass": float(species.get_mass()), + "scaled charge": float(species.get_charge()), + "update_pos": species.update_pos, + "update_v": species.update_v, + } + ) + + x = jnp.concatenate(x_list, axis=0) + y = jnp.concatenate(y_list, axis=0) + z = jnp.concatenate(z_list, axis=0) + vx = jnp.concatenate(vx_list, axis=0) + vy = jnp.concatenate(vy_list, axis=0) + vz = jnp.concatenate(vz_list, axis=0) + charge = jnp.concatenate(q_list, axis=0) + mass = jnp.concatenate(m_list, axis=0) + weight = jnp.concatenate(w_list, axis=0) + T = jnp.concatenate(T_list, axis=0) + + first = particles[0] + flat = flat_particle_species( + name="flat_all", + N_particles=int(x.shape[0]), + charge=charge, + mass=mass, + weight=weight, + T=T, + x1=x, + x2=y, + x3=z, + v1=vx, + v2=vy, + v3=vz, + x_wind=first.x_wind, + y_wind=first.y_wind, + z_wind=first.z_wind, + dx=first.dx, + dy=first.dy, + dz=first.dz, + x_bc=first.x_bc, + y_bc=first.y_bc, + z_bc=first.z_bc, + update_pos=first.update_pos, + update_v=first.update_v, + update_x=first.update_x, + update_y=first.update_y, + update_z=first.update_z, + update_vx=first.update_vx, + update_vy=first.update_vy, + update_vz=first.update_vz, + shape=first.shape, + dt=first.dt, + species_meta=species_meta, + ) + return [flat] + diff --git a/PyPIC3D/indexed_particles.py b/PyPIC3D/indexed_particles.py new file mode 100644 index 0000000..2ccafff --- /dev/null +++ b/PyPIC3D/indexed_particles.py @@ -0,0 +1,356 @@ +import jax +import jax.numpy as jnp +from functools import partial + +from PyPIC3D.utils import wrap_around + + +def _init_index_and_frac(x, xmin, dx, shape_factor): + s = (x - xmin) / dx + if shape_factor == 1: + i0 = jnp.floor(s).astype(jnp.int32) + else: + i0 = jnp.round(s).astype(jnp.int32) + r = s - i0 + return i0, r + + +def _advance_index_and_frac(i0, r, dr, n, shape_factor): + r_new = r + dr + if shape_factor == 1: + shift = jnp.floor(r_new).astype(jnp.int32) + else: + shift = jnp.floor(r_new + 0.5).astype(jnp.int32) + r_new = r_new - shift + i0_new = wrap_around(i0 + shift, n) + return i0_new, r_new + + +@jax.tree_util.register_pytree_node_class +class indexed_particle_species: + def __init__( + self, + name, + N_particles, + charge, + mass, + weight, + T, + v1, + v2, + v3, + i1, + i2, + i3, + r1, + r2, + r3, + x_wind, + y_wind, + z_wind, + dx, + dy, + dz, + xmin, + ymin, + zmin, + x_bc, + y_bc, + z_bc, + update_pos, + update_v, + update_x, + update_y, + update_z, + update_vx, + update_vy, + update_vz, + shape, + dt, + ): + self.name = name + self.N_particles = N_particles + self.charge = charge + self.mass = mass + self.weight = weight + self.T = T + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.i1 = i1 + self.i2 = i2 + self.i3 = i3 + self.r1 = r1 + self.r2 = r2 + self.r3 = r3 + + self.x_wind = x_wind + self.y_wind = y_wind + self.z_wind = z_wind + self.dx = dx + self.dy = dy + self.dz = dz + self.xmin = xmin + self.ymin = ymin + self.zmin = zmin + self.x_bc = x_bc + self.y_bc = y_bc + self.z_bc = z_bc + self.update_pos = update_pos + self.update_v = update_v + self.update_x = update_x + self.update_y = update_y + self.update_z = update_z + self.update_vx = update_vx + self.update_vy = update_vy + self.update_vz = update_vz + self.shape = shape + self.dt = dt + + def get_name(self): + return self.name + + def get_charge(self): + return self.charge * self.weight + + def get_number_of_particles(self): + return self.N_particles + + def get_velocity(self): + return self.v1, self.v2, self.v3 + + def get_mass(self): + return self.mass * self.weight + + def get_shape(self): + return self.shape + + def get_forward_position(self): + x = (self.i1 + self.r1) * self.dx + self.xmin + y = (self.i2 + self.r2) * self.dy + self.ymin + z = (self.i3 + self.r3) * self.dz + self.zmin + return x, y, z + + def get_indexed_position(self): + return self.i1, self.i2, self.i3, self.r1, self.r2, self.r3 + + def set_velocity(self, v1, v2, v3): + if self.update_v: + if self.update_vx: + self.v1 = v1 + if self.update_vy: + self.v2 = v2 + if self.update_vz: + self.v3 = v3 + + def update_position(self, world): + if not self.update_pos: + return + + dt = world["dt"] + if self.update_x: + self.i1, self.r1 = _advance_index_and_frac( + self.i1, + self.r1, + self.v1 * dt / self.dx, + world["Nx"], + self.shape, + ) + if self.update_y: + self.i2, self.r2 = _advance_index_and_frac( + self.i2, + self.r2, + self.v2 * dt / self.dy, + world["Ny"], + self.shape, + ) + if self.update_z: + self.i3, self.r3 = _advance_index_and_frac( + self.i3, + self.r3, + self.v3 * dt / self.dz, + world["Nz"], + self.shape, + ) + + def tree_flatten(self): + children = ( + self.v1, + self.v2, + self.v3, + self.i1, + self.i2, + self.i3, + self.r1, + self.r2, + self.r3, + ) + aux_data = ( + self.name, + self.N_particles, + self.charge, + self.mass, + self.weight, + self.T, + self.x_wind, + self.y_wind, + self.z_wind, + self.dx, + self.dy, + self.dz, + self.xmin, + self.ymin, + self.zmin, + self.x_bc, + self.y_bc, + self.z_bc, + self.update_pos, + self.update_v, + self.update_x, + self.update_y, + self.update_z, + self.update_vx, + self.update_vy, + self.update_vz, + self.shape, + self.dt, + ) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + v1, v2, v3, i1, i2, i3, r1, r2, r3 = children + ( + name, + N_particles, + charge, + mass, + weight, + T, + x_wind, + y_wind, + z_wind, + dx, + dy, + dz, + xmin, + ymin, + zmin, + x_bc, + y_bc, + z_bc, + update_pos, + update_v, + update_x, + update_y, + update_z, + update_vx, + update_vy, + update_vz, + shape, + dt, + ) = aux_data + return cls( + name=name, + N_particles=N_particles, + charge=charge, + mass=mass, + weight=weight, + T=T, + v1=v1, + v2=v2, + v3=v3, + i1=i1, + i2=i2, + i3=i3, + r1=r1, + r2=r2, + r3=r3, + x_wind=x_wind, + y_wind=y_wind, + z_wind=z_wind, + dx=dx, + dy=dy, + dz=dz, + xmin=xmin, + ymin=ymin, + zmin=zmin, + x_bc=x_bc, + y_bc=y_bc, + z_bc=z_bc, + update_pos=update_pos, + update_v=update_v, + update_x=update_x, + update_y=update_y, + update_z=update_z, + update_vx=update_vx, + update_vy=update_vy, + update_vz=update_vz, + shape=shape, + dt=dt, + ) + + +def to_indexed_particles(particles, world): + indexed = [] + xmin, ymin, zmin = world["grids"]["center"][0][0], world["grids"]["center"][1][0], world["grids"]["center"][2][0] + for species in particles: + x, y, z = species.get_forward_position() + shape = species.get_shape() + i1, r1 = _init_index_and_frac(x, xmin, world["dx"], shape) + i2, r2 = _init_index_and_frac(y, ymin, world["dy"], shape) + i3, r3 = _init_index_and_frac(z, zmin, world["dz"], shape) + i1 = wrap_around(i1, world["Nx"]) + i2 = wrap_around(i2, world["Ny"]) + i3 = wrap_around(i3, world["Nz"]) + v1, v2, v3 = species.get_velocity() + + indexed.append( + indexed_particle_species( + name=species.name, + N_particles=species.N_particles, + charge=species.charge, + mass=species.mass, + weight=species.weight, + T=species.T, + v1=v1, + v2=v2, + v3=v3, + i1=i1, + i2=i2, + i3=i3, + r1=r1, + r2=r2, + r3=r3, + x_wind=species.x_wind, + y_wind=species.y_wind, + z_wind=species.z_wind, + dx=species.dx, + dy=species.dy, + dz=species.dz, + xmin=xmin, + ymin=ymin, + zmin=zmin, + x_bc=species.x_bc, + y_bc=species.y_bc, + z_bc=species.z_bc, + update_pos=species.update_pos, + update_v=species.update_v, + update_x=species.update_x, + update_y=species.update_y, + update_z=species.update_z, + update_vx=species.update_vx, + update_vy=species.update_vy, + update_vz=species.update_vz, + shape=shape, + dt=species.dt, + ) + ) + return indexed + + +def check_periodic_bc(particles): + for species in particles: + if species.x_bc != "periodic" or species.y_bc != "periodic" or species.z_bc != "periodic": + return False + return True diff --git a/PyPIC3D/initialization.py b/PyPIC3D/initialization.py index 665a38b..fe85e9b 100644 --- a/PyPIC3D/initialization.py +++ b/PyPIC3D/initialization.py @@ -40,12 +40,18 @@ from PyPIC3D.evolve import ( - time_loop_electrodynamic, time_loop_electrostatic, time_loop_vector_potential + time_loop_electrodynamic, + time_loop_electrodynamic_inline, + time_loop_electrodynamic_indexed, + time_loop_electrostatic, + time_loop_vector_potential, ) from PyPIC3D.J import ( - J_from_rhov, Esirkepov_current + J_from_rhov, Esirkepov_current, J_from_rhov_indexed ) +from PyPIC3D.indexed_particles import to_indexed_particles, check_periodic_bc +from PyPIC3D.flat_particles import to_flat_particles, check_flat_compat from PyPIC3D.solvers.vector_potential import initialize_vector_potential @@ -97,6 +103,8 @@ def default_parameters(): "name": "Default Simulation", "output_dir": os.getcwd(), "solver": "fdtd", # solver: spectral, fdtd, vector_potential, curl_curl + "fast_mode": "off", # off | fp32 | aggressive | extreme (trades accuracy for speed) + "fast_backend": "default", # default | indexed | flat (physics-preserving refactor) "particle_bc": "periodic", # particle boundary conditions: periodic, absorb, reflect # "bc": "periodic", # boundary conditions: periodic, dirichlet, neumann "x_bc": "periodic", # x boundary conditions: periodic, conducting @@ -113,9 +121,11 @@ def default_parameters(): "Nt": None, # number of time steps "electrostatic": False, # boolean for electrostatic simulation "relativistic": True, # boolean for relativistic simulation + "enable_x64": True, # enable 64-bit JAX dtypes (slower but higher precision) "benchmark": False, # boolean for using the profiler "verbose": False, # boolean for printing verbose output "GPUs": False, # boolean for using GPUs + "scan_chunk": 1, # advance multiple steps per dispatch (1 keeps legacy per-step loop) "cfl" : 1.0, # CFL condition number "ds_per_debye" : None, # number of grid spacings per debye length "shape_factor" : 1, # shape factor for the simulation (1 for 1st order, 2 for 2nd order) @@ -248,11 +258,8 @@ def initialize_simulation(toml_file): } # set the simulation world parameters - world = convert_to_jax_compatible(world) - constants = convert_to_jax_compatible(constants) - simulation_parameters = convert_to_jax_compatible(simulation_parameters) - plotting_parameters = convert_to_jax_compatible(plotting_parameters) - # convert the world parameters to jax compatible format + # Keep scalar parameters as Python types so JAX can treat them as static + # (avoids traced metadata in PyTrees and enables compile-time specialization). # if solver == "vector_potential": # B_grid, E_grid = build_collocated_grid(world) @@ -307,8 +314,8 @@ def initialize_simulation(toml_file): # convert the E, B, and J tuples into one big list fields = load_external_fields_from_toml(fields, toml_file) # add any external fields to the simulation - E, B, J = fields[:3], fields[3:6], fields[6:9] - # convert the fields list back into tuples + E, B, J = tuple(fields[:3]), tuple(fields[3:6]), tuple(fields[6:9]) + # convert the fields list back into tuples (JAX scan expects stable PyTree types) if solver == "spectral": curl_func = functools.partial(spectral_curl, world=world) @@ -343,6 +350,22 @@ def initialize_simulation(toml_file): print(f"Using electrodynamic solver with: {solver}") evolve_loop = time_loop_electrodynamic # set the evolve loop function based on the electrostatic flag + fast_backend = simulation_parameters.get("fast_backend", "default") + if fast_backend == "indexed": + if electrostatic or solver == "vector_potential": + raise ValueError("fast_backend='indexed' currently supports electrodynamic solver only.") + if not check_periodic_bc(particles): + raise ValueError("fast_backend='indexed' requires periodic particle boundary conditions.") + evolve_loop = time_loop_electrodynamic_indexed + elif fast_backend == "flat": + if electrostatic or solver == "vector_potential": + raise ValueError("fast_backend='flat' currently supports electrodynamic solver only.") + if not check_flat_compat(particles): + raise ValueError("fast_backend='flat' requires periodic boundary conditions and uniform particle shape.") + evolve_loop = time_loop_electrodynamic_inline + + if simulation_parameters.get("fast_mode", "off") in ("aggressive", "extreme") and fast_backend != "indexed": + evolve_loop = time_loop_electrodynamic_inline if simulation_parameters['current_calculation'] == "esirkepov": print("Using Esirkepov current calculation method") @@ -350,7 +373,22 @@ def initialize_simulation(toml_file): J_func = Esirkepov_current elif simulation_parameters['current_calculation'] == "j_from_rhov": print(f"Using J from rhov current calculation method with filter: {simulation_parameters['filter_j']}") - J_func = functools.partial(J_from_rhov, filter=simulation_parameters['filter_j']) + if fast_backend == "indexed": + J_func = functools.partial( + J_from_rhov_indexed, + filter=simulation_parameters["filter_j"], + ) + else: + J_func = functools.partial( + J_from_rhov, + filter=simulation_parameters["filter_j"], + shape_factor=int(simulation_parameters["shape_factor"]), + ) + + if fast_backend == "indexed": + particles = to_indexed_particles(particles, world) + elif fast_backend == "flat": + particles = to_flat_particles(particles) if solver == "vector_potential": diff --git a/PyPIC3D/particle.py b/PyPIC3D/particle.py index 4655277..2aa6b19 100644 --- a/PyPIC3D/particle.py +++ b/PyPIC3D/particle.py @@ -664,9 +664,12 @@ def boundary_conditions(self): x1, x2, x3 = self.x1, self.x2, self.x3 v1, v2, v3 = self.v1, self.v2, self.v3 - x1, v1 = apply_axis_boundary_condition(x1, v1, self.x_wind, self.half_x_wind, self.x_periodic, self.x_reflecting) - x2, v2 = apply_axis_boundary_condition(x2, v2, self.y_wind, self.half_y_wind, self.y_periodic, self.y_reflecting) - x3, v3 = apply_axis_boundary_condition(x3, v3, self.z_wind, self.half_z_wind, self.z_periodic, self.z_reflecting) + if self.update_x: + x1, v1 = apply_axis_boundary_condition(x1, v1, self.x_wind, self.half_x_wind, self.x_periodic, self.x_reflecting) + if self.update_y: + x2, v2 = apply_axis_boundary_condition(x2, v2, self.y_wind, self.half_y_wind, self.y_periodic, self.y_reflecting) + if self.update_z: + x3, v3 = apply_axis_boundary_condition(x3, v3, self.z_wind, self.half_z_wind, self.z_periodic, self.z_reflecting) self.x1, self.x2, self.x3 = x1, x2, x3 self.v1, self.v2, self.v3 = v1, v2, v3 diff --git a/PyPIC3D/solvers/first_order_yee.py b/PyPIC3D/solvers/first_order_yee.py index ce24825..775ce1c 100644 --- a/PyPIC3D/solvers/first_order_yee.py +++ b/PyPIC3D/solvers/first_order_yee.py @@ -60,21 +60,19 @@ def update_E(E, B, J, world, constants, curl_func): eps = constants['eps'] # get the time resolution and necessary constants - Bx = jnp.pad(Bx, ((1,1), (1,1), (1,1)), mode="wrap") - By = jnp.pad(By, ((1,1), (1,1), (1,1)), mode="wrap") - Bz = jnp.pad(Bz, ((1,1), (1,1), (1,1)), mode="wrap") - # pad the magnetic field components for periodic boundary conditions - - dBz_dy = (jnp.roll(Bz, shift=-1, axis=1) - Bz) / dy - dBx_dy = (jnp.roll(Bx, shift=-1, axis=1) - Bx) / dy - dBy_dz = (jnp.roll(By, shift=-1, axis=2) - By) / dz - dBx_dz = (jnp.roll(Bx, shift=-1, axis=2) - Bx) / dz + Ny = Ex.shape[1] + Nz = Ex.shape[2] + + dBz_dy = (jnp.roll(Bz, shift=-1, axis=1) - Bz) / dy if Ny != 1 else 0.0 + dBx_dy = (jnp.roll(Bx, shift=-1, axis=1) - Bx) / dy if Ny != 1 else 0.0 + dBy_dz = (jnp.roll(By, shift=-1, axis=2) - By) / dz if Nz != 1 else 0.0 + dBx_dz = (jnp.roll(Bx, shift=-1, axis=2) - Bx) / dz if Nz != 1 else 0.0 dBz_dx = (jnp.roll(Bz, shift=-1, axis=0) - Bz) / dx dBy_dx = (jnp.roll(By, shift=-1, axis=0) - By) / dx - curl_x = (dBz_dy - dBy_dz)[1:-1,1:-1,1:-1] - curl_y = (dBx_dz - dBz_dx)[1:-1,1:-1,1:-1] - curl_z = (dBy_dx - dBx_dy)[1:-1,1:-1,1:-1] + curl_x = dBz_dy - dBy_dz + curl_y = dBx_dz - dBz_dx + curl_z = dBy_dx - dBx_dy # calculate the curl of the magnetic field Ex = Ex + ( C**2 * curl_x - Jx / eps ) * dt @@ -180,21 +178,19 @@ def update_B(E, B, world, constants, curl_func): Bx, By, Bz = B # unpack the E and B fields - Ex = jnp.pad(Ex, ((1,1), (1,1), (1,1)), mode="wrap") - Ey = jnp.pad(Ey, ((1,1), (1,1), (1,1)), mode="wrap") - Ez = jnp.pad(Ez, ((1,1), (1,1), (1,1)), mode="wrap") - # pad the electric field components for periodic boundary conditions + Ny = Ex.shape[1] + Nz = Ex.shape[2] - dEz_dy = (Ez - jnp.roll(Ez, shift=1, axis=1)) / dy - dEx_dy = (Ex - jnp.roll(Ex, shift=1, axis=1)) / dy - dEy_dz = (Ey - jnp.roll(Ey, shift=1, axis=2)) / dz - dEx_dz = (Ex - jnp.roll(Ex, shift=1, axis=2)) / dz + dEz_dy = (Ez - jnp.roll(Ez, shift=1, axis=1)) / dy if Ny != 1 else 0.0 + dEx_dy = (Ex - jnp.roll(Ex, shift=1, axis=1)) / dy if Ny != 1 else 0.0 + dEy_dz = (Ey - jnp.roll(Ey, shift=1, axis=2)) / dz if Nz != 1 else 0.0 + dEx_dz = (Ex - jnp.roll(Ex, shift=1, axis=2)) / dz if Nz != 1 else 0.0 dEz_dx = (Ez - jnp.roll(Ez, shift=1, axis=0)) / dx dEy_dx = (Ey - jnp.roll(Ey, shift=1, axis=0)) / dx - curl_x = (dEz_dy - dEy_dz)[1:-1,1:-1,1:-1] - curl_y = (dEx_dz - dEz_dx)[1:-1,1:-1,1:-1] - curl_z = (dEy_dx - dEx_dy)[1:-1,1:-1,1:-1] + curl_x = dEz_dy - dEy_dz + curl_y = dEx_dz - dEz_dx + curl_z = dEy_dx - dEx_dy # calculate the curl of the electric field Bx = Bx - dt*curl_x @@ -209,4 +205,3 @@ def update_B(E, B, world, constants, curl_func): # apply a digital filter to the magnetic field components return (Bx, By, Bz) - diff --git a/PyPIC3D/utils.py b/PyPIC3D/utils.py index 1938e8d..d35a072 100644 --- a/PyPIC3D/utils.py +++ b/PyPIC3D/utils.py @@ -2,7 +2,7 @@ import plotly import tqdm import pyevtk -from jax import jit +from jax import jit, lax import argparse import jax.numpy as jnp import functools @@ -39,24 +39,38 @@ def wrap_around(ix, size): @jit def bilinear_filter(phi, mode="wrap"): """ - Apply a 3D (tri-linear) smoothing filter to a 3D array using a separable - [1, 2, 1]/4 kernel in each dimension. + Apply a tri-linear smoothing filter using a separable [1, 2, 1]/4 kernel + in each spatial dimension. Args: - phi (jnp.ndarray): 3D field array with shape (Nx, Ny, Nz). + phi (jnp.ndarray): Field array with leading spatial shape (Nx, Ny, Nz). + Any trailing feature dimensions are preserved (e.g. (Nx, Ny, Nz, 3)). mode (str): Padding mode passed to jnp.pad (default: "wrap"). Returns: jnp.ndarray: Filtered array with the same shape as phi. """ - k1 = jnp.array([1.0, 2.0, 1.0], dtype=phi.dtype) / 4.0 # sums to 1 - k3 = k1[:, None, None] * k1[None, :, None] * k1[None, None, :] # (3,3,3), sums to 1 + if mode == "wrap": + quarter = jnp.asarray(0.25, dtype=phi.dtype) + + def smooth_axis(arr, axis): + return (jnp.roll(arr, 1, axis=axis) + 2 * arr + jnp.roll(arr, -1, axis=axis)) * quarter + + phi = smooth_axis(phi, 0) + phi = smooth_axis(phi, 1) + phi = smooth_axis(phi, 2) + return phi + + if phi.ndim != 3: + raise ValueError("bilinear_filter only supports non-wrap mode for 3D arrays.") + + k1 = jnp.array([1.0, 2.0, 1.0], dtype=phi.dtype) / 4.0 + k3 = k1[:, None, None] * k1[None, :, None] * k1[None, None, :] kernel = jnp.zeros((3, 3, 3, 1, 1), dtype=phi.dtype) kernel = kernel.at[:, :, :, 0, 0].set(k3) padded_phi = jnp.pad(phi, ((1, 1), (1, 1), (1, 1)), mode=mode) - filtered = jax.lax.conv_general_dilated( padded_phi[jnp.newaxis, ..., jnp.newaxis], kernel, @@ -79,25 +93,22 @@ def digital_filter(phi, alpha): Returns: ndarray: Filtered field array. """ - neighbor_weight = (1 - alpha) / 6 - kernel = jnp.zeros((3, 3, 3, 1, 1), dtype=phi.dtype) - kernel = kernel.at[1, 1, 1, 0, 0].set(alpha) - kernel = kernel.at[0, 1, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[2, 1, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 0, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 2, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 1, 0, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 1, 2, 0, 0].set(neighbor_weight) - - padded_phi = jnp.pad(phi, ((1, 1), (1, 1), (1, 1)), mode="wrap") - filtered = jax.lax.conv_general_dilated( - padded_phi[jnp.newaxis, ..., jnp.newaxis], - kernel, - window_strides=(1, 1, 1), - padding="VALID", - dimension_numbers=("NDHWC", "DHWIO", "NDHWC"), - ) - return jnp.squeeze(filtered, axis=(0, 4)) + def apply(phi): + neighbor_weight = (1 - alpha) / 6 + return ( + alpha * phi + + neighbor_weight + * ( + jnp.roll(phi, 1, axis=0) + + jnp.roll(phi, -1, axis=0) + + jnp.roll(phi, 1, axis=1) + + jnp.roll(phi, -1, axis=1) + + jnp.roll(phi, 1, axis=2) + + jnp.roll(phi, -1, axis=2) + ) + ) + + return lax.cond(alpha == 1.0, lambda phi: phi, apply, phi) def mae(x, y): """ @@ -222,11 +233,10 @@ def nd_trapezoid(arr, dxs): mass = species.get_mass() vx, vy, vz = species.get_velocity() v2 = vx**2 + vy**2 + vz**2 - gamma = 1.0 / jnp.sqrt(1 - v2 / C**2) - momentum2 = jnp.square(mass * gamma ) * v2 - # compute the squared momentum for each particle - KE = jnp.sum( jnp.sqrt( momentum2 * C**2 + mass**2 * C**4) - mass * C**2 ) - # compute the kinetic energy for this species + # Relativistic KE: use m c^2 (gamma - 1) to avoid catastrophic cancellation, + # especially in fp32 fast modes. + gamma = 1.0 / jnp.sqrt(jnp.maximum(1.0 - v2 / C**2, 0.0)) + KE = jnp.sum(mass * C**2 * (gamma - 1.0)) kinetic_energy += KE # add to total kinetic energy @@ -696,6 +706,9 @@ def dump_parameters_to_toml(simulation_stats, simulation_parameters, plasma_para } for particle in particles: + if hasattr(particle, "species_meta") and particle.species_meta: + config["particles"].extend(particle.species_meta) + continue particle_dict = { "name": particle.name, "N_particles": float(particle.N_particles),