Skip to content

Commit 626ce6e

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Reorganize dependencies to speed up build
PiperOrigin-RevId: 830815100
1 parent 10bb57a commit 626ce6e

File tree

6 files changed

+281
-116
lines changed

6 files changed

+281
-116
lines changed

jaxlib/mosaic/BUILD

Lines changed: 117 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,70 @@ cc_library(
4747
"dialect/tpu/tpu_ops.cc",
4848
"dialect/tpu/util.cc",
4949
"dialect/tpu/vreg_util.cc",
50-
":extension_srcs",
51-
] + glob([
52-
"dialect/tpu/transforms/*.cc",
53-
]),
50+
],
5451
hdrs = [
5552
"dialect/tpu/array_util.h",
5653
"dialect/tpu/layout.h",
5754
"dialect/tpu/tpu_dialect.h",
5855
"dialect/tpu/util.h",
5956
"dialect/tpu/vreg_util.h",
60-
] + glob([
61-
"dialect/tpu/transforms/*.h",
62-
]),
57+
],
58+
# compatible with libtpu
59+
deps = [
60+
":tpu_inc_gen",
61+
"@com_google_absl//absl/algorithm:container",
62+
"@com_google_absl//absl/hash",
63+
"@com_google_absl//absl/log",
64+
"@com_google_absl//absl/log:check",
65+
"@com_google_absl//absl/status",
66+
"@com_google_absl//absl/strings",
67+
"@com_google_absl//absl/strings:str_format",
68+
"@com_google_absl//absl/types:span",
69+
"@llvm-project//llvm:Support",
70+
"@llvm-project//mlir:ArithDialect",
71+
"@llvm-project//mlir:CommonFolders",
72+
"@llvm-project//mlir:DataLayoutInterfaces",
73+
"@llvm-project//mlir:Dialect",
74+
"@llvm-project//mlir:DialectUtils",
75+
"@llvm-project//mlir:FuncDialect",
76+
"@llvm-project//mlir:IR",
77+
"@llvm-project//mlir:MathDialect",
78+
"@llvm-project//mlir:MemRefDialect",
79+
"@llvm-project//mlir:Pass",
80+
"@llvm-project//mlir:Support",
81+
"@llvm-project//mlir:VectorDialect",
82+
"@xla//xla:array",
83+
"@xla//xla:shape_util",
84+
"@xla//xla/tsl/platform:statusor",
85+
] + mosaic_extension_deps,
86+
)
87+
88+
cc_library(
89+
name = "tpu_passes",
90+
srcs = [
91+
":extension_srcs",
92+
] + glob(
93+
[
94+
"dialect/tpu/transforms/*.cc",
95+
],
96+
exclude = [
97+
"dialect/tpu/transforms/serde.cc",
98+
"dialect/tpu/transforms/linalg_vectorization.cc",
99+
],
100+
),
101+
hdrs = glob(
102+
[
103+
"dialect/tpu/transforms/*.h",
104+
],
105+
exclude = [
106+
"dialect/tpu/transforms/serde.h",
107+
"dialect/tpu/transforms/linalg_vectorization.h",
108+
],
109+
),
63110
# compatible with libtpu
64111
deps = [
65112
":pass_boilerplate",
66-
":serde",
113+
":tpu_dialect",
67114
":tpu_inc_gen",
68115
"@com_google_absl//absl/algorithm:container",
69116
"@com_google_absl//absl/container:flat_hash_set",
@@ -84,13 +131,11 @@ cc_library(
84131
"@llvm-project//mlir:DialectUtils",
85132
"@llvm-project//mlir:FuncDialect",
86133
"@llvm-project//mlir:IR",
87-
"@llvm-project//mlir:LinalgTransforms",
88134
"@llvm-project//mlir:MathDialect",
89135
"@llvm-project//mlir:MemRefDialect",
90136
"@llvm-project//mlir:Pass",
91137
"@llvm-project//mlir:SCFDialect",
92138
"@llvm-project//mlir:Support",
93-
"@llvm-project//mlir:TensorDialect",
94139
"@llvm-project//mlir:TransformUtils",
95140
"@llvm-project//mlir:VectorDialect",
96141
"@llvm-project//mlir:VectorTransforms",
@@ -103,6 +148,66 @@ cc_library(
103148
] + mosaic_extension_deps,
104149
)
105150

151+
cc_library(
152+
name = "tpu_serde_pass",
153+
srcs = ["dialect/tpu/transforms/serde.cc"],
154+
hdrs = ["dialect/tpu/transforms/serde.h"],
155+
# compatible with libtpu
156+
deps = [
157+
":pass_boilerplate",
158+
":serde",
159+
":tpu_dialect",
160+
":tpu_inc_gen",
161+
"@llvm-project//llvm:Support",
162+
"@llvm-project//mlir:ArithDialect",
163+
"@llvm-project//mlir:ControlFlowDialect",
164+
"@llvm-project//mlir:DataLayoutInterfaces",
165+
"@llvm-project//mlir:Dialect",
166+
"@llvm-project//mlir:FuncDialect",
167+
"@llvm-project//mlir:IR",
168+
"@llvm-project//mlir:MathDialect",
169+
"@llvm-project//mlir:MemRefDialect",
170+
"@llvm-project//mlir:Pass",
171+
"@llvm-project//mlir:SCFDialect",
172+
"@llvm-project//mlir:Support",
173+
"@llvm-project//mlir:VectorDialect",
174+
] + mosaic_extension_deps,
175+
)
176+
177+
cc_library(
178+
name = "tpu_linalg_vectorization_pass",
179+
srcs = ["dialect/tpu/transforms/linalg_vectorization.cc"],
180+
hdrs = ["dialect/tpu/transforms/linalg_vectorization.h"],
181+
# compatible with libtpu
182+
deps = [
183+
":pass_boilerplate",
184+
":serde",
185+
":tpu_dialect",
186+
":tpu_inc_gen",
187+
"@com_google_absl//absl/algorithm:container",
188+
"@com_google_absl//absl/container:flat_hash_set",
189+
"@llvm-project//llvm:Support",
190+
"@llvm-project//mlir:ArithDialect",
191+
"@llvm-project//mlir:CommonFolders",
192+
"@llvm-project//mlir:ControlFlowDialect",
193+
"@llvm-project//mlir:DataLayoutInterfaces",
194+
"@llvm-project//mlir:Dialect",
195+
"@llvm-project//mlir:DialectUtils",
196+
"@llvm-project//mlir:FuncDialect",
197+
"@llvm-project//mlir:IR",
198+
"@llvm-project//mlir:LinalgTransforms",
199+
"@llvm-project//mlir:MathDialect",
200+
"@llvm-project//mlir:MemRefDialect",
201+
"@llvm-project//mlir:Pass",
202+
"@llvm-project//mlir:SCFDialect",
203+
"@llvm-project//mlir:Support",
204+
"@llvm-project//mlir:TensorDialect",
205+
"@llvm-project//mlir:TransformUtils",
206+
"@llvm-project//mlir:VectorDialect",
207+
"@llvm-project//mlir:VectorTransforms",
208+
] + mosaic_extension_deps,
209+
)
210+
106211
gentbl_cc_library(
107212
name = "tpu_inc_gen",
108213
# compatible with libtpu
@@ -170,6 +275,8 @@ cc_library(
170275
deps = [
171276
":tpu_dialect",
172277
":tpu_inc_gen",
278+
":tpu_passes",
279+
":tpu_serde_pass",
173280
"@com_google_absl//absl/log",
174281
"@com_google_absl//absl/log:check",
175282
"@llvm-project//llvm:Support",

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,22 +1576,6 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO
15761576
];
15771577
}
15781578

1579-
def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp"> {
1580-
let dependentDialects = [
1581-
"::mlir::func::FuncDialect",
1582-
"::mlir::memref::MemRefDialect",
1583-
"::mlir::linalg::LinalgDialect",
1584-
"::mlir::tensor::TensorDialect",
1585-
"::mlir::vector::VectorDialect",
1586-
"::mlir::tpu::TPUDialect",
1587-
];
1588-
let constructor = "::mlir::tpu::createLinalgVectorizationPass(false)";
1589-
let options = [
1590-
Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">,
1591-
Option<"supports_bf16_matmul", "supports-bf16-matmul", "bool", "", "">,
1592-
];
1593-
}
1594-
15951579
def PreCanonicalizationOptimizationPass : Pass<"pre-canonicalization-optimization", "::mlir::func::FuncOp"> {
15961580
let summary = "Fold matmul rhs tranpose into the op before layout inference";
15971581
let constructor = "::mlir::tpu::createPreCanonicalizationOptimizationPass()";

jaxlib/mosaic/dialect/tpu/tpu_dialect.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <array>
2020
#include <cstdint>
2121
#include <memory>
22+
#include <string_view>
2223
#include <utility>
2324

2425
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -98,10 +99,6 @@ createPreCanonicalizationOptimizationPass(
9899
std::unique_ptr<OperationPass<func::FuncOp>>
99100
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
100101

101-
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass(
102-
bool supports_bf16_alu_instructions = false,
103-
bool supports_bf16_matmul = false);
104-
105102
std::unique_ptr<OperationPass<func::FuncOp>> createDebugAssertInsertionPass();
106103

107104
#define GEN_PASS_DECL_MOSAICSERDEPASS
@@ -129,6 +126,8 @@ DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder,
129126
#define GEN_PASS_REGISTRATION
130127
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
131128

129+
constexpr std::string_view kLeadingTileRows = "leading_tile_rows";
130+
132131
} // namespace tpu
133132
} // namespace mlir
134133

jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818

1919
#include <array>
2020
#include <cstdint>
21-
#include <string_view>
2221

2322
#include "mlir/IR/BuiltinTypes.h"
2423
#include "mlir/Support/LogicalResult.h"
@@ -32,8 +31,6 @@ FailureOr<MemRefType> inferMemref(MemRefType memref, int hardware_generation,
3231
bool is_kernel_argument,
3332
int64_t leading_tile_rows = 0);
3433

35-
const std::string_view kLeadingTileRows = "leading_tile_rows";
36-
3734
} // namespace mlir::tpu
3835

3936
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_

0 commit comments

Comments
 (0)