Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 213 additions & 82 deletions PyPIC3D/J.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,127 @@
from PyPIC3D.utils import digital_filter, wrap_around, bilinear_filter
from PyPIC3D.shapes import get_first_order_weights, get_second_order_weights

@partial(jit, static_argnames=("filter",))
def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'):

def _weights_order1(r):
w0 = 1.0 - r
w1 = r
return jnp.stack((w0, w1), axis=0)


def _weights_order2(r):
w0 = 0.5 * (0.5 - r) ** 2
w1 = 0.75 - r**2
w2 = 0.5 * (0.5 + r) ** 2
return jnp.stack((w0, w1, w2), axis=0)


def _deposit_1d(J_stack, dq, vx, vy, vz, x, grid_x0, dx, dt, shape_factor, Nx):
if shape_factor == 1:
x0 = jnp.floor((x - grid_x0) / dx).astype(jnp.int32)
deltax_node = (x - grid_x0) - x0 * dx
deltax_face = (x - grid_x0) - (x0 + 0.5) * dx

r_node = deltax_node / dx
r_face = deltax_face / dx

w_node = _weights_order1(r_node) # (2,Np)
w_face = _weights_order1(r_face) # (2,Np)

ix = jnp.stack((x0, x0 + 1), axis=0)
ix = wrap_around(ix, Nx)
else:
x0 = jnp.round((x - grid_x0) / dx).astype(jnp.int32)
deltax_node = (x - grid_x0) - x0 * dx
deltax_face = (x - grid_x0) - (x0 + 0.5) * dx

r_node = deltax_node / dx
r_face = deltax_face / dx

w_node = _weights_order2(r_node) # (3,Np)
w_face = _weights_order2(r_face) # (3,Np)

ix = jnp.stack((x0 - 1, x0, x0 + 1), axis=0)
ix = wrap_around(ix, Nx)

# Jx uses face weights; Jy/Jz use node weights.
val = jnp.stack(
(
(dq * vx)[None, :] * w_face,
(dq * vy)[None, :] * w_node,
(dq * vz)[None, :] * w_node,
),
axis=-1,
) # (S,Np,3)

comp = jnp.arange(3, dtype=ix.dtype)[None, None, :] # (1,1,3)
idx = ix[:, :, None] + comp * jnp.asarray(Nx, dtype=ix.dtype) # (S,Np,3)

out = jnp.bincount(
idx.reshape(-1),
weights=val.reshape(-1),
length=Nx * 3,
).reshape(3, Nx)

Jx, Jy, Jz = out[0], out[1], out[2]
return jnp.stack((Jx, Jy, Jz), axis=-1).reshape((Nx, 1, 1, 3))


def _deposit_2d(J_stack, dq, vx, vy, vz, x, y, xmin, ymin, dx, dy, dt, shape_factor, Nx, Ny):
if shape_factor == 1:
x0 = jnp.floor((x - xmin) / dx).astype(jnp.int32)
y0 = jnp.floor((y - ymin) / dy).astype(jnp.int32)
deltax_node = (x - xmin) - x0 * dx
deltay_node = (y - ymin) - y0 * dy
deltax_face = (x - xmin) - (x0 + 0.5) * dx
deltay_face = (y - ymin) - (y0 + 0.5) * dy

wx_node = _weights_order1(deltax_node / dx) # (2,Np)
wy_node = _weights_order1(deltay_node / dy) # (2,Np)
wx_face = _weights_order1(deltax_face / dx) # (2,Np)
wy_face = _weights_order1(deltay_face / dy) # (2,Np)

ix = jnp.stack((x0, x0 + 1), axis=0)
iy = jnp.stack((y0, y0 + 1), axis=0)
ix = wrap_around(ix, Nx)
iy = wrap_around(iy, Ny)
else:
x0 = jnp.round((x - xmin) / dx).astype(jnp.int32)
y0 = jnp.round((y - ymin) / dy).astype(jnp.int32)
deltax_node = (x - xmin) - x0 * dx
deltay_node = (y - ymin) - y0 * dy
deltax_face = (x - xmin) - (x0 + 0.5) * dx
deltay_face = (y - ymin) - (y0 + 0.5) * dy

wx_node = _weights_order2(deltax_node / dx) # (3,Np)
wy_node = _weights_order2(deltay_node / dy) # (3,Np)
wx_face = _weights_order2(deltax_face / dx) # (3,Np)
wy_face = _weights_order2(deltay_face / dy) # (3,Np)

ix = jnp.stack((x0 - 1, x0, x0 + 1), axis=0)
iy = jnp.stack((y0 - 1, y0, y0 + 1), axis=0)
ix = wrap_around(ix, Nx)
iy = wrap_around(iy, Ny)

idx = ix[:, None, :] + Nx * iy[None, :, :] # (Sx,Sy,Np)
idx_flat = idx.reshape(-1)

# weights for each component
wjx = wx_face[:, None, :] * wy_node[None, :, :]
wjy = wx_node[:, None, :] * wy_face[None, :, :]
wjz = wx_node[:, None, :] * wy_node[None, :, :]

