diff --git a/PyPIC3D/boris.py b/PyPIC3D/boris.py index 49405c8..7510f90 100644 --- a/PyPIC3D/boris.py +++ b/PyPIC3D/boris.py @@ -63,8 +63,24 @@ def particle_push(particles, E, B, grid, staggered_grid, dt, constants, periodic #################### BORIS ALGORITHM #################################### - boris_vmap = jax.vmap(boris_single_particle, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None)) - relativistic_boris_vmap = jax.vmap(relativistic_boris_single_particle, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None)) + if jnp.ndim(q) == 0: + boris_vmap = jax.vmap( + boris_single_particle, + in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None), + ) + relativistic_boris_vmap = jax.vmap( + relativistic_boris_single_particle, + in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None), + ) + else: + boris_vmap = jax.vmap( + boris_single_particle, + in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None), + ) + relativistic_boris_vmap = jax.vmap( + relativistic_boris_single_particle, + in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None), + ) # vectorize the Boris algorithm for batch processing newvx, newvy, newvz = jax.lax.cond( diff --git a/PyPIC3D/diagnostics/openPMD.py b/PyPIC3D/diagnostics/openPMD.py index a76c9fd..842b9c6 100644 --- a/PyPIC3D/diagnostics/openPMD.py +++ b/PyPIC3D/diagnostics/openPMD.py @@ -113,8 +113,29 @@ def write_openpmd_particles_to_iteration(iteration, particles, constants): gamma = _ensure_openpmd_array(gamma) num_particles = x.shape[0] - particle_mass = float(species.mass) - particle_charge = float(species.charge) + # number of particles in this species + + particle_mass = species.get_mass() + particle_charge = species.get_charge() + weights = species.get_weight() + # get the particle mass, charge, and weight for this species + + + if jnp.ndim(weights) == 0: + weights = np.full(num_particles, float(weights), dtype=np.float64) + else: + weights = _ensure_openpmd_array(weights) + + if jnp.ndim(particle_mass) == 0: + masses = np.full(num_particles, float(particle_mass), dtype=np.float64) + else: + masses = _ensure_openpmd_array(particle_mass) + + if jnp.ndim(particle_charge) == 0: + charges = np.full(num_particles, float(particle_charge), dtype=np.float64) + else: + charges = _ensure_openpmd_array(particle_charge) + # ensure weights, masses, and charges are 1D arrays of the correct length for openPMD output position = species_group["position"] for component, data in zip(("x", "y", "z"), (x, y, z)): @@ -123,6 +144,7 @@ def write_openpmd_particles_to_iteration(iteration, particles, constants): record_component.store_chunk(data, [0], [num_particles]) record_component.unit_SI = 1.0 + # positionOffset: required by openPMD consumers (WarpX expects it) pos_off = species_group["positionOffset"] zeros = np.zeros(num_particles, dtype=np.float64) for comp in ("x", "y", "z"): @@ -135,25 +157,24 @@ def write_openpmd_particles_to_iteration(iteration, particles, constants): for component, data in zip(("x", "y", "z"), (vx, vy, vz)): record_component = momentum[component] record_component.reset_dataset(io.Dataset(data.dtype, [num_particles])) - record_component.store_chunk(data * particle_mass * gamma, [0], [num_particles]) + momenta = data * masses * gamma + # compute the momentum for each particle + record_component.store_chunk(momenta, [0], [num_particles]) record_component.unit_SI = 1.0 weighting = species_group["weighting"] - weights = np.full(num_particles, float(species.weight), dtype=np.float64) weighting.reset_dataset(io.Dataset(weights.dtype, [num_particles])) weighting.store_chunk(weights, [0], [num_particles]) weighting.unit_SI = 1.0 charge = species_group["charge"] - charges = np.full(num_particles, particle_charge, dtype=np.float64) charge.reset_dataset(io.Dataset(charges.dtype, [num_particles])) - charge.store_chunk(charges, [0], [num_particles]) + charge.store_chunk(charges / weights, [0], [num_particles]) charge.unit_SI = 1.0 mass = species_group["mass"] - masses = np.full(num_particles, particle_mass, dtype=np.float64) mass.reset_dataset(io.Dataset(masses.dtype, [num_particles])) - mass.store_chunk(masses, [0], [num_particles]) + mass.store_chunk(masses / weights, [0], [num_particles]) mass.unit_SI = 1.0 diff --git a/PyPIC3D/flat_particles.py b/PyPIC3D/flat_particles.py new file mode 100644 index 0000000..e16b04f --- /dev/null +++ b/PyPIC3D/flat_particles.py @@ -0,0 +1,353 @@ +import jax +import jax.numpy as jnp + + +@jax.tree_util.register_pytree_node_class +class flat_particle_species: + def __init__( + self, + name, + N_particles, + charge, + mass, + weight, + T, + x1, + x2, + x3, + v1, + v2, + v3, + x_wind, + y_wind, + z_wind, + dx, + dy, + dz, + x_bc, + y_bc, + z_bc, + update_pos, + update_v, + update_x, + update_y, + update_z, + update_vx, + update_vy, + update_vz, + shape, + dt, + species_meta, + ): + self.name = name + self.N_particles = N_particles + self.charge = charge + self.mass = mass + self.weight = weight + self.T = T + self.x1 = x1 + self.x2 = x2 + self.x3 = x3 + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.x_wind = x_wind + self.y_wind = y_wind + self.z_wind = z_wind + self.dx = dx + self.dy = dy + self.dz = dz + self.x_bc = x_bc + self.y_bc = y_bc + self.z_bc = z_bc + self.update_pos = update_pos + self.update_v = update_v + self.update_x = update_x + self.update_y = update_y + self.update_z = update_z + self.update_vx = update_vx + self.update_vy = update_vy + self.update_vz = update_vz + self.shape = shape + self.dt = dt + self.species_meta = species_meta + + def get_name(self): + return self.name + + def get_charge(self): + return self.charge * self.weight + + def get_number_of_particles(self): + return self.N_particles + + def get_temperature(self): + return self.T + + def get_velocity(self): + return self.v1, self.v2, self.v3 + + def get_forward_position(self): + return self.x1, self.x2, self.x3 + + def get_position(self): + x1_back = self.x1 - self.v1 * self.dt / 2 + x2_back = self.x2 - self.v2 * self.dt / 2 + x3_back = self.x3 - self.v3 * self.dt / 2 + + half_x = self.x_wind / 2 + half_y = self.y_wind / 2 + half_z = self.z_wind / 2 + + x1_back = jnp.where(x1_back > half_x, x1_back - self.x_wind, jnp.where(x1_back < -half_x, x1_back + self.x_wind, x1_back)) + x2_back = jnp.where(x2_back > half_y, x2_back - self.y_wind, jnp.where(x2_back < -half_y, x2_back + self.y_wind, x2_back)) + x3_back = jnp.where(x3_back > half_z, x3_back - self.z_wind, jnp.where(x3_back < -half_z, x3_back + self.z_wind, x3_back)) + + return x1_back, x2_back, x3_back + + def get_mass(self): + return self.mass * self.weight + + def get_weight(self): + return self.weight + + def get_shape(self): + return self.shape + + def momentum(self): + vmag = jnp.sqrt(self.v1**2 + self.v2**2 + self.v3**2) + return jnp.sum(vmag * self.mass * self.weight) + + def set_velocity(self, v1, v2, v3): + if self.update_v: + if self.update_vx: + self.v1 = v1 + if self.update_vy: + self.v2 = v2 + if self.update_vz: + self.v3 = v3 + + def update_position(self): + if self.update_pos: + if self.update_x: + self.x1 = self.x1 + self.v1 * self.dt + if self.update_y: + self.x2 = self.x2 + self.v2 * self.dt + if self.update_z: + self.x3 = self.x3 + self.v3 * self.dt + + def boundary_conditions(self): + half_x = self.x_wind / 2 + half_y = self.y_wind / 2 + half_z = self.z_wind / 2 + + self.x1 = jnp.where(self.x1 > half_x, self.x1 - self.x_wind, jnp.where(self.x1 < -half_x, self.x1 + self.x_wind, self.x1)) + self.x2 = jnp.where(self.x2 > half_y, self.x2 - self.y_wind, jnp.where(self.x2 < -half_y, self.x2 + self.y_wind, self.x2)) + self.x3 = jnp.where(self.x3 > half_z, self.x3 - self.z_wind, jnp.where(self.x3 < -half_z, self.x3 + self.z_wind, self.x3)) + + def tree_flatten(self): + children = (self.x1, self.x2, self.x3, self.v1, self.v2, self.v3) + aux_data = ( + self.name, + self.N_particles, + self.charge, + self.mass, + self.weight, + self.T, + self.x_wind, + self.y_wind, + self.z_wind, + self.dx, + self.dy, + self.dz, + self.x_bc, + self.y_bc, + self.z_bc, + self.update_pos, + self.update_v, + self.update_x, + self.update_y, + self.update_z, + self.update_vx, + self.update_vy, + self.update_vz, + self.shape, + self.dt, + self.species_meta, + ) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + x1, x2, x3, v1, v2, v3 = children + ( + name, + N_particles, + charge, + mass, + weight, + T, + x_wind, + y_wind, + z_wind, + dx, + dy, + dz, + x_bc, + y_bc, + z_bc, + update_pos, + update_v, + update_x, + update_y, + update_z, + update_vx, + update_vy, + update_vz, + shape, + dt, + species_meta, + ) = aux_data + return cls( + name=name, + N_particles=N_particles, + charge=charge, + mass=mass, + weight=weight, + T=T, + x1=x1, + x2=x2, + x3=x3, + v1=v1, + v2=v2, + v3=v3, + x_wind=x_wind, + y_wind=y_wind, + z_wind=z_wind, + dx=dx, + dy=dy, + dz=dz, + x_bc=x_bc, + y_bc=y_bc, + z_bc=z_bc, + update_pos=update_pos, + update_v=update_v, + update_x=update_x, + update_y=update_y, + update_z=update_z, + update_vx=update_vx, + update_vy=update_vy, + update_vz=update_vz, + shape=shape, + dt=dt, + species_meta=species_meta, + ) + + +def _normalize_attr(value): + try: + return jnp.asarray(value).item() + except Exception: + return value + + +def _same(attr_list): + norm = [_normalize_attr(v) for v in attr_list] + return len(set(norm)) == 1 + + +def check_flat_compat(particles): + if not particles: + return False + if not _same([p.get_shape() for p in particles]): + return False + if not _same([p.x_bc for p in particles]) or not _same([p.y_bc for p in particles]) or not _same([p.z_bc for p in particles]): + return False + if particles[0].x_bc != "periodic" or particles[0].y_bc != "periodic" or particles[0].z_bc != "periodic": + return False + if not _same([p.update_pos for p in particles]) or not _same([p.update_v for p in particles]): + return False + return True + + +def to_flat_particles(particles): + species_meta = [] + x_list, y_list, z_list = [], [], [] + vx_list, vy_list, vz_list = [], [], [] + q_list, m_list, w_list, T_list = [], [], [], [] + + for species in particles: + x, y, z = species.get_forward_position() + vx, vy, vz = species.get_velocity() + x_list.append(x) + y_list.append(y) + z_list.append(z) + vx_list.append(vx) + vy_list.append(vy) + vz_list.append(vz) + q_list.append(jnp.full_like(x, species.charge)) + m_list.append(jnp.full_like(x, species.mass)) + w_list.append(jnp.full_like(x, species.weight)) + T_list.append(jnp.full_like(x, species.T)) + species_meta.append( + { + "name": species.name, + "N_particles": float(species.N_particles), + "weight": float(species.weight), + "charge": float(species.charge), + "mass": float(species.mass), + "temperature": float(species.T), + "scaled mass": float(species.get_mass()), + "scaled charge": float(species.get_charge()), + "update_pos": species.update_pos, + "update_v": species.update_v, + } + ) + + x = jnp.concatenate(x_list, axis=0) + y = jnp.concatenate(y_list, axis=0) + z = jnp.concatenate(z_list, axis=0) + vx = jnp.concatenate(vx_list, axis=0) + vy = jnp.concatenate(vy_list, axis=0) + vz = jnp.concatenate(vz_list, axis=0) + charge = jnp.concatenate(q_list, axis=0) + mass = jnp.concatenate(m_list, axis=0) + weight = jnp.concatenate(w_list, axis=0) + T = jnp.concatenate(T_list, axis=0) + + first = particles[0] + flat = flat_particle_species( + name="flat_all", + N_particles=int(x.shape[0]), + charge=charge, + mass=mass, + weight=weight, + T=T, + x1=x, + x2=y, + x3=z, + v1=vx, + v2=vy, + v3=vz, + x_wind=first.x_wind, + y_wind=first.y_wind, + z_wind=first.z_wind, + dx=first.dx, + dy=first.dy, + dz=first.dz, + x_bc=first.x_bc, + y_bc=first.y_bc, + z_bc=first.z_bc, + update_pos=first.update_pos, + update_v=first.update_v, + update_x=first.update_x, + update_y=first.update_y, + update_z=first.update_z, + update_vx=first.update_vx, + update_vy=first.update_vy, + update_vz=first.update_vz, + shape=first.shape, + dt=first.dt, + species_meta=species_meta, + ) + return [flat] diff --git a/PyPIC3D/initialization.py b/PyPIC3D/initialization.py index 665a38b..c07e6f1 100644 --- a/PyPIC3D/initialization.py +++ b/PyPIC3D/initialization.py @@ -38,6 +38,10 @@ write_openpmd_initial_particles, write_openpmd_initial_fields ) +from PyPIC3D.flat_particles import ( + to_flat_particles, check_flat_compat +) + from PyPIC3D.evolve import ( time_loop_electrodynamic, time_loop_electrostatic, time_loop_vector_potential @@ -97,6 +101,7 @@ def default_parameters(): "name": "Default Simulation", "output_dir": os.getcwd(), "solver": "fdtd", # solver: spectral, fdtd, vector_potential, curl_curl + "fast_backend": "flat", # flat | default (flat when compatible, else fallback) "particle_bc": "periodic", # particle boundary conditions: periodic, absorb, reflect # "bc": "periodic", # boundary conditions: periodic, dirichlet, neumann "x_bc": "periodic", # x boundary conditions: periodic, conducting @@ -299,6 +304,23 @@ def initialize_simulation(toml_file): write_openpmd_initial_particles(particles, world, constants, simulation_parameters['output_dir']) # write the initial particles to an openPMD file + + fast_backend = simulation_parameters.get("fast_backend", "flat") + if fast_backend == "flat": + if electrostatic or solver == "vector_potential": + print("fast_backend='flat' not supported for electrostatic/vector_potential; falling back to default") + simulation_parameters["fast_backend"] = "default" + elif not check_flat_compat(particles): + print("fast_backend='flat' incompatible with species layout; falling back to default") + simulation_parameters["fast_backend"] = "default" + else: + print("Using flat particle backend for simulation") + particles = to_flat_particles(particles) + simulation_parameters["fast_backend"] = "flat" + # check if we can use the flat particle backend for faster performance, and convert to flat particles if possible + + + E, B, J, phi, rho = initialize_fields(Nx, Ny, Nz) # initialize the electric and magnetic fields diff --git a/PyPIC3D/particle.py b/PyPIC3D/particle.py index 4655277..6c43860 100644 --- a/PyPIC3D/particle.py +++ b/PyPIC3D/particle.py @@ -652,6 +652,9 @@ def set_mass(self, mass): def set_weight(self, weight): self.weight = weight + def get_weight(self): + return self.weight + def kinetic_energy(self): v2 = jnp.square(self.v1) + jnp.square(self.v2) + jnp.square(self.v3) # compute the square of the velocity diff --git a/PyPIC3D/utils.py b/PyPIC3D/utils.py index 1938e8d..4a917c6 100644 --- a/PyPIC3D/utils.py +++ b/PyPIC3D/utils.py @@ -696,6 +696,9 @@ def dump_parameters_to_toml(simulation_stats, simulation_parameters, plasma_para } for particle in particles: + if hasattr(particle, "species_meta") and particle.species_meta: + config["particles"].extend(particle.species_meta) + continue particle_dict = { "name": particle.name, "N_particles": float(particle.N_particles),