Skip to content

Commit d01f1ad

Browse files
committed
Respect use_vectorized_jac in RootOp
1 parent 90ae165 commit d01f1ad

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

pytensor/tensor/optimize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ def __init__(
831831
)
832832

833833
self.fgraph = FunctionGraph([variables, *args], [equations])
834+
self.use_vectorized_jac = use_vectorized_jac
834835

835836
if jac:
836837
jac_wrt_x = jacobian(
@@ -914,12 +915,15 @@ def L_op(self, inputs, outputs, output_grads):
914915
inner_fx = self.fgraph.outputs[0]
915916

916917
df_dx = (
917-
jacobian(inner_fx, inner_x, vectorize=True)
918+
jacobian(inner_fx, inner_x, vectorize=self.use_vectorized_jac)
918919
if not self.jac
919920
else self.fgraph.outputs[1]
920921
)
921922
df_dtheta_columns = jacobian(
922-
inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True
923+
inner_fx,
924+
inner_args,
925+
disconnected_inputs="ignore",
926+
vectorize=self.use_vectorized_jac,
923927
)
924928

925929
grad_wrt_args = implict_optimization_grads(

0 commit comments

Comments
 (0)