Skip to content

Commit fbaa2fe

Browse files
committed
Optimize: Handle gradient wrt scalar inputs and guard against unsupported types
1 parent c333e3b commit fbaa2fe

File tree

2 files changed

+170
-13
lines changed

2 files changed

+170
-13
lines changed

pytensor/tensor/optimize.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,24 @@
66

77
import pytensor.scalar as ps
88
from pytensor.compile.function import function
9-
from pytensor.gradient import grad, jacobian
9+
from pytensor.gradient import grad, grad_not_implemented, jacobian
1010
from pytensor.graph.basic import Apply, Constant
1111
from pytensor.graph.fg import FunctionGraph
1212
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1313
from pytensor.graph.replace import graph_replace
1414
from pytensor.graph.traversal import ancestors, truncated_graph_inputs
15+
from pytensor.scalar import ScalarType, ScalarVariable
1516
from pytensor.tensor.basic import (
1617
atleast_2d,
1718
concatenate,
19+
scalar_from_tensor,
1820
tensor,
1921
tensor_from_scalar,
2022
zeros_like,
2123
)
2224
from pytensor.tensor.math import dot
2325
from pytensor.tensor.slinalg import solve
26+
from pytensor.tensor.type import DenseTensorType
2427
from pytensor.tensor.variable import TensorVariable, Variable
2528

2629

@@ -143,9 +146,9 @@ def _find_optimization_parameters(
143146
def _get_parameter_grads_from_vector(
144147
grad_wrt_args_vector: TensorVariable,
145148
x_star: TensorVariable,
146-
args: Sequence[Variable],
149+
args: Sequence[TensorVariable | ScalarVariable],
147150
output_grad: TensorVariable,
148-
) -> list[TensorVariable]:
151+
) -> list[TensorVariable | ScalarVariable]:
149152
"""
150153
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
151154
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
@@ -160,7 +163,10 @@ def _get_parameter_grads_from_vector(
160163
(*x_star.shape, *arg_shape)
161164
)
162165

163-
grad_wrt_args.append(dot(output_grad, arg_grad))
166+
grad_wrt_arg = dot(output_grad, arg_grad)
167+
if isinstance(arg.type, ScalarType):
168+
grad_wrt_arg = scalar_from_tensor(grad_wrt_arg)
169+
grad_wrt_args.append(grad_wrt_arg)
164170

165171
cursor += arg_size
166172

@@ -267,12 +273,12 @@ def build_fn(self):
267273
def scalar_implict_optimization_grads(
268274
inner_fx: TensorVariable,
269275
inner_x: TensorVariable,
270-
inner_args: Sequence[Variable],
271-
args: Sequence[Variable],
276+
inner_args: Sequence[TensorVariable | ScalarVariable],
277+
args: Sequence[TensorVariable | ScalarVariable],
272278
x_star: TensorVariable,
273279
output_grad: TensorVariable,
274280
fgraph: FunctionGraph,
275-
) -> list[Variable]:
281+
) -> list[TensorVariable | ScalarVariable]:
276282
df_dx, *df_dthetas = grad(
277283
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
278284
)
@@ -291,11 +297,11 @@ def scalar_implict_optimization_grads(
291297
def implict_optimization_grads(
292298
df_dx: TensorVariable,
293299
df_dtheta_columns: Sequence[TensorVariable],
294-
args: Sequence[Variable],
300+
args: Sequence[TensorVariable | ScalarVariable],
295301
x_star: TensorVariable,
296302
output_grad: TensorVariable,
297303
fgraph: FunctionGraph,
298-
) -> list[TensorVariable]:
304+
) -> list[TensorVariable | ScalarVariable]:
299305
r"""
300306
Compute gradients of an optimization problem with respect to its parameters.
301307
@@ -410,7 +416,19 @@ def perform(self, node, inputs, outputs):
410416
outputs[1][0] = np.bool_(res.success)
411417

