From c85ec0bb5298367246a2856df95d0c5da2c46b3c Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sat, 14 Mar 2026 16:12:44 -0400 Subject: [PATCH 1/2] CPU fastpath: cheaper stencils, fewer allocations --- PyPIC3D/J.py | 134 +++++++++----------- PyPIC3D/__main__.py | 118 +++++++++++++---- PyPIC3D/boris.py | 197 +++++++++++++++++++++-------- PyPIC3D/evolve.py | 6 +- PyPIC3D/initialization.py | 13 +- PyPIC3D/solvers/first_order_yee.py | 23 +--- PyPIC3D/utils.py | 54 ++++---- 7 files changed, 341 insertions(+), 204 deletions(-) diff --git a/PyPIC3D/J.py b/PyPIC3D/J.py index 2a8620d..fc073a7 100644 --- a/PyPIC3D/J.py +++ b/PyPIC3D/J.py @@ -29,24 +29,17 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): dx = world['dx'] dy = world['dy'] dz = world['dz'] - Nx = world['Nx'] - Ny = world['Ny'] - Nz = world['Nz'] - # get the world parameters - Jx, Jy, Jz = J + Nx, Ny, Nz = Jx.shape + # get the world parameters x_active = Jx.shape[0] != 1 y_active = Jx.shape[1] != 1 z_active = Jx.shape[2] != 1 # infer effective dimensionality from the current-grid shape - # unpack the values of J - Jx = Jx.at[:, :, :].set(0) - Jy = Jy.at[:, :, :].set(0) - Jz = Jz.at[:, :, :].set(0) - # initialize the current arrays as 0 - J = (Jx, Jy, Jz) - # initialize the current density as a tuple + J_stack = jnp.stack((Jx, Jy, Jz), axis=-1) + J_stack = jnp.zeros_like(J_stack) + # keep J together so deposition and filtering can be fused across components for species in particles: shape_factor = species.get_shape() @@ -62,24 +55,14 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): z = z - vz * world['dt'] / 2 # step back to half time step positions for proper time staggering - x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (x - grid[0][0]) / dx).astype(int), - lambda _: jnp.round( (x - grid[0][0]) / dx).astype(int), - operand=None - ) - y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (y - grid[1][0]) / dy).astype(int), - lambda _: jnp.round( (y - grid[1][0]) / dy).astype(int), - operand=None - ) - z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor( (z - grid[2][0]) / dz).astype(int), - lambda _: jnp.round( (z - grid[2][0]) / dz).astype(int), - operand=None - ) + if shape_factor == 1: + x0 = jnp.floor((x - grid[0][0]) / dx).astype(jnp.int32) + y0 = jnp.floor((y - grid[1][0]) / dy).astype(jnp.int32) + z0 = jnp.floor((z - grid[2][0]) / dz).astype(jnp.int32) + else: + x0 = jnp.round((x - grid[0][0]) / dx).astype(jnp.int32) + y0 = jnp.round((y - grid[1][0]) / dy).astype(jnp.int32) + z0 = jnp.round((z - grid[2][0]) / dz).astype(jnp.int32) # calculate the nearest grid point based on shape factor deltax_node = (x - grid[0][0]) - (x0 * dx) @@ -109,19 +92,12 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): zpts = [z_minus1, z0, z1] # place all the points in a list - x_weights_node, y_weights_node, z_weights_node = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights( deltax_node, deltay_node, deltaz_node, dx, dy, dz), - lambda _: get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz), - operand=None - ) - - x_weights_face, y_weights_face, z_weights_face = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights( deltax_face, deltay_face, deltaz_face, dx, dy, dz), - lambda _: get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz), - operand=None - ) + if shape_factor == 1: + x_weights_node, y_weights_node, z_weights_node = get_first_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz) + x_weights_face, y_weights_face, z_weights_face = get_first_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz) + else: + x_weights_node, y_weights_node, z_weights_node = get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz) + x_weights_face, y_weights_face, z_weights_face = get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz) # get the weights for node and face positions xpts = jnp.asarray(xpts) # (Sx, Np) @@ -136,6 +112,18 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): y_weights_node = jnp.asarray(y_weights_node) # (Sy, Np) z_weights_node = jnp.asarray(z_weights_node) # (Sz, Np) + if shape_factor == 1: + # drop the redundant (-1) stencil point for first-order (its weights are identically 0) + xpts = xpts[1:, ...] + ypts = ypts[1:, ...] + zpts = zpts[1:, ...] + x_weights_face = x_weights_face[1:, ...] + y_weights_face = y_weights_face[1:, ...] + z_weights_face = z_weights_face[1:, ...] + x_weights_node = x_weights_node[1:, ...] + y_weights_node = y_weights_node[1:, ...] + z_weights_node = z_weights_node[1:, ...] + # Keep full shape-factor computation but collapse inactive axes to an # effective stencil of size 1 to avoid redundant deposition work. if x_active: @@ -184,41 +172,41 @@ def idx_and_dJ_values(idx): valy = (dq * vy) * x_weights_node_eff[i, ...] * y_weights_face_eff[j, ...] * z_weights_node_eff[k, ...] valz = (dq * vz) * x_weights_node_eff[i, ...] * y_weights_node_eff[j, ...] * z_weights_face_eff[k, ...] # calculate the current contributions for this stencil point - return ix, iy, iz, valx, valy, valz + return ix, iy, iz, jnp.stack((valx, valy, valz), axis=-1) - ix, iy, iz, valx, valy, valz = jax.vmap(idx_and_dJ_values)(combos) # each: (M, Np) + ix, iy, iz, dJ = jax.vmap(idx_and_dJ_values)(combos) # (M,Np), (M,Np), (M,Np), (M,Np,3) # vectorized computation of indices and current contributions - Jx = Jx.at[(ix, iy, iz)].add(valx, mode="drop") - Jy = Jy.at[(ix, iy, iz)].add(valy, mode="drop") - Jz = Jz.at[(ix, iy, iz)].add(valz, mode="drop") - # deposit the current contributions into the global J arrays - - def filter_func(J_, filter): - J_ = jax.lax.cond( - filter == 'bilinear', - lambda J_: bilinear_filter(J_), - lambda J_: J_, - operand=J_ + ix_flat = ix.reshape(-1) + iy_flat = iy.reshape(-1) + iz_flat = iz.reshape(-1) + dJ_flat = dJ.reshape(-1, 3) + + in_bounds = ( + (ix_flat >= 0) + & (ix_flat < Nx) + & (iy_flat >= 0) + & (iy_flat < Ny) + & (iz_flat >= 0) + & (iz_flat < Nz) ) - # alpha = constants['alpha'] - # J_ = jax.lax.cond( - # filter == 'digital', - # lambda J_: digital_filter(J_, alpha), - # lambda J_: J_, - # operand=J_ - # ) - return J_ - # define a filtering function - - Jx = filter_func(Jx, filter) - Jy = filter_func(Jy, filter) - Jz = filter_func(Jz, filter) - # apply the selected filter to each component of J - J = (Jx, Jy, Jz) - - return J + ix_flat = jnp.clip(ix_flat, 0, Nx - 1) + iy_flat = jnp.clip(iy_flat, 0, Ny - 1) + iz_flat = jnp.clip(iz_flat, 0, Nz - 1) + + idx_flat = ix_flat + Nx * (iy_flat + Ny * iz_flat) + dJ_flat = jnp.where(in_bounds[:, None], dJ_flat, 0) + + J_flat = jax.ops.segment_sum(dJ_flat, idx_flat, num_segments=Nx * Ny * Nz) + J_stack = J_stack + J_flat.reshape((Nx, Ny, Nz, 3)) + # segment_sum avoids large scatter updates on CPU + + if filter == "bilinear": + J_stack = bilinear_filter(J_stack) + # (optional) digital filter disabled by default + + return (J_stack[..., 0], J_stack[..., 1], J_stack[..., 2]) def _roll_old_weights_to_new_frame(old_w_list, shift): """ diff --git a/PyPIC3D/__main__.py b/PyPIC3D/__main__.py index 775213c..3541504 100644 --- a/PyPIC3D/__main__.py +++ b/PyPIC3D/__main__.py @@ -8,7 +8,7 @@ import os import time import jax -from jax import block_until_ready +from jax import block_until_ready, lax import jax.numpy as jnp from tqdm import tqdm @@ -23,10 +23,6 @@ write_openpmd_particles, write_openpmd_fields ) -from PyPIC3D.diagnostics.vtk import ( - plot_field_slice_vtk, plot_vectorfield_slice_vtk, plot_vtk_particles -) - from PyPIC3D.utils import ( dump_parameters_to_toml, load_config_file, compute_energy, setup_pmd_files @@ -53,8 +49,6 @@ def run_PyPIC3D(config_file): loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(config_file) # initialize the simulation - jit_loop = jax.jit(loop, static_argnames=('curl_func', 'J_func', 'solver', 'relativistic')) - dt = world['dt'] output_dir = simulation_parameters['output_dir'] vertex_grid = world['grids']['vertex'] @@ -77,12 +71,43 @@ def run_PyPIC3D(config_file): ###################################################### SIMULATION LOOP ##################################### - for t in tqdm(range(Nt)): + scan_chunk = int(simulation_parameters.get("scan_chunk", 1)) + plotting_interval = int(plotting_parameters["plotting_interval"]) + + if scan_chunk < 1: + raise ValueError("simulation_parameters.scan_chunk must be >= 1") + + if scan_chunk > 1 and (plotting_interval % scan_chunk) != 0: + raise ValueError( + f"scan_chunk={scan_chunk} requires plotting_interval to be a multiple of scan_chunk " + f"(got plotting_interval={plotting_interval})." + ) + + def make_advance(n_steps): + def advance(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): + def body(_, state): + p, f = state + return loop(p, f, world, constants, curl_func, J_func, solver, relativistic=relativistic) + + return lax.fori_loop(0, n_steps, body, (particles, fields)) + + return jax.jit( + advance, + static_argnames=("curl_func", "J_func", "solver", "relativistic"), + donate_argnums=(0, 1), + ) + + advance_full = make_advance(scan_chunk) if scan_chunk > 1 else None + tail = Nt % scan_chunk + advance_tail = make_advance(tail) if (scan_chunk > 1 and tail) else None + + step_iter = range(0, Nt, scan_chunk) if scan_chunk > 1 else range(Nt) + for t in tqdm(step_iter): # plot the data - if t % plotting_parameters['plotting_interval'] == 0: + if t % plotting_interval == 0: - plot_num = t // plotting_parameters['plotting_interval'] + plot_num = t // plotting_interval # determine the plot number E, B, J, rho, *rest = fields @@ -111,6 +136,11 @@ def run_PyPIC3D(config_file): if plotting_parameters['plot_vtk_scalars']: + try: + from PyPIC3D.diagnostics.vtk import plot_field_slice_vtk + except ModuleNotFoundError as e: + raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e + rho = compute_rho(particles, rho, world, constants) # calculate the charge density based on the particle positions mass_density = compute_mass_density(particles, rho, world) @@ -122,6 +152,11 @@ def run_PyPIC3D(config_file): if plotting_parameters['plot_vtk_vectors']: + try: + from PyPIC3D.diagnostics.vtk import plot_vectorfield_slice_vtk + except ModuleNotFoundError as e: + raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e + vector_field_slices = [ [E[0][:,world['Ny']//2,:], E[1][:,world['Ny']//2,:], E[2][:,world['Ny']//2,:]], [B[0][:,world['Ny']//2,:], B[1][:,world['Ny']//2,:], B[2][:,world['Ny']//2,:]], [J[0][:,world['Ny']//2,:], J[1][:,world['Ny']//2,:], J[2][:,world['Ny']//2,:]]] @@ -129,6 +164,11 @@ def run_PyPIC3D(config_file): # Plot the vector fields in VTK format if plotting_parameters['plot_vtk_particles']: + try: + from PyPIC3D.diagnostics.vtk import plot_vtk_particles + except ModuleNotFoundError as e: + raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e + plot_vtk_particles(particles, plot_num, output_dir) # Plot the particles in VTK format @@ -143,25 +183,54 @@ def run_PyPIC3D(config_file): fields = (E, B, J, rho, *rest) # repack the fields - particles, fields = jit_loop( - particles, - fields, - world, - constants, - curl_func, - J_func, - solver, - relativistic=relativistic, - ) - # time loop to update the particles and fields + if scan_chunk == 1: + particles, fields = loop( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + else: + if (t + scan_chunk) <= Nt: + particles, fields = advance_full( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + else: + particles, fields = advance_tail( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + # advance the particles and fields return Nt, plotting_parameters, simulation_parameters, plasma_parameters, constants, particles, fields, world def main(): ###################### JAX SETTINGS ######################################################################## - jax.config.update("jax_enable_x64", True) - # set Jax to use 64 bit precision + toml_file = load_config_file() + # load the configuration file + + enable_x64 = bool(toml_file.get("simulation_parameters", {}).get("enable_x64", True)) + jax.config.update("jax_enable_x64", enable_x64) + # set Jax precision (default preserves legacy behavior) + # jax.config.update("jax_debug_nans", True) # debugging for nans jax.config.update('jax_platform_name', 'cpu') @@ -169,9 +238,6 @@ def main(): #jax.config.update("jax_disable_jit", True) ############################################################################################################ - toml_file = load_config_file() - # load the configuration file - start = time.time() # start the timer diff --git a/PyPIC3D/boris.py b/PyPIC3D/boris.py index 49405c8..d2b08c1 100644 --- a/PyPIC3D/boris.py +++ b/PyPIC3D/boris.py @@ -1,6 +1,7 @@ import jax from jax import jit import jax.numpy as jnp +from functools import partial from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights from PyPIC3D.utils import wrap_around @@ -63,17 +64,40 @@ def particle_push(particles, E, B, grid, staggered_grid, dt, constants, periodic #################### BORIS ALGORITHM #################################### - boris_vmap = jax.vmap(boris_single_particle, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None)) - relativistic_boris_vmap = jax.vmap(relativistic_boris_single_particle, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None)) - # vectorize the Boris algorithm for batch processing - newvx, newvy, newvz = jax.lax.cond( relativistic == True, - lambda _: relativistic_boris_vmap(vx, vy, vz, efield_atx, efield_aty, efield_atz, bfield_atx, bfield_aty, bfield_atz, q, m, dt, constants), - lambda _: boris_vmap(vx, vy, vz, efield_atx, efield_aty, efield_atz, bfield_atx, bfield_aty, bfield_atz, q, m, dt, constants), - operand=None + lambda _: relativistic_boris_push( + vx, + vy, + vz, + efield_atx, + efield_aty, + efield_atz, + bfield_atx, + bfield_aty, + bfield_atz, + q, + m, + dt, + constants, + ), + lambda _: boris_push( + vx, + vy, + vz, + efield_atx, + efield_aty, + efield_atz, + bfield_atx, + bfield_aty, + bfield_atz, + q, + m, + dt, + ), + operand=None, ) - # apply the Boris algorithm to update the velocities of the particles + # apply the Boris algorithm (vectorized over particles) ######################################################################### @@ -81,6 +105,84 @@ def particle_push(particles, E, B, grid, staggered_grid, dt, constants, periodic # set the new velocities of the particles return particles + +def boris_push(vx, vy, vz, ex, ey, ez, bx, by, bz, q, m, dt): + qmdt2 = q * dt / (2 * m) + + vminus_x = vx + qmdt2 * ex + vminus_y = vy + qmdt2 * ey + vminus_z = vz + qmdt2 * ez + + t_x = qmdt2 * bx + t_y = qmdt2 * by + t_z = qmdt2 * bz + + t2 = t_x * t_x + t_y * t_y + t_z * t_z + inv = 1.0 / (1.0 + t2) + s_x = 2.0 * t_x * inv + s_y = 2.0 * t_y * inv + s_z = 2.0 * t_z * inv + + vprime_x = vminus_x + (vminus_y * t_z - vminus_z * t_y) + vprime_y = vminus_y + (vminus_z * t_x - vminus_x * t_z) + vprime_z = vminus_z + (vminus_x * t_y - vminus_y * t_x) + + vplus_x = vminus_x + (vprime_y * s_z - vprime_z * s_y) + vplus_y = vminus_y + (vprime_z * s_x - vprime_x * s_z) + vplus_z = vminus_z + (vprime_x * s_y - vprime_y * s_x) + + newvx = vplus_x + qmdt2 * ex + newvy = vplus_y + qmdt2 * ey + newvz = vplus_z + qmdt2 * ez + + return newvx, newvy, newvz + + +def relativistic_boris_push(vx, vy, vz, ex, ey, ez, bx, by, bz, q, m, dt, constants): + C = constants["C"] + qmdt2 = q * dt / (2 * m) + + v2_over_c2 = (vx * vx + vy * vy + vz * vz) / (C * C) + gamma = 1.0 / jnp.sqrt(1.0 - v2_over_c2) + + uminus_x = vx * gamma + qmdt2 * ex + uminus_y = vy * gamma + qmdt2 * ey + uminus_z = vz * gamma + qmdt2 * ez + + uminus2_over_c2 = (uminus_x * uminus_x + uminus_y * uminus_y + uminus_z * uminus_z) / (C * C) + gamma_minus = jnp.sqrt(1.0 + uminus2_over_c2) + + t_x = (qmdt2 * bx) / gamma_minus + t_y = (qmdt2 * by) / gamma_minus + t_z = (qmdt2 * bz) / gamma_minus + + t2 = t_x * t_x + t_y * t_y + t_z * t_z + inv = 1.0 / (1.0 + t2) + s_x = 2.0 * t_x * inv + s_y = 2.0 * t_y * inv + s_z = 2.0 * t_z * inv + + uprime_x = uminus_x + (uminus_y * t_z - uminus_z * t_y) + uprime_y = uminus_y + (uminus_z * t_x - uminus_x * t_z) + uprime_z = uminus_z + (uminus_x * t_y - uminus_y * t_x) + + uplus_x = uminus_x + (uprime_y * s_z - uprime_z * s_y) + uplus_y = uminus_y + (uprime_z * s_x - uprime_x * s_z) + uplus_z = uminus_z + (uprime_x * s_y - uprime_y * s_x) + + newu_x = uplus_x + qmdt2 * ex + newu_y = uplus_y + qmdt2 * ey + newu_z = uplus_z + qmdt2 * ez + + newu2_over_c2 = (newu_x * newu_x + newu_y * newu_y + newu_z * newu_z) / (C * C) + new_gamma = jnp.sqrt(1.0 + newu2_over_c2) + + newvx = newu_x / new_gamma + newvy = newu_y / new_gamma + newvz = newu_z / new_gamma + + return newvx, newvy, newvz + @jit def boris_single_particle(vx, vy, vz, efield_atx, efield_aty, efield_atz, bfield_atx, bfield_aty, bfield_atz, q, m, dt, constants): """ @@ -194,7 +296,7 @@ def relativistic_boris_single_particle(vx, vy, vz, efield_atx, efield_aty, efiel return newv[0], newv[1], newv[2] -@jit +@partial(jit, static_argnames=("shape_factor",)) def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): """ Interpolate a Yee-grid field component to particle positions using PIC shape functions. @@ -227,24 +329,14 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): dz = z_grid[1] - z_grid[0] if Nz > 1 else 1.0 # grid spacing in each direction - x0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor((x - xmin) / dx).astype(int), - lambda _: jnp.round((x - xmin) / dx).astype(int), - operand=None, - ) - y0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor((y - ymin) / dy).astype(int), - lambda _: jnp.round((y - ymin) / dy).astype(int), - operand=None, - ) - z0 = jax.lax.cond( - shape_factor == 1, - lambda _: jnp.floor((z - zmin) / dz).astype(int), - lambda _: jnp.round((z - zmin) / dz).astype(int), - operand=None, - ) + if shape_factor == 1: + x0 = jnp.floor((x - xmin) / dx).astype(jnp.int32) + y0 = jnp.floor((y - ymin) / dy).astype(jnp.int32) + z0 = jnp.floor((z - zmin) / dz).astype(jnp.int32) + else: + x0 = jnp.round((x - xmin) / dx).astype(jnp.int32) + y0 = jnp.round((y - ymin) / dy).astype(jnp.int32) + z0 = jnp.round((z - zmin) / dz).astype(jnp.int32) # compute the stencil anchor points (cell-left for first order, nearest node for second order) deltax = (x - xmin) - x0 * dx @@ -252,12 +344,10 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): deltaz = (z - zmin) - z0 * dz # determine the distance from the closest grid nodes - x_weights, y_weights, z_weights = jax.lax.cond( - shape_factor == 1, - lambda _: get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz), - lambda _: get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz), - operand=None, - ) + if shape_factor == 1: + x_weights, y_weights, z_weights = get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz) + else: + x_weights, y_weights, z_weights = get_second_order_weights(deltax, deltay, deltaz, dx, dy, dz) x_weights = jnp.asarray(x_weights) y_weights = jnp.asarray(y_weights) z_weights = jnp.asarray(z_weights) @@ -280,6 +370,15 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): zpts = jnp.asarray([z_minus1, z0, z1]) # place all the points in a list + if shape_factor == 1: + # drop the redundant (-1) stencil point for first-order (its weights are identically 0) + xpts = xpts[1:, ...] + ypts = ypts[1:, ...] + zpts = zpts[1:, ...] + x_weights = x_weights[1:, ...] + y_weights = y_weights[1:, ...] + z_weights = z_weights[1:, ...] + # Keep full shape-factor computation but collapse inactive axes to an # effective stencil size of 1 to avoid redundant interpolation work. if x_active: @@ -303,26 +402,14 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): zpts_eff = jnp.zeros((1, zpts.shape[1]), dtype=zpts.dtype) z_weights_eff = jnp.sum(z_weights, axis=0, keepdims=True) - def stencil_contribution(stencil_idx): - i, j, k = stencil_idx - return ( - field[xpts_eff[i, ...], ypts_eff[j, ...], zpts_eff[k, ...]] - * x_weights_eff[i, ...] - * y_weights_eff[j, ...] - * z_weights_eff[k, ...] - ) - # define a function to compute the contribution from each point in the effective stencil - - ii, jj, kk = jnp.meshgrid( - jnp.arange(xpts_eff.shape[0]), - jnp.arange(ypts_eff.shape[0]), - jnp.arange(zpts_eff.shape[0]), - indexing="ij", + field_vals = field[ + xpts_eff[:, None, None, :], + ypts_eff[None, :, None, :], + zpts_eff[None, None, :, :], + ] + weights = ( + x_weights_eff[:, None, None, :] + * y_weights_eff[None, :, None, :] + * z_weights_eff[None, None, :, :] ) - stencil_indicies = jnp.stack([ii.ravel(), jj.ravel(), kk.ravel()], axis=1) - # build effective stencil indices with shape (Sx*Sy*Sz, 3) - - interpolated_field = jnp.sum(jax.vmap(stencil_contribution)(stencil_indicies), axis=0) - # sum the contributions from all stencil points to get the final interpolated field value at each particle position - - return interpolated_field + return jnp.sum(field_vals * weights, axis=(0, 1, 2)) diff --git a/PyPIC3D/evolve.py b/PyPIC3D/evolve.py index 5cdacff..e0525b2 100644 --- a/PyPIC3D/evolve.py +++ b/PyPIC3D/evolve.py @@ -22,7 +22,7 @@ E_from_A, B_from_A, update_vector_potential ) -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_electrostatic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advances the simulation by one time step for an electrostatic Particle-In-Cell (PIC) loop. @@ -74,7 +74,7 @@ def time_loop_electrostatic(particles, fields, world, constants, curl_func, J_fu return particles, fields -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advance an electrodynamic Particle-In-Cell (PIC) system by one time step. @@ -154,7 +154,7 @@ def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_f return particles, fields -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic")) +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_vector_potential(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advance a PIC (Particle-In-Cell) simulation by one time step using a diff --git a/PyPIC3D/initialization.py b/PyPIC3D/initialization.py index 665a38b..03f2843 100644 --- a/PyPIC3D/initialization.py +++ b/PyPIC3D/initialization.py @@ -113,9 +113,11 @@ def default_parameters(): "Nt": None, # number of time steps "electrostatic": False, # boolean for electrostatic simulation "relativistic": True, # boolean for relativistic simulation + "enable_x64": True, # enable 64-bit JAX dtypes (slower but higher precision) "benchmark": False, # boolean for using the profiler "verbose": False, # boolean for printing verbose output "GPUs": False, # boolean for using GPUs + "scan_chunk": 1, # advance multiple steps per dispatch (1 keeps legacy per-step loop) "cfl" : 1.0, # CFL condition number "ds_per_debye" : None, # number of grid spacings per debye length "shape_factor" : 1, # shape factor for the simulation (1 for 1st order, 2 for 2nd order) @@ -248,11 +250,8 @@ def initialize_simulation(toml_file): } # set the simulation world parameters - world = convert_to_jax_compatible(world) - constants = convert_to_jax_compatible(constants) - simulation_parameters = convert_to_jax_compatible(simulation_parameters) - plotting_parameters = convert_to_jax_compatible(plotting_parameters) - # convert the world parameters to jax compatible format + # Keep scalar parameters as Python types so JAX can treat them as static + # (avoids traced metadata in PyTrees and enables compile-time specialization). # if solver == "vector_potential": # B_grid, E_grid = build_collocated_grid(world) @@ -307,8 +306,8 @@ def initialize_simulation(toml_file): # convert the E, B, and J tuples into one big list fields = load_external_fields_from_toml(fields, toml_file) # add any external fields to the simulation - E, B, J = fields[:3], fields[3:6], fields[6:9] - # convert the fields list back into tuples + E, B, J = tuple(fields[:3]), tuple(fields[3:6]), tuple(fields[6:9]) + # convert the fields list back into tuples (JAX scan expects stable PyTree types) if solver == "spectral": curl_func = functools.partial(spectral_curl, world=world) diff --git a/PyPIC3D/solvers/first_order_yee.py b/PyPIC3D/solvers/first_order_yee.py index ce24825..7b25011 100644 --- a/PyPIC3D/solvers/first_order_yee.py +++ b/PyPIC3D/solvers/first_order_yee.py @@ -60,11 +60,6 @@ def update_E(E, B, J, world, constants, curl_func): eps = constants['eps'] # get the time resolution and necessary constants - Bx = jnp.pad(Bx, ((1,1), (1,1), (1,1)), mode="wrap") - By = jnp.pad(By, ((1,1), (1,1), (1,1)), mode="wrap") - Bz = jnp.pad(Bz, ((1,1), (1,1), (1,1)), mode="wrap") - # pad the magnetic field components for periodic boundary conditions - dBz_dy = (jnp.roll(Bz, shift=-1, axis=1) - Bz) / dy dBx_dy = (jnp.roll(Bx, shift=-1, axis=1) - Bx) / dy dBy_dz = (jnp.roll(By, shift=-1, axis=2) - By) / dz @@ -72,9 +67,9 @@ def update_E(E, B, J, world, constants, curl_func): dBz_dx = (jnp.roll(Bz, shift=-1, axis=0) - Bz) / dx dBy_dx = (jnp.roll(By, shift=-1, axis=0) - By) / dx - curl_x = (dBz_dy - dBy_dz)[1:-1,1:-1,1:-1] - curl_y = (dBx_dz - dBz_dx)[1:-1,1:-1,1:-1] - curl_z = (dBy_dx - dBx_dy)[1:-1,1:-1,1:-1] + curl_x = dBz_dy - dBy_dz + curl_y = dBx_dz - dBz_dx + curl_z = dBy_dx - dBx_dy # calculate the curl of the magnetic field Ex = Ex + ( C**2 * curl_x - Jx / eps ) * dt @@ -180,11 +175,6 @@ def update_B(E, B, world, constants, curl_func): Bx, By, Bz = B # unpack the E and B fields - Ex = jnp.pad(Ex, ((1,1), (1,1), (1,1)), mode="wrap") - Ey = jnp.pad(Ey, ((1,1), (1,1), (1,1)), mode="wrap") - Ez = jnp.pad(Ez, ((1,1), (1,1), (1,1)), mode="wrap") - # pad the electric field components for periodic boundary conditions - dEz_dy = (Ez - jnp.roll(Ez, shift=1, axis=1)) / dy dEx_dy = (Ex - jnp.roll(Ex, shift=1, axis=1)) / dy dEy_dz = (Ey - jnp.roll(Ey, shift=1, axis=2)) / dz @@ -192,9 +182,9 @@ def update_B(E, B, world, constants, curl_func): dEz_dx = (Ez - jnp.roll(Ez, shift=1, axis=0)) / dx dEy_dx = (Ey - jnp.roll(Ey, shift=1, axis=0)) / dx - curl_x = (dEz_dy - dEy_dz)[1:-1,1:-1,1:-1] - curl_y = (dEx_dz - dEz_dx)[1:-1,1:-1,1:-1] - curl_z = (dEy_dx - dEx_dy)[1:-1,1:-1,1:-1] + curl_x = dEz_dy - dEy_dz + curl_y = dEx_dz - dEz_dx + curl_z = dEy_dx - dEx_dy # calculate the curl of the electric field Bx = Bx - dt*curl_x @@ -209,4 +199,3 @@ def update_B(E, B, world, constants, curl_func): # apply a digital filter to the magnetic field components return (Bx, By, Bz) - diff --git a/PyPIC3D/utils.py b/PyPIC3D/utils.py index 1938e8d..a4af249 100644 --- a/PyPIC3D/utils.py +++ b/PyPIC3D/utils.py @@ -39,24 +39,38 @@ def wrap_around(ix, size): @jit def bilinear_filter(phi, mode="wrap"): """ - Apply a 3D (tri-linear) smoothing filter to a 3D array using a separable - [1, 2, 1]/4 kernel in each dimension. + Apply a tri-linear smoothing filter using a separable [1, 2, 1]/4 kernel + in each spatial dimension. Args: - phi (jnp.ndarray): 3D field array with shape (Nx, Ny, Nz). + phi (jnp.ndarray): Field array with leading spatial shape (Nx, Ny, Nz). + Any trailing feature dimensions are preserved (e.g. (Nx, Ny, Nz, 3)). mode (str): Padding mode passed to jnp.pad (default: "wrap"). Returns: jnp.ndarray: Filtered array with the same shape as phi. """ - k1 = jnp.array([1.0, 2.0, 1.0], dtype=phi.dtype) / 4.0 # sums to 1 - k3 = k1[:, None, None] * k1[None, :, None] * k1[None, None, :] # (3,3,3), sums to 1 + if mode == "wrap": + quarter = jnp.asarray(0.25, dtype=phi.dtype) + + def smooth_axis(arr, axis): + return (jnp.roll(arr, 1, axis=axis) + 2 * arr + jnp.roll(arr, -1, axis=axis)) * quarter + + phi = smooth_axis(phi, 0) + phi = smooth_axis(phi, 1) + phi = smooth_axis(phi, 2) + return phi + + if phi.ndim != 3: + raise ValueError("bilinear_filter only supports non-wrap mode for 3D arrays.") + + k1 = jnp.array([1.0, 2.0, 1.0], dtype=phi.dtype) / 4.0 + k3 = k1[:, None, None] * k1[None, :, None] * k1[None, None, :] kernel = jnp.zeros((3, 3, 3, 1, 1), dtype=phi.dtype) kernel = kernel.at[:, :, :, 0, 0].set(k3) padded_phi = jnp.pad(phi, ((1, 1), (1, 1), (1, 1)), mode=mode) - filtered = jax.lax.conv_general_dilated( padded_phi[jnp.newaxis, ..., jnp.newaxis], kernel, @@ -80,24 +94,18 @@ def digital_filter(phi, alpha): ndarray: Filtered field array. """ neighbor_weight = (1 - alpha) / 6 - kernel = jnp.zeros((3, 3, 3, 1, 1), dtype=phi.dtype) - kernel = kernel.at[1, 1, 1, 0, 0].set(alpha) - kernel = kernel.at[0, 1, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[2, 1, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 0, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 2, 1, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 1, 0, 0, 0].set(neighbor_weight) - kernel = kernel.at[1, 1, 2, 0, 0].set(neighbor_weight) - - padded_phi = jnp.pad(phi, ((1, 1), (1, 1), (1, 1)), mode="wrap") - filtered = jax.lax.conv_general_dilated( - padded_phi[jnp.newaxis, ..., jnp.newaxis], - kernel, - window_strides=(1, 1, 1), - padding="VALID", - dimension_numbers=("NDHWC", "DHWIO", "NDHWC"), + return ( + alpha * phi + + neighbor_weight + * ( + jnp.roll(phi, 1, axis=0) + + jnp.roll(phi, -1, axis=0) + + jnp.roll(phi, 1, axis=1) + + jnp.roll(phi, -1, axis=1) + + jnp.roll(phi, 1, axis=2) + + jnp.roll(phi, -1, axis=2) + ) ) - return jnp.squeeze(filtered, axis=(0, 4)) def mae(x, y): """ From 5184a38d658f2e3b545be9f671010699b0109782 Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sat, 14 Mar 2026 17:03:11 -0400 Subject: [PATCH 2/2] Aggressive CPU fast modes: fewer ops, opt-in extreme --- PyPIC3D/J.py | 165 +++++++++++++++++++++++++++-- PyPIC3D/__main__.py | 85 ++++++++++++++- PyPIC3D/boris.py | 80 +++++++++++--- PyPIC3D/evolve.py | 17 ++- PyPIC3D/initialization.py | 15 ++- PyPIC3D/particle.py | 9 +- PyPIC3D/solvers/first_order_yee.py | 22 ++-- PyPIC3D/utils.py | 38 +++---- 8 files changed, 373 insertions(+), 58 deletions(-) diff --git a/PyPIC3D/J.py b/PyPIC3D/J.py index fc073a7..51f86ce 100644 --- a/PyPIC3D/J.py +++ b/PyPIC3D/J.py @@ -8,8 +8,127 @@ from PyPIC3D.utils import digital_filter, wrap_around, bilinear_filter from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights -@partial(jit, static_argnames=("filter",)) -def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): + +def _weights_order1(r): + w0 = 1.0 - r + w1 = r + return jnp.stack((w0, w1), axis=0) + + +def _weights_order2(r): + w0 = 0.5 * (0.5 - r) ** 2 + w1 = 0.75 - r**2 + w2 = 0.5 * (0.5 + r) ** 2 + return jnp.stack((w0, w1, w2), axis=0) + + +def _deposit_1d(J_stack, dq, vx, vy, vz, x, grid_x0, dx, dt, shape_factor, Nx): + if shape_factor == 1: + x0 = jnp.floor((x - grid_x0) / dx).astype(jnp.int32) + deltax_node = (x - grid_x0) - x0 * dx + deltax_face = (x - grid_x0) - (x0 + 0.5) * dx + + r_node = deltax_node / dx + r_face = deltax_face / dx + + w_node = _weights_order1(r_node) # (2,Np) + w_face = _weights_order1(r_face) # (2,Np) + + ix = jnp.stack((x0, x0 + 1), axis=0) + ix = wrap_around(ix, Nx) + else: + x0 = jnp.round((x - grid_x0) / dx).astype(jnp.int32) + deltax_node = (x - grid_x0) - x0 * dx + deltax_face = (x - grid_x0) - (x0 + 0.5) * dx + + r_node = deltax_node / dx + r_face = deltax_face / dx + + w_node = _weights_order2(r_node) # (3,Np) + w_face = _weights_order2(r_face) # (3,Np) + + ix = jnp.stack((x0 - 1, x0, x0 + 1), axis=0) + ix = wrap_around(ix, Nx) + + # Jx uses face weights; Jy/Jz use node weights. + val = jnp.stack( + ( + (dq * vx)[None, :] * w_face, + (dq * vy)[None, :] * w_node, + (dq * vz)[None, :] * w_node, + ), + axis=-1, + ) # (S,Np,3) + + comp = jnp.arange(3, dtype=ix.dtype)[None, None, :] # (1,1,3) + idx = ix[:, :, None] + comp * jnp.asarray(Nx, dtype=ix.dtype) # (S,Np,3) + + out = jnp.bincount( + idx.reshape(-1), + weights=val.reshape(-1), + length=Nx * 3, + ).reshape(3, Nx) + + Jx, Jy, Jz = out[0], out[1], out[2] + return jnp.stack((Jx, Jy, Jz), axis=-1).reshape((Nx, 1, 1, 3)) + + +def _deposit_2d(J_stack, dq, vx, vy, vz, x, y, xmin, ymin, dx, dy, dt, shape_factor, Nx, Ny): + if shape_factor == 1: + x0 = jnp.floor((x - xmin) / dx).astype(jnp.int32) + y0 = jnp.floor((y - ymin) / dy).astype(jnp.int32) + deltax_node = (x - xmin) - x0 * dx + deltay_node = (y - ymin) - y0 * dy + deltax_face = (x - xmin) - (x0 + 0.5) * dx + deltay_face = (y - ymin) - (y0 + 0.5) * dy + + wx_node = _weights_order1(deltax_node / dx) # (2,Np) + wy_node = _weights_order1(deltay_node / dy) # (2,Np) + wx_face = _weights_order1(deltax_face / dx) # (2,Np) + wy_face = _weights_order1(deltay_face / dy) # (2,Np) + + ix = jnp.stack((x0, x0 + 1), axis=0) + iy = jnp.stack((y0, y0 + 1), axis=0) + ix = wrap_around(ix, Nx) + iy = wrap_around(iy, Ny) + else: + x0 = jnp.round((x - xmin) / dx).astype(jnp.int32) + y0 = jnp.round((y - ymin) / dy).astype(jnp.int32) + deltax_node = (x - xmin) - x0 * dx + deltay_node = (y - ymin) - y0 * dy + deltax_face = (x - xmin) - (x0 + 0.5) * dx + deltay_face = (y - ymin) - (y0 + 0.5) * dy + + wx_node = _weights_order2(deltax_node / dx) # (3,Np) + wy_node = _weights_order2(deltay_node / dy) # (3,Np) + wx_face = _weights_order2(deltax_face / dx) # (3,Np) + wy_face = _weights_order2(deltay_face / dy) # (3,Np) + + ix = jnp.stack((x0 - 1, x0, x0 + 1), axis=0) + iy = jnp.stack((y0 - 1, y0, y0 + 1), axis=0) + ix = wrap_around(ix, Nx) + iy = wrap_around(iy, Ny) + + idx = ix[:, None, :] + Nx * iy[None, :, :] # (Sx,Sy,Np) + idx_flat = idx.reshape(-1) + + # weights for each component + wjx = wx_face[:, None, :] * wy_node[None, :, :] + wjy = wx_node[:, None, :] * wy_face[None, :, :] + wjz = wx_node[:, None, :] * wy_node[None, :, :] + + valx = (dq * vx)[None, None, :] * wjx + valy = (dq * vy)[None, None, :] * wjy + valz = (dq * vz)[None, None, :] * wjz + + vals = jnp.stack((valx, valy, valz), axis=-1).reshape(-1, 3) + J_flat = jax.ops.segment_sum(vals, idx_flat, num_segments=Nx * Ny) # (Nx*Ny,3) + J2 = J_flat.reshape((Nx, Ny, 3))[:, :, None, :] + return J2 + + +@partial(jit, static_argnames=("filter", "shape_factor")) +def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear', shape_factor=2): """ Compute the current density from the charge density and particle velocities. @@ -37,12 +156,10 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): z_active = Jx.shape[2] != 1 # infer effective dimensionality from the current-grid shape - J_stack = jnp.stack((Jx, Jy, Jz), axis=-1) - J_stack = jnp.zeros_like(J_stack) + J_stack = jnp.zeros((Nx, Ny, Nz, 3), dtype=Jx.dtype) # keep J together so deposition and filtering can be fused across components for species in particles: - shape_factor = species.get_shape() charge = species.get_charge() dq = charge / (dx * dy * dz) # calculate the charge density contribution per particle @@ -50,11 +167,37 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): vx, vy, vz = species.get_velocity() # get the particles positions and velocities - x = x - vx * world['dt'] / 2 - y = y - vy * world['dt'] / 2 - z = z - vz * world['dt'] / 2 + dt = world["dt"] + x = x - vx * dt / 2 # step back to half time step positions for proper time staggering + if Ny == 1 and Nz == 1: + J_stack = _deposit_1d(J_stack, dq, vx, vy, vz, x, grid[0][0], dx, world["dt"], shape_factor, Nx) + continue + if Nz == 1: + y = y - vy * dt / 2 + J_stack = _deposit_2d( + J_stack, + dq, + vx, + vy, + vz, + x, + y, + grid[0][0], + grid[1][0], + dx, + dy, + world["dt"], + shape_factor, + Nx, + Ny, + ) + continue + + y = y - vy * dt / 2 + z = z - vz * dt / 2 + if shape_factor == 1: x0 = jnp.floor((x - grid[0][0]) / dx).astype(jnp.int32) y0 = jnp.floor((y - grid[1][0]) / dy).astype(jnp.int32) @@ -79,9 +222,9 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'): y0 = wrap_around(y0, Ny) z0 = wrap_around(z0, Nz) # wrap around the grid points for periodic boundary conditions - x1 = wrap_around(x0+1, Nx) - y1 = wrap_around(y0+1, Ny) - z1 = wrap_around(z0+1, Nz) + x1 = wrap_around(x0 + 1, Nx) + y1 = wrap_around(y0 + 1, Ny) + z1 = wrap_around(z0 + 1, Nz) # calculate the right grid point x_minus1 = x0 - 1 y_minus1 = y0 - 1 diff --git a/PyPIC3D/__main__.py b/PyPIC3D/__main__.py index 3541504..22b6ec7 100644 --- a/PyPIC3D/__main__.py +++ b/PyPIC3D/__main__.py @@ -46,7 +46,39 @@ def run_PyPIC3D(config_file): ##################################### INITIALIZE SIMULATION ################################################ - loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(config_file) + cfg = config_file + if isinstance(cfg, dict): + sim = cfg.setdefault("simulation_parameters", {}) + fast_mode = sim.get("fast_mode", "off") + if fast_mode not in ("off", "fp32", "aggressive", "extreme"): + raise ValueError("simulation_parameters.fast_mode must be one of: off, fp32, aggressive, extreme") + + if fast_mode in ("fp32", "aggressive", "extreme"): + sim["enable_x64"] = False + + if fast_mode in ("aggressive", "extreme"): + sim["shape_factor"] = 1 + sim["filter_j"] = "none" + + plot = cfg.setdefault("plotting", {}) + for key in ( + "plot_phasespace", + "plot_vtk_particles", + "plot_vtk_scalars", + "plot_vtk_vectors", + "plot_openpmd_particles", + "plot_openpmd_fields", + "dump_particles", + "dump_fields", + ): + plot[key] = False + plot["plotting_interval"] = 10**9 + + if fast_mode == "extreme": + # opt-in physics approximation for maximum throughput + sim["relativistic"] = False + + loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(cfg) # initialize the simulation dt = world['dt'] @@ -73,6 +105,9 @@ def run_PyPIC3D(config_file): scan_chunk = int(simulation_parameters.get("scan_chunk", 1)) plotting_interval = int(plotting_parameters["plotting_interval"]) + fast_mode = str(simulation_parameters.get("fast_mode", "off")) + advance_impl = str(simulation_parameters.get("advance_impl", "fori")) + scan_unroll = int(simulation_parameters.get("scan_unroll", 1)) if scan_chunk < 1: raise ValueError("simulation_parameters.scan_chunk must be >= 1") @@ -85,6 +120,20 @@ def run_PyPIC3D(config_file): def make_advance(n_steps): def advance(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): + if fast_mode == "aggressive" and advance_impl == "scan": + def scan_body(carry, _): + p, f = carry + return loop(p, f, world, constants, curl_func, J_func, solver, relativistic=relativistic), None + + (particles, fields), _ = lax.scan( + scan_body, + (particles, fields), + xs=None, + length=n_steps, + unroll=scan_unroll, + ) + return particles, fields + def body(_, state): p, f = state return loop(p, f, world, constants, curl_func, J_func, solver, relativistic=relativistic) @@ -101,6 +150,34 @@ def body(_, state): tail = Nt % scan_chunk advance_tail = make_advance(tail) if (scan_chunk > 1 and tail) else None + outputs_enabled = any( + plotting_parameters.get(k, False) + for k in ( + "plot_phasespace", + "plot_vtk_scalars", + "plot_vtk_vectors", + "plot_vtk_particles", + "plot_openpmd_particles", + "plot_openpmd_fields", + "dump_particles", + "dump_fields", + ) + ) + + if (not outputs_enabled) and plotting_interval > Nt: + advance_all = make_advance(Nt) + particles, fields = advance_all( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + return Nt, plotting_parameters, simulation_parameters, plasma_parameters, constants, particles, fields, world + step_iter = range(0, Nt, scan_chunk) if scan_chunk > 1 else range(Nt) for t in tqdm(step_iter): @@ -227,7 +304,11 @@ def main(): toml_file = load_config_file() # load the configuration file - enable_x64 = bool(toml_file.get("simulation_parameters", {}).get("enable_x64", True)) + sim = toml_file.get("simulation_parameters", {}) if isinstance(toml_file, dict) else {} + fast_mode = sim.get("fast_mode", "off") + enable_x64 = bool(sim.get("enable_x64", True)) + if fast_mode in ("fp32", "aggressive", "extreme"): + enable_x64 = False jax.config.update("jax_enable_x64", enable_x64) # set Jax precision (default preserves legacy behavior) diff --git a/PyPIC3D/boris.py b/PyPIC3D/boris.py index d2b08c1..f2cd8e2 100644 --- a/PyPIC3D/boris.py +++ b/PyPIC3D/boris.py @@ -50,16 +50,45 @@ def particle_push(particles, E, B, grid, staggered_grid, dt, constants, periodic ################## INTERPOLATE FIELDS TO PARTICLE POSITIONS ############## Ex, Ey, Ez = E # unpack the electric field components - efield_atx = interpolate_field_to_particles(Ex, x, y, z, Ex_grid, shape_factor) - efield_aty = interpolate_field_to_particles(Ey, x, y, z, Ey_grid, shape_factor) - efield_atz = interpolate_field_to_particles(Ez, x, y, z, Ez_grid, shape_factor) - # calculate the electric field at the particle positions on the Yee-staggered component grids Bx, By, Bz = B # unpack the magnetic field components - bfield_atx = interpolate_field_to_particles(Bx, x, y, z, Bx_grid, shape_factor) - bfield_aty = interpolate_field_to_particles(By, x, y, z, By_grid, shape_factor) - bfield_atz = interpolate_field_to_particles(Bz, x, y, z, Bz_grid, shape_factor) - # calculate the magnetic field at the particle positions on the Yee-staggered component grids + + Ny = len(grid[1]) + Nz = len(grid[2]) + + if Ny == 1 and Nz == 1: + node_stack = jnp.stack((Ey, Ez, Bx), axis=-1) + face_stack = jnp.stack((Ex, By, Bz), axis=-1) + + node_vals = interpolate_field_to_particles(node_stack, x, y, z, (grid[0], grid[1], grid[2]), shape_factor) + face_vals = interpolate_field_to_particles(face_stack, x, y, z, (staggered_grid[0], grid[1], grid[2]), shape_factor) + + efield_aty, efield_atz, bfield_atx = node_vals[:, 0], node_vals[:, 1], node_vals[:, 2] + efield_atx, bfield_aty, bfield_atz = face_vals[:, 0], face_vals[:, 1], face_vals[:, 2] + + elif Nz == 1: + ex_by = jnp.stack((Ex, By), axis=-1) + ey_bx = jnp.stack((Ey, Bx), axis=-1) + + ex_by_vals = interpolate_field_to_particles(ex_by, x, y, z, Ex_grid, shape_factor) + ey_bx_vals = interpolate_field_to_particles(ey_bx, x, y, z, Ey_grid, shape_factor) + ez_vals = interpolate_field_to_particles(Ez, x, y, z, Ez_grid, shape_factor) + bz_vals = interpolate_field_to_particles(Bz, x, y, z, Bz_grid, shape_factor) + + efield_atx, bfield_aty = ex_by_vals[:, 0], ex_by_vals[:, 1] + efield_aty, bfield_atx = ey_bx_vals[:, 0], ey_bx_vals[:, 1] + efield_atz = ez_vals + bfield_atz = bz_vals + + else: + efield_atx = interpolate_field_to_particles(Ex, x, y, z, Ex_grid, shape_factor) + efield_aty = interpolate_field_to_particles(Ey, x, y, z, Ey_grid, shape_factor) + efield_atz = interpolate_field_to_particles(Ez, x, y, z, Ez_grid, shape_factor) + # calculate the electric field at the particle positions on the Yee-staggered component grids + bfield_atx = interpolate_field_to_particles(Bx, x, y, z, Bx_grid, shape_factor) + bfield_aty = interpolate_field_to_particles(By, x, y, z, By_grid, shape_factor) + bfield_atz = interpolate_field_to_particles(Bz, x, y, z, Bz_grid, shape_factor) + # calculate the magnetic field at the particle positions on the Yee-staggered component grids ######################################################################### @@ -331,18 +360,43 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): if shape_factor == 1: x0 = jnp.floor((x - xmin) / dx).astype(jnp.int32) + else: + x0 = jnp.round((x - xmin) / dx).astype(jnp.int32) + # compute the stencil anchor points (cell-left for first order, nearest node for second order) + + deltax = (x - xmin) - x0 * dx + # determine the distance from the closest grid nodes + + if x_active and (not y_active) and (not z_active): + x0 = wrap_around(x0, Nx) + if shape_factor == 1: + r = deltax / dx + xpts = jnp.stack((x0, wrap_around(x0 + 1, Nx)), axis=0) + xw = jnp.stack((1.0 - r, r), axis=0) + else: + r = deltax / dx + xpts = jnp.stack((wrap_around(x0 - 1, Nx), x0, wrap_around(x0 + 1, Nx)), axis=0) + xw = jnp.stack( + ( + 0.5 * (0.5 - r) ** 2, + 0.75 - r**2, + 0.5 * (0.5 + r) ** 2, + ), + axis=0, + ) + if field.ndim == 4: + return jnp.sum(field[xpts, 0, 0, :] * xw[:, :, None], axis=0) + return jnp.sum(field[xpts, 0, 0] * xw, axis=0) + + if shape_factor == 1: y0 = jnp.floor((y - ymin) / dy).astype(jnp.int32) z0 = jnp.floor((z - zmin) / dz).astype(jnp.int32) else: - x0 = jnp.round((x - xmin) / dx).astype(jnp.int32) y0 = jnp.round((y - ymin) / dy).astype(jnp.int32) z0 = jnp.round((z - zmin) / dz).astype(jnp.int32) - # compute the stencil anchor points (cell-left for first order, nearest node for second order) - deltax = (x - xmin) - x0 * dx deltay = (y - ymin) - y0 * dy deltaz = (z - zmin) - z0 * dz - # determine the distance from the closest grid nodes if shape_factor == 1: x_weights, y_weights, z_weights = get_first_order_weights(deltax, deltay, deltaz, dx, dy, dz) @@ -412,4 +466,6 @@ def interpolate_field_to_particles(field, x, y, z, grid, shape_factor): * y_weights_eff[None, :, None, :] * z_weights_eff[None, None, :, :] ) + if field.ndim == 4: + weights = weights[..., None] return jnp.sum(field_vals * weights, axis=(0, 1, 2)) diff --git a/PyPIC3D/evolve.py b/PyPIC3D/evolve.py index e0525b2..4d99927 100644 --- a/PyPIC3D/evolve.py +++ b/PyPIC3D/evolve.py @@ -74,8 +74,7 @@ def time_loop_electrostatic(particles, fields, world, constants, curl_func, J_fu return particles, fields -@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) -def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): +def time_loop_electrodynamic_inline(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ Advance an electrodynamic Particle-In-Cell (PIC) system by one time step. This routine performs, in order: @@ -154,6 +153,20 @@ def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_f return particles, fields +@partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) +def time_loop_electrodynamic(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): + return time_loop_electrodynamic_inline( + particles, + fields, + world, + constants, + curl_func, + J_func, + solver, + relativistic=relativistic, + ) + + @partial(jit, static_argnames=("curl_func", "J_func", "solver", "relativistic"), donate_argnums=(0, 1)) def time_loop_vector_potential(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True): """ diff --git a/PyPIC3D/initialization.py b/PyPIC3D/initialization.py index 03f2843..8caa438 100644 --- a/PyPIC3D/initialization.py +++ b/PyPIC3D/initialization.py @@ -40,7 +40,10 @@ from PyPIC3D.evolve import ( - time_loop_electrodynamic, time_loop_electrostatic, time_loop_vector_potential + time_loop_electrodynamic, + time_loop_electrodynamic_inline, + time_loop_electrostatic, + time_loop_vector_potential, ) from PyPIC3D.J import ( @@ -97,6 +100,7 @@ def default_parameters(): "name": "Default Simulation", "output_dir": os.getcwd(), "solver": "fdtd", # solver: spectral, fdtd, vector_potential, curl_curl + "fast_mode": "off", # off | fp32 | aggressive | extreme (trades accuracy for speed) "particle_bc": "periodic", # particle boundary conditions: periodic, absorb, reflect # "bc": "periodic", # boundary conditions: periodic, dirichlet, neumann "x_bc": "periodic", # x boundary conditions: periodic, conducting @@ -343,13 +347,20 @@ def initialize_simulation(toml_file): evolve_loop = time_loop_electrodynamic # set the evolve loop function based on the electrostatic flag + if simulation_parameters.get("fast_mode", "off") in ("aggressive", "extreme"): + evolve_loop = time_loop_electrodynamic_inline + if simulation_parameters['current_calculation'] == "esirkepov": print("Using Esirkepov current calculation method") # raise NotImplementedError("Esirkepov current calculation method is not fully functional yet.") J_func = Esirkepov_current elif simulation_parameters['current_calculation'] == "j_from_rhov": print(f"Using J from rhov current calculation method with filter: {simulation_parameters['filter_j']}") - J_func = functools.partial(J_from_rhov, filter=simulation_parameters['filter_j']) + J_func = functools.partial( + J_from_rhov, + filter=simulation_parameters["filter_j"], + shape_factor=int(simulation_parameters["shape_factor"]), + ) if solver == "vector_potential": diff --git a/PyPIC3D/particle.py b/PyPIC3D/particle.py index 4655277..2aa6b19 100644 --- a/PyPIC3D/particle.py +++ b/PyPIC3D/particle.py @@ -664,9 +664,12 @@ def boundary_conditions(self): x1, x2, x3 = self.x1, self.x2, self.x3 v1, v2, v3 = self.v1, self.v2, self.v3 - x1, v1 = apply_axis_boundary_condition(x1, v1, self.x_wind, self.half_x_wind, self.x_periodic, self.x_reflecting) - x2, v2 = apply_axis_boundary_condition(x2, v2, self.y_wind, self.half_y_wind, self.y_periodic, self.y_reflecting) - x3, v3 = apply_axis_boundary_condition(x3, v3, self.z_wind, self.half_z_wind, self.z_periodic, self.z_reflecting) + if self.update_x: + x1, v1 = apply_axis_boundary_condition(x1, v1, self.x_wind, self.half_x_wind, self.x_periodic, self.x_reflecting) + if self.update_y: + x2, v2 = apply_axis_boundary_condition(x2, v2, self.y_wind, self.half_y_wind, self.y_periodic, self.y_reflecting) + if self.update_z: + x3, v3 = apply_axis_boundary_condition(x3, v3, self.z_wind, self.half_z_wind, self.z_periodic, self.z_reflecting) self.x1, self.x2, self.x3 = x1, x2, x3 self.v1, self.v2, self.v3 = v1, v2, v3 diff --git a/PyPIC3D/solvers/first_order_yee.py b/PyPIC3D/solvers/first_order_yee.py index 7b25011..775ce1c 100644 --- a/PyPIC3D/solvers/first_order_yee.py +++ b/PyPIC3D/solvers/first_order_yee.py @@ -60,10 +60,13 @@ def update_E(E, B, J, world, constants, curl_func): eps = constants['eps'] # get the time resolution and necessary constants - dBz_dy = (jnp.roll(Bz, shift=-1, axis=1) - Bz) / dy - dBx_dy = (jnp.roll(Bx, shift=-1, axis=1) - Bx) / dy - dBy_dz = (jnp.roll(By, shift=-1, axis=2) - By) / dz - dBx_dz = (jnp.roll(Bx, shift=-1, axis=2) - Bx) / dz + Ny = Ex.shape[1] + Nz = Ex.shape[2] + + dBz_dy = (jnp.roll(Bz, shift=-1, axis=1) - Bz) / dy if Ny != 1 else 0.0 + dBx_dy = (jnp.roll(Bx, shift=-1, axis=1) - Bx) / dy if Ny != 1 else 0.0 + dBy_dz = (jnp.roll(By, shift=-1, axis=2) - By) / dz if Nz != 1 else 0.0 + dBx_dz = (jnp.roll(Bx, shift=-1, axis=2) - Bx) / dz if Nz != 1 else 0.0 dBz_dx = (jnp.roll(Bz, shift=-1, axis=0) - Bz) / dx dBy_dx = (jnp.roll(By, shift=-1, axis=0) - By) / dx @@ -175,10 +178,13 @@ def update_B(E, B, world, constants, curl_func): Bx, By, Bz = B # unpack the E and B fields - dEz_dy = (Ez - jnp.roll(Ez, shift=1, axis=1)) / dy - dEx_dy = (Ex - jnp.roll(Ex, shift=1, axis=1)) / dy - dEy_dz = (Ey - jnp.roll(Ey, shift=1, axis=2)) / dz - dEx_dz = (Ex - jnp.roll(Ex, shift=1, axis=2)) / dz + Ny = Ex.shape[1] + Nz = Ex.shape[2] + + dEz_dy = (Ez - jnp.roll(Ez, shift=1, axis=1)) / dy if Ny != 1 else 0.0 + dEx_dy = (Ex - jnp.roll(Ex, shift=1, axis=1)) / dy if Ny != 1 else 0.0 + dEy_dz = (Ey - jnp.roll(Ey, shift=1, axis=2)) / dz if Nz != 1 else 0.0 + dEx_dz = (Ex - jnp.roll(Ex, shift=1, axis=2)) / dz if Nz != 1 else 0.0 dEz_dx = (Ez - jnp.roll(Ez, shift=1, axis=0)) / dx dEy_dx = (Ey - jnp.roll(Ey, shift=1, axis=0)) / dx diff --git a/PyPIC3D/utils.py b/PyPIC3D/utils.py index a4af249..c4a2ce1 100644 --- a/PyPIC3D/utils.py +++ b/PyPIC3D/utils.py @@ -2,7 +2,7 @@ import plotly import tqdm import pyevtk -from jax import jit +from jax import jit, lax import argparse import jax.numpy as jnp import functools @@ -93,19 +93,22 @@ def digital_filter(phi, alpha): Returns: ndarray: Filtered field array. """ - neighbor_weight = (1 - alpha) / 6 - return ( - alpha * phi - + neighbor_weight - * ( - jnp.roll(phi, 1, axis=0) - + jnp.roll(phi, -1, axis=0) - + jnp.roll(phi, 1, axis=1) - + jnp.roll(phi, -1, axis=1) - + jnp.roll(phi, 1, axis=2) - + jnp.roll(phi, -1, axis=2) + def apply(phi): + neighbor_weight = (1 - alpha) / 6 + return ( + alpha * phi + + neighbor_weight + * ( + jnp.roll(phi, 1, axis=0) + + jnp.roll(phi, -1, axis=0) + + jnp.roll(phi, 1, axis=1) + + jnp.roll(phi, -1, axis=1) + + jnp.roll(phi, 1, axis=2) + + jnp.roll(phi, -1, axis=2) + ) ) - ) + + return lax.cond(alpha == 1.0, lambda phi: phi, apply, phi) def mae(x, y): """ @@ -230,11 +233,10 @@ def nd_trapezoid(arr, dxs): mass = species.get_mass() vx, vy, vz = species.get_velocity() v2 = vx**2 + vy**2 + vz**2 - gamma = 1.0 / jnp.sqrt(1 - v2 / C**2) - momentum2 = jnp.square(mass * gamma ) * v2 - # compute the squared momentum for each particle - KE = jnp.sum( jnp.sqrt( momentum2 * C**2 + mass**2 * C**4) - mass * C**2 ) - # compute the kinetic energy for this species + # Relativistic KE: use m c^2 (gamma - 1) to avoid catastrophic cancellation, + # especially in fp32 fast modes. + gamma = 1.0 / jnp.sqrt(jnp.maximum(1.0 - v2 / C**2, 0.0)) + KE = jnp.sum(mass * C**2 * (gamma - 1.0)) kinetic_energy += KE # add to total kinetic energy