Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
456 changes: 374 additions & 82 deletions PyPIC3D/J.py

Large diffs are not rendered by default.

201 changes: 174 additions & 27 deletions PyPIC3D/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import time
import jax
from jax import block_until_ready
from jax import block_until_ready, lax
import jax.numpy as jnp
from tqdm import tqdm

Expand All @@ -23,10 +23,6 @@
write_openpmd_particles, write_openpmd_fields
)

from PyPIC3D.diagnostics.vtk import (
plot_field_slice_vtk, plot_vectorfield_slice_vtk, plot_vtk_particles
)

from PyPIC3D.utils import (
dump_parameters_to_toml, load_config_file, compute_energy,
setup_pmd_files
Expand All @@ -50,11 +46,41 @@
def run_PyPIC3D(config_file):
##################################### INITIALIZE SIMULATION ################################################

loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(config_file)
cfg = config_file
if isinstance(cfg, dict):
sim = cfg.setdefault("simulation_parameters", {})
fast_mode = sim.get("fast_mode", "off")
if fast_mode not in ("off", "fp32", "aggressive", "extreme"):
raise ValueError("simulation_parameters.fast_mode must be one of: off, fp32, aggressive, extreme")

if fast_mode in ("fp32", "aggressive", "extreme"):
sim["enable_x64"] = False

if fast_mode in ("aggressive", "extreme"):
sim["shape_factor"] = 1
sim["filter_j"] = "none"

plot = cfg.setdefault("plotting", {})
for key in (
"plot_phasespace",
"plot_vtk_particles",
"plot_vtk_scalars",
"plot_vtk_vectors",
"plot_openpmd_particles",
"plot_openpmd_fields",
"dump_particles",
"dump_fields",
):
plot[key] = False
plot["plotting_interval"] = 10**9

if fast_mode == "extreme":
# opt-in physics approximation for maximum throughput
sim["relativistic"] = False

loop, particles, fields, world, simulation_parameters, constants, plotting_parameters, plasma_parameters, solver, electrostatic, verbose, GPUs, Nt, curl_func, J_func, relativistic = initialize_simulation(cfg)
# initialize the simulation

jit_loop = jax.jit(loop, static_argnames=('curl_func', 'J_func', 'solver', 'relativistic'))

dt = world['dt']
output_dir = simulation_parameters['output_dir']
vertex_grid = world['grids']['vertex']
Expand All @@ -77,12 +103,88 @@ def run_PyPIC3D(config_file):

###################################################### SIMULATION LOOP #####################################

for t in tqdm(range(Nt)):
scan_chunk = int(simulation_parameters.get("scan_chunk", 1))
plotting_interval = int(plotting_parameters["plotting_interval"])
fast_mode = str(simulation_parameters.get("fast_mode", "off"))
advance_impl = str(simulation_parameters.get("advance_impl", "fori"))
scan_unroll = int(simulation_parameters.get("scan_unroll", 1))

if scan_chunk < 1:
raise ValueError("simulation_parameters.scan_chunk must be >= 1")

if scan_chunk > 1 and (plotting_interval % scan_chunk) != 0:
raise ValueError(
f"scan_chunk={scan_chunk} requires plotting_interval to be a multiple of scan_chunk "
f"(got plotting_interval={plotting_interval})."
)

def make_advance(n_steps):
def advance(particles, fields, world, constants, curl_func, J_func, solver, relativistic=True):
if fast_mode == "aggressive" and advance_impl == "scan":
def scan_body(carry, _):
p, f = carry
return loop(p, f, world, constants, curl_func, J_func, solver, relativistic=relativistic), None

(particles, fields), _ = lax.scan(
scan_body,
(particles, fields),
xs=None,
length=n_steps,
unroll=scan_unroll,
)
return particles, fields

def body(_, state):
p, f = state
return loop(p, f, world, constants, curl_func, J_func, solver, relativistic=relativistic)

return lax.fori_loop(0, n_steps, body, (particles, fields))

return jax.jit(
advance,
static_argnames=("curl_func", "J_func", "solver", "relativistic"),
donate_argnums=(0, 1),
)

advance_full = make_advance(scan_chunk) if scan_chunk > 1 else None
tail = Nt % scan_chunk
advance_tail = make_advance(tail) if (scan_chunk > 1 and tail) else None

outputs_enabled = any(
plotting_parameters.get(k, False)
for k in (
"plot_phasespace",
"plot_vtk_scalars",
"plot_vtk_vectors",
"plot_vtk_particles",
"plot_openpmd_particles",
"plot_openpmd_fields",
"dump_particles",
"dump_fields",
)
)

if (not outputs_enabled) and plotting_interval > Nt:
advance_all = make_advance(Nt)
particles, fields = advance_all(
particles,
fields,
world,
constants,
curl_func,
J_func,
solver,
relativistic=relativistic,
)
return Nt, plotting_parameters, simulation_parameters, plasma_parameters, constants, particles, fields, world

step_iter = range(0, Nt, scan_chunk) if scan_chunk > 1 else range(Nt)
for t in tqdm(step_iter):

# plot the data
if t % plotting_parameters['plotting_interval'] == 0:
if t % plotting_interval == 0:

plot_num = t // plotting_parameters['plotting_interval']
plot_num = t // plotting_interval
# determine the plot number

E, B, J, rho, *rest = fields
Expand Down Expand Up @@ -111,6 +213,11 @@ def run_PyPIC3D(config_file):


if plotting_parameters['plot_vtk_scalars']:
try:
from PyPIC3D.diagnostics.vtk import plot_field_slice_vtk
except ModuleNotFoundError as e:
raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e

rho = compute_rho(particles, rho, world, constants)
# calculate the charge density based on the particle positions
mass_density = compute_mass_density(particles, rho, world)
Expand All @@ -122,13 +229,23 @@ def run_PyPIC3D(config_file):


if plotting_parameters['plot_vtk_vectors']:
try:
from PyPIC3D.diagnostics.vtk import plot_vectorfield_slice_vtk
except ModuleNotFoundError as e:
raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e

vector_field_slices = [ [E[0][:,world['Ny']//2,:], E[1][:,world['Ny']//2,:], E[2][:,world['Ny']//2,:]],
[B[0][:,world['Ny']//2,:], B[1][:,world['Ny']//2,:], B[2][:,world['Ny']//2,:]],
[J[0][:,world['Ny']//2,:], J[1][:,world['Ny']//2,:], J[2][:,world['Ny']//2,:]]]
plot_vectorfield_slice_vtk(vector_field_slices, vector_field_names, 1, vertex_grid, t, 'vector_field', output_dir, world)
# Plot the vector fields in VTK format

if plotting_parameters['plot_vtk_particles']:
try:
from PyPIC3D.diagnostics.vtk import plot_vtk_particles
except ModuleNotFoundError as e:
raise ModuleNotFoundError("VTK diagnostics requested but 'vtk' is not installed.") from e

plot_vtk_particles(particles, plot_num, output_dir)
# Plot the particles in VTK format

Expand All @@ -143,35 +260,65 @@ def run_PyPIC3D(config_file):
fields = (E, B, J, rho, *rest)
# repack the fields

particles, fields = jit_loop(
particles,
fields,
world,
constants,
curl_func,
J_func,
solver,
relativistic=relativistic,
)
# time loop to update the particles and fields
if scan_chunk == 1:
particles, fields = loop(
particles,
fields,
world,
constants,
curl_func,
J_func,
solver,
relativistic=relativistic,
)
else:
if (t + scan_chunk) <= Nt:
particles, fields = advance_full(
particles,
fields,
world,
constants,
curl_func,
J_func,
solver,
relativistic=relativistic,
)
else:
particles, fields = advance_tail(
particles,
fields,
world,
constants,
curl_func,
J_func,
solver,
relativistic=relativistic,
)
# advance the particles and fields


return Nt, plotting_parameters, simulation_parameters, plasma_parameters, constants, particles, fields, world

def main():
###################### JAX SETTINGS ########################################################################
jax.config.update("jax_enable_x64", True)
# set Jax to use 64 bit precision
toml_file = load_config_file()
# load the configuration file

sim = toml_file.get("simulation_parameters", {}) if isinstance(toml_file, dict) else {}
fast_mode = sim.get("fast_mode", "off")
enable_x64 = bool(sim.get("enable_x64", True))
if fast_mode in ("fp32", "aggressive", "extreme"):
enable_x64 = False
jax.config.update("jax_enable_x64", enable_x64)
# set Jax precision (default preserves legacy behavior)

# jax.config.update("jax_debug_nans", True)
# debugging for nans
jax.config.update('jax_platform_name', 'cpu')
# set Jax to use CPUs
#jax.config.update("jax_disable_jit", True)
############################################################################################################

toml_file = load_config_file()
# load the configuration file

start = time.time()
# start the timer

Expand Down
Loading
Loading