|  | 
| 96 | 96 | } | 
| 97 | 97 | 
 | 
| 98 | 98 | 
 | 
|  | 99 | +try: | 
|  | 100 | +    # torch >=2.3 | 
|  | 101 | +    _uint_promotion_table = { | 
|  | 102 | +        # uints | 
|  | 103 | +        (torch.uint8, torch.uint16): torch.uint16, | 
|  | 104 | +        (torch.uint8, torch.uint32): torch.uint32, | 
|  | 105 | +        (torch.uint8, torch.uint64): torch.uint64, | 
|  | 106 | +        (torch.uint16, torch.uint8): torch.uint16, | 
|  | 107 | +        (torch.uint16, torch.uint16): torch.uint16, | 
|  | 108 | +        (torch.uint16, torch.uint32): torch.uint32, | 
|  | 109 | +        (torch.uint16, torch.uint64): torch.uint64, | 
|  | 110 | +        (torch.uint32, torch.uint8): torch.uint32, | 
|  | 111 | +        (torch.uint32, torch.uint16): torch.uint32, | 
|  | 112 | +        (torch.uint32, torch.uint32): torch.uint32, | 
|  | 113 | +        (torch.uint32, torch.uint64): torch.uint64, | 
|  | 114 | +        (torch.uint64, torch.uint8): torch.uint64, | 
|  | 115 | +        (torch.uint64, torch.uint16): torch.uint64, | 
|  | 116 | +        (torch.uint64, torch.uint32): torch.uint64, | 
|  | 117 | +        (torch.uint64, torch.uint64): torch.uint64, | 
|  | 118 | +        # ints and uints (mixed sign) | 
|  | 119 | +        (torch.int8, torch.uint16): torch.int32, | 
|  | 120 | +        (torch.int8, torch.uint32): torch.int64, | 
|  | 121 | +        (torch.int16, torch.uint8): torch.int16, | 
|  | 122 | +        (torch.int16, torch.uint16): torch.int32, | 
|  | 123 | +        (torch.int16, torch.uint32): torch.int64, | 
|  | 124 | +        (torch.int32, torch.uint8): torch.int32, | 
|  | 125 | +        (torch.int32, torch.uint16): torch.int32, | 
|  | 126 | +        (torch.int32, torch.uint32): torch.int64, | 
|  | 127 | +        (torch.int64, torch.uint8): torch.int64, | 
|  | 128 | +        (torch.int64, torch.uint16): torch.int64, | 
|  | 129 | +        (torch.int64, torch.uint32): torch.int64, | 
|  | 130 | +        (torch.uint16, torch.int8): torch.int32, | 
|  | 131 | +        (torch.uint16, torch.int16): torch.int32, | 
|  | 132 | +        (torch.uint16, torch.int32): torch.int32, | 
|  | 133 | +        (torch.uint16, torch.int64): torch.int64, | 
|  | 134 | +        (torch.uint32, torch.int8): torch.int64, | 
|  | 135 | +        (torch.uint32, torch.int16): torch.int64, | 
|  | 136 | +        (torch.uint32, torch.int32): torch.int64, | 
|  | 137 | +        (torch.uint32, torch.int64): torch.int64, | 
|  | 138 | +} | 
|  | 139 | +except AttributeError: | 
|  | 140 | +    pass | 
|  | 141 | + | 
|  | 142 | +_promotion_table.update(**_uint_promotion_table) | 
|  | 143 | + | 
|  | 144 | + | 
| 99 | 145 | def _two_arg(f): | 
| 100 | 146 |     @_wraps(f) | 
| 101 | 147 |     def _f(x1, x2, /, **kwargs): | 
|  | 
0 commit comments