@@ -1255,7 +1255,9 @@ def make_node(self, *inputs):
12551255 f"Wrong number of inputs for { self } .make_node "
12561256 f"(got { len (inputs )} ({ inputs } ), expected { self .nin } )"
12571257 )
1258- inputs = [as_scalar (input ) for input in inputs ]
1258+ inputs = [
1259+ inp if isinstance (inp , ScalarVariable ) else as_scalar (inp ) for inp in inputs
1260+ ]
12591261 outputs = [t () for t in self .output_types ([input .type for input in inputs ])]
12601262 if len (outputs ) != self .nout :
12611263 inputs_str = (", " .join (str (input ) for input in inputs ),)
@@ -4294,7 +4296,13 @@ class Composite(ScalarInnerGraphOp):
42944296 init_param : tuple [str , ...] = ("inputs" , "outputs" )
42954297
42964298 def __init__ (
4297- self , inputs , outputs , name = "Composite" , clone_graph : builtins .bool = True
4299+ self ,
4300+ inputs ,
4301+ outputs ,
4302+ name = "Composite" ,
4303+ clone_graph : builtins .bool = True ,
4304+ cleanup_graph : builtins .bool = True ,
4305+ output_types_preference = None ,
42984306 ):
42994307 self .name = name
43004308 self ._name = None
@@ -4324,7 +4332,9 @@ def __init__(
43244332 # 1. Create a new graph from inputs up to the
43254333 # Composite
43264334 res = pytensor .compile .rebuild_collect_shared (
4327- inputs = inputs , outputs = outputs [0 ].owner .inputs , copy_inputs_over = False
4335+ inputs = inputs ,
4336+ outputs = outputs [0 ].owner .inputs ,
4337+ copy_inputs_over = False ,
43284338 ) # Clone also the inputs
43294339 # 2. We continue this partial clone with the graph in
43304340 # the inner Composite
@@ -4338,36 +4348,42 @@ def __init__(
43384348 assert res [0 ] != inputs
43394349 inputs , outputs = res [0 ], res2 [1 ]
43404350
4341- # We already cloned the graph, or the user told us there was no need for it
4342- self .inputs , self .outputs = self ._cleanup_graph (inputs , outputs , clone = False )
4351+ if cleanup_graph :
4352+ # We already cloned the graph, or the user told us there was no need for it
4353+ self .inputs , self .outputs = self ._cleanup_graph (
4354+ inputs , outputs , clone = False
4355+ )
4356+ else :
4357+ self .inputs , self .outputs = inputs , outputs
43434358 self .inputs_type = tuple (input .type for input in self .inputs )
43444359 self .outputs_type = tuple (output .type for output in self .outputs )
43454360 self .nin = len (inputs )
43464361 self .nout = len (outputs )
4347- super ().__init__ ()
4362+ super ().__init__ (output_types_preference = output_types_preference )
43484363
43494364 def __str__ (self ):
43504365 if self ._name is not None :
43514366 return self ._name
43524367
4353- # Rename internal variables
4354- for i , r in enumerate (self .fgraph .inputs ):
4355- r .name = f"i{ i } "
4356- for i , r in enumerate (self .fgraph .outputs ):
4357- r .name = f"o{ i } "
4358- io = set (self .fgraph .inputs + self .fgraph .outputs )
4359- for i , r in enumerate (self .fgraph .variables ):
4360- if (
4361- not isinstance (r , Constant )
4362- and r not in io
4363- and len (self .fgraph .clients [r ]) > 1
4364- ):
4365- r .name = f"t{ i } "
4368+ fgraph = self .fgraph
43664369
4367- if len (self . fgraph .outputs ) > 1 or len (self . fgraph .apply_nodes ) > 10 :
4370+ if len (fgraph .outputs ) > 1 or len (fgraph .apply_nodes ) > 10 :
43684371 self ._name = "Composite{...}"
43694372 else :
4370- outputs_str = ", " .join (pprint (output ) for output in self .fgraph .outputs )
4373+ # Rename internal variables
4374+ for i , r in enumerate (fgraph .inputs ):
4375+ r .name = f"i{ i } "
4376+ for i , r in enumerate (fgraph .outputs ):
4377+ r .name = f"o{ i } "
4378+ io = set (fgraph .inputs + fgraph .outputs )
4379+ for i , r in enumerate (fgraph .variables ):
4380+ if (
4381+ not isinstance (r , Constant )
4382+ and r not in io
4383+ and len (fgraph .clients [r ]) > 1
4384+ ):
4385+ r .name = f"t{ i } "
4386+ outputs_str = ", " .join (pprint (output ) for output in fgraph .outputs )
43714387 self ._name = f"Composite{{{ outputs_str } }}"
43724388
43734389 return self ._name
@@ -4380,12 +4396,16 @@ def make_new_inplace(self, output_types_preference=None, name=None):
43804396
43814397 """
43824398 d = {k : getattr (self , k ) for k in self .init_param }
4383- out = self .__class__ (** d )
4384- if name :
4385- out .name = name
4386- else :
4387- name = out .name
4388- super (Composite , out ).__init__ (output_types_preference , name )
4399+ out = type (self )(
4400+ ** d ,
4401+ cleanup_graph = False ,
4402+ clone_graph = False ,
4403+ output_types_preference = output_types_preference ,
4404+ name = name or self .name ,
4405+ )
4406+ # No need to recompute the _cocde and nodenames if they were already computed (which is true if the hash of the Op was requested)
4407+ out ._c_code = self ._c_code
4408+ out .nodenames = self .nodenames
43894409 return out
43904410
43914411 @property
@@ -4452,9 +4472,10 @@ def c_code_template(self):
44524472 fg = self .fgraph
44534473 subd = {e : f"%(i{ i } )s" for i , e in enumerate (fg .inputs )}
44544474
4475+ inputs_set = frozenset (fg .inputs )
44554476 for var in fg .variables :
44564477 if var .owner is None :
4457- if var not in fg . inputs :
4478+ if var not in inputs_set :
44584479 # This is an orphan
44594480 if isinstance (var , Constant ) and isinstance (var .type , CLinkerType ):
44604481 subd [var ] = f"({ var .type .c_literal (var .data )} )"
0 commit comments