module {
func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar2"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @gelu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
%cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
%1 = stablehlo.multiply %0, %0 : tensor<f32>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%2 = stablehlo.multiply %1, %cst : tensor<f32>
%3 = stablehlo.add %2, %cst_1 : tensor<f32>
%4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
%5 = stablehlo.multiply %4, %3 : tensor<f32>
%6 = stablehlo.logistic %5 : tensor<f32>
%7 = stablehlo.multiply %0, %6 : tensor<f32>
%8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
%9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %8, %9 : tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @gelu_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
%cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
%1 = stablehlo.multiply %0, %0 : tensor<f32>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%2 = stablehlo.multiply %1, %cst : tensor<f32>
%3 = stablehlo.add %2, %cst_1 : tensor<f32>
%4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
%5 = stablehlo.multiply %4, %3 : tensor<f32>
%6 = stablehlo.logistic %5 : tensor<f32>
%7 = stablehlo.multiply %0, %6 : tensor<f32>
%8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
%9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %8, %9 : tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar5"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar6"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @gelu_broadcast_scalar2(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
%cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
%1 = stablehlo.multiply %0, %0 : tensor<f32>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%2 = stablehlo.multiply %1, %cst : tensor<f32>
%3 = stablehlo.add %2, %cst_1 : tensor<f32>
%4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
%5 = stablehlo.multiply %4, %3 : tensor<f32>
%6 = stablehlo.logistic %5 : tensor<f32>
%7 = stablehlo.multiply %0, %6 : tensor<f32>
%8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
%9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %8, %9 : tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar7"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar8"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @gelu_broadcast_scalar3(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
%cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
%1 = stablehlo.multiply %0, %0 : tensor<f32>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%2 = stablehlo.multiply %1, %cst : tensor<f32>
%3 = stablehlo.add %2, %cst_1 : tensor<f32>
%4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
%5 = stablehlo.multiply %4, %3 : tensor<f32>
%6 = stablehlo.logistic %5 : tensor<f32>
%7 = stablehlo.multiply %0, %6 : tensor<f32>
%8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
%9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %8, %9 : tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar9"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @gelu_broadcast_scalar4(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
%cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
%1 = stablehlo.multiply %0, %0 : tensor<f32>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%2 = stablehlo.multiply %1, %cst : tensor<f32>
%3 = stablehlo.add %2, %cst_1 : tensor<f32>
%4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
%5 = stablehlo.multiply %4, %3 : tensor<f32>
%6 = stablehlo.logistic %5 : tensor<f32>
%7 = stablehlo.multiply %0, %6 : tensor<f32>
%8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
%9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %8, %9 : tensor<f32>, tensor<f32>
}
func.func private @"+_broadcast_scalar10"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
}
func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.abs %0 : tensor<f32>
%2 = stablehlo.multiply %1, %1 : tensor<f32>
%3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
%4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %3, %4 : tensor<f32>, tensor<f32>
}
func.func private @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
%15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
%17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
%19 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
%20 = stablehlo.reshape %19 : (tensor<5x32x2xf32>) -> tensor<160x2xf32>
%21 = stablehlo.transpose %20, dims = [1, 0] : (tensor<160x2xf32>) -> tensor<2x160xf32>
%22 = stablehlo.dot_general %0, %21, contracting_dims = [1] x [0] : (tensor<64x2xf32>, tensor<2x160xf32>) -> tensor<64x160xf32>
%23 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
%24:3 = enzyme.batch @"+_broadcast_scalar"(%22, %23) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
%25 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
%26 = stablehlo.reshape %25 : (tensor<160x64xf32>) -> tensor<160x64xf32>
%27 = stablehlo.transpose %26, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
%28 = stablehlo.dot_general %2, %27, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
%29 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
%30:3 = enzyme.batch @"+_broadcast_scalar1"(%28, %29) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
%31 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
%32 = stablehlo.reshape %31 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
%33 = stablehlo.transpose %32, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
%34 = stablehlo.transpose %33, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
%35 = stablehlo.convert %34 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
%36 = stablehlo.transpose %35, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%37 = stablehlo.fft %36, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%38 = stablehlo.transpose %37, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%c = stablehlo.constant dense<0> : tensor<i64>
%c_1 = stablehlo.constant dense<0> : tensor<i64>
%c_2 = stablehlo.constant dense<0> : tensor<i64>
%39 = stablehlo.dynamic_slice %38, %c, %c_1, %c_2, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
%40 = stablehlo.transpose %39, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%41 = stablehlo.reshape %40 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%42 = stablehlo.transpose %41, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%43 = stablehlo.transpose %42, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
%44 = stablehlo.transpose %4, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%45 = stablehlo.transpose %43, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%46 = stablehlo.dot_general %44, %45, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%47 = stablehlo.transpose %46, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%48 = stablehlo.transpose %47, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%cst_4 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%49 = stablehlo.transpose %48, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%50 = stablehlo.reshape %49 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%51 = stablehlo.transpose %50, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%52 = stablehlo.pad %51, %cst_4, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%53 = stablehlo.transpose %52, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%54 = stablehlo.fft %53, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%55 = stablehlo.transpose %54, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%56 = stablehlo.real %55 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
%57 = stablehlo.transpose %56, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
%58 = stablehlo.transpose %30#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
%59 = stablehlo.reshape %58 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
%60 = stablehlo.transpose %59, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
%61:3 = enzyme.batch @"+_broadcast_scalar2"(%60, %57) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%62:2 = enzyme.batch @gelu_broadcast_scalar(%61#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
%63 = stablehlo.transpose %62#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
%64 = stablehlo.reshape %63 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
%65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
%66 = stablehlo.dot_general %5, %65, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
%67 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
%68:3 = enzyme.batch @"+_broadcast_scalar3"(%66, %67) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
%69 = stablehlo.transpose %62#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
%70 = stablehlo.convert %69 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
%71 = stablehlo.transpose %70, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%72 = stablehlo.fft %71, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%73 = stablehlo.transpose %72, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%c_6 = stablehlo.constant dense<0> : tensor<i64>
%c_7 = stablehlo.constant dense<0> : tensor<i64>
%c_8 = stablehlo.constant dense<0> : tensor<i64>
%74 = stablehlo.dynamic_slice %73, %c_6, %c_7, %c_8, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
%75 = stablehlo.transpose %74, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%76 = stablehlo.reshape %75 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%77 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%78 = stablehlo.transpose %77, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
%79 = stablehlo.transpose %7, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%80 = stablehlo.transpose %78, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%81 = stablehlo.dot_general %79, %80, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%82 = stablehlo.transpose %81, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%83 = stablehlo.transpose %82, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%cst_10 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%84 = stablehlo.transpose %83, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%85 = stablehlo.reshape %84 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%86 = stablehlo.transpose %85, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%87 = stablehlo.pad %86, %cst_10, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%88 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%89 = stablehlo.fft %88, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%90 = stablehlo.transpose %89, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%91 = stablehlo.real %90 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
%92 = stablehlo.transpose %91, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
%93 = stablehlo.transpose %68#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
%94 = stablehlo.reshape %93 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
%95 = stablehlo.transpose %94, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
%96:3 = enzyme.batch @"+_broadcast_scalar4"(%95, %92) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%97:2 = enzyme.batch @gelu_broadcast_scalar1(%96#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
%98 = stablehlo.transpose %97#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
%99 = stablehlo.reshape %98 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
%100 = stablehlo.transpose %99, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
%101 = stablehlo.dot_general %8, %100, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
%102 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
%103:3 = enzyme.batch @"+_broadcast_scalar5"(%101, %102) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
%104 = stablehlo.transpose %97#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
%105 = stablehlo.convert %104 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
%106 = stablehlo.transpose %105, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%107 = stablehlo.fft %106, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%108 = stablehlo.transpose %107, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%c_12 = stablehlo.constant dense<0> : tensor<i64>
%c_13 = stablehlo.constant dense<0> : tensor<i64>
%c_14 = stablehlo.constant dense<0> : tensor<i64>
%109 = stablehlo.dynamic_slice %108, %c_12, %c_13, %c_14, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
%110 = stablehlo.transpose %109, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%111 = stablehlo.reshape %110 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%112 = stablehlo.transpose %111, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%113 = stablehlo.transpose %112, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%cst_15 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
%114 = stablehlo.transpose %10, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%115 = stablehlo.transpose %113, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%116 = stablehlo.dot_general %114, %115, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%117 = stablehlo.transpose %116, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%118 = stablehlo.transpose %117, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%cst_16 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%119 = stablehlo.transpose %118, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%120 = stablehlo.reshape %119 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%121 = stablehlo.transpose %120, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%122 = stablehlo.pad %121, %cst_16, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%123 = stablehlo.transpose %122, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%124 = stablehlo.fft %123, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%125 = stablehlo.transpose %124, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%126 = stablehlo.real %125 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
%127 = stablehlo.transpose %126, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
%128 = stablehlo.transpose %103#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
%129 = stablehlo.reshape %128 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
%130 = stablehlo.transpose %129, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
%131:3 = enzyme.batch @"+_broadcast_scalar6"(%130, %127) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%132:2 = enzyme.batch @gelu_broadcast_scalar2(%131#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
%133 = stablehlo.transpose %132#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
%134 = stablehlo.reshape %133 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
%135 = stablehlo.transpose %134, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
%136 = stablehlo.dot_general %11, %135, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
%137 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
%138:3 = enzyme.batch @"+_broadcast_scalar7"(%136, %137) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
%139 = stablehlo.transpose %132#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
%140 = stablehlo.convert %139 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
%141 = stablehlo.transpose %140, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%142 = stablehlo.fft %141, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%143 = stablehlo.transpose %142, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%c_18 = stablehlo.constant dense<0> : tensor<i64>
%c_19 = stablehlo.constant dense<0> : tensor<i64>
%c_20 = stablehlo.constant dense<0> : tensor<i64>
%144 = stablehlo.dynamic_slice %143, %c_18, %c_19, %c_20, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
%145 = stablehlo.transpose %144, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%146 = stablehlo.reshape %145 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%147 = stablehlo.transpose %146, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%148 = stablehlo.transpose %147, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%cst_21 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
%149 = stablehlo.transpose %13, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%150 = stablehlo.transpose %148, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%151 = stablehlo.dot_general %149, %150, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%152 = stablehlo.transpose %151, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
%153 = stablehlo.transpose %152, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%cst_22 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%154 = stablehlo.transpose %153, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%155 = stablehlo.reshape %154 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
%156 = stablehlo.transpose %155, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
%157 = stablehlo.pad %156, %cst_22, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%158 = stablehlo.transpose %157, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%159 = stablehlo.fft %158, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
%160 = stablehlo.transpose %159, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
%161 = stablehlo.real %160 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
%162 = stablehlo.transpose %161, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
%163 = stablehlo.transpose %138#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
%164 = stablehlo.reshape %163 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
%165 = stablehlo.transpose %164, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
%166:3 = enzyme.batch @"+_broadcast_scalar8"(%165, %162) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%167:2 = enzyme.batch @gelu_broadcast_scalar3(%166#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
%cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<128x160xf32>
%168 = stablehlo.transpose %167#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
%169 = stablehlo.reshape %168 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
%170 = stablehlo.transpose %169, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
%171 = stablehlo.dot_general %14, %170, contracting_dims = [1] x [0] : (tensor<128x64xf32>, tensor<64x160xf32>) -> tensor<128x160xf32>
%172 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%173 = stablehlo.reshape %172 : (tensor<128xf32>) -> tensor<1x128xf32>
%174 = stablehlo.transpose %173, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
%175 = stablehlo.broadcast_in_dim %174, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x160xf32>
%176:3 = enzyme.batch @"+_broadcast_scalar9"(%171, %175) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>, tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>, tensor<128x160xf32>)
%177:2 = enzyme.batch @gelu_broadcast_scalar4(%176#0) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>)
%cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<1x160xf32>
%178 = stablehlo.transpose %177#0, dims = [1, 0] : (tensor<128x160xf32>) -> tensor<160x128xf32>
%179 = stablehlo.reshape %178 : (tensor<160x128xf32>) -> tensor<160x128xf32>
%180 = stablehlo.transpose %179, dims = [1, 0] : (tensor<160x128xf32>) -> tensor<128x160xf32>
%181 = stablehlo.dot_general %16, %180, contracting_dims = [1] x [0] : (tensor<1x128xf32>, tensor<128x160xf32>) -> tensor<1x160xf32>
%182 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<1xf32>) -> tensor<1x160xf32>
%183:3 = enzyme.batch @"+_broadcast_scalar10"(%181, %182) {batch_shape = array<i64: 1, 160>} : (tensor<1x160xf32>, tensor<1x160xf32>) -> (tensor<1x160xf32>, tensor<1x160xf32>, tensor<1x160xf32>)
%184 = stablehlo.transpose %183#0, dims = [1, 0] : (tensor<1x160xf32>) -> tensor<160x1xf32>
%185 = stablehlo.reshape %184 : (tensor<160x1xf32>) -> tensor<5x32x1xf32>
%186 = stablehlo.transpose %185, dims = [2, 1, 0] : (tensor<5x32x1xf32>) -> tensor<1x32x5xf32>
%cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%187:2 = enzyme.batch @abs2_broadcast_scalar(%186) {batch_shape = array<i64: 1, 32, 5>} : (tensor<1x32x5xf32>) -> (tensor<1x32x5xf32>, tensor<1x32x5xf32>)
%188 = stablehlo.reduce(%187#0 init: %cst_25) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x32x5xf32>, tensor<f32>) -> tensor<f32>
%189 = stablehlo.transpose %188, dims = [] : (tensor<f32>) -> tensor<f32>
%190 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
%191 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%192 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%193 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%194 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%195 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%196 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%197 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%198 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%199 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%200 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%201 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%202 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%203 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%204 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
%205 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%206 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
%207 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%208 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
return %189, %190, %191, %192, %193, %194, %195, %196, %197, %198, %199, %200, %201, %202, %203, %204, %205, %206, %207, %208 : tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>
}
func.func @main(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
%15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
%17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<64x2xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
%cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
%cst_6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
%cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
%cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
%cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
%cst_12 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
%cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<128x64xf32>
%cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32>
%cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<1x128xf32>
%cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<1xf32>
%cst_17 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%19 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
%20 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%21 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%22 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%23 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%24 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%25 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%26 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%27 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%28 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%29 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%30 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%31 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%32 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%33 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
%34 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%35 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
%36 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%37 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
%38 = stablehlo.transpose %cst_17, dims = [] : (tensor<f32>) -> tensor<f32>
%39 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
%40 = stablehlo.transpose %cst_0, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%41 = stablehlo.transpose %cst_1, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%42 = stablehlo.transpose %cst_2, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%43 = stablehlo.transpose %cst_3, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%44 = stablehlo.transpose %cst_4, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%45 = stablehlo.transpose %cst_5, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%46 = stablehlo.transpose %cst_6, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%47 = stablehlo.transpose %cst_7, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%48 = stablehlo.transpose %cst_8, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%49 = stablehlo.transpose %cst_9, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%50 = stablehlo.transpose %cst_10, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%51 = stablehlo.transpose %cst_11, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%52 = stablehlo.transpose %cst_12, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%53 = stablehlo.transpose %cst_13, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
%54 = stablehlo.transpose %cst_14, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%55 = stablehlo.transpose %cst_15, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
%56 = stablehlo.transpose %cst_16, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%57:37 = enzyme.autodiff @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>)
%58 = stablehlo.transpose %57#0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
%59 = stablehlo.transpose %57#1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%60 = stablehlo.transpose %57#2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%61 = stablehlo.transpose %57#3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%62 = stablehlo.transpose %57#4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%63 = stablehlo.transpose %57#5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%64 = stablehlo.transpose %57#6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%65 = stablehlo.transpose %57#7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%66 = stablehlo.transpose %57#8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%67 = stablehlo.transpose %57#9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%68 = stablehlo.transpose %57#10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%69 = stablehlo.transpose %57#11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%70 = stablehlo.transpose %57#12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%71 = stablehlo.transpose %57#13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%72 = stablehlo.transpose %57#14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
%73 = stablehlo.transpose %57#15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%74 = stablehlo.transpose %57#16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
%75 = stablehlo.transpose %57#17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%76 = stablehlo.transpose %57#18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
%77 = stablehlo.transpose %57#19, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
%78 = stablehlo.transpose %57#20, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%79 = stablehlo.transpose %57#21, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%80 = stablehlo.transpose %57#22, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%81 = stablehlo.transpose %57#23, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%82 = stablehlo.transpose %57#24, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%83 = stablehlo.transpose %57#25, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%84 = stablehlo.transpose %57#26, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%85 = stablehlo.transpose %57#27, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%86 = stablehlo.transpose %57#28, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%87 = stablehlo.transpose %57#29, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%88 = stablehlo.transpose %57#30, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%89 = stablehlo.transpose %57#31, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%90 = stablehlo.transpose %57#32, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
%91 = stablehlo.transpose %57#33, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
%92 = stablehlo.transpose %57#34, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%93 = stablehlo.transpose %57#35, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
%94 = stablehlo.transpose %57#36, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%95 = stablehlo.transpose %77, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
%96 = stablehlo.transpose %78, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%97 = stablehlo.transpose %79, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%98 = stablehlo.transpose %80, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%99 = stablehlo.transpose %81, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%100 = stablehlo.transpose %82, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%101 = stablehlo.transpose %83, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%102 = stablehlo.transpose %84, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%103 = stablehlo.transpose %85, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%104 = stablehlo.transpose %86, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%105 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%106 = stablehlo.transpose %88, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%107 = stablehlo.transpose %89, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%108 = stablehlo.transpose %90, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%109 = stablehlo.transpose %91, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
%110 = stablehlo.transpose %92, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%111 = stablehlo.transpose %93, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
%112 = stablehlo.transpose %94, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%113 = stablehlo.transpose %58, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
%114 = stablehlo.transpose %59, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%115 = stablehlo.transpose %60, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%116 = stablehlo.transpose %61, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%117 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%118 = stablehlo.transpose %63, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%119 = stablehlo.transpose %64, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%120 = stablehlo.transpose %65, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%121 = stablehlo.transpose %66, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%122 = stablehlo.transpose %67, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%123 = stablehlo.transpose %68, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%124 = stablehlo.transpose %69, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
%125 = stablehlo.transpose %70, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
%126 = stablehlo.transpose %71, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
%127 = stablehlo.transpose %72, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
%128 = stablehlo.transpose %73, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
%129 = stablehlo.transpose %74, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
%130 = stablehlo.transpose %75, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%131 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
return %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127, %128, %129, %130, %131 : tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>
}
}
Unoptimized MLIR
Error Message with debug build