88from  operator  import  or_ 
99from  warnings  import  warn 
1010
11- import  pytensor .scalar .basic  as  ps 
12- from  pytensor  import  clone_replace , compile 
1311from  pytensor .compile .function .types  import  Supervisor 
14- from  pytensor .compile .mode  import  get_target_language 
12+ from  pytensor .compile .mode  import  get_target_language ,  optdb 
1513from  pytensor .configdefaults  import  config 
1614from  pytensor .graph .basic  import  Apply , Variable 
1715from  pytensor .graph .destroyhandler  import  DestroyHandler , inplace_candidates 
1816from  pytensor .graph .features  import  ReplaceValidate 
1917from  pytensor .graph .fg  import  FunctionGraph , Output 
2018from  pytensor .graph .op  import  Op 
19+ from  pytensor .graph .replace  import  clone_replace 
2120from  pytensor .graph .rewriting .basic  import  (
2221    GraphRewriter ,
2322    copy_stack_trace ,
3029from  pytensor .graph .rewriting .unify  import  OpPattern 
3130from  pytensor .graph .traversal  import  toposort 
3231from  pytensor .graph .utils  import  InconsistencyError , MethodNotDefined 
33- from  pytensor .scalar .math  import  Grad2F1Loop , _grad_2f1_loop 
34- from  pytensor .tensor .basic  import  (
35-     MakeVector ,
36-     constant ,
32+ from  pytensor .scalar  import  (
33+     Add ,
34+     Composite ,
35+     Mul ,
36+     ScalarOp ,
37+     get_scalar_type ,
38+     transfer_type ,
39+     upcast_out ,
40+     upgrade_to_float ,
3741)
42+ from  pytensor .scalar  import  cast  as  scalar_cast 
43+ from  pytensor .scalar  import  constant  as  scalar_constant 
44+ from  pytensor .scalar .math  import  Grad2F1Loop , _grad_2f1_loop 
45+ from  pytensor .tensor .basic  import  MakeVector 
46+ from  pytensor .tensor .basic  import  constant  as  tensor_constant 
3847from  pytensor .tensor .elemwise  import  CAReduce , DimShuffle , Elemwise 
3948from  pytensor .tensor .math  import  add , exp , mul 
4049from  pytensor .tensor .rewriting .basic  import  (
@@ -280,7 +289,7 @@ def create_inplace_node(self, node, inplace_pattern):
280289        inplace_pattern  =  {i : o  for  i , [o ] in  inplace_pattern .items ()}
281290        if  hasattr (scalar_op , "make_new_inplace" ):
282291            new_scalar_op  =  scalar_op .make_new_inplace (
283-                 ps . transfer_type (
292+                 transfer_type (
284293                    * [
285294                        inplace_pattern .get (i , o .dtype )
286295                        for  i , o  in  enumerate (node .outputs )
@@ -289,14 +298,14 @@ def create_inplace_node(self, node, inplace_pattern):
289298            )
290299        else :
291300            new_scalar_op  =  type (scalar_op )(
292-                 ps . transfer_type (
301+                 transfer_type (
293302                    * [inplace_pattern .get (i , None ) for  i  in  range (len (node .outputs ))]
294303                )
295304            )
296305        return  type (op )(new_scalar_op , inplace_pattern ).make_node (* node .inputs )
297306
298307
299- compile . optdb .register (
308+ optdb .register (
300309    "inplace_elemwise" ,
301310    InplaceElemwiseOptimizer (),
302311    "inplace_elemwise_opt" ,  # for historic reason 
@@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node):
428437@register_canonicalize  
429438@node_rewriter ( 
430439    [ 
431-         elemwise_of ( 
432-             OpPattern (ps .ScalarOp , output_types_preference = ps .upgrade_to_float ) 
433-         ), 
434-         elemwise_of (OpPattern (ps .ScalarOp , output_types_preference = ps .upcast_out )), 
440+         elemwise_of (OpPattern (ScalarOp , output_types_preference = upgrade_to_float )), 
441+         elemwise_of (OpPattern (ScalarOp , output_types_preference = upcast_out )), 
435442    ] 
436443) 
437444def  local_upcast_elemwise_constant_inputs (fgraph , node ):
@@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
452459    changed  =  False 
453460    for  i , inp  in  enumerate (node .inputs ):
454461        if  inp .type .dtype  !=  output_dtype  and  isinstance (inp , TensorConstant ):
455-             new_inputs [i ] =  constant (inp .data .astype (output_dtype ))
462+             new_inputs [i ] =  tensor_constant (inp .data .astype (output_dtype ))
456463            changed  =  True 
457464
458465    if  not  changed :
@@ -531,7 +538,7 @@ def add_requirements(self, fgraph):
531538    @staticmethod  
532539    def  elemwise_to_scalar (inputs , outputs ):
533540        replacement  =  {
534-             inp : ps . get_scalar_type (inp .type .dtype ).make_variable () for  inp  in  inputs 
541+             inp : get_scalar_type (inp .type .dtype ).make_variable () for  inp  in  inputs 
535542        }
536543        for  node  in  toposort (outputs , blockers = inputs ):
537544            scalar_inputs  =  [replacement [inp ] for  inp  in  node .inputs ]
@@ -853,7 +860,7 @@ def elemwise_scalar_op_has_c_code(
853860            scalar_inputs , scalar_outputs  =  self .elemwise_to_scalar (inputs , outputs )
854861            composite_outputs  =  Elemwise (
855862                # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables 
856-                 ps . Composite (scalar_inputs , scalar_outputs , clone_graph = False )
863+                 Composite (scalar_inputs , scalar_outputs , clone_graph = False )
857864            )(* inputs , return_list = True )
858865            assert  len (outputs ) ==  len (composite_outputs )
859866            for  old_out , composite_out  in  zip (outputs , composite_outputs ):
@@ -913,7 +920,7 @@ def print_profile(stream, prof, level=0):
913920
914921@register_canonicalize  
915922@register_specialize  
916- @node_rewriter ([elemwise_of (ps . Composite )]) 
923+ @node_rewriter ([elemwise_of (Composite )]) 
917924def  local_useless_composite_outputs (fgraph , node ):
918925    """Remove inputs and outputs of Composite Ops that are not used anywhere.""" 
919926    comp  =  node .op .scalar_op 
@@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node):
934941        node .outputs 
935942    ):
936943        used_inputs  =  [node .inputs [i ] for  i  in  used_inputs_idxs ]
937-         c  =  ps . Composite (inputs = used_inner_inputs , outputs = used_inner_outputs )
944+         c  =  Composite (inputs = used_inner_inputs , outputs = used_inner_outputs )
938945        e  =  Elemwise (scalar_op = c )(* used_inputs , return_list = True )
939946        return  dict (zip ([node .outputs [i ] for  i  in  used_outputs_idxs ], e , strict = True ))
940947
@@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node):
948955
949956    # FIXME: This check is needed because of the faulty logic in the FIXME below! 
950957    # Right now, rewrite only works for `Sum`/`Prod` 
951-     if  not  isinstance (car_scalar_op , ps . Add  |  ps . Mul ):
958+     if  not  isinstance (car_scalar_op , Add  |  Mul ):
952959        return  None 
953960
954961    elm_node  =  car_input .owner 
@@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node):
992999    car_acc_dtype  =  node .op .acc_dtype 
9931000
9941001    scalar_elm_inputs  =  [
995-         ps . get_scalar_type (inp .type .dtype ).make_variable () for  inp  in  elm_inputs 
1002+         get_scalar_type (inp .type .dtype ).make_variable () for  inp  in  elm_inputs 
9961003    ]
9971004
9981005    elm_output  =  elm_scalar_op (* scalar_elm_inputs )
9991006
10001007    # This input represents the previous value in the `CAReduce` binary reduction 
1001-     carried_car_input  =  ps . get_scalar_type (car_acc_dtype ).make_variable ()
1008+     carried_car_input  =  get_scalar_type (car_acc_dtype ).make_variable ()
10021009
10031010    scalar_fused_output  =  car_scalar_op (carried_car_input , elm_output )
10041011    if  scalar_fused_output .type .dtype  !=  car_acc_dtype :
1005-         scalar_fused_output  =  ps . cast (scalar_fused_output , car_acc_dtype )
1012+         scalar_fused_output  =  scalar_cast (scalar_fused_output , car_acc_dtype )
10061013
1007-     fused_scalar_op  =  ps . Composite (
1014+     fused_scalar_op  =  Composite (
10081015        inputs = [carried_car_input , * scalar_elm_inputs ], outputs = [scalar_fused_output ]
10091016    )
10101017
@@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node):
10251032    return  [new_car_op (* elm_inputs )]
10261033
10271034
1028- @node_rewriter ([elemwise_of (ps . Composite )]) 
1035+ @node_rewriter ([elemwise_of (Composite )]) 
10291036def  local_inline_composite_constants (fgraph , node ):
10301037    """Inline scalar constants in Composite graphs.""" 
10311038    composite_op  =  node .op .scalar_op 
@@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node):
10411048            and  "complex"  not  in   outer_inp .type .dtype 
10421049        ):
10431050            if  outer_inp .unique_value  is  not   None :
1044-                 inner_replacements [inner_inp ] =  ps . constant (
1051+                 inner_replacements [inner_inp ] =  scalar_constant (
10451052                    outer_inp .unique_value , dtype = inner_inp .dtype 
10461053                )
10471054                continue 
@@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node):
10541061    new_inner_outs  =  clone_replace (
10551062        composite_op .fgraph .outputs , replace = inner_replacements 
10561063    )
1057-     new_composite_op  =  ps . Composite (new_inner_inputs , new_inner_outs )
1064+     new_composite_op  =  Composite (new_inner_inputs , new_inner_outs )
10581065    new_outputs  =  Elemwise (new_composite_op ).make_node (* new_outer_inputs ).outputs 
10591066
10601067    # Some of the inlined constants were broadcasting the output shape 
@@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
10951102            if  other_inps :
10961103                python_op  =  operator .mul  if  node .op  ==  mul  else  operator .add 
10971104                folded_inputs  =  [reference_inp , * other_inps ]
1098-                 new_inp  =  constant (
1105+                 new_inp  =  tensor_constant (
10991106                    reduce (python_op , (const .data  for  const  in  folded_inputs ))
11001107                )
11011108                new_constants  =  [
@@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
11191126
11201127
11211128add_mul_fusion_seqopt  =  SequenceDB ()
1122- compile . optdb .register (
1129+ optdb .register (
11231130    "add_mul_fusion" ,
11241131    add_mul_fusion_seqopt ,
11251132    "fast_run" ,
@@ -1140,7 +1147,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
11401147
11411148# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) 
11421149fuse_seqopt  =  SequenceDB ()
1143- compile . optdb .register (
1150+ optdb .register (
11441151    "elemwise_fusion" ,
11451152    fuse_seqopt ,
11461153    "fast_run" ,
@@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node):
12711278    return  replacements 
12721279
12731280
1274- compile . optdb ["py_only" ].register (
1281+ optdb ["py_only" ].register (
12751282    "split_2f1grad_loop" ,
12761283    split_2f1grad_loop ,
12771284    "fast_compile" ,
0 commit comments