diff --git a/shap_e/rendering/raycast/cast.py b/shap_e/rendering/raycast/cast.py index 7ef357128..5252ddf24 100644 --- a/shap_e/rendering/raycast/cast.py +++ b/shap_e/rendering/raycast/cast.py @@ -111,7 +111,7 @@ def forward( def backward( ctx, _collides_grad, ray_dists_grad, _tri_indices_grad, barycentric_grad, normals_grad ): - origins, directions, faces, vertices = ctx.input_tensors + origins, directions, faces, vertices = ctx.saved_tensors origins = origins.detach().requires_grad_(True) directions = directions.detach().requires_grad_(True)