Skip to content

Commit 51e3fcc

Browse files
Athe-kunaloulgen
authored andcommitted
Helion Puzzle docs bug fixes (#1062)
1 parent 7b36a9e commit 51e3fcc

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

docs/helion_puzzles.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ While PyTorch and torch.compile automatically generates the backwards pass for y
269269
x = x.clone()
270270
y = y.clone()
271271
x = x.requires_grad_(True)
272+
y = y.requires_grad_(True)
272273
z = torch.relu(x * y[:, None])
273274
grad_x, grad_y = torch.autograd.grad(z, [x, y], dz, retain_graph=True)
274275
return grad_x
@@ -325,7 +326,7 @@ Sum of a batch of numbers.
325326
# Use Helion to tile the batch dimension
326327
for tile_batch in hl.tile(batch):
327328
# Initialize accumulator for each batch element
328-
acc = torch.zeros_like(tile_batch, dtype=torch.float32)
329+
acc = torch.zeros(tile_batch, dtype=torch.float32, device=x.device)
329330
330331
# Process the sequence in chunks
331332
for tile_seq in hl.tile(seq_len):

0 commit comments

Comments
 (0)