diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 339da84cd1..769a5dfeeb 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -13,7 +13,6 @@ import builtins import math from collections.abc import Callable -from copy import copy from itertools import chain from textwrap import dedent from typing import Any, TypeAlias @@ -779,9 +778,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType: This caches objects to save allocation and run time. """ - if dtype not in cache: - cache[dtype] = ScalarType(dtype=dtype) - return cache[dtype] + try: + return cache[dtype] + except KeyError: + cache[dtype] = res = ScalarType(dtype=dtype) + return res # Register C code for ViewOp on Scalars. @@ -987,25 +988,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: def as_scalar(x: Any, name: str | None = None) -> ScalarVariable: - from pytensor.tensor.basic import scalar_from_tensor - from pytensor.tensor.type import TensorType + if isinstance(x, ScalarVariable): + return x + + if isinstance(x, Variable): + from pytensor.tensor.basic import scalar_from_tensor + from pytensor.tensor.type import TensorType + + if isinstance(x.type, TensorType) and x.type.ndim == 0: + return scalar_from_tensor(x) + else: + raise TypeError(f"Cannot convert {x} to a scalar type") if isinstance(x, Apply): + # FIXME: Why do we support calling this with Apply? + # Also, if we do, why can't we support multiple outputs? if len(x.outputs) != 1: raise ValueError( "It is ambiguous which output of a multi-output" " Op has to be fetched.", x, ) - else: - x = x.outputs[0] - if isinstance(x, Variable): - if isinstance(x, ScalarVariable): - return x - elif isinstance(x.type, TensorType) and x.type.ndim == 0: - return scalar_from_tensor(x) - else: - raise TypeError(f"Cannot convert {x} to a scalar type") + return as_scalar(x.outputs[0]) return constant(x) @@ -1329,32 +1333,26 @@ def supports_c_code(self, inputs, outputs): the given Elemwise inputs, outputs. """ - try: - tmp_s_input = [] - # To keep the same aliasing between inputs - mapping = dict() - for ii in inputs: - if ii in mapping: - tmp_s_input.append(mapping[ii]) - else: - tmp = get_scalar_type(ii.dtype).make_variable() - tmp_s_input.append(tmp) - mapping[ii] = tmp_s_input[-1] - - with config.change_flags(compute_test_value="ignore"): - s_op = self(*tmp_s_input, return_list=True) + tmp_s_input = [] + # To keep the same aliasing between inputs + mapping = {} + for ii in inputs: + if ii in mapping: + tmp_s_input.append(mapping[ii]) + else: + tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable() + tmp_s_input.append(tmp) - # if the scalar_op don't have a c implementation, - # we skip its fusion to allow the fusion of the - # other ops. + try: self.c_code( - s_op[0].owner, + self.make_node(*tmp_s_input), "test_presence_of_c_code", + # FIXME: Shouldn't this be a unique name per unique variable? ["x" for x in inputs], ["z" for z in outputs], {"fail": "%(fail)s"}, ) - except (MethodNotDefined, NotImplementedError): + except (NotImplementedError, MethodNotDefined): return False return True @@ -4094,12 +4092,12 @@ def __init__(self, *args, **kwargs): self.prepare_node_called = set() super().__init__(*args, **kwargs) - def _cleanup_graph(self, inputs, outputs): + def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True): # TODO: We could convert to TensorVariable, optimize graph, # and then convert back to ScalarVariable. # This would introduce rewrites like `log(1 + x) -> log1p`. - fgraph = FunctionGraph(copy(inputs), copy(outputs)) + fgraph = FunctionGraph(inputs, outputs, clone=clone) # Validate node types for node in fgraph.apply_nodes: @@ -4282,7 +4280,9 @@ class Composite(ScalarInnerGraphOp): init_param: tuple[str, ...] = ("inputs", "outputs") - def __init__(self, inputs, outputs, name="Composite"): + def __init__( + self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True + ): self.name = name self._name = None # We need to clone the graph as sometimes its nodes already @@ -4300,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"): if len(outputs) > 1 or not any( isinstance(var.owner.op, Composite) for var in outputs ): - # No inner Composite - inputs, outputs = clone(inputs, outputs) + if clone_graph: + inputs, outputs = clone(inputs, outputs) + else: # Inner Composite that we need to flatten + # FIXME: There could be a composite in the middle of the graph, why is this here? + # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway. assert len(outputs) == 1 # 1. Create a new graph from inputs up to the # Composite @@ -4322,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"): assert res[0] != inputs inputs, outputs = res[0], res2[1] - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) + # We already cloned the graph, or the user told us there was no need for it + self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) self.inputs_type = tuple(input.type for input in self.inputs) self.outputs_type = tuple(output.type for output in self.outputs) self.nin = len(inputs) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e2d420f361..dfcdfdd471 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,22 +2,21 @@ import itertools import operator import sys -from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce -from typing import TypeVar +from heapq import heapify, heappop, heappush +from operator import or_ from warnings import warn -import pytensor.scalar.basic as ps -from pytensor import clone_replace, compile from pytensor.compile.function.types import Supervisor -from pytensor.compile.mode import get_target_language +from pytensor.compile.mode import get_target_language, optdb from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.features import ReplaceValidate from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import ( GraphRewriter, copy_stack_trace, @@ -28,13 +27,23 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors +from pytensor.graph.traversal import toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined -from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop -from pytensor.tensor.basic import ( - MakeVector, - constant, +from pytensor.scalar import ( + Add, + Composite, + Mul, + ScalarOp, + get_scalar_type, + transfer_type, + upcast_out, + upgrade_to_float, ) +from pytensor.scalar import cast as scalar_cast +from pytensor.scalar import constant as scalar_constant +from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop +from pytensor.tensor.basic import MakeVector +from pytensor.tensor.basic import constant as tensor_constant from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import add, exp, mul from pytensor.tensor.rewriting.basic import ( @@ -280,7 +289,7 @@ def create_inplace_node(self, node, inplace_pattern): inplace_pattern = {i: o for i, [o] in inplace_pattern.items()} if hasattr(scalar_op, "make_new_inplace"): new_scalar_op = scalar_op.make_new_inplace( - ps.transfer_type( + transfer_type( *[ inplace_pattern.get(i, o.dtype) for i, o in enumerate(node.outputs) @@ -289,14 +298,14 @@ def create_inplace_node(self, node, inplace_pattern): ) else: new_scalar_op = type(scalar_op)( - ps.transfer_type( + transfer_type( *[inplace_pattern.get(i, None) for i in range(len(node.outputs))] ) ) return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs) -compile.optdb.register( +optdb.register( "inplace_elemwise", InplaceElemwiseOptimizer(), "inplace_elemwise_opt", # for historic reason @@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node): @register_canonicalize @node_rewriter( [ - elemwise_of( - OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float) - ), - elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)), + elemwise_of(OpPattern(ScalarOp, output_types_preference=upgrade_to_float)), + elemwise_of(OpPattern(ScalarOp, output_types_preference=upcast_out)), ] ) def local_upcast_elemwise_constant_inputs(fgraph, node): @@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): changed = False for i, inp in enumerate(node.inputs): if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): - new_inputs[i] = constant(inp.data.astype(output_dtype)) + new_inputs[i] = tensor_constant(inp.data.astype(output_dtype)) changed = True if not changed: @@ -530,424 +537,344 @@ def add_requirements(self, fgraph): @staticmethod def elemwise_to_scalar(inputs, outputs): - replace_inputs = [(inp, inp.clone()) for inp in inputs] - outputs = clone_replace(outputs, replace=replace_inputs) - - inputs = [inp for _, inp in replace_inputs] - fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) - middle_inputs = [] - - scalar_inputs = [ - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs - ] - middle_scalar_inputs = [] - - for node in fg.toposort(): - node_scalar_inputs = [] - for inp in node.inputs: - if inp in inputs: - node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) - elif inp in middle_inputs: - node_scalar_inputs.append( - middle_scalar_inputs[middle_inputs.index(inp)] + replacement = { + inp: get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + } + for node in toposort(outputs, blockers=inputs): + scalar_inputs = [replacement[inp] for inp in node.inputs] + replacement.update( + dict( + zip( + node.outputs, + node.op.scalar_op.make_node(*scalar_inputs).outputs, ) - else: - new_scalar_input = ps.get_scalar_type( - inp.type.dtype - ).make_variable() - node_scalar_inputs.append(new_scalar_input) - middle_scalar_inputs.append(new_scalar_input) - middle_inputs.append(inp) - - new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) - middle_scalar_inputs.append(new_scalar_node.outputs[0]) - middle_inputs.append(node.outputs[0]) - - scalar_outputs = [ - middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs - ] - return scalar_inputs, scalar_outputs + ) + ) - def apply(self, fgraph): - nb_replacement = 0 + return ( + [replacement[inp] for inp in inputs], + [replacement[out] for out in outputs], + ) + def apply(self, fgraph): if fgraph.profile: validate_before = fgraph.profile.validate_time callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time - max_operands = elemwise_max_operands_fct(None) - - def find_next_fuseable_subgraph( + def find_fuseable_subgraphs( fg: FunctionGraph, - ) -> Generator[tuple[list[Variable], list[Variable]], None, None]: - """Find all subgraphs in a FunctionGraph that can be fused together - - Yields - ------- - List of inputs and outputs that determine subgraphs which can be fused. - This generator assumes that such subgraph is replaced by a single - Elemwise Composite before being accessed again in the next iteration. - """ + ) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]: + """Find subgraphs of Elemwise nodes that can be fused together. - FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]] - UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - - def initialize_fuseable_mappings( - *, fg: FunctionGraph - ) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: - @cache - def elemwise_scalar_op_has_c_code(node: Apply) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): - return True - else: - if config.optimizer_verbose: - warn( - f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." - ) - return False - - # Fuseable nodes have to be accessed in a deterministic manner - # to ensure the rewrite remains deterministic. - # This is not a problem from unfuseable ones, as they can never - # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) - unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) - for out, clients in fg.clients.items(): - # Old FunctionGraph nodes remain in the clients dictionary - # even after they are removed by rewrites - if not clients: - continue + In general, there is no single solution. We try to find large subgraphs eagerly - out_maybe_fuseable = ( - out.owner - and isinstance(out.owner.op, Elemwise) - # and not isinstance(out.owner.op.scalar_op, ps.Composite) - and len(out.owner.outputs) == 1 - and elemwise_scalar_op_has_c_code(out.owner) - ) - for client, _ in clients: - if ( - out_maybe_fuseable - and isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out.type.broadcastable - == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) - ): - if client not in fuseable_clients[out]: - fuseable_clients[out].append(client) - else: - unfuseable_clients[out].add(client) - - return fuseable_clients, unfuseable_clients - - def find_fuseable_subgraph( - *, - fg: FunctionGraph, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - ) -> tuple[list[Variable], list[Variable]]: - KT = TypeVar("KT") - VT = TypeVar("VT", list, set) - - def shallow_clone_defaultdict( - d: defaultdict[KT, VT], - ) -> defaultdict[KT, VT]: - new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory) - new_dict.update({k: v.copy() for k, v in d.items()}) - return new_dict - - def variables_depend_on( - variables, depend_on, stop_search_at=None - ) -> bool: - return any( - a in depend_on - for a in ancestors(variables, blockers=stop_search_at) - ) + Any two consecutive Elemwise nodes that have the same broadcasting pattern, + and a C-implementation (historical accident that should be revisited), are potentially fuseable. - toposort = fg.toposort() - for starting_node in toposort: - if starting_node in visited_nodes: - continue + However, not all collections of fuseable pairs make a valid fused subgraph. + A valid fused subgraph must be "convex", meaning that no two nodes in the subgraph + are connected via a path that goes outside the subgraph, either because they + are connected via unfuseable nodes, or nodes that have been claimed by another fused subgraph. - starting_out = starting_node.outputs[0] - if not fuseable_clients.get(starting_out): - visited_nodes.add(starting_node) - continue + For example the subgraph add(sin(exp(x)), sum(exp(x)) cannot be fused together, + because the sum node breaks the convexity of the subgraph {exp, sin, add}. + However, we can fuse {exp, sin}, and perhaps fuse add with something else. - subgraph_inputs: list[Variable] = [] - subgraph_outputs: list[Variable] = [] - unfuseable_clients_subgraph: set[Variable] = set() + This function yields subgraph in reverse topological order so they can be safely replaced one at a time + """ - # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) - unfuseable_clients_clone = shallow_clone_defaultdict( - unfuseable_clients + @cache + def elemwise_scalar_op_has_c_code( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: + warn( + f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." ) + return False + + # Create a map from node to a set of fuseable client (successor) nodes + # A node and a client are fuseable if they are both single output Elemwise + # (with C-implementation) and have the same output broadcastable pattern + # Nodes that have no fuseable clients are not included + fuseable_clients: dict[Apply, set[Apply]] = {} + # We also create a set with candidate nodes from which to start a subgraph expansion + # These are Single output Elemwise nodes (with C-implementation) that may or not + # have fuseable ancestors/clients at the start. + candidate_starting_nodes = set() + fg_clients = fg.clients + for out, clients_and_indices in fg_clients.items(): + out_node = out.owner + + if not ( + out_node is not None + and len(out_node.outputs) == 1 + and isinstance(out_node.op, Elemwise) + and elemwise_scalar_op_has_c_code(out_node) + ): + continue - fuseable_nodes_to_visit = deque([starting_node]) - - # We now try to expand as much as possible towards the potentially - # fuseable clients and ancestors to detect the largest possible - # subgraph that can be Composed together into a single `Op`. The - # largest issue to watch out is for cyclical dependencies, where - # some inputs or clients may depend on other nodes of the same - # subgraph via a path that cannot be included in the Composite - # (unfuseable) - while fuseable_nodes_to_visit: - next_node = fuseable_nodes_to_visit.popleft() - visited_nodes.add(next_node) - next_out = next_node.outputs[0] - - # If the output variable of next_node has no fuseable clients - # or has unfuseable clients, then next_node must become an output - # if it is to be fused. - must_become_output = ( - next_out not in fuseable_clients_temp - or next_out in unfuseable_clients_clone - ) - - # We have backtracked to this node, and it may no longer be a viable output, - # so we remove it and check again as if we had never seen this node - if must_become_output and next_out in subgraph_outputs: - subgraph_outputs.remove(next_out) - - required_unfuseable_inputs = [ - inp - for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp, ()) - ] - new_required_unfuseable_inputs = [ - inp - for inp in required_unfuseable_inputs - if inp not in subgraph_inputs - ] - - must_backtrack = False - if new_required_unfuseable_inputs and subgraph_outputs: - # We need to check that any new inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - if variables_depend_on( - [next_out], - depend_on=unfuseable_clients_subgraph, - stop_search_at=subgraph_outputs, - ): - must_backtrack = True - - if not must_backtrack: - implied_unfuseable_clients = { - c - for client in unfuseable_clients_clone.get(next_out, ()) - if not isinstance(client.op, Output) - for c in client.outputs - } - - new_implied_unfuseable_clients = ( - implied_unfuseable_clients - unfuseable_clients_subgraph - ) + candidate_starting_nodes.add(out_node) + out_bcast = out.type.broadcastable + out_fuseable_clients = { + client + for client, _ in clients_and_indices + if ( + len(client.outputs) == 1 + and isinstance(client.op, Elemwise) + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ) + } + if out_fuseable_clients: + fuseable_clients[out_node] = out_fuseable_clients + + if not candidate_starting_nodes: + return None + + # To enable fast dependency queries, we create a bitset of ancestors for each node. + # Each node is first represented by a bit flag of it's position in the toposort + # This can be achieved with python integers, via 1 << toposort_idx (equivalent to slower 2 ** toposort_idx) + # The ancestors bitsets of each node are obtained by bitwise OR of the ancestor bitsets + # of each of the nodes' inputs, and the bit flag of the node itself. + # + # Example: With three variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c, + # the nodes bit flags would be {A: 0b001, B: 0b010, C: 0b100} (integers {A: 1, B: 2, C: 4}) + # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} (integers {A: 1, B: 3, C: 7}) + # + # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND + # For example, to ask if A is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[A] != 0` + # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do + # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` + nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} + # Root variables have `None` as owner, which we can handle with a bitset of 0 + ancestors_bitset = {None: 0} + for node, node_bitflag in nodes_bitflags.items(): + # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag + ancestors_bitset[node] = reduce( + or_, + (ancestors_bitset[inp.owner] for inp in node.inputs), + node_bitflag, + ) + # Handle root and leaf nodes gracefully + # We do it after the ancestors_bitset are built to simplify the previous loop. + # Root variables have `None` as owner, which we can handle with a bitflag of 0 + nodes_bitflags[None] = 0 + # Nothing ever depends on the special Output nodes, so just use a new bit for all of them + out_bitflag = 1 << len(nodes_bitflags) + for out in fg.outputs: + for client, _ in fg_clients[out]: + if isinstance(client.op, Output): + nodes_bitflags[client] = out_bitflag + + # Start main loop to find collection of fuseable subgraphs + # We store the collection in `sorted_subgraphs`, in reverse topological order + sorted_subgraphs: list[ + tuple[int, tuple[tuple[Variable], tuple[Variable]]] + ] = [] + # Keep a bitset of nodes that have been claimed by subgraphs + all_subgraphs_bitset = 0 + # Start exploring in reverse topological order from candidate sink nodes + # Sink nodes, are nodes that don't have any potential fuseable clients + for starting_node, starting_bitflag in reversed(nodes_bitflags.items()): + if ( + starting_bitflag & all_subgraphs_bitset + or starting_node not in candidate_starting_nodes + or starting_node in fuseable_clients + ): + continue + + # We use an ordered queue to control the direction in which we expand the subgraph + # For simplicity, we always want to visit ancestors before clients + # For ancestors, we want to visit the later nodes first (those that have more dependencies) + # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) + # To achieve this we use the bitflag as the sorting key (which encodes the topological order) + # and negate it for ancestors. + fuseables_nodes_queue = [(-starting_bitflag, starting_node)] + heapify(fuseables_nodes_queue) + + # We keep 3 bitsets during the exploration of a new subgraph: + # - the nodes that are part of the subgraph + # - the unfuseable ancestors of the subgraph (i.e., ancestors that are not fuseable with a node in the subgraph) + # - the unfuseable clients of the subgraph (i.e., clients that are not fuseable with a node in the subgraph) + # Whenever we visit a candidate node, we check if the subgraph's unfuseable ancestors depend on it, + # or if it depends on one of the subgraphs' unfuseable client, in which case we can't add it. + # If we can add it, we then add its unfuseable ancestors/clients to the respective bitsets + # and add its fuseable ancestors/clients to the queue to explore later. + # To work correctly, we must visit candidate subgraph nodes in the order described by the queue above. + # Otherwise, we would need to perform more complex dependency checks in every iteration and/or backtrack. + subgraph_nodes = [] + subgraph_bitset = 0 + unfuseable_ancestors_bitset = 0 + unfuseable_clients_bitset = 0 + + while fuseables_nodes_queue: + node_bitflag, node = heappop(fuseables_nodes_queue) + is_ancestor = node_bitflag < 0 + if is_ancestor: + node_bitflag = -node_bitflag + + if node_bitflag & subgraph_bitset: + # Already part of the subgraph + continue - if new_implied_unfuseable_clients and subgraph_inputs: - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - if variables_depend_on( - subgraph_inputs, - depend_on=new_implied_unfuseable_clients, - ): - must_backtrack = True - - if must_backtrack: - for inp in next_node.inputs: - if ( - inp.owner in visited_nodes - # next_node could have the same input repeated - and next_node in fuseable_clients_temp[inp] - ): - fuseable_clients_temp[inp].remove(next_node) - unfuseable_clients_clone[inp].add(next_node) - # This input must become an output of the subgraph, - # because it can't be merged with next_node. - # We will revisit it to make sure this is safe. - fuseable_nodes_to_visit.appendleft(inp.owner) - - for client in fuseable_clients_temp[next_out]: - if client in visited_nodes: - fuseable_clients_temp[next_out].remove(client) - unfuseable_clients_clone[next_out].add(client) - # next_out must become an input of the subgraph. - # We will revisit any of its clients currently - # in the subgraph to make sure this is safe. - fuseable_nodes_to_visit.appendleft(client) - - # Revisit node at a later time - visited_nodes.remove(next_node) + if is_ancestor: + if node_bitflag & unfuseable_ancestors_bitset: + # An unfuseable ancestor of the subgraph depends on this node, can't fuse continue + elif ancestors_bitset[node] & unfuseable_clients_bitset: + # This node depends on an unfuseable client of the subgraph, can't fuse + continue - # Adding next_node to subgraph does not result in any - # immediate dependency problems. Update subgraph - # mappings as if it next_node was part of it. - # Useless inputs will be removed by the useless Composite rewrite - for inp in new_required_unfuseable_inputs: - if inp not in subgraph_inputs: - subgraph_inputs.append(inp) - - if must_become_output: - subgraph_outputs.append(next_out) - unfuseable_clients_subgraph.update( - new_implied_unfuseable_clients + # Add node to subgraph + subgraph_nodes.append(node) + subgraph_bitset |= node_bitflag + + # Expand through ancestors and client nodes + # A node can either be: + # - already part of the subgraph (skip) + # - fuseable (add to queue) + # - unfuseable (add to respective unfuseable bitset) + for inp in node.inputs: + ancestor_node = inp.owner + ancestor_bitflag = nodes_bitflags[ancestor_node] + if ancestor_bitflag & subgraph_bitset: + continue + if node in fuseable_clients.get(ancestor_node, ()): + heappush( + fuseables_nodes_queue, + (-ancestor_bitflag, ancestor_node), ) + else: + # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it, + # nor with any of the ancestor's ancestors + unfuseable_ancestors_bitset |= ancestors_bitset[ + ancestor_node + ] + + next_fuseable_clients = fuseable_clients.get(node, ()) + for client, _ in fg_clients[node.outputs[0]]: + client_bitflag = nodes_bitflags[client] + if client_bitflag & subgraph_bitset: + continue + if client in next_fuseable_clients: + heappush(fuseables_nodes_queue, (client_bitflag, client)) + else: + # If a client is not in the node's fuseable clients set, it's nto fuseable with it, + # nor any of its clients. But we don't need to keep track of those as any downstream + # client we may consider later will also depend on this unfuseable client and be rejected + unfuseable_clients_bitset |= client_bitflag - # Expand through unvisited fuseable ancestors - for inp in sorted( - ( - inp - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=lambda inp: toposort.index(inp.owner), - reverse=True, - ): - fuseable_nodes_to_visit.appendleft(inp.owner) - - # Expand through unvisited fuseable clients - for next_node in sorted( - ( - node - for node in fuseable_clients_temp.get(next_out, ()) - if node not in visited_nodes - ), - key=lambda node: toposort.index(node), - ): - fuseable_nodes_to_visit.append(next_node) - - # Don't return if final subgraph is just the original Elemwise - if len(subgraph_outputs) == 1 and set( - subgraph_outputs[0].owner.inputs - ) == set(subgraph_inputs): - # Update global fuseable mappings - # No input was actually fuseable - for inp in starting_node.inputs: - if starting_node in fuseable_clients.get(inp, ()): - fuseable_clients[inp].remove(starting_node) - unfuseable_clients[inp].add(starting_node) - # No client was actually fuseable - unfuseable_clients[starting_out].update( - fuseable_clients.pop(starting_out, ()) - ) - continue + # Finished exploring this subgraph + all_subgraphs_bitset |= subgraph_bitset + + if subgraph_bitset == starting_bitflag: + # We ended were we started, no fusion possible + continue - return subgraph_inputs, subgraph_outputs - raise ValueError - - def update_fuseable_mappings_after_fg_replace( - *, - fg: FunctionGraph, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - starting_nodes: set[Apply], - ) -> None: - # Find new composite node and dropped intermediate nodes - # by comparing the current fg.apply nodes with the cached - # original nodes - next_nodes = fg.apply_nodes - (new_composite_node,) = next_nodes - starting_nodes - dropped_nodes = starting_nodes - next_nodes - - # Remove intermediate Composite nodes from mappings - for dropped_node in dropped_nodes: - (dropped_out,) = dropped_node.outputs - fuseable_clients.pop(dropped_out, None) - unfuseable_clients.pop(dropped_out, None) - visited_nodes.remove(dropped_node) - - # Update fuseable information for subgraph inputs + # Find out the actual inputs/outputs variables of the subgraph + not_subgraph_bitset = ~subgraph_bitset + # Inputs are variables whose nodes are not part of the subgraph (including root variables without nodes) + # Use a dict to deduplicate while preserving order + subgraph_inputs = tuple( + dict.fromkeys( + inp + for node in subgraph_nodes + for inp in node.inputs + if (inp_node := inp.owner) is None + or nodes_bitflags[inp_node] & not_subgraph_bitset + ) + ) + # Outputs are variables with client nodes that are not part of the subgraph (including special fgraph output nodes) + # Outputs are unique, no need to deduplicate + subgraph_outputs = tuple( + node.outputs[0] + for node in subgraph_nodes + if any( + nodes_bitflags[client] & not_subgraph_bitset + for client, _ in fg_clients[node.outputs[0]] + ) + ) + + # Update fuseable clients mapping for subgraph inputs and outputs + # Inputs cannot be fused with nodes in the subgraph for inp in subgraph_inputs: - if inp in fuseable_clients: - new_fuseable_clients = [ - client - for client in fuseable_clients[inp] - if client not in dropped_nodes - ] - if new_fuseable_clients: - fuseable_clients[inp] = new_fuseable_clients - else: - fuseable_clients.pop(inp) - unfuseable_clients[inp] = ( - unfuseable_clients[inp] - dropped_nodes - ) | {new_composite_node} - - # Update fuseable information for subgraph outputs - for out in new_composite_node.outputs: - unfuseable_clients[out] = {client for client, _ in fg.clients[out]} - - visited_nodes.add(new_composite_node) - return - - # We start by creating two maps, 1) from each node to each potentially - # fuseable client (both nodes must be single output Elemwise with same - # broadcast type) and 2) from each node to each certainly unfuseable - # client (those that don't fit into 1)) - fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) - visited_nodes: set[Apply] = set() - while True: - starting_nodes = fg.apply_nodes.copy() - try: - subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - fg=fg, - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, + if (inp_node := inp.owner) is not None and ( + inp_fuseable_clients := fuseable_clients.get(inp_node) + ): + inp_fuseable_clients.difference_update(subgraph_nodes) + # If there are no fuseable_clients left for this input delete it's entry + if not inp_fuseable_clients: + del fuseable_clients[inp_node] + # Outputs cannot be fused with anything else + for out in subgraph_outputs: + fuseable_clients.pop(out.owner, None) + + # Add new subgraph to sorted_subgraphs + # Because we start from sink nodes in reverse topological order, most times new subgraphs + # don't depend on previous subgraphs, so we can just append them at the end. + if not (unfuseable_ancestors_bitset & all_subgraphs_bitset): + # That's the case here + # None of the unfuseable_ancestors (i.e, the ancestors) are present in the previous collected subgraphs + sorted_subgraphs.append( + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)) ) - except ValueError: - return else: - # The caller is now expected to update fg in place, - # by replacing the subgraph with a Composite Op - yield subgraph_inputs, subgraph_outputs - - # This is where we avoid repeated work by using a stateful - # generator. For large models (as in `TestFusion.test_big_fusion`) - # this can provide huge speedups - update_fuseable_mappings_after_fg_replace( - fg=fg, - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - starting_nodes=starting_nodes, + # But not here, so we need to find the right position for insertion. + # We iterate through the previous subgraphs in topological order (reverse of the stored order). + # We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again. + # The (index + 1) of the firs iteration where the check passes is the correct insertion position. + remaining_subgraphs_bitset = all_subgraphs_bitset + for index, (other_subgraph_bitset, _) in enumerate( + reversed(sorted_subgraphs) + ): + # Exclude subgraph bitset + remaining_subgraphs_bitset &= ~other_subgraph_bitset + if not ( + unfuseable_ancestors_bitset & remaining_subgraphs_bitset + ): + break # bingo + sorted_subgraphs.insert( + -(index + 1), + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), ) - for inputs, outputs in find_next_fuseable_subgraph(fgraph): + # yield from sorted_subgraphs, discarding the subgraph_bitset + yield from (io for _, io in sorted_subgraphs) + + max_operands = elemwise_max_operands_fct(None) + reason = self.__class__.__name__ + nb_fused = 0 + nb_replacement = 0 + for inputs, outputs in find_fuseable_subgraphs(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( - "Loop fusion failed because the resulting node would exceed " - "the kernel argument limit." + "Loop fusion failed because the resulting node would exceed the kernel argument limit." ) - break + continue scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) - composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))( - *inputs - ) - if not isinstance(composite_outputs, list): - composite_outputs = [composite_outputs] - for old_out, composite_out in zip(outputs, composite_outputs, strict=True): - if old_out.name: - composite_out.name = old_out.name - + composite_outputs = Elemwise( + # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables + Composite(scalar_inputs, scalar_outputs, clone_graph=False) + )(*inputs, return_list=True) + assert len(outputs) == len(composite_outputs) + for old_out, composite_out in zip(outputs, composite_outputs): + # Preserve any names on the original outputs + if old_name := old_out.name: + composite_out.name = old_name + + starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( - list(zip(outputs, composite_outputs, strict=True)), - reason=self.__class__.__name__, + tuple(zip(outputs, composite_outputs)), + reason=reason, ) - nb_replacement += 1 + nb_fused += 1 + nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -965,7 +892,7 @@ def update_fuseable_mappings_after_fg_replace( return ( self, - 1, # nb_iter + nb_fused, nb_replacement, 0, # nb_inconsintency_replace validate_time, @@ -978,7 +905,7 @@ def update_fuseable_mappings_after_fg_replace( def print_profile(stream, prof, level=0): blanc = " " * level print(blanc, "FusionOptimizer", file=stream) - print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_fused", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) print(blanc, " validate_time", prof[4], file=stream) @@ -993,7 +920,7 @@ def print_profile(stream, prof, level=0): @register_canonicalize @register_specialize -@node_rewriter([elemwise_of(ps.Composite)]) +@node_rewriter([elemwise_of(Composite)]) def local_useless_composite_outputs(fgraph, node): """Remove inputs and outputs of Composite Ops that are not used anywhere.""" comp = node.op.scalar_op @@ -1014,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node): node.outputs ): used_inputs = [node.inputs[i] for i in used_inputs_idxs] - c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) + c = Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True)) @@ -1028,7 +955,7 @@ def local_careduce_fusion(fgraph, node): # FIXME: This check is needed because of the faulty logic in the FIXME below! # Right now, rewrite only works for `Sum`/`Prod` - if not isinstance(car_scalar_op, ps.Add | ps.Mul): + if not isinstance(car_scalar_op, Add | Mul): return None elm_node = car_input.owner @@ -1072,19 +999,19 @@ def local_careduce_fusion(fgraph, node): car_acc_dtype = node.op.acc_dtype scalar_elm_inputs = [ - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs + get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs ] elm_output = elm_scalar_op(*scalar_elm_inputs) # This input represents the previous value in the `CAReduce` binary reduction - carried_car_input = ps.get_scalar_type(car_acc_dtype).make_variable() + carried_car_input = get_scalar_type(car_acc_dtype).make_variable() scalar_fused_output = car_scalar_op(carried_car_input, elm_output) if scalar_fused_output.type.dtype != car_acc_dtype: - scalar_fused_output = ps.cast(scalar_fused_output, car_acc_dtype) + scalar_fused_output = scalar_cast(scalar_fused_output, car_acc_dtype) - fused_scalar_op = ps.Composite( + fused_scalar_op = Composite( inputs=[carried_car_input, *scalar_elm_inputs], outputs=[scalar_fused_output] ) @@ -1105,7 +1032,7 @@ def local_careduce_fusion(fgraph, node): return [new_car_op(*elm_inputs)] -@node_rewriter([elemwise_of(ps.Composite)]) +@node_rewriter([elemwise_of(Composite)]) def local_inline_composite_constants(fgraph, node): """Inline scalar constants in Composite graphs.""" composite_op = node.op.scalar_op @@ -1121,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node): and "complex" not in outer_inp.type.dtype ): if outer_inp.unique_value is not None: - inner_replacements[inner_inp] = ps.constant( + inner_replacements[inner_inp] = scalar_constant( outer_inp.unique_value, dtype=inner_inp.dtype ) continue @@ -1134,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node): new_inner_outs = clone_replace( composite_op.fgraph.outputs, replace=inner_replacements ) - new_composite_op = ps.Composite(new_inner_inputs, new_inner_outs) + new_composite_op = Composite(new_inner_inputs, new_inner_outs) new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs # Some of the inlined constants were broadcasting the output shape @@ -1175,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): if other_inps: python_op = operator.mul if node.op == mul else operator.add folded_inputs = [reference_inp, *other_inps] - new_inp = constant( + new_inp = tensor_constant( reduce(python_op, (const.data for const in folded_inputs)) ) new_constants = [ @@ -1199,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): add_mul_fusion_seqopt = SequenceDB() -compile.optdb.register( +optdb.register( "add_mul_fusion", add_mul_fusion_seqopt, "fast_run", @@ -1220,7 +1147,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) fuse_seqopt = SequenceDB() -compile.optdb.register( +optdb.register( "elemwise_fusion", fuse_seqopt, "fast_run", @@ -1351,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node): return replacements -compile.optdb["py_only"].register( +optdb["py_only"].register( "split_2f1grad_loop", split_2f1grad_loop, "fast_compile", diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index c23d0ac23a..523effb1d1 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -273,7 +273,8 @@ def my_init(dtype="float64", num=0): fwx = fw + fx ftanx = tan(fx) - def large_fuseable_graph(self, n): + @staticmethod + def large_fuseable_graph(n): factors = [] sd = dscalar() means = dvector() @@ -296,6 +297,48 @@ def large_fuseable_graph(self, n): dlogp = [pytensor.grad(logp, v) for v in vars] return vars, dlogp + @staticmethod + def deep_small_kernels(n): + x = pt.matrix("x") + out = x + for _ in range(n): + out = pt.sin(out.T) + pt.cos(out) + + return [x], [out] + + @staticmethod + def test_diamond_graph(): + a = pt.matrix("a") + b = pt.exp(a) + c = pt.log(b) + d = pt.sin(c) + e = c + d + + fg = FunctionGraph([a], [e], clone=False) + _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + assert nb_fused == 1 + assert nb_replacement == 4 + + def test_expansion_order(self): + # This test is designed to fail if we don't use the right expansion order in the current implementation + # It may be considered irrelevant if the algorithm changes and this is no longer a concern. + # In that case the test can be tweaked or removed + a = pt.vector("a") + b = pt.exp(a) + # Unique creates an unfuesable path between b and d/e + c = pt.unique(b) + d = pt.log(c) + # The critical aspect of the current implementation, is that we must visit d before c, + # so we learn about the unfuseable path by the time we visit c + e1 = b + d + e2 = d + b # test both orders + + fg = FunctionGraph([a], [e1, e2], clone=False) + _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + fg.dprint() + assert nb_fused == 1 + assert nb_replacement == 3 + @pytest.mark.parametrize( "case", [ @@ -1347,16 +1390,26 @@ def test_eval_benchmark(self, benchmark): benchmark(func) @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") - def test_rewrite_benchmark(self, benchmark): - inps, outs = self.large_fuseable_graph(n=25) + @pytest.mark.parametrize( + "graph_fn, n, expected_n_repl", + [ + ("deep_small_kernels", 20, (20, 60)), + ("large_fuseable_graph", 25, (128, 876)), + ], + ) + def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): + inps, outs = getattr(self, graph_fn)(n) fg = FunctionGraph(inps, outs) opt = FusionOptimizer() def rewrite_func(): - nb_replacement = opt.apply(fg.clone())[2] - return nb_replacement + fg_clone = fg.clone() + _, nb_fused, nb_replacement, *_ = opt.apply(fg_clone) + # fg_clone.dprint() + return nb_fused, nb_replacement - assert benchmark(rewrite_func) == 103 + assert rewrite_func() == expected_n_repl + benchmark.pedantic(rewrite_func, rounds=7, iterations=5) def test_no_warning_from_old_client(self): # There used to be a warning issued when creating fuseable mapping diff --git a/tests/test_printing.py b/tests/test_printing.py index 95c3c938cf..dbad8c063b 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -301,7 +301,8 @@ def test_debugprint(): Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv" exp_res = dedent( r""" - Composite{(i2 + (i0 - i1))} 4 + Composite{(i0 + (i1 - i2))} 4 + ├─ A ├─ ExpandDims{axis=0} v={0: [0]} 3 """ f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2" @@ -313,17 +314,16 @@ def test_debugprint(): │ ├─ B │ ├─ │ └─ 0.0 - ├─ D - └─ A + └─ D Inner graphs: - Composite{(i2 + (i0 - i1))} + Composite{(i0 + (i1 - i2))} ← add 'o0' - ├─ i2 - └─ sub ├─ i0 - └─ i1 + └─ sub + ├─ i1 + └─ i2 """ ).lstrip()