Skip to content

Commit 17f8b31

Browse files
authored
[mypyc] Import librt.base64 capsule automatically if needed (#20233)
Allow primitives to specify the capsule they need via module name such as `librt.base64`. This way we can import the capsule automatically only when there are references to the contents of the capsule in the compiled code. Only make the change for `librt.base64`, but we can also do a similar thing for `librt.internal` in a follow-up PR.
1 parent 3d23716 commit 17f8b31

File tree

14 files changed

+113
-10
lines changed

14 files changed

+113
-10
lines changed

mypyc/analysis/capsule_deps.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
from mypyc.ir.func_ir import FuncIR
4+
from mypyc.ir.ops import CallC, PrimitiveOp
5+
6+
7+
def find_implicit_capsule_dependencies(fn: FuncIR) -> set[str] | None:
8+
"""Find implicit dependencies on capsules that need to be imported.
9+
10+
Using primitives or types defined in librt submodules such as "librt.base64"
11+
requires a capsule import.
12+
13+
Note that a module can depend on a librt module even if it doesn't explicitly
14+
import it, for example via re-exported names or via return types of functions
15+
defined in other modules.
16+
"""
17+
deps: set[str] | None = None
18+
for block in fn.blocks:
19+
for op in block.ops:
20+
# TODO: Also determine implicit type object dependencies (e.g. cast targets)
21+
if isinstance(op, CallC) and op.capsule is not None:
22+
if deps is None:
23+
deps = set()
24+
deps.add(op.capsule)
25+
else:
26+
assert not isinstance(op, PrimitiveOp), "Lowered IR is expected"
27+
return deps

mypyc/build.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,6 @@ def mypycify(
495495
group_name: str | None = None,
496496
log_trace: bool = False,
497497
depends_on_librt_internal: bool = False,
498-
depends_on_librt_base64: bool = False,
499498
install_librt: bool = False,
500499
experimental_features: bool = False,
501500
) -> list[Extension]:
@@ -570,7 +569,6 @@ def mypycify(
570569
group_name=group_name,
571570
log_trace=log_trace,
572571
depends_on_librt_internal=depends_on_librt_internal,
573-
depends_on_librt_base64=depends_on_librt_base64,
574572
experimental_features=experimental_features,
575573
)
576574

mypyc/codegen/emitmodule.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from mypy.options import Options
2828
from mypy.plugin import Plugin, ReportConfigContext
2929
from mypy.util import hash_digest, json_dumps
30+
from mypyc.analysis.capsule_deps import find_implicit_capsule_dependencies
3031
from mypyc.codegen.cstring import c_string_initializer
3132
from mypyc.codegen.emit import Emitter, EmitterContext, HeaderDeclaration, c_array_initializer
3233
from mypyc.codegen.emitclass import generate_class, generate_class_reuse, generate_class_type_decl
@@ -259,6 +260,10 @@ def compile_scc_to_ir(
259260

260261
# Switch to lower abstraction level IR.
261262
lower_ir(fn, compiler_options)
263+
# Calculate implicit module dependencies (needed for librt)
264+
capsules = find_implicit_capsule_dependencies(fn)
265+
if capsules is not None:
266+
module.capsules.update(capsules)
262267
# Perform optimizations.
263268
do_copy_propagation(fn, compiler_options)
264269
do_flag_elimination(fn, compiler_options)
@@ -604,7 +609,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
604609
ext_declarations.emit_line("#include <CPy.h>")
605610
if self.compiler_options.depends_on_librt_internal:
606611
ext_declarations.emit_line("#include <librt_internal.h>")
607-
if self.compiler_options.depends_on_librt_base64:
612+
if any("librt.base64" in mod.capsules for mod in self.modules.values()):
608613
ext_declarations.emit_line("#include <librt_base64.h>")
609614

610615
declarations = Emitter(self.context)
@@ -1036,7 +1041,7 @@ def emit_module_exec_func(
10361041
emitter.emit_line("if (import_librt_internal() < 0) {")
10371042
emitter.emit_line("return -1;")
10381043
emitter.emit_line("}")
1039-
if self.compiler_options.depends_on_librt_base64:
1044+
if "librt.base64" in module.capsules:
10401045
emitter.emit_line("if (import_librt_base64() < 0) {")
10411046
emitter.emit_line("return -1;")
10421047
emitter.emit_line("}")

mypyc/ir/module_ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def __init__(
3030
# These are only visible in the module that defined them, so no need
3131
# to serialize.
3232
self.type_var_names = type_var_names
33+
# Capsules needed by the module, specified via module names such as "librt.base64"
34+
self.capsules: set[str] = set()
3335

3436
def serialize(self) -> JsonDict:
3537
return {
@@ -38,18 +40,21 @@ def serialize(self) -> JsonDict:
3840
"functions": [f.serialize() for f in self.functions],
3941
"classes": [c.serialize() for c in self.classes],
4042
"final_names": [(k, t.serialize()) for k, t in self.final_names],
43+
"capsules": sorted(self.capsules),
4144
}
4245

4346
@classmethod
4447
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ModuleIR:
45-
return ModuleIR(
48+
module = ModuleIR(
4649
data["fullname"],
4750
data["imports"],
4851
[ctx.functions[FuncDecl.get_id_from_json(f)] for f in data["functions"]],
4952
[ClassIR.deserialize(c, ctx) for c in data["classes"]],
5053
[(k, deserialize_type(t, ctx)) for k, t in data["final_names"]],
5154
[],
5255
)
56+
module.capsules = set(data["capsules"])
57+
return module
5358

5459

5560
def deserialize_modules(data: dict[str, JsonDict], ctx: DeserMaps) -> dict[str, ModuleIR]:

mypyc/ir/ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ def __init__(
708708
priority: int,
709709
is_pure: bool,
710710
experimental: bool,
711+
capsule: str | None,
711712
) -> None:
712713
# Each primitive much have a distinct name, but otherwise they are arbitrary.
713714
self.name: Final = name
@@ -733,6 +734,9 @@ def __init__(
733734
# Experimental primitives are not used unless mypyc experimental features are
734735
# explicitly enabled
735736
self.experimental = experimental
737+
# Capsule that needs to imported and configured to call the primitive
738+
# (name of the target module, e.g. "librt.base64").
739+
self.capsule = capsule
736740

737741
def __repr__(self) -> str:
738742
return f"<PrimitiveDescription {self.name!r}: {self.arg_types}>"
@@ -1233,6 +1237,7 @@ def __init__(
12331237
*,
12341238
is_pure: bool = False,
12351239
returns_null: bool = False,
1240+
capsule: str | None = None,
12361241
) -> None:
12371242
self.error_kind = error_kind
12381243
super().__init__(line)
@@ -1250,6 +1255,9 @@ def __init__(
12501255
# The function might return a null value that does not indicate
12511256
# an error.
12521257
self.returns_null = returns_null
1258+
# A capsule from this module must be imported and initialized before calling this
1259+
# function (used for C functions exported from librt). Example value: "librt.base64"
1260+
self.capsule = capsule
12531261
if is_pure or returns_null:
12541262
assert error_kind == ERR_NEVER
12551263

mypyc/irbuild/ll_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,6 +2075,7 @@ def call_c(
20752075
var_arg_idx,
20762076
is_pure=desc.is_pure,
20772077
returns_null=desc.returns_null,
2078+
capsule=desc.capsule,
20782079
)
20792080
)
20802081
if desc.is_borrowed:
@@ -2159,6 +2160,7 @@ def primitive_op(
21592160
desc.priority,
21602161
is_pure=desc.is_pure,
21612162
returns_null=False,
2163+
capsule=desc.capsule,
21622164
)
21632165
return self.call_c(c_desc, args, line, result_type=result_type)
21642166

mypyc/options.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def __init__(
1818
group_name: str | None = None,
1919
log_trace: bool = False,
2020
depends_on_librt_internal: bool = False,
21-
depends_on_librt_base64: bool = False,
2221
experimental_features: bool = False,
2322
) -> None:
2423
self.strip_asserts = strip_asserts
@@ -57,7 +56,6 @@ def __init__(
5756
# only for mypy itself, third-party code compiled with mypyc should not use
5857
# librt.internal.
5958
self.depends_on_librt_internal = depends_on_librt_internal
60-
self.depends_on_librt_base64 = depends_on_librt_base64
6159
# Some experimental features are only available when building librt in
6260
# experimental mode (e.g. use _experimental suffix in librt run test).
6361
# These can't be used with a librt wheel installed from PyPI.

mypyc/primitives/misc_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,4 +473,5 @@
473473
c_function_name="LibRTBase64_b64encode_internal",
474474
error_kind=ERR_MAGIC,
475475
experimental=True,
476+
capsule="librt.base64",
476477
)

mypyc/primitives/registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class CFunctionDescription(NamedTuple):
6262
priority: int
6363
is_pure: bool
6464
returns_null: bool
65+
capsule: str | None
6566

6667

6768
# A description for C load operations including LoadGlobal and LoadAddress
@@ -100,6 +101,7 @@ def method_op(
100101
is_borrowed: bool = False,
101102
priority: int = 1,
102103
is_pure: bool = False,
104+
capsule: str | None = None,
103105
) -> PrimitiveDescription:
104106
"""Define a c function call op that replaces a method call.
105107
@@ -145,6 +147,7 @@ def method_op(
145147
priority,
146148
is_pure=is_pure,
147149
experimental=False,
150+
capsule=capsule,
148151
)
149152
ops.append(desc)
150153
return desc
@@ -164,6 +167,7 @@ def function_op(
164167
is_borrowed: bool = False,
165168
priority: int = 1,
166169
experimental: bool = False,
170+
capsule: str | None = None,
167171
) -> PrimitiveDescription:
168172
"""Define a C function call op that replaces a function call.
169173
@@ -193,6 +197,7 @@ def function_op(
193197
priority=priority,
194198
is_pure=False,
195199
experimental=experimental,
200+
capsule=capsule,
196201
)
197202
ops.append(desc)
198203
return desc
@@ -212,6 +217,7 @@ def binary_op(
212217
steals: StealsDescription = False,
213218
is_borrowed: bool = False,
214219
priority: int = 1,
220+
capsule: str | None = None,
215221
) -> PrimitiveDescription:
216222
"""Define a c function call op for a binary operation.
217223
@@ -240,6 +246,7 @@ def binary_op(
240246
priority=priority,
241247
is_pure=False,
242248
experimental=False,
249+
capsule=capsule,
243250
)
244251
ops.append(desc)
245252
return desc
@@ -281,6 +288,7 @@ def custom_op(
281288
0,
282289
is_pure=is_pure,
283290
returns_null=returns_null,
291+
capsule=None,
284292
)
285293

286294

@@ -297,6 +305,7 @@ def custom_primitive_op(
297305
steals: StealsDescription = False,
298306
is_borrowed: bool = False,
299307
is_pure: bool = False,
308+
capsule: str | None = None,
300309
) -> PrimitiveDescription:
301310
"""Define a primitive op that can't be automatically generated based on the AST.
302311
@@ -319,6 +328,7 @@ def custom_primitive_op(
319328
priority=0,
320329
is_pure=is_pure,
321330
experimental=False,
331+
capsule=capsule,
322332
)
323333

324334

@@ -335,6 +345,7 @@ def unary_op(
335345
is_borrowed: bool = False,
336346
priority: int = 1,
337347
is_pure: bool = False,
348+
capsule: str | None = None,
338349
) -> PrimitiveDescription:
339350
"""Define a primitive op for an unary operation.
340351
@@ -361,6 +372,7 @@ def unary_op(
361372
priority=priority,
362373
is_pure=is_pure,
363374
experimental=False,
375+
capsule=capsule,
364376
)
365377
ops.append(desc)
366378
return desc
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
[case testBase64_experimental]
2+
from librt.base64 import b64encode
3+
4+
def enc(b: bytes) -> bytes:
5+
return b64encode(b)
6+
[out]
7+
def enc(b):
8+
b, r0 :: bytes
9+
L0:
10+
r0 = LibRTBase64_b64encode_internal(b)
11+
return r0
12+
13+
[case testBase64ExperimentalDisabled]
14+
from librt.base64 import b64encode
15+
16+
def enc(b: bytes) -> bytes:
17+
return b64encode(b)
18+
[out]
19+
def enc(b):
20+
b :: bytes
21+
r0 :: dict
22+
r1 :: str
23+
r2 :: object
24+
r3 :: object[1]
25+
r4 :: object_ptr
26+
r5 :: object
27+
r6 :: bytes
28+
L0:
29+
r0 = __main__.globals :: static
30+
r1 = 'b64encode'
31+
r2 = CPyDict_GetItem(r0, r1)
32+
r3 = [b]
33+
r4 = load_address r3
34+
r5 = PyObject_Vectorcall(r2, r4, 1, 0)
35+
keep_alive b
36+
r6 = cast(bytes, r5)
37+
return r6

0 commit comments

Comments
 (0)