Skip to content

Commit ddac3aa

Browse files
WanHsuanLinweinbe58Copilot
authored
Subgrid shift (#33)
* add test cases for sub grid shift * add test in test_concrete * add implementation for shift_subgrid_x and shift_subgrid_y * add test for shift_sub_grid_x/y with slice * add implementation for shift_sub_grid_x/y with slice * fix format * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix format --------- Co-authored-by: Phillip Weinberg <pweinberg@quera.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 957ac1e commit ddac3aa

File tree

7 files changed

+301
-0
lines changed

7 files changed

+301
-0
lines changed

src/bloqade/geometry/dialects/grid/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,7 @@
3131
Scale as Scale,
3232
Shape as Shape,
3333
Shift as Shift,
34+
ShiftSubgridX as ShiftSubgridX,
35+
ShiftSubgridY as ShiftSubgridY,
3436
)
3537
from .types import Grid as Grid, GridType as GridType

src/bloqade/geometry/dialects/grid/_interface.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
Scale,
1818
Shape,
1919
Shift,
20+
ShiftSubgridX,
21+
ShiftSubgridY,
2022
)
2123
from .types import Grid
2224

@@ -231,6 +233,38 @@ def shift(grid: Grid[Nx, Ny], x_shift: float, y_shift: float) -> Grid[Nx, Ny]:
231233
...
232234

233235

