Skip to content

Commit e2cb1e2

Browse files
committed
Optimize: Enforce input types, and stop appeasing mypy
1 parent 10d225d commit e2cb1e2

File tree

2 files changed

+71
-88
lines changed

2 files changed

+71
-88
lines changed

pytensor/tensor/optimize.py

Lines changed: 70 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
from collections.abc import Sequence
33
from copy import copy
4-
from typing import cast
54

65
import numpy as np
76

@@ -126,7 +125,9 @@ def clear_cache(self):
126125
self.hess_calls = 0
127126

128127

129-
def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
128+
def _find_optimization_parameters(
129+
objective: TensorVariable, x: TensorVariable
130+
) -> list[Variable]:
130131
"""
131132
Find the parameters of the optimization problem that are not the variable `x`.
132133
@@ -140,23 +141,19 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
140141

141142

142143
def _get_parameter_grads_from_vector(
143-
grad_wrt_args_vector: Variable,
144-
x_star: Variable,
144+
grad_wrt_args_vector: TensorVariable,
145+
x_star: TensorVariable,
145146
args: Sequence[Variable],
146-
output_grad: Variable,
147-
):
147+
output_grad: TensorVariable,
148+
) -> list[TensorVariable]:
148149
"""
149150
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
150151
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
151152
"""
152-
grad_wrt_args_vector = cast(TensorVariable, grad_wrt_args_vector)
153-
x_star = cast(TensorVariable, x_star)
154-
155153
cursor = 0
156154
grad_wrt_args = []
157155

158156
for arg in args:
159-
arg = cast(TensorVariable, arg)
160157
arg_shape = arg.shape
161158
arg_size = arg_shape.prod()
162159
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
@@ -268,38 +265,37 @@ def build_fn(self):
268265

269266

270267
def scalar_implict_optimization_grads(
271-
inner_fx: Variable,
272-
inner_x: Variable,
268+
inner_fx: TensorVariable,
269+
inner_x: TensorVariable,
273270
inner_args: Sequence[Variable],
274271
args: Sequence[Variable],
275-
x_star: Variable,
276-
output_grad: Variable,
272+
x_star: TensorVariable,
273+
output_grad: TensorVariable,
277274
fgraph: FunctionGraph,
278275
) -> list[Variable]:
279-
df_dx, *df_dthetas = cast(
280-
list[Variable],
281-
grad(inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"),
276+
df_dx, *df_dthetas = grad(
277+
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
282278
)
283279

284280
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
285281
df_dx_star, *df_dthetas_stars = graph_replace([df_dx, *df_dthetas], replace=replace)
286282

287283
grad_wrt_args = [
288284
(-df_dtheta_star / df_dx_star) * output_grad
289-
for df_dtheta_star in cast(list[TensorVariable], df_dthetas_stars)
285+
for df_dtheta_star in df_dthetas_stars
290286
]
291287

292288
return grad_wrt_args
293289

294290

295291
def implict_optimization_grads(
296-
df_dx: Variable,
297-
df_dtheta_columns: Sequence[Variable],
292+
df_dx: TensorVariable,
293+
df_dtheta_columns: Sequence[TensorVariable],
298294
args: Sequence[Variable],
299-
x_star: Variable,
300-
output_grad: Variable,
295+
x_star: TensorVariable,
296+
output_grad: TensorVariable,
301297
fgraph: FunctionGraph,
302-
):
298+
) -> list[TensorVariable]:
303299
r"""
304300
Compute gradients of an optimization problem with respect to its parameters.
305301
@@ -341,21 +337,15 @@ def implict_optimization_grads(
341337
fgraph : FunctionGraph
342338
The function graph that contains the inputs and outputs of the optimization problem.
343339
"""
344-
df_dx = cast(TensorVariable, df_dx)
345-
346340
df_dtheta = concatenate(
347-
[
348-
atleast_2d(jac_col, left=False)
349-
for jac_col in cast(list[TensorVariable], df_dtheta_columns)
350-
],
341+
[atleast_2d(jac_col, left=False) for jac_col in df_dtheta_columns],
351342
axis=-1,
352343
)
353344

354345
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
355346

356-
df_dx_star, df_dtheta_star = cast(
357-
list[TensorVariable],
358-
graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
347+
df_dx_star, df_dtheta_star = graph_replace(
348+
[atleast_2d(df_dx), df_dtheta], replace=replace
359349
)
360350

361351
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
@@ -369,20 +359,24 @@ def implict_optimization_grads(
369359
class MinimizeScalarOp(ScipyScalarWrapperOp):
370360
def __init__(
371361
self,
372-
x: Variable,
362+
x: TensorVariable,
373363
*args: Variable,
374-
objective: Variable,
375-
method: str = "brent",
364+
objective: TensorVariable,
365+
method: str,
376366
optimizer_kwargs: dict | None = None,
377367
):
378-
if not cast(TensorVariable, x).ndim == 0:
368+
if not (isinstance(x, TensorVariable) and x.ndim == 0):
379369
raise ValueError(
380370
"The variable `x` must be a scalar (0-dimensional) tensor for minimize_scalar."
381371
)
382-
if not cast(TensorVariable, objective).ndim == 0:
372+
if not (isinstance(objective, TensorVariable) and objective.ndim == 0):
383373
raise ValueError(
384374
"The objective function must be a scalar (0-dimensional) tensor for minimize_scalar."
385375
)
376+
if x not in ancestors([objective]):
377+
raise ValueError(
378+
"The variable `x` must be an input to the computational graph of the objective function."
379+
)
386380
self.fgraph = FunctionGraph([x, *args], [objective])
387381

388382
self.method = method
@@ -468,7 +462,6 @@ def minimize_scalar(
468462
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
469463
value, based on the requested convergence criteria.
470464
"""
471-
472465
args = _find_optimization_parameters(objective, x)
473466

474467
minimize_scalar_op = MinimizeScalarOp(
@@ -479,27 +472,29 @@ def minimize_scalar(
479472
optimizer_kwargs=optimizer_kwargs,
480473
)
481474

482-
solution, success = cast(
483-
tuple[TensorVariable, TensorVariable], minimize_scalar_op(x, *args)
484-
)
475+
solution, success = minimize_scalar_op(x, *args)
485476

486477
return solution, success
487478

488479

489480
class MinimizeOp(ScipyVectorWrapperOp):
490481
def __init__(
491482
self,
492-
x: Variable,
483+
x: TensorVariable,
493484
*args: Variable,
494-
objective: Variable,
495-
method: str = "BFGS",
485+
objective: TensorVariable,
486+
method: str,
496487
jac: bool = True,
497488
hess: bool = False,
498489
hessp: bool = False,
499490
use_vectorized_jac: bool = False,
500491
optimizer_kwargs: dict | None = None,
501492
):
502-
if not cast(TensorVariable, objective).ndim == 0:
493+
if not (isinstance(x, TensorVariable) and x.ndim in (0, 1)):
494+
raise ValueError(
495+
"The variable `x` must be a scalar or vector (0-or-1-dimensional) tensor for minimize."
496+
)
497+
if not (isinstance(objective, TensorVariable) and objective.ndim == 0):
503498
raise ValueError(
504499
"The objective function must be a scalar (0-dimensional) tensor for minimize."
505500
)
@@ -512,19 +507,14 @@ def __init__(
512507
self.use_vectorized_jac = use_vectorized_jac
513508

514509
if jac:
515-
grad_wrt_x = cast(
516-
Variable, grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
517-
)
510+
grad_wrt_x = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
518511
self.fgraph.add_output(grad_wrt_x)
519512

520513
if hess:
521-
hess_wrt_x = cast(
522-
Variable,
523-
jacobian(
524-
self.fgraph.outputs[-1],
525-
self.fgraph.inputs[0],
526-
vectorize=use_vectorized_jac,
527-
),
514+
hess_wrt_x = jacobian(
515+
self.fgraph.outputs[-1],
516+
self.fgraph.inputs[0],
517+
vectorize=use_vectorized_jac,
528518
)
529519
self.fgraph.add_output(hess_wrt_x)
530520

@@ -654,41 +644,39 @@ def minimize(
654644
optimizer_kwargs=optimizer_kwargs,
655645
)
656646

657-
solution, success = cast(
658-
tuple[TensorVariable, TensorVariable], minimize_op(x, *args)
659-
)
647+
solution, success = minimize_op(x, *args)
660648

661649
return solution, success
662650

663651

664652
class RootScalarOp(ScipyScalarWrapperOp):
665653
def __init__(
666654
self,
667-
variables,
668-
*args,
669-
equation,
670-
method,
655+
variables: TensorVariable,
656+
*args: Variable,
657+
equation: TensorVariable,
658+
method: str,
671659
jac: bool = False,
672660
hess: bool = False,
673661
optimizer_kwargs=None,
674662
):
675-
if not equation.ndim == 0:
663+
if not (isinstance(variables, TensorVariable) and variables.ndim == 0):
664+
raise ValueError(
665+
"The variable `x` must be a scalar (0-dimensional) tensor for root_scalar."
666+
)
667+
if not (isinstance(equation, TensorVariable) and equation.ndim == 0):
676668
raise ValueError(
677669
"The equation must be a scalar (0-dimensional) tensor for root_scalar."
678670
)
679-
if not isinstance(variables, Variable) or variables not in ancestors(
680-
[equation]
681-
):
671+
if variables not in ancestors([equation]):
682672
raise ValueError(
683673
"The variable `variables` must be an input to the computational graph of the equation."
684674
)
685675

686676
self.fgraph = FunctionGraph([variables, *args], [equation])
687677

688678
if jac:
689-
f_prime = cast(
690-
Variable, grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
691-
)
679+
f_prime = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
692680
self.fgraph.add_output(f_prime)
693681

694682
if hess:
@@ -697,9 +685,7 @@ def __init__(
697685
"Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also"
698686
" using first derivatives."
699687
)
700-
f_double_prime = cast(
701-
Variable, grad(self.fgraph.outputs[-1], self.fgraph.inputs[0])
702-
)
688+
f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0])
703689
self.fgraph.add_output(f_double_prime)
704690

705691
self.method = method
@@ -813,9 +799,7 @@ def root_scalar(
813799
optimizer_kwargs=optimizer_kwargs,
814800
)
815801

816-
solution, success = cast(
817-
tuple[TensorVariable, TensorVariable], root_scalar_op(variable, *args)
818-
)
802+
solution, success = root_scalar_op(variable, *args)
819803

820804
return solution, success
821805

@@ -825,15 +809,19 @@ class RootOp(ScipyVectorWrapperOp):
825809

826810
def __init__(
827811
self,
828-
variables: Variable,
812+
variables: TensorVariable,
829813
*args: Variable,
830-
equations: Variable,
831-
method: str = "hybr",
814+
equations: TensorVariable,
815+
method: str,
832816
jac: bool = True,
833817
optimizer_kwargs: dict | None = None,
834818
use_vectorized_jac: bool = False,
835819
):
836-
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
820+
if not isinstance(variables, TensorVariable):
821+
raise ValueError("The variable `variables` must be a tensor for root.")
822+
if not isinstance(equations, TensorVariable):
823+
raise ValueError("The equations must be a tensor for root.")
824+
if variables.ndim != equations.ndim:
837825
raise ValueError(
838826
"The variable `variables` must have the same number of dimensions as the equations."
839827
)
@@ -916,12 +904,8 @@ def perform(self, node, inputs, outputs):
916904
outputs[0][0] = res.x.reshape(variables.shape).astype(variables.dtype)
917905
outputs[1][0] = np.bool_(res.success)
918906

919-
def L_op(
920-
self,
921-
inputs: Sequence[Variable],
922-
outputs: Sequence[Variable],
923-
output_grads: Sequence[Variable],
924-
) -> list[Variable]:
907+
def L_op(self, inputs, outputs, output_grads):
908+
# TODO: Handle disconnected inputs
925909
x, *args = inputs
926910
x_star, _ = outputs
927911
output_grad, _ = output_grads
@@ -1004,9 +988,7 @@ def root(
1004988
use_vectorized_jac=use_vectorized_jac,
1005989
)
1006990

1007-
solution, success = cast(
1008-
tuple[TensorVariable, TensorVariable], root_op(variables, *args)
1009-
)
991+
solution, success = root_op(variables, *args)
1010992

1011993
return solution, success
1012994

scripts/mypy-failing.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pytensor/tensor/blas_headers.py
1515
pytensor/tensor/elemwise.py
1616
pytensor/tensor/extra_ops.py
1717
pytensor/tensor/math.py
18+
pytensor/tensor/optimize.py
1819
pytensor/tensor/random/basic.py
1920
pytensor/tensor/random/op.py
2021
pytensor/tensor/random/utils.py

0 commit comments

Comments
 (0)