412418
def L_op(self, inputs, outputs, output_grads):
419+
# TODO: Handle disconnected inputs
413420
x, *args = inputs
421+
if non_supported_types := tuple(
422+
inp.type
423+
for inp in inputs
424+
if not isinstance(inp.type, DenseTensorType | ScalarType)
425+
):
426+
# TODO: Support SparseTensorTypes
427+
# TODO: Remaining types are likely just disconnected anyway
428+
msg = f"Minimize gradient not implemented due to inputs of type {non_supported_types}"
429+
return [
430+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
431+
]
414432
x_star, _ = outputs
415433
output_grad, _ = output_grads
416434

@@ -560,7 +578,19 @@ def perform(self, node, inputs, outputs):
560578
outputs[1][0] = np.bool_(res.success)
561579

562580
def L_op(self, inputs, outputs, output_grads):
581+
# TODO: Handle disconnected inputs
563582
x, *args = inputs
583+
if non_supported_types := tuple(
584+
inp.type
585+
for inp in inputs
586+
if not isinstance(inp.type, DenseTensorType | ScalarType)
587+
):
588+
# TODO: Support SparseTensorTypes
589+
# TODO: Remaining types are likely just disconnected anyway
590+
msg = f"MinimizeOp gradient not implemented due to inputs of type {non_supported_types}"
591+
return [
592+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
593+
]
564594
x_star, _success = outputs
565595
output_grad, _ = output_grads
566596

@@ -727,7 +757,19 @@ def perform(self, node, inputs, outputs):
727757
outputs[1][0] = np.bool_(res.converged)
728758

729759
def L_op(self, inputs, outputs, output_grads):
760+
# TODO: Handle disconnected inputs
730761
x, *args = inputs
762+
if non_supported_types := tuple(
763+
inp.type
764+
for inp in inputs
765+
if not isinstance(inp.type, DenseTensorType | ScalarType)
766+
):
767+
# TODO: Support SparseTensorTypes
768+
# TODO: Remaining types are likely just disconnected anyway
769+
msg = f"RootScalarOp gradient not implemented due to inputs of type {non_supported_types}"
770+
return [
771+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
772+
]
731773
x_star, _ = outputs
732774
output_grad, _ = output_grads
733775

@@ -908,6 +950,17 @@ def perform(self, node, inputs, outputs):
908950
def L_op(self, inputs, outputs, output_grads):
909951
# TODO: Handle disconnected inputs
910952
x, *args = inputs
953+
if non_supported_types := tuple(
954+
inp.type
955+
for inp in inputs
956+
if not isinstance(inp.type, DenseTensorType | ScalarType)
957+
):
958+
# TODO: Support SparseTensorTypes
959+
# TODO: Remaining types are likely just disconnected anyway
960+
msg = f"RootOp gradient not implemented due to inputs of type {non_supported_types}"
961+
return [
962+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
963+
]
911964
x_star, _ = outputs
912965
output_grad, _ = output_grads
913966

tests/tensor/test_optimize.py

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6-
from pytensor import config, function
7-
from pytensor.graph import Apply, Op
8-
from pytensor.tensor import scalar
6+
from pytensor import Variable, config, function
7+
from pytensor.gradient import NullTypeGradError, disconnected_type
8+
from pytensor.graph import Apply, Op, Type
9+
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
910
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
1011
from tests import unittest_tools as utt
1112

@@ -224,7 +225,7 @@ def root_fn(x, a, b):
224225

225226

226227
@pytest.mark.parametrize("optimize_op", (minimize, root))
227-
def test_minimize_0d(optimize_op):
228+
def test_optimize_0d(optimize_op):
228229
# Scipy vector minimizers upcast 0d x to 1d. We need to work-around this
229230