valx = (dq * vx)[None, None, :] * wjx
valy = (dq * vy)[None, None, :] * wjy
valz = (dq * vz)[None, None, :] * wjz

vals = jnp.stack((valx, valy, valz), axis=-1).reshape(-1, 3)
J_flat = jax.ops.segment_sum(vals, idx_flat, num_segments=Nx * Ny) # (Nx*Ny,3)
J2 = J_flat.reshape((Nx, Ny, 3))[:, :, None, :]
return J2


@partial(jit, static_argnames=("filter", "shape_factor"))
def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear', shape_factor=2):
"""
Compute the current density from the charge density and particle velocities.

Expand All @@ -29,57 +148,64 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'):
dx = world['dx']
dy = world['dy']
dz = world['dz']
Nx = world['Nx']
Ny = world['Ny']
Nz = world['Nz']
# get the world parameters

Jx, Jy, Jz = J
Nx, Ny, Nz = Jx.shape
# get the world parameters
x_active = Jx.shape[0] != 1
y_active = Jx.shape[1] != 1
z_active = Jx.shape[2] != 1
# infer effective dimensionality from the current-grid shape

# unpack the values of J
Jx = Jx.at[:, :, :].set(0)
Jy = Jy.at[:, :, :].set(0)
Jz = Jz.at[:, :, :].set(0)
# initialize the current arrays as 0
J = (Jx, Jy, Jz)
# initialize the current density as a tuple
J_stack = jnp.zeros((Nx, Ny, Nz, 3), dtype=Jx.dtype)
# keep J together so deposition and filtering can be fused across components

for species in particles:
shape_factor = species.get_shape()
charge = species.get_charge()
dq = charge / (dx * dy * dz)
# calculate the charge density contribution per particle
x, y, z = species.get_forward_position()
vx, vy, vz = species.get_velocity()
# get the particles positions and velocities

x = x - vx * world['dt'] / 2
y = y - vy * world['dt'] / 2
z = z - vz * world['dt'] / 2
dt = world["dt"]
x = x - vx * dt / 2
# step back to half time step positions for proper time staggering

x0 = jax.lax.cond(
shape_factor == 1,
lambda _: jnp.floor( (x - grid[0][0]) / dx).astype(int),
lambda _: jnp.round( (x - grid[0][0]) / dx).astype(int),
operand=None
)
y0 = jax.lax.cond(
shape_factor == 1,
lambda _: jnp.floor( (y - grid[1][0]) / dy).astype(int),
lambda _: jnp.round( (y - grid[1][0]) / dy).astype(int),
operand=None
)
z0 = jax.lax.cond(
shape_factor == 1,
lambda _: jnp.floor( (z - grid[2][0]) / dz).astype(int),
lambda _: jnp.round( (z - grid[2][0]) / dz).astype(int),
operand=None
)
if Ny == 1 and Nz == 1:
J_stack = _deposit_1d(J_stack, dq, vx, vy, vz, x, grid[0][0], dx, world["dt"], shape_factor, Nx)
continue
if Nz == 1:
y = y - vy * dt / 2
J_stack = _deposit_2d(
J_stack,
dq,
vx,
vy,
vz,
x,
y,
grid[0][0],
grid[1][0],
dx,
dy,
world["dt"],
shape_factor,
Nx,
Ny,
)
continue

y = y - vy * dt / 2
z = z - vz * dt / 2

if shape_factor == 1:
x0 = jnp.floor((x - grid[0][0]) / dx).astype(jnp.int32)
y0 = jnp.floor((y - grid[1][0]) / dy).astype(jnp.int32)
z0 = jnp.floor((z - grid[2][0]) / dz).astype(jnp.int32)
else:
x0 = jnp.round((x - grid[0][0]) / dx).astype(jnp.int32)
y0 = jnp.round((y - grid[1][0]) / dy).astype(jnp.int32)
z0 = jnp.round((z - grid[2][0]) / dz).astype(jnp.int32)
# calculate the nearest grid point based on shape factor

deltax_node = (x - grid[0][0]) - (x0 * dx)
Expand All @@ -96,9 +222,9 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'):
y0 = wrap_around(y0, Ny)
z0 = wrap_around(z0, Nz)
# wrap around the grid points for periodic boundary conditions
x1 = wrap_around(x0+1, Nx)
y1 = wrap_around(y0+1, Ny)
z1 = wrap_around(z0+1, Nz)
x1 = wrap_around(x0 + 1, Nx)
y1 = wrap_around(y0 + 1, Ny)
z1 = wrap_around(z0 + 1, Nz)
# calculate the right grid point
x_minus1 = x0 - 1
y_minus1 = y0 - 1
Expand All @@ -109,19 +235,12 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'):
zpts = [z_minus1, z0, z1]
# place all the points in a list

x_weights_node, y_weights_node, z_weights_node = jax.lax.cond(
shape_factor == 1,
lambda _: get_first_order_weights( deltax_node, deltay_node, deltaz_node, dx, dy, dz),
lambda _: get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz),
operand=None
)

x_weights_face, y_weights_face, z_weights_face = jax.lax.cond(
shape_factor == 1,
lambda _: get_first_order_weights( deltax_face, deltay_face, deltaz_face, dx, dy, dz),
lambda _: get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz),
operand=None
)
if shape_factor == 1:
x_weights_node, y_weights_node, z_weights_node = get_first_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz)
x_weights_face, y_weights_face, z_weights_face = get_first_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz)
else:
x_weights_node, y_weights_node, z_weights_node = get_second_order_weights(deltax_node, deltay_node, deltaz_node, dx, dy, dz)
x_weights_face, y_weights_face, z_weights_face = get_second_order_weights(deltax_face, deltay_face, deltaz_face, dx, dy, dz)
# get the weights for node and face positions

xpts = jnp.asarray(xpts) # (Sx, Np)
Expand All @@ -136,6 +255,18 @@ def J_from_rhov(particles, J, constants, world, grid=None, filter='bilinear'):
y_weights_node = jnp.asarray(y_weights_node) # (Sy, Np)
z_weights_node = jnp.asarray(z_weights_node) # (Sz, Np)

if shape_factor == 1:
# drop the redundant (-1) stencil point for first-order (its weights are identically 0)
xpts = xpts[1:, ...]
ypts = ypts[1:, ...]
zpts = zpts[1:, ...]
x_weights_face = x_weights_face[1:, ...]
y_weights_face = y_weights_face[1:, ...]
z_weights_face = z_weights_face[1:, ...]
x_weights_node = x_weights_node[1:, ...]
y_weights_node = y_weights_node[1:, ...]
z_weights_node = z_weights_node[1:, ...]

# Keep full shape-factor computation but collapse inactive axes to an
# effective stencil of size 1 to avoid redundant deposition work.
if x_active:
Expand Down Expand Up @@ -184,41 +315,41 @@ def idx_and_dJ_values(idx):
valy = (dq * vy) * x_weights_node_eff[i, ...] * y_weights_face_eff[j, ...] * z_weights_node_eff[k, ...]
valz = (dq * vz) * x_weights_node_eff[i, ...] * y_weights_node_eff[j, ...] * z_weights_face_eff[k, ...]
# calculate the current contributions for this stencil point
return ix, iy, iz, valx, valy, valz
return ix, iy, iz, jnp.stack((valx, valy, valz), axis=-1)

ix, iy, iz, valx, valy, valz = jax.vmap(idx_and_dJ_values)(combos) # each: (M, Np)
ix, iy, iz, dJ = jax.vmap(idx_and_dJ_values)(combos) # (M,Np), (M,Np), (M,Np), (M,Np,3)
# vectorized computation of indices and current contributions

Jx = Jx.at[(ix, iy, iz)].add(valx, mode="drop")
Jy = Jy.at[(ix, iy, iz)].add(valy, mode="drop")
Jz = Jz.at[(ix, iy, iz)].add(valz, mode="drop")
# deposit the current contributions into the global J arrays

def filter_func(J_, filter):
J_ = jax.lax.cond(
filter == 'bilinear',
lambda J_: bilinear_filter(J_),
lambda J_: J_,
operand=J_
ix_flat = ix.reshape(-1)
iy_flat = iy.reshape(-1)
iz_flat = iz.reshape(-1)
dJ_flat = dJ.reshape(-1, 3)

in_bounds = (
(ix_flat >= 0)
& (ix_flat < Nx)
& (iy_flat >= 0)
& (iy_flat < Ny)
& (iz_flat >= 0)
& (iz_flat < Nz)
)

# alpha = constants['alpha']
# J_ = jax.lax.cond(
# filter == 'digital',
# lambda J_: digital_filter(J_, alpha),
# lambda J_: J_,
# operand=J_
# )
return J_
# define a filtering function

Jx = filter_func(Jx, filter)
Jy = filter_func(Jy, filter)
Jz = filter_func(Jz, filter)
# apply the selected filter to each component of J
J = (Jx, Jy, Jz)

return J
ix_flat = jnp.clip(ix_flat, 0, Nx - 1)
iy_flat = jnp.clip(iy_flat, 0, Ny - 1)
iz_flat = jnp.clip(iz_flat, 0, Nz - 1)

idx_flat = ix_flat + Nx * (iy_flat + Ny * iz_flat)
dJ_flat = jnp.where(in_bounds[:, None], dJ_flat, 0)

J_flat = jax.ops.segment_sum(dJ_flat, idx_flat, num_segments=Nx * Ny * Nz)
J_stack = J_stack + J_flat.reshape((Nx, Ny, Nz, 3))
# segment_sum avoids large scatter updates on CPU

if filter == "bilinear":
J_stack = bilinear_filter(J_stack)
# (optional) digital filter disabled by default

return (J_stack[..., 0], J_stack[..., 1], J_stack[..., 2])

def _roll_old_weights_to_new_frame(old_w_list, shift):
"""
Expand Down
Loading
Loading