77
88import pytensor .scalar as ps
99from pytensor .compile .function import function
10- from pytensor .gradient import grad , hessian , jacobian
10+ from pytensor .gradient import grad , jacobian
1111from pytensor .graph import Apply , Constant , FunctionGraph
1212from pytensor .graph .basic import ancestors , truncated_graph_inputs
1313from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
@@ -483,6 +483,7 @@ def __init__(
483483 jac : bool = True ,
484484 hess : bool = False ,
485485 hessp : bool = False ,
486+ use_vectorized_jac : bool = False ,
486487 optimizer_kwargs : dict | None = None ,
487488 ):
488489 if not cast (TensorVariable , objective ).ndim == 0 :
@@ -495,6 +496,7 @@ def __init__(
495496 )
496497
497498 self .fgraph = FunctionGraph ([x , * args ], [objective ])
499+ self .use_vectorized_jac = use_vectorized_jac
498500
499501 if jac :
500502 grad_wrt_x = cast (
@@ -504,7 +506,12 @@ def __init__(
504506
505507 if hess :
506508 hess_wrt_x = cast (
507- Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
509+ Variable ,
510+ jacobian (
511+ self .fgraph .outputs [- 1 ],
512+ self .fgraph .inputs [0 ],
513+ vectorize = use_vectorized_jac ,
514+ ),
508515 )
509516 self .fgraph .add_output (hess_wrt_x )
510517
@@ -563,7 +570,7 @@ def L_op(self, inputs, outputs, output_grads):
563570 implicit_f ,
564571 [inner_x , * inner_args ],
565572 disconnected_inputs = "ignore" ,
566- vectorize = True ,
573+ vectorize = self . use_vectorized_jac ,
567574 )
568575 grad_wrt_args = implict_optimization_grads (
569576 df_dx = df_dx ,
@@ -583,6 +590,7 @@ def minimize(
583590 method : str = "BFGS" ,
584591 jac : bool = True ,
585592 hess : bool = False ,
593+ use_vectorized_jac : bool = False ,
586594 optimizer_kwargs : dict | None = None ,
587595) -> tuple [TensorVariable , TensorVariable ]:
588596 """
@@ -592,18 +600,21 @@ def minimize(
592600 ----------
593601 objective : TensorVariable
594602 The objective function to minimize. This should be a pytensor variable representing a scalar value.
595-
596- x : TensorVariable
603+ x: TensorVariable
597604 The variable with respect to which the objective function is minimized. It must be an input to the
598605 computational graph of `objective`.
599-
600- method : str, optional
606+ method: str, optional
601607 The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
602-
603- jac : bool, optional
604- Whether to compute and use the gradient of teh objective function with respect to x for optimization.
608+ jac: bool, optional
609+ Whether to compute and use the gradient of the objective function with respect to x for optimization.
605610 Default is True.
606-
611+ hess: bool, optional
612+ Whether to compute and use the Hessian of the objective function with respect to x for optimization.
613+ Default is False. Note that some methods require this, while others do not support it.
614+ use_vectorized_jac: bool, optional
615+ Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
616+ scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
617+ but use more memory. Default is False.
607618 optimizer_kwargs
608619 Additional keyword arguments to pass to scipy.optimize.minimize
609620
@@ -626,6 +637,7 @@ def minimize(
626637 method = method ,
627638 jac = jac ,
628639 hess = hess ,
640+ use_vectorized_jac = use_vectorized_jac ,
629641 optimizer_kwargs = optimizer_kwargs ,
630642 )
631643
@@ -806,6 +818,7 @@ def __init__(
806818 method : str = "hybr" ,
807819 jac : bool = True ,
808820 optimizer_kwargs : dict | None = None ,
821+ use_vectorized_jac : bool = False ,
809822 ):
810823 if cast (TensorVariable , variables ).ndim != cast (TensorVariable , equations ).ndim :
811824 raise ValueError (
@@ -820,7 +833,9 @@ def __init__(
820833
821834 if jac :
822835 jac_wrt_x = jacobian (
823- self .fgraph .outputs [0 ], self .fgraph .inputs [0 ], vectorize = True
836+ self .fgraph .outputs [0 ],
837+ self .fgraph .inputs [0 ],
838+ vectorize = use_vectorized_jac ,
824839 )
825840 self .fgraph .add_output (atleast_2d (jac_wrt_x ))
826841
@@ -927,6 +942,7 @@ def root(
927942 variables : TensorVariable ,
928943 method : str = "hybr" ,
929944 jac : bool = True ,
945+ use_vectorized_jac : bool = False ,
930946 optimizer_kwargs : dict | None = None ,
931947) -> tuple [TensorVariable , TensorVariable ]:
932948 """
@@ -945,6 +961,10 @@ def root(
945961 jac : bool, optional
946962 Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
947963 Default is True. Most methods require this.
964+ use_vectorized_jac: bool, optional
965+ Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
966+ This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
967+ Default is False.
948968 optimizer_kwargs : dict, optional
949969 Additional keyword arguments to pass to `scipy.optimize.root`.
950970
@@ -968,6 +988,7 @@ def root(
968988 method = method ,
969989 jac = jac ,
970990 optimizer_kwargs = optimizer_kwargs ,
991+ use_vectorized_jac = use_vectorized_jac ,
971992 )
972993
973994 solution , success = cast (
0 commit comments