230231
class AssertScalar(Op):
@@ -248,3 +249,106 @@ def L_op(self, inputs, outputs, out_grads):
248249
np.testing.assert_allclose(
249250
opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6
250251
)
252+
253+
254+
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
255+
def test_optimize_grad_scalar_arg(optimize_op):
256+
# Regression test for https://github.com/pymc-devs/pytensor/pull/1744
257+
x = scalar("x")
258+
theta = scalar("theta")
259+
theta_scalar = scalar_from_tensor(theta)
260+
obj = tensor_from_scalar((scalar_from_tensor(x) + theta_scalar) ** 2)
261+
x0, _ = optimize_op(obj, x)
262+
263+
# Confirm theta is a direct input to the node
264+
assert x0.owner.inputs[1] is theta_scalar
265+
266+
grad_wrt_theta = pt.grad(x0, theta)
267+
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: np.e}), -1)
268+
269+
270+
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
271+
def test_optimize_grad_disconnected_numerical_inp(optimize_op):
272+
x = scalar("x", dtype="float64")
273+
theta = scalar("theta", dtype="int64")
274+
obj = alloc(x**2, theta).sum() # repeat theta times and sum
275+
x0, _ = optimize_op(obj, x)
276+
277+
# Confirm theta is a direct input to the node
278+
assert x0.owner.inputs[1] is theta
279+
280+
# This should technically raise, but does not right now
281+
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="raise")
282+
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
283+
284+
# This should work even if the previous one raised
285+
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="ignore")
286+
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
287+
288+
289+
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
290+
def test_optimize_grad_disconnected_non_numerical_inp(optimize_op):
291+
class StrType(Type):
292+
def filter(self, x, **kwargs):
293+
if isinstance(x, str):
294+
return x
295+
raise TypeError
296+
297+
class SmileOrFrown(Op):
298+
def make_node(self, x, str_emoji):
299+
return Apply(self, [x, str_emoji], [x.type()])
300+
301+
def perform(self, node, inputs, output_storage):
302+
[x, str_emoji] = inputs
303+
match str_emoji:
304+
case ":)":
305+
out = np.array(x)
306+
case ":(":
307+
out = np.array(-x)
308+
case _:
309+
ValueError("str_emoji must be a smile or a frown")
310+
output_storage[0][0] = out
311+
312+
def connection_pattern(self, node):
313+
# Gradient connected only to first input
314+
return [[True], [False]]
315+
316+
def L_op(self, inputs, outputs, output_gradients):
317+
[_x, str_emoji] = inputs
318+
[g] = output_gradients
319+
return [
320+
self(g, str_emoji),
321+
disconnected_type(),
322+
]
323+
324+
# We could try to use real types like NoneTypeT or SliceType, but this is more robust to future API changes
325+
str_type = StrType()
326+
smile_or_frown = SmileOrFrown()
327+
328+
x = scalar("x", dtype="float64")
329+
num_theta = pt.scalar("num_theta", dtype="float64")
330+
str_theta = Variable(str_type, None, None, name="str_theta")
331+
obj = (smile_or_frown(x, str_theta) + num_theta) ** 2
332+
x_star, _ = optimize_op(obj, x)
333+
334+
# Confirm thetas are direct inputs to the node
335+
assert set(x_star.owner.inputs[1:]) == {num_theta, str_theta}
336+
337+
# Confirm forward pass works, no point in worrying about gradient otherwise
338+
np.testing.assert_allclose(
339+
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}),
340+
-np.e,
341+
)
342+
np.testing.assert_allclose(
343+
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":("}),
344+
np.e,
345+
)
346+
347+
with pytest.raises(NullTypeGradError):
348+
pt.grad(x_star, str_theta, disconnected_inputs="raise")
349+
350+
# This could be supported, but it is not right now.
351+
with pytest.raises(NullTypeGradError):
352+
_grad_wrt_num_theta = pt.grad(x_star, num_theta, disconnected_inputs="raise")
353+
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}), -1)
354+
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":("}), 1)

0 commit comments

Comments
 (0)