11import logging
22from collections .abc import Sequence
33from copy import copy
4- from typing import cast
54
65import numpy as np
76
@@ -126,7 +125,9 @@ def clear_cache(self):
126125 self .hess_calls = 0
127126
128127
129- def _find_optimization_parameters (objective : TensorVariable , x : TensorVariable ):
128+ def _find_optimization_parameters (
129+ objective : TensorVariable , x : TensorVariable
130+ ) -> list [Variable ]:
130131 """
131132 Find the parameters of the optimization problem that are not the variable `x`.
132133
@@ -140,23 +141,19 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
140141
141142
142143def _get_parameter_grads_from_vector (
143- grad_wrt_args_vector : Variable ,
144- x_star : Variable ,
144+ grad_wrt_args_vector : TensorVariable ,
145+ x_star : TensorVariable ,
145146 args : Sequence [Variable ],
146- output_grad : Variable ,
147- ):
147+ output_grad : TensorVariable ,
148+ ) -> list [ TensorVariable ] :
148149 """
149150 Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
150151 returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
151152 """
152- grad_wrt_args_vector = cast (TensorVariable , grad_wrt_args_vector )
153- x_star = cast (TensorVariable , x_star )
154-
155153 cursor = 0
156154 grad_wrt_args = []
157155
158156 for arg in args :
159- arg = cast (TensorVariable , arg )
160157 arg_shape = arg .shape
161158 arg_size = arg_shape .prod ()
162159 arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
@@ -268,38 +265,37 @@ def build_fn(self):
268265
269266
270267def scalar_implict_optimization_grads (
271- inner_fx : Variable ,
272- inner_x : Variable ,
268+ inner_fx : TensorVariable ,
269+ inner_x : TensorVariable ,
273270 inner_args : Sequence [Variable ],
274271 args : Sequence [Variable ],
275- x_star : Variable ,
276- output_grad : Variable ,
272+ x_star : TensorVariable ,
273+ output_grad : TensorVariable ,
277274 fgraph : FunctionGraph ,
278275) -> list [Variable ]:
279- df_dx , * df_dthetas = cast (
280- list [Variable ],
281- grad (inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore" ),
276+ df_dx , * df_dthetas = grad (
277+ inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore"
282278 )
283279
284280 replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
285281 df_dx_star , * df_dthetas_stars = graph_replace ([df_dx , * df_dthetas ], replace = replace )
286282
287283 grad_wrt_args = [
288284 (- df_dtheta_star / df_dx_star ) * output_grad
289- for df_dtheta_star in cast ( list [ TensorVariable ], df_dthetas_stars )
285+ for df_dtheta_star in df_dthetas_stars
290286 ]
291287
292288 return grad_wrt_args
293289
294290
295291def implict_optimization_grads (
296- df_dx : Variable ,
297- df_dtheta_columns : Sequence [Variable ],
292+ df_dx : TensorVariable ,
293+ df_dtheta_columns : Sequence [TensorVariable ],
298294 args : Sequence [Variable ],
299- x_star : Variable ,
300- output_grad : Variable ,
295+ x_star : TensorVariable ,
296+ output_grad : TensorVariable ,
301297 fgraph : FunctionGraph ,
302- ):
298+ ) -> list [ TensorVariable ] :
303299 r"""
304300 Compute gradients of an optimization problem with respect to its parameters.
305301
@@ -341,21 +337,15 @@ def implict_optimization_grads(
341337 fgraph : FunctionGraph
342338 The function graph that contains the inputs and outputs of the optimization problem.
343339 """
344- df_dx = cast (TensorVariable , df_dx )
345-
346340 df_dtheta = concatenate (
347- [
348- atleast_2d (jac_col , left = False )
349- for jac_col in cast (list [TensorVariable ], df_dtheta_columns )
350- ],
341+ [atleast_2d (jac_col , left = False ) for jac_col in df_dtheta_columns ],
351342 axis = - 1 ,
352343 )
353344
354345 replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
355346
356- df_dx_star , df_dtheta_star = cast (
357- list [TensorVariable ],
358- graph_replace ([atleast_2d (df_dx ), df_dtheta ], replace = replace ),
347+ df_dx_star , df_dtheta_star = graph_replace (
348+ [atleast_2d (df_dx ), df_dtheta ], replace = replace
359349 )
360350
361351 grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
@@ -369,20 +359,24 @@ def implict_optimization_grads(
369359class MinimizeScalarOp (ScipyScalarWrapperOp ):
370360 def __init__ (
371361 self ,
372- x : Variable ,
362+ x : TensorVariable ,
373363 * args : Variable ,
374- objective : Variable ,
375- method : str = "brent" ,
364+ objective : TensorVariable ,
365+ method : str ,
376366 optimizer_kwargs : dict | None = None ,
377367 ):
378- if not cast ( TensorVariable , x ) .ndim == 0 :
368+ if not ( isinstance ( x , TensorVariable ) and x .ndim == 0 ) :
379369 raise ValueError (
380370 "The variable `x` must be a scalar (0-dimensional) tensor for minimize_scalar."
381371 )
382- if not cast ( TensorVariable , objective ) .ndim == 0 :
372+ if not ( isinstance ( objective , TensorVariable ) and objective .ndim == 0 ) :
383373 raise ValueError (
384374 "The objective function must be a scalar (0-dimensional) tensor for minimize_scalar."
385375 )
376+ if x not in ancestors ([objective ]):
377+ raise ValueError (
378+ "The variable `x` must be an input to the computational graph of the objective function."
379+ )
386380 self .fgraph = FunctionGraph ([x , * args ], [objective ])
387381
388382 self .method = method
@@ -468,7 +462,6 @@ def minimize_scalar(
468462 Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
469463 value, based on the requested convergence criteria.
470464 """
471-
472465 args = _find_optimization_parameters (objective , x )
473466
474467 minimize_scalar_op = MinimizeScalarOp (
@@ -479,27 +472,29 @@ def minimize_scalar(
479472 optimizer_kwargs = optimizer_kwargs ,
480473 )
481474
482- solution , success = cast (
483- tuple [TensorVariable , TensorVariable ], minimize_scalar_op (x , * args )
484- )
475+ solution , success = minimize_scalar_op (x , * args )
485476
486477 return solution , success
487478
488479
489480class MinimizeOp (ScipyVectorWrapperOp ):
490481 def __init__ (
491482 self ,
492- x : Variable ,
483+ x : TensorVariable ,
493484 * args : Variable ,
494- objective : Variable ,
495- method : str = "BFGS" ,
485+ objective : TensorVariable ,
486+ method : str ,
496487 jac : bool = True ,
497488 hess : bool = False ,
498489 hessp : bool = False ,
499490 use_vectorized_jac : bool = False ,
500491 optimizer_kwargs : dict | None = None ,
501492 ):
502- if not cast (TensorVariable , objective ).ndim == 0 :
493+ if not (isinstance (x , TensorVariable ) and x .ndim in (0 , 1 )):
494+ raise ValueError (
495+ "The variable `x` must be a scalar or vector (0-or-1-dimensional) tensor for minimize."
496+ )
497+ if not (isinstance (objective , TensorVariable ) and objective .ndim == 0 ):
503498 raise ValueError (
504499 "The objective function must be a scalar (0-dimensional) tensor for minimize."
505500 )
@@ -512,19 +507,14 @@ def __init__(
512507 self .use_vectorized_jac = use_vectorized_jac
513508
514509 if jac :
515- grad_wrt_x = cast (
516- Variable , grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
517- )
510+ grad_wrt_x = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
518511 self .fgraph .add_output (grad_wrt_x )
519512
520513 if hess :
521- hess_wrt_x = cast (
522- Variable ,
523- jacobian (
524- self .fgraph .outputs [- 1 ],
525- self .fgraph .inputs [0 ],
526- vectorize = use_vectorized_jac ,
527- ),
514+ hess_wrt_x = jacobian (
515+ self .fgraph .outputs [- 1 ],
516+ self .fgraph .inputs [0 ],
517+ vectorize = use_vectorized_jac ,
528518 )
529519 self .fgraph .add_output (hess_wrt_x )
530520
@@ -654,41 +644,39 @@ def minimize(
654644 optimizer_kwargs = optimizer_kwargs ,
655645 )
656646
657- solution , success = cast (
658- tuple [TensorVariable , TensorVariable ], minimize_op (x , * args )
659- )
647+ solution , success = minimize_op (x , * args )
660648
661649 return solution , success
662650
663651
664652class RootScalarOp (ScipyScalarWrapperOp ):
665653 def __init__ (
666654 self ,
667- variables ,
668- * args ,
669- equation ,
670- method ,
655+ variables : TensorVariable ,
656+ * args : Variable ,
657+ equation : TensorVariable ,
658+ method : str ,
671659 jac : bool = False ,
672660 hess : bool = False ,
673661 optimizer_kwargs = None ,
674662 ):
675- if not equation .ndim == 0 :
663+ if not (isinstance (variables , TensorVariable ) and variables .ndim == 0 ):
664+ raise ValueError (
665+ "The variable `x` must be a scalar (0-dimensional) tensor for root_scalar."
666+ )
667+ if not (isinstance (equation , TensorVariable ) and equation .ndim == 0 ):
676668 raise ValueError (
677669 "The equation must be a scalar (0-dimensional) tensor for root_scalar."
678670 )
679- if not isinstance (variables , Variable ) or variables not in ancestors (
680- [equation ]
681- ):
671+ if variables not in ancestors ([equation ]):
682672 raise ValueError (
683673 "The variable `variables` must be an input to the computational graph of the equation."
684674 )
685675
686676 self .fgraph = FunctionGraph ([variables , * args ], [equation ])
687677
688678 if jac :
689- f_prime = cast (
690- Variable , grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
691- )
679+ f_prime = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
692680 self .fgraph .add_output (f_prime )
693681
694682 if hess :
@@ -697,9 +685,7 @@ def __init__(
697685 "Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also"
698686 " using first derivatives."
699687 )
700- f_double_prime = cast (
701- Variable , grad (self .fgraph .outputs [- 1 ], self .fgraph .inputs [0 ])
702- )
688+ f_double_prime = grad (self .fgraph .outputs [- 1 ], self .fgraph .inputs [0 ])
703689 self .fgraph .add_output (f_double_prime )
704690
705691 self .method = method
@@ -813,9 +799,7 @@ def root_scalar(
813799 optimizer_kwargs = optimizer_kwargs ,
814800 )
815801
816- solution , success = cast (
817- tuple [TensorVariable , TensorVariable ], root_scalar_op (variable , * args )
818- )
802+ solution , success = root_scalar_op (variable , * args )
819803
820804 return solution , success
821805
@@ -825,15 +809,19 @@ class RootOp(ScipyVectorWrapperOp):
825809
826810 def __init__ (
827811 self ,
828- variables : Variable ,
812+ variables : TensorVariable ,
829813 * args : Variable ,
830- equations : Variable ,
831- method : str = "hybr" ,
814+ equations : TensorVariable ,
815+ method : str ,
832816 jac : bool = True ,
833817 optimizer_kwargs : dict | None = None ,
834818 use_vectorized_jac : bool = False ,
835819 ):
836- if cast (TensorVariable , variables ).ndim != cast (TensorVariable , equations ).ndim :
820+ if not isinstance (variables , TensorVariable ):
821+ raise ValueError ("The variable `variables` must be a tensor for root." )
822+ if not isinstance (equations , TensorVariable ):
823+ raise ValueError ("The equations must be a tensor for root." )
824+ if variables .ndim != equations .ndim :
837825 raise ValueError (
838826 "The variable `variables` must have the same number of dimensions as the equations."
839827 )
@@ -916,12 +904,8 @@ def perform(self, node, inputs, outputs):
916904 outputs [0 ][0 ] = res .x .reshape (variables .shape ).astype (variables .dtype )
917905 outputs [1 ][0 ] = np .bool_ (res .success )
918906
919- def L_op (
920- self ,
921- inputs : Sequence [Variable ],
922- outputs : Sequence [Variable ],
923- output_grads : Sequence [Variable ],
924- ) -> list [Variable ]:
907+ def L_op (self , inputs , outputs , output_grads ):
908+ # TODO: Handle disconnected inputs
925909 x , * args = inputs
926910 x_star , _ = outputs
927911 output_grad , _ = output_grads
@@ -1004,9 +988,7 @@ def root(
1004988 use_vectorized_jac = use_vectorized_jac ,
1005989 )
1006990
1007- solution , success = cast (
1008- tuple [TensorVariable , TensorVariable ], root_op (variables , * args )
1009- )
991+ solution , success = root_op (variables , * args )
1010992
1011993 return solution , success
1012994
0 commit comments