Skip to content

Commit 90ab712

Browse files
committed
.speedup composite rewrites
This should be cleaned up into nicer git commits
1 parent 272816e commit 90ab712

File tree

2 files changed

+87
-47
lines changed

2 files changed

+87
-47
lines changed

pytensor/scalar/basic.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,9 @@ def make_node(self, *inputs):
12551255
f"Wrong number of inputs for {self}.make_node "
12561256
f"(got {len(inputs)}({inputs}), expected {self.nin})"
12571257
)
1258-
inputs = [as_scalar(input) for input in inputs]
1258+
inputs = [
1259+
inp if isinstance(inp, ScalarVariable) else as_scalar(inp) for inp in inputs
1260+
]
12591261
outputs = [t() for t in self.output_types([input.type for input in inputs])]
12601262
if len(outputs) != self.nout:
12611263
inputs_str = (", ".join(str(input) for input in inputs),)
@@ -4294,7 +4296,13 @@ class Composite(ScalarInnerGraphOp):
42944296
init_param: tuple[str, ...] = ("inputs", "outputs")
42954297

42964298
def __init__(
4297-
self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True
4299+
self,
4300+
inputs,
4301+
outputs,
4302+
name="Composite",
4303+
clone_graph: builtins.bool = True,
4304+
cleanup_graph: builtins.bool = True,
4305+
output_types_preference=None,
42984306
):
42994307
self.name = name
43004308
self._name = None
@@ -4324,7 +4332,9 @@ def __init__(
43244332
# 1. Create a new graph from inputs up to the
43254333
# Composite
43264334
res = pytensor.compile.rebuild_collect_shared(
4327-
inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False
4335+
inputs=inputs,
4336+
outputs=outputs[0].owner.inputs,
4337+
copy_inputs_over=False,
43284338
) # Clone also the inputs
43294339
# 2. We continue this partial clone with the graph in
43304340
# the inner Composite
@@ -4338,36 +4348,42 @@ def __init__(
43384348
assert res[0] != inputs
43394349
inputs, outputs = res[0], res2[1]
43404350

4341-
# We already cloned the graph, or the user told us there was no need for it
4342-
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
4351+
if cleanup_graph:
4352+
# We already cloned the graph, or the user told us there was no need for it
4353+
self.inputs, self.outputs = self._cleanup_graph(
4354+
inputs, outputs, clone=False
4355+
)
4356+
else:
4357+
self.inputs, self.outputs = inputs, outputs
43434358
self.inputs_type = tuple(input.type for input in self.inputs)
43444359
self.outputs_type = tuple(output.type for output in self.outputs)
43454360
self.nin = len(inputs)
43464361
self.nout = len(outputs)
4347-
super().__init__()
4362+
super().__init__(output_types_preference=output_types_preference)
43484363

43494364
def __str__(self):
43504365
if self._name is not None:
43514366
return self._name
43524367

4353-
# Rename internal variables
4354-
for i, r in enumerate(self.fgraph.inputs):
4355-
r.name = f"i{i}"
4356-
for i, r in enumerate(self.fgraph.outputs):
4357-
r.name = f"o{i}"
4358-
io = set(self.fgraph.inputs + self.fgraph.outputs)
4359-
for i, r in enumerate(self.fgraph.variables):
4360-
if (
4361-
not isinstance(r, Constant)
4362-
and r not in io
4363-
and len(self.fgraph.clients[r]) > 1
4364-
):
4365-
r.name = f"t{i}"
4368+
fgraph = self.fgraph
43664369

4367-
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
4370+
if len(fgraph.outputs) > 1 or len(fgraph.apply_nodes) > 10:
43684371
self._name = "Composite{...}"
43694372
else:
4370-
outputs_str = ", ".join(pprint(output) for output in self.fgraph.outputs)
4373+
# Rename internal variables
4374+
for i, r in enumerate(fgraph.inputs):
4375+
r.name = f"i{i}"
4376+
for i, r in enumerate(fgraph.outputs):
4377+
r.name = f"o{i}"
4378+
io = set(fgraph.inputs + fgraph.outputs)
4379+
for i, r in enumerate(fgraph.variables):
4380+
if (
4381+
not isinstance(r, Constant)
4382+
and r not in io
4383+
and len(fgraph.clients[r]) > 1
4384+
):
4385+
r.name = f"t{i}"
4386+
outputs_str = ", ".join(pprint(output) for output in fgraph.outputs)
43714387
self._name = f"Composite{{{outputs_str}}}"
43724388

43734389
return self._name
@@ -4380,12 +4396,16 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43804396
43814397
"""
43824398
d = {k: getattr(self, k) for k in self.init_param}
4383-
out = self.__class__(**d)
4384-
if name:
4385-
out.name = name
4386-
else:
4387-
name = out.name
4388-
super(Composite, out).__init__(output_types_preference, name)
4399+
out = type(self)(
4400+
**d,
4401+
cleanup_graph=False,
4402+
clone_graph=False,
4403+
output_types_preference=output_types_preference,
4404+
name=name or self.name,
4405+
)
4406+
# No need to recompute the _cocde and nodenames if they were already computed (which is true if the hash of the Op was requested)
4407+
out._c_code = self._c_code
4408+
out.nodenames = self.nodenames
43894409
return out
43904410

43914411
@property
@@ -4452,9 +4472,10 @@ def c_code_template(self):
44524472
fg = self.fgraph
44534473
subd = {e: f"%(i{i})s" for i, e in enumerate(fg.inputs)}
44544474

4475+
inputs_set = frozenset(fg.inputs)
44554476
for var in fg.variables:
44564477
if var.owner is None:
4457-
if var not in fg.inputs:
4478+
if var not in inputs_set:
44584479
# This is an orphan
44594480
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
44604481
subd[var] = f"({var.type.c_literal(var.data)})"

pytensor/tensor/rewriting/elemwise.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
)
2929
from pytensor.graph.rewriting.db import SequenceDB
3030
from pytensor.graph.rewriting.unify import OpPattern
31-
from pytensor.graph.traversal import toposort
31+
from pytensor.graph.traversal import graph_inputs, toposort
3232
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
33+
from pytensor.scalar import ScalarConstant
3334
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3435
from pytensor.tensor.basic import (
3536
MakeVector,
@@ -885,26 +886,44 @@ def print_profile(stream, prof, level=0):
885886
def local_useless_composite_outputs(fgraph, node):
886887
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
887888
comp = node.op.scalar_op
888-
used_outputs_idxs = [
889-
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
890-
]
891-
used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
892-
comp_fgraph = FunctionGraph(
893-
inputs=comp.inputs, outputs=used_inner_outputs, clone=False
894-
)
889+
890+
clients = fgraph.clients
891+
outer_inputs, outer_outputs = node.inputs, node.outputs
892+
inner_inputs, inner_outputs = comp.inputs, comp.outputs
893+
894+
used_inner_outputs = {
895+
inner_out
896+
for inner_out, outer_out in zip(inner_outputs, outer_outputs)
897+
if clients[outer_out]
898+
}
899+
used_inner_inputs = {
900+
inner_inp
901+
for inner_inp in graph_inputs(used_inner_outputs)
902+
if not isinstance(inner_inp, ScalarConstant)
903+
}
904+
905+
if len(used_inner_inputs) == len(outer_inputs) or len(used_inner_outputs) == len(
906+
outer_outputs
907+
):
908+
return None
909+
895910
used_inputs_idxs = [
896-
i
897-
for i, i_intern in enumerate(comp_fgraph.inputs)
898-
if comp_fgraph.clients[i_intern]
911+
i for i, inp in enumerate(inner_inputs) if inp in used_inner_inputs
899912
]
900-
used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs]
901-
if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len(
902-
node.outputs
903-
):
904-
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
905-
c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
906-
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
907-
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True))
913+
used_inner_inputs = [inner_inputs[i] for i in used_inputs_idxs]
914+
used_outer_inputs = [outer_inputs[i] for i in used_inputs_idxs]
915+
916+
new_comp = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
917+
new_outer_outputs = Elemwise(scalar_op=new_comp)(
918+
*used_outer_inputs, return_list=True
919+
)
920+
921+
used_outer_outputs = (
922+
outer_outputs[i]
923+
for i, out in enumerate(inner_outputs)
924+
if out in used_inner_outputs
925+
)
926+
return dict(zip(used_outer_outputs, new_outer_outputs, strict=True))
908927

909928

910929
@node_rewriter([CAReduce])

0 commit comments

Comments
 (0)