236+
@_wraps(ShiftSubgridX)
237+
def shift_subgrid_x(
238+
grid: Grid[Nx, Ny], x_indices: ilist.IList[int, typing.Any], x_shift: float
239+
) -> Grid[Nx, Ny]:
240+
"""Shift a sub grid of grid in the x directions.
241+
242+
Args:
243+
grid (Grid): a grid object
244+
x_indices (ilist.IList[int, typing.Any]): a list/ilist of x indices to shift
245+
x_shift (float): shift in the x direction
246+
Returns:
247+
Grid: a new grid object that has been shifted
248+
"""
249+
...
250+
251+
252+
@_wraps(ShiftSubgridY)
253+
def shift_subgrid_y(
254+
grid: Grid[Nx, Ny], y_indices: ilist.IList[int, typing.Any], y_shift: float
255+
) -> Grid[Nx, Ny]:
256+
"""Shift a sub grid of grid in the y directions.
257+
258+
Args:
259+
grid (Grid): a grid object
260+
y_indices (ilist.IList[int, typing.Any]): a list/ilist of y indices to shift
261+
y_shift (float): shift in the y direction
262+
Returns:
263+
Grid: a new grid object that has been shifted
264+
"""
265+
...
266+
267+
234268
@_wraps(Shape)
235269
def shape(grid: Grid) -> tuple[int, int]:
236270
"""Get the shape of a grid.

src/bloqade/geometry/dialects/grid/concrete.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,32 @@ def shift(
154154

155155
return (grid.shift(x_shift, y_shift),)
156156

157+
@impl(stmts.ShiftSubgridX)
158+
def shift_subgrid_x(
159+
self,
160+
interp: Interpreter,
161+
frame: Frame,
162+
stmt: stmts.ShiftSubgridX,
163+
):
164+
grid = frame.get_casted(stmt.zone, Grid)
165+
x_indices = frame.get_casted(stmt.x_indices, ilist.IList)
166+
x_shift = frame.get_casted(stmt.x_shift, float)
167+
168+
return (grid.shift_subgrid_x(x_indices, x_shift),)
169+
170+
@impl(stmts.ShiftSubgridY)
171+
def shift_subgrid_y(
172+
self,
173+
interp: Interpreter,
174+
frame: Frame,
175+
stmt: stmts.ShiftSubgridY,
176+
):
177+
grid = frame.get_casted(stmt.zone, Grid)
178+
y_indices = frame.get_casted(stmt.y_indices, ilist.IList)
179+
y_shift = frame.get_casted(stmt.y_shift, float)
180+
181+
return (grid.shift_subgrid_y(y_indices, y_shift),)
182+
157183
@impl(stmts.Scale)
158184
def scale(
159185
self,

src/bloqade/geometry/dialects/grid/stmts.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,34 @@ class Shift(ir.Statement):
146146
result: ir.ResultValue = info.result(GridType[NumX, NumY])
147147

148148

149+
@statement(dialect=dialect)
150+
class ShiftSubgridX(ir.Statement):
151+
name = "shift_subgrid_x"
152+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
153+
zone: ir.SSAValue = info.argument(
154+
type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")]
155+
)
156+
x_indices: ir.SSAValue = info.argument(
157+
ilist.IListType[types.Int, types.TypeVar("SubNumX")]
158+
)
159+
x_shift: ir.SSAValue = info.argument(types.Float)
160+
result: ir.ResultValue = info.result(GridType[NumX, NumY])
161+
162+
163+
@statement(dialect=dialect)
164+
class ShiftSubgridY(ir.Statement):
165+
name = "shift_subgrid_y"
166+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
167+
zone: ir.SSAValue = info.argument(
168+
type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")]
169+
)
170+
y_indices: ir.SSAValue = info.argument(
171+
ilist.IListType[types.Int, types.TypeVar("SubNumY")]
172+
)
173+
y_shift: ir.SSAValue = info.argument(types.Float)
174+
result: ir.ResultValue = info.result(GridType[NumX, NumY])
175+
176+
149177
@statement(dialect=dialect)
150178
class Scale(ir.Statement):
151179
name = "scale_grid"

src/bloqade/geometry/dialects/grid/types.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,84 @@ def shift(self, x_shift: float, y_shift: float) -> "Grid[NumX, NumY]":
377377
y_init=self.y_init + y_shift if self.y_init is not None else None,
378378
)
379379

380+
def shift_subgrid_x(
381+
self, x_indices: ilist.IList[int, Nx] | slice, x_shift: float
382+
) -> "Grid[NumX, NumY]":
383+
"""Shift a sub grid of grid in the x directions.
384+
385+
Args:
386+
grid (Grid): a grid object
387+
x_indices (float): a list/ilist of x indices to shift
388+
x_shift (float): shift in the x direction
389+
Returns:
390+
Grid: a new grid object that has been shifted
391+
"""
392+
indices = get_indices(len(self.x_spacing) + 1, x_indices)
393+
394+
def shift_x(index):
395+
new_spacing = self.x_spacing[index]
396+
if index in indices and (index + 1) not in indices:
397+
new_spacing -= x_shift
398+
elif index not in indices and (index + 1) in indices:
399+
new_spacing += x_shift
400+
return new_spacing
401+
402+
new_spacing = tuple(shift_x(i) for i in range(len(self.x_spacing)))
403+
404+
assert all(
405+
x >= 0 for x in new_spacing
406+
), "Invalid shift: column order changes after shift."
407+
408+
x_init = self.x_init
409+
if x_init is not None and 0 in indices:
410+
x_init += x_shift
411+
412+
return Grid(
413+
x_spacing=new_spacing,
414+
y_spacing=self.y_spacing,
415+
x_init=x_init,
416+
y_init=self.y_init,
417+
)
418+
419+
def shift_subgrid_y(
420+
self, y_indices: ilist.IList[int, Ny] | slice, y_shift: float
421+
) -> "Grid[NumX, NumY]":
422+
"""Shift a sub grid of grid in the y directions.
423+
424+
Args:
425+
grid (Grid): a grid object
426+
y_indices (float): a list/ilist of y indices to shift
427+
y_shift (float): shift in the y direction
428+
Returns:
429+
Grid: a new grid object that has been shifted
430+
"""
431+
indices = get_indices(len(self.y_spacing) + 1, y_indices)
432+
433+
def shift_y(index):
434+
new_spacing = self.y_spacing[index]
435+
if index in indices and (index + 1) not in indices:
436+
new_spacing -= y_shift
437+
elif index not in indices and (index + 1) in indices:
438+
new_spacing += y_shift
439+
return new_spacing
440+
441+
new_spacing = tuple(shift_y(i) for i in range(len(self.y_spacing)))
442+
443+
assert all(
444+
y >= 0 for y in new_spacing
445+
), "Invalid shift: row order changes after shift."
446+
447+
y_init = self.y_init
448+
if y_init is not None and 0 in indices:
449+
y_init += y_shift
450+
451+
return Grid(
452+
x_spacing=self.x_spacing,
453+
y_spacing=new_spacing,
454+
x_init=self.x_init,
455+
y_init=y_init,
456+
)
457+
380458
def repeat(
381459
self, x_times: int, y_times: int, x_gap: float, y_gap: float
382460
) -> "Grid[NumX, NumY]":

test/grid/test_concrete.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def test_from_ranges(self):
8181
(grid.GetYPos, "y_positions", ()),
8282
(grid.Get, "get", ((1, 0),)),
8383
(grid.Shift, "shift", (1.0, 2.0)),
84+
(grid.ShiftSubgridX, "shift_subgrid_x", (ilist.IList([0]), -1)),
85+
(grid.ShiftSubgridY, "shift_subgrid_y", (ilist.IList([0]), -1)),
8486
(grid.Scale, "scale", (1.0, 2.0)),
8587
(grid.Repeat, "repeat", (1, 2, 0.5, 1.0)),
8688
(grid.GetSubGrid, "get_view", (ilist.IList((0,)), ilist.IList((1,)))),

test/grid/test_types.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,137 @@ def test_shift(self):
8888
)
8989
assert shifted_grid.is_equal(expected_grid)
9090

91+
@pytest.mark.parametrize(
92+
"x_indices, x_shift, expected_grid",
93+
[
94+
(
95+
ilist.IList([]),
96+
0,
97+
Grid(
98+
x_spacing=(1, 2, 3),
99+
y_spacing=(4, 5),
100+
x_init=1,
101+
y_init=2,
102+
),
103+
),
104+
(
105+
ilist.IList([0, 1]),
106+
1,
107+
Grid(
108+
x_spacing=(1, 1, 3),
109+
y_spacing=(4, 5),
110+
x_init=2,
111+
y_init=2,
112+
),
113+
),
114+
(
115+
ilist.IList([1]),
116+
1,
117+
Grid(
118+
x_spacing=(2, 1, 3),
119+
y_spacing=(4, 5),
120+
x_init=1,
121+
y_init=2,
122+
),
123+
),
124+
(
125+
ilist.IList([1, 2, 3]),
126+
1,
127+
Grid(
128+
x_spacing=(2, 2, 3),
129+
y_spacing=(4, 5),
130+
x_init=1,
131+
y_init=2,
132+
),
133+
),
134+
(
135+
slice(1, 4, 1),
136+
1,
137+
Grid(
138+
x_spacing=(2, 2, 3),
139+
y_spacing=(4, 5),
140+
x_init=1,
141+
y_init=2,
142+
),
143+
),
144+
(ilist.IList([1]), 3, None),
145+
],
146+
)
147+
def test_shift_subgrid_x(self, x_indices, x_shift, expected_grid):
148+
if expected_grid is None:
149+
with pytest.raises(AssertionError):
150+
shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift)
151+
return
152+
153+
shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift)
154+
assert shifted_grid.is_equal(expected_grid)
155+
156+
@pytest.mark.parametrize(
157+
"y_indices, y_shift, expected_grid",
158+
[
159+
(
160+
ilist.IList([]),
161+
0,
162+
Grid(
163+
x_spacing=(1, 2, 3),
164+
y_spacing=(4, 5),
165+
x_init=1,
166+
y_init=2,
167+
),
168+
),
169+
(
170+
ilist.IList([0]),
171+
-1,
172+
Grid(
173+
x_spacing=(1, 2, 3),
174+
y_spacing=(5, 5),
175+
x_init=1,
176+
y_init=1,
177+
),
178+
),
179+
(
180+
ilist.IList([1]),
181+
1,
182+
Grid(
183+
x_spacing=(1, 2, 3),
184+
y_spacing=(5, 4),
185+
x_init=1,
186+
y_init=2,
187+
),
188+
),
189+
(
190+
ilist.IList([0, 2]),
191+
1,
192+
Grid(
193+
x_spacing=(1, 2, 3),
194+
y_spacing=(3, 6),
195+
x_init=1,
196+
y_init=3,
197+
),
198+
),
199+
(
200+
slice(0, 1, 1),
201+
-1,
202+
Grid(
203+
x_spacing=(1, 2, 3),
204+
y_spacing=(5, 5),
205+
x_init=1,
206+
y_init=1,
207+
),
208+
),
209+
(ilist.IList([0]), 5, None),
210+
],
211+
)
212+
def test_shift_subgrid_y(self, y_indices, y_shift, expected_grid):
213+
214+
if expected_grid is None:
215+
with pytest.raises(AssertionError):
216+
shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift)
217+
return
218+
219+
shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift)
220+
assert shifted_grid.is_equal(expected_grid)
221+
91222
def test_scale(self):
92223
scaled_grid = self.grid_obj.scale(2, 3)
93224
expected_grid = Grid(

0 commit comments

Comments
 (0)