From ec4a6852ec6473641a5702055612905aa1f45088 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Sat, 21 Feb 2026 18:59:33 +0530 Subject: [PATCH] PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs --- pythonbpf/allocation_pass.py | 61 +++++---- pythonbpf/assign_pass.py | 24 ++-- pythonbpf/codegen.py | 25 ++-- pythonbpf/context.py | 82 ++++++++++++ pythonbpf/expr/call_registry.py | 8 +- pythonbpf/expr/expr_pass.py | 143 +++++++------------- pythonbpf/functions/functions_pass.py | 149 +++++++++------------ pythonbpf/globals_pass.py | 30 +++-- pythonbpf/helper/__init__.py | 16 +-- pythonbpf/helper/bpf_helper_handler.py | 172 ++++++++----------------- pythonbpf/helper/helper_utils.py | 146 +++++++++------------ pythonbpf/license_pass.py | 14 +- pythonbpf/maps/maps_pass.py | 71 +++++++--- pythonbpf/structs/structs_pass.py | 11 +- 14 files changed, 455 insertions(+), 497 deletions(-) create mode 100644 pythonbpf/context.py diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index db6f0bac..a0a7e68f 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -26,9 +26,7 @@ def create_targets_and_rvals(stmt): return stmt.targets, [stmt.value] -def handle_assign_allocation( - builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab -): +def handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab): """Handle memory allocation for assignment statements.""" logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}") @@ -59,7 +57,7 @@ def handle_assign_allocation( # Determine type and allocate based on rval if isinstance(rval, ast.Call): _allocate_for_call( - builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + builder, var_name, rval, local_sym_tab, compilation_context ) elif isinstance(rval, ast.Constant): _allocate_for_constant(builder, var_name, rval, local_sym_tab) @@ -71,7 +69,7 @@ def handle_assign_allocation( elif isinstance(rval, ast.Attribute): # Struct field-to-variable assignment (a = dat.fld) _allocate_for_attribute( - builder, var_name, rval, local_sym_tab, structs_sym_tab + builder, var_name, rval, local_sym_tab, compilation_context ) else: logger.warning( @@ -79,10 +77,9 @@ def handle_assign_allocation( ) -def _allocate_for_call( - builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab -): +def _allocate_for_call(builder, var_name, rval, local_sym_tab, compilation_context): """Allocate memory for variable assigned from a call.""" + structs_sym_tab = compilation_context.structs_sym_tab if isinstance(rval.func, ast.Name): call_type = rval.func.id @@ -149,7 +146,7 @@ def _allocate_for_call( elif isinstance(rval.func, ast.Attribute): # Map method calls - need double allocation for ptr handling _allocate_for_map_method( - builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + builder, var_name, rval, local_sym_tab, compilation_context ) else: @@ -157,9 +154,11 @@ def _allocate_for_call( def _allocate_for_map_method( - builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + builder, var_name, rval, local_sym_tab, compilation_context ): """Allocate memory for variable assigned from map method (double alloc).""" + map_sym_tab = compilation_context.map_sym_tab + structs_sym_tab = compilation_context.structs_sym_tab map_name = rval.func.value.id method_name = rval.func.attr @@ -299,6 +298,15 @@ def allocate_temp_pool(builder, max_temps, local_sym_tab): logger.debug(f"Allocated temp variable: {temp_name}") +def _get_alignment(tmp_type): + """Return alignment for a given type.""" + if isinstance(tmp_type, ir.PointerType): + return 8 + elif isinstance(tmp_type, ir.IntType): + return tmp_type.width // 8 + return 8 + + def _allocate_for_name(builder, var_name, rval, local_sym_tab): """Allocate memory for variable-to-variable assignment (b = a).""" source_var = rval.id @@ -321,8 +329,22 @@ def _allocate_for_name(builder, var_name, rval, local_sym_tab): ) -def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_tab): +def _allocate_with_type(builder, var_name, ir_type): + """Allocate memory for a variable with a specific type.""" + var = builder.alloca(ir_type, name=var_name) + if isinstance(ir_type, ir.IntType): + var.align = ir_type.width // 8 + elif isinstance(ir_type, ir.PointerType): + var.align = 8 + return var + + +def _allocate_for_attribute( + builder, var_name, rval, local_sym_tab, compilation_context +): """Allocate memory for struct field-to-variable assignment (a = dat.fld).""" + structs_sym_tab = compilation_context.structs_sym_tab + if not isinstance(rval.value, ast.Name): logger.warning(f"Complex attribute access not supported for {var_name}") return @@ -455,20 +477,3 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_ logger.info( f"Pre-allocated {var_name} from {struct_var}.{field_name} with type {alloc_type}" ) - - -def _allocate_with_type(builder, var_name, ir_type): - """Allocate variable with appropriate alignment for type.""" - var = builder.alloca(ir_type, name=var_name) - var.align = _get_alignment(ir_type) - return var - - -def _get_alignment(ir_type): - """Get appropriate alignment for IR type.""" - if isinstance(ir_type, ir.IntType): - return ir_type.width // 8 - elif isinstance(ir_type, ir.ArrayType) and isinstance(ir_type.element, ir.IntType): - return ir_type.element.width // 8 - else: - return 8 # Default: pointer size diff --git a/pythonbpf/assign_pass.py b/pythonbpf/assign_pass.py index 412af932..d252c49a 100644 --- a/pythonbpf/assign_pass.py +++ b/pythonbpf/assign_pass.py @@ -12,7 +12,7 @@ def handle_struct_field_assignment( - func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, builder, target, rval, local_sym_tab ): """Handle struct field assignment (obj.field = value).""" @@ -24,7 +24,7 @@ def handle_struct_field_assignment( return struct_type = local_sym_tab[var_name].metadata - struct_info = structs_sym_tab[struct_type] + struct_info = compilation_context.structs_sym_tab[struct_type] if field_name not in struct_info.fields: logger.error(f"Field '{field_name}' not found in struct '{struct_type}'") @@ -33,9 +33,7 @@ def handle_struct_field_assignment( # Get field pointer and evaluate value field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name) field_type = struct_info.field_type(field_name) - val_result = eval_expr( - func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab - ) + val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab) if val_result is None: logger.error(f"Failed to evaluate value for {var_name}.{field_name}") @@ -47,14 +45,12 @@ def handle_struct_field_assignment( if _is_char_array(field_type) and _is_i8_ptr(val_type): _copy_string_to_char_array( func, - module, + compilation_context, builder, val, field_ptr, field_type, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) logger.info(f"Copied string to char array {var_name}.{field_name}") return @@ -66,14 +62,12 @@ def handle_struct_field_assignment( def _copy_string_to_char_array( func, - module, + compilation_context, builder, src_ptr, dst_ptr, array_type, local_sym_tab, - map_sym_tab, - struct_sym_tab, ): """Copy string (i8*) to char array ([N x i8]) using bpf_probe_read_kernel_str""" @@ -109,7 +103,7 @@ def _is_i8_ptr(ir_type): def handle_variable_assignment( - func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, builder, var_name, rval, local_sym_tab ): """Handle single named variable assignment.""" @@ -120,6 +114,8 @@ def handle_variable_assignment( var_ptr = local_sym_tab[var_name].var var_type = local_sym_tab[var_name].ir_type + structs_sym_tab = compilation_context.structs_sym_tab + # NOTE: Special case for struct initialization if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name): struct_name = rval.func.id @@ -142,9 +138,7 @@ def handle_variable_assignment( logger.info(f"Assigned char array pointer to {var_name}") return True - val_result = eval_expr( - func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab - ) + val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab) if val_result is None: logger.error(f"Failed to evaluate value for {var_name}") return False diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index 543675b5..7c95b140 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -1,5 +1,6 @@ import ast from llvmlite import ir +from .context import CompilationContext from .license_pass import license_processing from .functions import func_proc from .maps import maps_proc @@ -67,9 +68,10 @@ def find_bpf_chunks(tree): return bpf_functions -def processor(source_code, filename, module): +def processor(source_code, filename, compilation_context): tree = ast.parse(source_code, filename) logger.debug(ast.dump(tree, indent=4)) + module = compilation_context.module bpf_chunks = find_bpf_chunks(tree) for func_node in bpf_chunks: @@ -81,15 +83,18 @@ def processor(source_code, filename, module): if vmlinux_symtab: handler = VmlinuxHandler.initialize(vmlinux_symtab) VmlinuxHandlerRegistry.set_handler(handler) + compilation_context.vmlinux_handler = handler - populate_global_symbol_table(tree, module) - license_processing(tree, module) - globals_processing(tree, module) - structs_sym_tab = structs_proc(tree, module, bpf_chunks) - map_sym_tab = maps_proc(tree, module, bpf_chunks, structs_sym_tab) - func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab) + populate_global_symbol_table(tree, compilation_context) + license_processing(tree, compilation_context) + globals_processing(tree, compilation_context) + structs_sym_tab = structs_proc(tree, compilation_context, bpf_chunks) - globals_list_creation(tree, module) + map_sym_tab = maps_proc(tree, compilation_context, bpf_chunks) + + func_proc(tree, compilation_context, bpf_chunks) + + globals_list_creation(tree, compilation_context) return structs_sym_tab, map_sym_tab @@ -104,6 +109,8 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO): module.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128" module.triple = "bpf" + compilation_context = CompilationContext(module) + if not hasattr(module, "_debug_compile_unit"): debug_generator = DebugInfoGenerator(module) debug_generator.generate_file_metadata(filename, os.path.dirname(filename)) @@ -116,7 +123,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO): True, ) - structs_sym_tab, maps_sym_tab = processor(source, filename, module) + structs_sym_tab, maps_sym_tab = processor(source, filename, compilation_context) wchar_size = module.add_metadata( [ diff --git a/pythonbpf/context.py b/pythonbpf/context.py new file mode 100644 index 00000000..5297889f --- /dev/null +++ b/pythonbpf/context.py @@ -0,0 +1,82 @@ +from llvmlite import ir +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pythonbpf.structs.struct_type import StructType + from pythonbpf.maps.maps_utils import MapSymbol + +logger = logging.getLogger(__name__) + + +class ScratchPoolManager: + """Manage the temporary helper variables in local_sym_tab""" + + def __init__(self): + self._counters = {} + + @property + def counter(self): + return sum(self._counters.values()) + + def reset(self): + self._counters.clear() + logger.debug("Scratch pool counter reset to 0") + + def _get_type_name(self, ir_type): + if isinstance(ir_type, ir.PointerType): + return "ptr" + elif isinstance(ir_type, ir.IntType): + return f"i{ir_type.width}" + elif isinstance(ir_type, ir.ArrayType): + return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]" + else: + return str(ir_type).replace(" ", "") + + def get_next_temp(self, local_sym_tab, expected_type=None): + # Default to i64 if no expected type provided + type_name = self._get_type_name(expected_type) if expected_type else "i64" + if type_name not in self._counters: + self._counters[type_name] = 0 + + counter = self._counters[type_name] + temp_name = f"__helper_temp_{type_name}_{counter}" + self._counters[type_name] += 1 + + if temp_name not in local_sym_tab: + raise ValueError( + f"Scratch pool exhausted or inadequate: {temp_name}. " + f"Type: {type_name} Counter: {counter}" + ) + + logger.debug(f"Using {temp_name} for type {type_name}") + return local_sym_tab[temp_name].var, temp_name + + +class CompilationContext: + """ + Holds the state for a single compilation run. + This replaces global mutable state modules. + """ + + def __init__(self, module: ir.Module): + self.module = module + + # Symbol tables + self.global_sym_tab: list[ir.GlobalVariable] = [] + self.structs_sym_tab: dict[str, "StructType"] = {} + self.map_sym_tab: dict[str, "MapSymbol"] = {} + + # Helper management + self.scratch_pool = ScratchPoolManager() + + # Vmlinux handling (optional, specialized) + self.vmlinux_handler = None # Can be VmlinuxHandler instance + + # Current function context (optional, if needed globally during function processing) + self.current_func = None + + def reset(self): + """Reset state between functions if necessary, though new context per compile is preferred.""" + self.scratch_pool.reset() + self.current_func = None diff --git a/pythonbpf/expr/call_registry.py b/pythonbpf/expr/call_registry.py index 858e23c4..2b3f8b3e 100644 --- a/pythonbpf/expr/call_registry.py +++ b/pythonbpf/expr/call_registry.py @@ -9,12 +9,8 @@ def set_handler(cls, handler): cls._handler = handler @classmethod - def handle_call( - cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab - ): + def handle_call(cls, call, compilation_context, builder, func, local_sym_tab): """Handle a call using the registered handler""" if cls._handler is None: return None - return cls._handler( - call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab - ) + return cls._handler(call, compilation_context, builder, func, local_sym_tab) diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index d34dff5c..2270fdbb 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -37,7 +37,7 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder raise SyntaxError(f"Undefined variable {expr.id}") -def _handle_constant_expr(module, builder, expr: ast.Constant): +def _handle_constant_expr(compilation_context, builder, expr: ast.Constant): """Handle ast.Constant expressions.""" if isinstance(expr.value, int) or isinstance(expr.value, bool): return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) @@ -48,7 +48,9 @@ def _handle_constant_expr(module, builder, expr: ast.Constant): str_constant = ir.Constant(str_type, bytearray(str_bytes)) # Create global variable - global_str = ir.GlobalVariable(module, str_type, name=str_name) + global_str = ir.GlobalVariable( + compilation_context.module, str_type, name=str_name + ) global_str.linkage = "internal" global_str.global_constant = True global_str.initializer = str_constant @@ -64,10 +66,11 @@ def _handle_attribute_expr( func, expr: ast.Attribute, local_sym_tab: Dict, - structs_sym_tab: Dict, + compilation_context, builder: ir.IRBuilder, ): """Handle ast.Attribute expressions for struct field access.""" + structs_sym_tab = compilation_context.structs_sym_tab if isinstance(expr.value, ast.Name): var_name = expr.value.id attr_name = expr.attr @@ -157,9 +160,7 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde # ============================================================================ -def get_operand_value( - func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None -): +def get_operand_value(func, compilation_context, operand, builder, local_sym_tab): """Extract the value from an operand, handling variables and constants.""" logger.info(f"Getting operand value for: {ast.dump(operand)}") if isinstance(operand, ast.Name): @@ -187,13 +188,11 @@ def get_operand_value( raise TypeError(f"Unsupported constant type: {type(operand.value)}") elif isinstance(operand, ast.BinOp): res = _handle_binary_op_impl( - func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, operand, builder, local_sym_tab ) return res else: - res = eval_expr( - func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab - ) + res = eval_expr(func, compilation_context, builder, operand, local_sym_tab) if res is None: raise ValueError(f"Failed to evaluate call expression: {operand}") val, _ = res @@ -205,15 +204,13 @@ def get_operand_value( raise TypeError(f"Unsupported operand type: {type(operand)}") -def _handle_binary_op_impl( - func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None -): +def _handle_binary_op_impl(func, compilation_context, rval, builder, local_sym_tab): op = rval.op left = get_operand_value( - func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, rval.left, builder, local_sym_tab ) right = get_operand_value( - func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, rval.right, builder, local_sym_tab ) logger.info(f"left is {left}, right is {right}, op is {op}") @@ -249,16 +246,14 @@ def _handle_binary_op_impl( def _handle_binary_op( func, - module, + compilation_context, rval, builder, var_name, local_sym_tab, - map_sym_tab, - structs_sym_tab=None, ): result = _handle_binary_op_impl( - func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, rval, builder, local_sym_tab ) if var_name and var_name in local_sym_tab: logger.info( @@ -275,12 +270,10 @@ def _handle_binary_op( def _handle_ctypes_call( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab=None, ): """Handle ctypes type constructor calls.""" if len(expr.args) != 1: @@ -290,12 +283,10 @@ def _handle_ctypes_call( arg = expr.args[0] val = eval_expr( func, - module, + compilation_context, builder, arg, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) if val is None: logger.info("Failed to evaluate argument to ctypes constructor") @@ -344,9 +335,7 @@ def _handle_ctypes_call( return value, expected_type -def _handle_compare( - func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None -): +def _handle_compare(func, compilation_context, builder, cond, local_sym_tab): """Handle ast.Compare expressions.""" if len(cond.ops) != 1 or len(cond.comparators) != 1: @@ -354,21 +343,17 @@ def _handle_compare( return None lhs = eval_expr( func, - module, + compilation_context, builder, cond.left, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) rhs = eval_expr( func, - module, + compilation_context, builder, cond.comparators[0], local_sym_tab, - map_sym_tab, - structs_sym_tab, ) if lhs is None or rhs is None: @@ -382,12 +367,10 @@ def _handle_compare( def _handle_unary_op( func, - module, + compilation_context, builder, expr: ast.UnaryOp, local_sym_tab, - map_sym_tab, - structs_sym_tab=None, ): """Handle ast.UnaryOp expressions.""" if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub): @@ -395,7 +378,7 @@ def _handle_unary_op( return None operand = get_operand_value( - func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, expr.operand, builder, local_sym_tab ) if operand is None: logger.error("Failed to evaluate operand for unary operation") @@ -418,7 +401,7 @@ def _handle_unary_op( # ============================================================================ -def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): +def _handle_and_op(func, builder, expr, local_sym_tab, compilation_context): """Handle `and` boolean operations.""" logger.debug(f"Handling 'and' operator with {len(expr.values)} operands") @@ -433,7 +416,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_ # Evaluate current operand operand_result = eval_expr( - func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, builder, value, local_sym_tab ) if operand_result is None: logger.error(f"Failed to evaluate operand {i} in 'and' expression") @@ -471,7 +454,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_ return phi, ir.IntType(1) -def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): +def _handle_or_op(func, builder, expr, local_sym_tab, compilation_context): """Handle `or` boolean operations.""" logger.debug(f"Handling 'or' operator with {len(expr.values)} operands") @@ -486,7 +469,7 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t # Evaluate current operand operand_result = eval_expr( - func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, builder, value, local_sym_tab ) if operand_result is None: logger.error(f"Failed to evaluate operand {i} in 'or' expression") @@ -526,23 +509,17 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t def _handle_boolean_op( func, - module, + compilation_context, builder, expr: ast.BoolOp, local_sym_tab, - map_sym_tab, - structs_sym_tab=None, ): """Handle `and` and `or` boolean operations.""" if isinstance(expr.op, ast.And): - return _handle_and_op( - func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab - ) + return _handle_and_op(func, builder, expr, local_sym_tab, compilation_context) elif isinstance(expr.op, ast.Or): - return _handle_or_op( - func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab - ) + return _handle_or_op(func, builder, expr, local_sym_tab, compilation_context) else: logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}") return None @@ -555,12 +532,10 @@ def _handle_boolean_op( def _handle_vmlinux_cast( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab=None, ): # handle expressions such as struct_request(ctx.di) where struct_request is a vmlinux # struct and ctx.di is a pointer to a struct but is actually represented as a c_uint64 @@ -576,12 +551,10 @@ def _handle_vmlinux_cast( # Evaluate the argument (e.g., ctx.di which is a c_uint64) arg_result = eval_expr( func, - module, + compilation_context, builder, expr.args[0], local_sym_tab, - map_sym_tab, - structs_sym_tab, ) if arg_result is None: @@ -614,18 +587,17 @@ def _handle_vmlinux_cast( def _handle_user_defined_struct_cast( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab, ): """Handle user-defined struct cast expressions like iphdr(nh). This casts a pointer/integer value to a pointer to the user-defined struct, similar to how vmlinux struct casts work but for user-defined @struct types. """ + structs_sym_tab = compilation_context.structs_sym_tab if len(expr.args) != 1: logger.info("User-defined struct cast takes exactly one argument") return None @@ -643,12 +615,10 @@ def _handle_user_defined_struct_cast( # an address/pointer value) arg_result = eval_expr( func, - module, + compilation_context, builder, expr.args[0], local_sym_tab, - map_sym_tab, - structs_sym_tab, ) if arg_result is None: @@ -683,30 +653,28 @@ def _handle_user_defined_struct_cast( def eval_expr( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab=None, ): + structs_sym_tab = compilation_context.structs_sym_tab + logger.info(f"Evaluating expression: {ast.dump(expr)}") if isinstance(expr, ast.Name): return _handle_name_expr(expr, local_sym_tab, builder) elif isinstance(expr, ast.Constant): - return _handle_constant_expr(module, builder, expr) + return _handle_constant_expr(compilation_context, builder, expr) elif isinstance(expr, ast.Call): if isinstance(expr.func, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct( expr.func.id ): return _handle_vmlinux_cast( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) if isinstance(expr.func, ast.Name) and expr.func.id == "deref": return _handle_deref_call(expr, local_sym_tab, builder) @@ -714,26 +682,23 @@ def eval_expr( if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id): return _handle_ctypes_call( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab): return _handle_user_defined_struct_cast( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) + # NOTE: Updated handle_call signature result = CallHandlerRegistry.handle_call( - expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab + expr, compilation_context, builder, func, local_sym_tab ) if result is not None: return result @@ -742,30 +707,24 @@ def eval_expr( return None elif isinstance(expr, ast.Attribute): return _handle_attribute_expr( - func, expr, local_sym_tab, structs_sym_tab, builder + func, expr, local_sym_tab, compilation_context, builder ) elif isinstance(expr, ast.BinOp): return _handle_binary_op( func, - module, + compilation_context, expr, builder, None, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) elif isinstance(expr, ast.Compare): - return _handle_compare( - func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab - ) + return _handle_compare(func, compilation_context, builder, expr, local_sym_tab) elif isinstance(expr, ast.UnaryOp): - return _handle_unary_op( - func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab - ) + return _handle_unary_op(func, compilation_context, builder, expr, local_sym_tab) elif isinstance(expr, ast.BoolOp): return _handle_boolean_op( - func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab + func, compilation_context, builder, expr, local_sym_tab ) logger.info("Unsupported expression evaluation") return None @@ -773,12 +732,10 @@ def eval_expr( def handle_expr( func, - module, + compilation_context, builder, expr, local_sym_tab, - map_sym_tab, - structs_sym_tab, ): """Handle expression statements in the function body.""" logger.info(f"Handling expression: {ast.dump(expr)}") @@ -786,12 +743,10 @@ def handle_expr( if isinstance(call, ast.Call): eval_expr( func, - module, + compilation_context, builder, call, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) else: logger.info("Unsupported expression type") diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index f78ed923..b2d2fc8b 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -4,7 +4,6 @@ from pythonbpf.helper import ( HelperHandlerRegistry, - reset_scratch_pool, ) from pythonbpf.type_deducer import ctypes_to_ir from pythonbpf.expr import ( @@ -76,36 +75,30 @@ def count_temps_in_call(call_node, local_sym_tab): def handle_if_allocation( - module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab + compilation_context, builder, stmt, func, ret_type, local_sym_tab ): """Recursively handle allocations in if/else branches.""" if stmt.body: allocate_mem( - module, + compilation_context, builder, stmt.body, func, ret_type, - map_sym_tab, local_sym_tab, - structs_sym_tab, ) if stmt.orelse: allocate_mem( - module, + compilation_context, builder, stmt.orelse, func, ret_type, - map_sym_tab, local_sym_tab, - structs_sym_tab, ) -def allocate_mem( - module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab -): +def allocate_mem(compilation_context, builder, body, func, ret_type, local_sym_tab): max_temps_needed = {} def merge_type_counts(count_dict): @@ -137,19 +130,15 @@ def update_max_temps_for_stmt(stmt): # Handle allocations if isinstance(stmt, ast.If): handle_if_allocation( - module, + compilation_context, builder, stmt, func, ret_type, - map_sym_tab, local_sym_tab, - structs_sym_tab, ) elif isinstance(stmt, ast.Assign): - handle_assign_allocation( - builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab - ) + handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab) allocate_temp_pool(builder, max_temps_needed, local_sym_tab) @@ -161,9 +150,7 @@ def update_max_temps_for_stmt(stmt): # ============================================================================ -def handle_assign( - func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab -): +def handle_assign(func, compilation_context, builder, stmt, local_sym_tab): """Handle assignment statements in the function body.""" # NOTE: Support multi-target assignments (e.g.: a, b = 1, 2) @@ -175,13 +162,11 @@ def handle_assign( var_name = target.id result = handle_variable_assignment( func, - module, + compilation_context, builder, var_name, rval, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) if not result: logger.error(f"Failed to handle assignment to {var_name}") @@ -191,13 +176,11 @@ def handle_assign( # NOTE: Struct field assignment case: pkt.field = value handle_struct_field_assignment( func, - module, + compilation_context, builder, target, rval, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) continue @@ -205,18 +188,12 @@ def handle_assign( logger.error(f"Unsupported assignment target: {ast.dump(target)}") -def handle_cond( - func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None -): - val = eval_expr( - func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab - )[0] +def handle_cond(func, compilation_context, builder, cond, local_sym_tab): + val = eval_expr(func, compilation_context, builder, cond, local_sym_tab)[0] return convert_to_bool(builder, val) -def handle_if( - func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None -): +def handle_if(func, compilation_context, builder, stmt, local_sym_tab): """Handle if statements in the function body.""" logger.info("Handling if statement") # start = builder.block.parent @@ -227,9 +204,7 @@ def handle_if( else: else_block = None - cond = handle_cond( - func, module, builder, stmt.test, local_sym_tab, map_sym_tab, structs_sym_tab - ) + cond = handle_cond(func, compilation_context, builder, stmt.test, local_sym_tab) if else_block: builder.cbranch(cond, then_block, else_block) else: @@ -237,9 +212,7 @@ def handle_if( builder.position_at_end(then_block) for s in stmt.body: - process_stmt( - func, module, builder, s, local_sym_tab, map_sym_tab, structs_sym_tab, False - ) + process_stmt(func, compilation_context, builder, s, local_sym_tab, False) if not builder.block.is_terminated: builder.branch(merge_block) @@ -248,12 +221,10 @@ def handle_if( for s in stmt.orelse: process_stmt( func, - module, + compilation_context, builder, s, local_sym_tab, - map_sym_tab, - structs_sym_tab, False, ) if not builder.block.is_terminated: @@ -262,21 +233,25 @@ def handle_if( builder.position_at_end(merge_block) -def handle_return(builder, stmt, local_sym_tab, ret_type): +def handle_return(builder, stmt, local_sym_tab, ret_type, compilation_context=None): logger.info(f"Handling return statement: {ast.dump(stmt)}") if stmt.value is None: return handle_none_return(builder) elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id): return handle_xdp_return(stmt, builder, ret_type) else: + # Fallback for now if ctx not passed, but caller should pass it + if compilation_context is None: + raise RuntimeError( + "CompilationContext required for return statement evaluation" + ) + val = eval_expr( func=None, - module=None, + compilation_context=compilation_context, builder=builder, expr=stmt.value, local_sym_tab=local_sym_tab, - map_sym_tab={}, - structs_sym_tab={}, ) logger.info(f"Evaluated return expression to {val}") builder.ret(val[0]) @@ -285,43 +260,34 @@ def handle_return(builder, stmt, local_sym_tab, ret_type): def process_stmt( func, - module, + compilation_context, builder, stmt, local_sym_tab, - map_sym_tab, - structs_sym_tab, did_return, ret_type=ir.IntType(64), ): logger.info(f"Processing statement: {ast.dump(stmt)}") - reset_scratch_pool() + # Use context scratch pool + compilation_context.scratch_pool.reset() + if isinstance(stmt, ast.Expr): handle_expr( func, - module, + compilation_context, builder, stmt, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) elif isinstance(stmt, ast.Assign): - handle_assign( - func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab - ) + handle_assign(func, compilation_context, builder, stmt, local_sym_tab) elif isinstance(stmt, ast.AugAssign): raise SyntaxError("Augmented assignment not supported") elif isinstance(stmt, ast.If): - handle_if( - func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab - ) + handle_if(func, compilation_context, builder, stmt, local_sym_tab) elif isinstance(stmt, ast.Return): did_return = handle_return( - builder, - stmt, - local_sym_tab, - ret_type, + builder, stmt, local_sym_tab, ret_type, compilation_context ) return did_return @@ -332,13 +298,11 @@ def process_stmt( def process_func_body( - module, + compilation_context, builder, func_node, func, ret_type, - map_sym_tab, - structs_sym_tab, ): """Process the body of a bpf function""" # TODO: A lot. We just have print -> bpf_trace_printk for now @@ -360,6 +324,9 @@ def process_func_body( raise TypeError( f"Unsupported annotation type: {ast.dump(context_arg.annotation)}" ) + + # Use context's handler if available, else usage of VmlinuxHandlerRegistry + # For now relying on VmlinuxHandlerRegistry which relies on codegen setting it if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name): resolved_type = VmlinuxHandlerRegistry.get_struct_type( context_type_name @@ -370,14 +337,12 @@ def process_func_body( # pre-allocate dynamic variables local_sym_tab = allocate_mem( - module, + compilation_context, builder, func_node.body, func, ret_type, - map_sym_tab, local_sym_tab, - structs_sym_tab, ) logger.info(f"Local symbol table: {local_sym_tab.keys()}") @@ -385,12 +350,10 @@ def process_func_body( for stmt in func_node.body: did_return = process_stmt( func, - module, + compilation_context, builder, stmt, local_sym_tab, - map_sym_tab, - structs_sym_tab, did_return, ret_type, ) @@ -399,9 +362,12 @@ def process_func_body( builder.ret(ir.Constant(ir.IntType(64), 0)) -def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab): +def process_bpf_chunk(func_node, compilation_context, return_type): """Process a single BPF chunk (function) and emit corresponding LLVM IR.""" + # Set current function in context (optional but good for future) + compilation_context.current_func = func_node + func_name = func_node.name ret_type = return_type @@ -413,7 +379,7 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t param_types.append(ir.PointerType()) func_ty = ir.FunctionType(ret_type, param_types) - func = ir.Function(module, func_ty, func_name) + func = ir.Function(compilation_context.module, func_ty, func_name) func.linkage = "dso_local" func.attributes.add("nounwind") @@ -433,13 +399,11 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t builder = ir.IRBuilder(block) process_func_body( - module, + compilation_context, builder, func_node, func, ret_type, - map_sym_tab, - structs_sym_tab, ) return func @@ -449,23 +413,32 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t # ============================================================================ -def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab): +def func_proc(tree, compilation_context, chunks): + """Process all functions decorated with @bpf and @bpfglobal""" for func_node in chunks: + # Ignore structs and maps + # Check against the lists + if ( + func_node.name in compilation_context.structs_sym_tab + or func_node.name in compilation_context.map_sym_tab + ): + continue + + # Also check decorators to be sure + decorators = [d.id for d in func_node.decorator_list if isinstance(d, ast.Name)] + if "struct" in decorators or "map" in decorators: + continue + if is_global_function(func_node): continue func_type = get_probe_string(func_node) logger.info(f"Found probe_string of {func_node.name}: {func_type}") - func = process_bpf_chunk( - func_node, - module, - ctypes_to_ir(infer_return_type(func_node)), - map_sym_tab, - structs_sym_tab, - ) + return_type = ctypes_to_ir(infer_return_type(func_node)) + func = process_bpf_chunk(func_node, compilation_context, return_type) logger.info(f"Generating Debug Info for Function {func_node.name}") - generate_function_debug_info(func_node, module, func) + generate_function_debug_info(func_node, compilation_context.module, func) # TODO: WIP, for string assignment to fixed-size arrays diff --git a/pythonbpf/globals_pass.py b/pythonbpf/globals_pass.py index 1e977634..c9ac8fc0 100644 --- a/pythonbpf/globals_pass.py +++ b/pythonbpf/globals_pass.py @@ -7,11 +7,11 @@ logger: Logger = logging.getLogger(__name__) -# TODO: this is going to be a huge fuck of a headache in the future. -global_sym_tab = [] - -def populate_global_symbol_table(tree, module: ir.Module): +def populate_global_symbol_table(tree, compilation_context): + """ + compilation_context: CompilationContext + """ for node in tree.body: if isinstance(node, ast.FunctionDef): for dec in node.decorator_list: @@ -23,12 +23,12 @@ def populate_global_symbol_table(tree, module: ir.Module): and isinstance(dec.args[0], ast.Constant) and isinstance(dec.args[0].value, str) ): - global_sym_tab.append(node) + compilation_context.global_sym_tab.append(node) elif isinstance(dec, ast.Name) and dec.id == "bpfglobal": - global_sym_tab.append(node) + compilation_context.global_sym_tab.append(node) elif isinstance(dec, ast.Name) and dec.id == "map": - global_sym_tab.append(node) + compilation_context.global_sym_tab.append(node) return False @@ -74,9 +74,12 @@ def emit_global(module: ir.Module, node, name): return gvar -def globals_processing(tree, module): +def globals_processing(tree, compilation_context): """Process stuff decorated with @bpf and @bpfglobal except license and return the section name""" - globals_sym_tab = [] + # Local tracking for duplicate checking if needed, or we can iterate context + # But for now, we process specific nodes + + current_globals = [] for node in tree.body: # Skip non-assignment and non-function nodes @@ -90,10 +93,10 @@ def globals_processing(tree, module): continue # Check for duplicate names - if name in globals_sym_tab: + if name in current_globals: raise SyntaxError(f"ERROR: Global name '{name}' previously defined") else: - globals_sym_tab.append(name) + current_globals.append(name) if isinstance(node, ast.FunctionDef) and node.name != "LICENSE": decorators = [ @@ -108,7 +111,7 @@ def globals_processing(tree, module): node.body[0].value, (ast.Constant, ast.Name, ast.Call) ) ): - emit_global(module, node, name) + emit_global(compilation_context.module, node, name) else: raise SyntaxError(f"ERROR: Invalid syntax for {name} global") @@ -137,8 +140,9 @@ def emit_llvm_compiler_used(module: ir.Module, names: list[str]): gv.section = "llvm.metadata" -def globals_list_creation(tree, module: ir.Module): +def globals_list_creation(tree, compilation_context): collected = ["LICENSE"] + module = compilation_context.module for node in tree.body: if isinstance(node, ast.FunctionDef): diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py index dcbfe24b..bd4fe174 100644 --- a/pythonbpf/helper/__init__.py +++ b/pythonbpf/helper/__init__.py @@ -1,5 +1,5 @@ from .helper_registry import HelperHandlerRegistry -from .helper_utils import reset_scratch_pool + from .bpf_helper_handler import ( handle_helper_call, emit_probe_read_kernel_str_call, @@ -28,9 +28,7 @@ def _register_helper_handler(): """Register helper call handler with the expression evaluator""" from pythonbpf.expr.expr_pass import CallHandlerRegistry - def helper_call_handler( - call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab - ): + def helper_call_handler(call, compilation_context, builder, func, local_sym_tab): """Check if call is a helper and handle it""" import ast @@ -39,17 +37,16 @@ def helper_call_handler( if HelperHandlerRegistry.has_handler(call.func.id): return handle_helper_call( call, - module, + compilation_context, builder, func, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) # Check for method calls (e.g., map.lookup()) elif isinstance(call.func, ast.Attribute): method_name = call.func.attr + map_sym_tab = compilation_context.map_sym_tab # Handle: my_map.lookup(key) if isinstance(call.func.value, ast.Name): @@ -58,12 +55,10 @@ def helper_call_handler( if HelperHandlerRegistry.has_handler(method_name): return handle_helper_call( call, - module, + compilation_context, builder, func, local_sym_tab, - map_sym_tab, - structs_sym_tab, ) return None @@ -76,7 +71,6 @@ def helper_call_handler( __all__ = [ "HelperHandlerRegistry", - "reset_scratch_pool", "handle_helper_call", "emit_probe_read_kernel_str_call", "emit_probe_read_kernel_call", diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index e59898f1..73ff7a37 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -50,12 +50,10 @@ class BPFHelperID(Enum): def bpf_ktime_get_ns_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_ktime_get_ns helper function call. @@ -77,12 +75,10 @@ def bpf_ktime_get_ns_emitter( def bpf_get_current_cgroup_id( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_current_cgroup_id helper function call. @@ -104,12 +100,10 @@ def bpf_get_current_cgroup_id( def bpf_map_lookup_elem_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_lookup_elem helper function call. @@ -119,7 +113,7 @@ def bpf_map_lookup_elem_emitter( f"Map lookup expects exactly one argument (key), got {len(call.args)}" ) key_ptr = get_or_create_ptr_from_arg( - func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab + func, compilation_context, call.args[0], builder, local_sym_tab ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) @@ -147,12 +141,10 @@ def bpf_map_lookup_elem_emitter( def bpf_printk_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """Emit LLVM IR for bpf_printk helper function call.""" if not hasattr(func, "_fmt_counter"): @@ -165,16 +157,18 @@ def bpf_printk_emitter( if isinstance(call.args[0], ast.JoinedStr): args = handle_fstring_print( call.args[0], - module, + compilation_context.module, builder, func, local_sym_tab, - struct_sym_tab, + compilation_context.structs_sym_tab, ) elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str): # TODO: We are only supporting single arguments for now. # In case of multiple args, the first one will be taken. - args = simple_string_print(call.args[0].value, module, builder, func) + args = simple_string_print( + call.args[0].value, compilation_context.module, builder, func + ) else: raise NotImplementedError( "Only simple strings or f-strings are supported in bpf_printk." @@ -203,12 +197,10 @@ def bpf_printk_emitter( def bpf_map_update_elem_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_update_elem helper function call. @@ -224,10 +216,10 @@ def bpf_map_update_elem_emitter( flags_arg = call.args[2] if len(call.args) > 2 else None key_ptr = get_or_create_ptr_from_arg( - func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab + func, compilation_context, key_arg, builder, local_sym_tab ) value_ptr = get_or_create_ptr_from_arg( - func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab + func, compilation_context, value_arg, builder, local_sym_tab ) flags_val = get_flags_val(flags_arg, builder, local_sym_tab) @@ -262,12 +254,10 @@ def bpf_map_update_elem_emitter( def bpf_map_delete_elem_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_delete_elem helper function call. @@ -278,7 +268,7 @@ def bpf_map_delete_elem_emitter( f"Map delete expects exactly one argument (key), got {len(call.args)}" ) key_ptr = get_or_create_ptr_from_arg( - func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab + func, compilation_context, call.args[0], builder, local_sym_tab ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) @@ -306,12 +296,10 @@ def bpf_map_delete_elem_emitter( def bpf_get_current_comm_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_current_comm helper function call. @@ -327,7 +315,7 @@ def bpf_get_current_comm_emitter( # Extract buffer pointer and size buf_ptr, buf_size = get_buffer_ptr_and_size( - buf_arg, builder, local_sym_tab, struct_sym_tab + buf_arg, builder, local_sym_tab, compilation_context ) # Validate it's a char array @@ -367,12 +355,10 @@ def bpf_get_current_comm_emitter( def bpf_get_current_pid_tgid_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_current_pid_tgid helper function call. @@ -394,12 +380,10 @@ def bpf_get_current_pid_tgid_emitter( def bpf_perf_event_output_handler( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_perf_event_output helper function call. @@ -412,7 +396,9 @@ def bpf_perf_event_output_handler( data_arg = call.args[0] ctx_ptr = func.args[0] # First argument to the function is ctx - data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab) + data_ptr, size_val = get_data_ptr_and_size( + data_arg, local_sym_tab, compilation_context.structs_sym_tab + ) # BPF_F_CURRENT_CPU is -1 in 32 bit flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF) @@ -445,12 +431,10 @@ def bpf_perf_event_output_handler( def bpf_ringbuf_output_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_ringbuf_output helper function call. @@ -461,7 +445,9 @@ def bpf_ringbuf_output_emitter( f"Ringbuf output expects exactly one argument, got {len(call.args)}" ) data_arg = call.args[0] - data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab) + data_ptr, size_val = get_data_ptr_and_size( + data_arg, local_sym_tab, compilation_context.structs_sym_tab + ) flags_val = ir.Constant(ir.IntType(64), 0) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) @@ -496,38 +482,32 @@ def bpf_ringbuf_output_emitter( def handle_output_helper( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Route output helper to the appropriate emitter based on map type. """ - match map_sym_tab[map_ptr.name].type: + match compilation_context.map_sym_tab[map_ptr.name].type: case BPFMapType.PERF_EVENT_ARRAY: return bpf_perf_event_output_handler( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab, - struct_sym_tab, - map_sym_tab, ) case BPFMapType.RINGBUF: return bpf_ringbuf_output_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab, - struct_sym_tab, - map_sym_tab, ) case _: logger.error("Unsupported map type for output helper.") @@ -572,12 +552,10 @@ def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr): def bpf_probe_read_kernel_str_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """Emit LLVM IR for bpf_probe_read_kernel_str helper.""" @@ -588,12 +566,12 @@ def bpf_probe_read_kernel_str_emitter( # Get destination buffer (char array -> i8*) dst_ptr, dst_size = get_or_create_ptr_from_arg( - func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab + func, compilation_context, call.args[0], builder, local_sym_tab ) # Get source pointer (evaluate expression) src_ptr, src_type = get_ptr_from_arg( - call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab + call.args[1], func, compilation_context, builder, local_sym_tab ) # Emit the helper call @@ -641,12 +619,10 @@ def emit_probe_read_kernel_call(builder, dst_ptr, dst_size, src_ptr): def bpf_probe_read_kernel_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """Emit LLVM IR for bpf_probe_read_kernel helper.""" @@ -657,12 +633,12 @@ def bpf_probe_read_kernel_emitter( # Get destination buffer (char array -> i8*) dst_ptr, dst_size = get_or_create_ptr_from_arg( - func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab + func, compilation_context, call.args[0], builder, local_sym_tab ) # Get source pointer (evaluate expression) src_ptr, src_type = get_ptr_from_arg( - call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab + call.args[1], func, compilation_context, builder, local_sym_tab ) # Emit the helper call @@ -680,12 +656,10 @@ def bpf_probe_read_kernel_emitter( def bpf_get_prandom_u32_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_prandom_u32 helper function call. @@ -710,12 +684,10 @@ def bpf_get_prandom_u32_emitter( def bpf_probe_read_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_probe_read helper function @@ -726,31 +698,25 @@ def bpf_probe_read_emitter( return dst_ptr = get_or_create_ptr_from_arg( func, - module, + compilation_context, call.args[0], builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, ir.IntType(8), ) size_val = get_int_value_from_arg( call.args[1], func, - module, + compilation_context, builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, ) src_ptr = get_or_create_ptr_from_arg( func, - module, + compilation_context, call.args[2], builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, ir.IntType(8), ) fn_type = ir.FunctionType( @@ -783,12 +749,10 @@ def bpf_probe_read_emitter( def bpf_get_smp_processor_id_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_smp_processor_id helper function call. @@ -810,12 +774,10 @@ def bpf_get_smp_processor_id_emitter( def bpf_get_current_uid_gid_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_current_uid_gid helper function call. @@ -846,12 +808,10 @@ def bpf_get_current_uid_gid_emitter( def bpf_skb_store_bytes_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_skb_store_bytes helper function call. @@ -875,30 +835,24 @@ def bpf_skb_store_bytes_emitter( offset_val = get_int_value_from_arg( call.args[0], func, - module, + compilation_context, builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, ) from_ptr = get_or_create_ptr_from_arg( func, - module, + compilation_context, call.args[1], builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, args_signature[2], ) len_val = get_int_value_from_arg( call.args[2], func, - module, + compilation_context, builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, ) if len(call.args) == 4: flags_val = get_flags_val(call.args[3], builder, local_sym_tab) @@ -940,12 +894,10 @@ def bpf_skb_store_bytes_emitter( def bpf_ringbuf_reserve_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_ringbuf_reserve helper function call. @@ -960,11 +912,9 @@ def bpf_ringbuf_reserve_emitter( size_val = get_int_value_from_arg( call.args[0], func, - module, + compilation_context, builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) @@ -991,12 +941,10 @@ def bpf_ringbuf_reserve_emitter( def bpf_ringbuf_submit_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_ringbuf_submit helper function call. @@ -1013,12 +961,10 @@ def bpf_ringbuf_submit_emitter( data_ptr = get_or_create_ptr_from_arg( func, - module, + compilation_context, data_arg, builder, local_sym_tab, - map_sym_tab, - struct_sym_tab, ir.PointerType(ir.IntType(8)), ) @@ -1050,12 +996,10 @@ def bpf_ringbuf_submit_emitter( def bpf_get_stack_emitter( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab=None, - struct_sym_tab=None, - map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_stack helper function call. @@ -1068,7 +1012,7 @@ def bpf_get_stack_emitter( buf_arg = call.args[0] flags_arg = call.args[1] if len(call.args) == 2 else None buf_ptr, buf_size = get_buffer_ptr_and_size( - buf_arg, builder, local_sym_tab, struct_sym_tab + buf_arg, builder, local_sym_tab, compilation_context ) flags_val = get_flags_val(flags_arg, builder, local_sym_tab) if isinstance(flags_val, int): @@ -1098,12 +1042,10 @@ def bpf_get_stack_emitter( def handle_helper_call( call, - module, + compilation_context, builder, func, local_sym_tab=None, - map_sym_tab=None, - struct_sym_tab=None, ): """Process a BPF helper function call and emit the appropriate LLVM IR.""" @@ -1117,14 +1059,14 @@ def invoke_helper(method_name, map_ptr=None): return handler( call, map_ptr, - module, + compilation_context, builder, func, local_sym_tab, - struct_sym_tab, - map_sym_tab, ) + map_sym_tab = compilation_context.map_sym_tab + # Handle direct function calls (e.g., print(), ktime()) if isinstance(call.func, ast.Name): return invoke_helper(call.func.id) diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index d6a76e02..fa26a2b3 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -3,7 +3,6 @@ from llvmlite import ir from pythonbpf.expr import ( - get_operand_value, eval_expr, access_struct_field, ) @@ -11,56 +10,38 @@ logger = logging.getLogger(__name__) -class ScratchPoolManager: - """Manage the temporary helper variables in local_sym_tab""" +# NOTE: ScratchPoolManager is now in context.py - def __init__(self): - self._counters = {} - @property - def counter(self): - return sum(self._counters.values()) - - def reset(self): - self._counters.clear() - logger.debug("Scratch pool counter reset to 0") - - def _get_type_name(self, ir_type): - if isinstance(ir_type, ir.PointerType): - return "ptr" - elif isinstance(ir_type, ir.IntType): - return f"i{ir_type.width}" - elif isinstance(ir_type, ir.ArrayType): - return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]" - else: - return str(ir_type).replace(" ", "") - - def get_next_temp(self, local_sym_tab, expected_type=None): - # Default to i64 if no expected type provided - type_name = self._get_type_name(expected_type) if expected_type else "i64" - if type_name not in self._counters: - self._counters[type_name] = 0 - - counter = self._counters[type_name] - temp_name = f"__helper_temp_{type_name}_{counter}" - self._counters[type_name] += 1 - - if temp_name not in local_sym_tab: - raise ValueError( - f"Scratch pool exhausted or inadequate: {temp_name}. " - f"Type: {type_name} Counter: {counter}" - ) - - logger.debug(f"Using {temp_name} for type {type_name}") - return local_sym_tab[temp_name].var, temp_name - - -_temp_pool_manager = ScratchPoolManager() # Singleton instance +def get_ptr_from_arg(arg, compilation_context, builder, local_sym_tab): + """Helper to get a pointer value from an argument.""" + # This is a bit duplicative of logic in eval_expr but simplified for helpers + # We might need to handle more cases here or defer to eval_expr + # Simple check for name + if isinstance(arg, ast.Name): + if arg.id in local_sym_tab: + sym = local_sym_tab[arg.id] + if isinstance(sym.ir_type, ir.PointerType): + return builder.load(sym.var) + # If it's an array/struct we might need GEP depending on how it was allocated + # For now assume load returns the pointer/value + return builder.load(sym.var) + + # Use eval_expr for general case + val = eval_expr( + None, + compilation_context.module, + builder, + arg, + local_sym_tab, + compilation_context.map_sym_tab, + compilation_context.structs_sym_tab, + ) + if val and isinstance(val[0].type, ir.PointerType): + return val[0] -def reset_scratch_pool(): - """Reset the scratch pool counter""" - _temp_pool_manager.reset() + return None # ============================================================================ @@ -75,11 +56,15 @@ def get_var_ptr_from_name(var_name, local_sym_tab): raise ValueError(f"Variable '{var_name}' not found in local symbol table") -def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): +def create_int_constant_ptr( + value, builder, compilation_context, local_sym_tab, int_width=64 +): """Create a pointer to an integer constant.""" int_type = ir.IntType(int_width) - ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type) + ptr, temp_name = compilation_context.scratch_pool.get_next_temp( + local_sym_tab, int_type + ) logger.info(f"Using temp variable '{temp_name}' for int constant {value}") const_val = ir.Constant(int_type, value) builder.store(const_val, ptr) @@ -88,12 +73,10 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): def get_or_create_ptr_from_arg( func, - module, + compilation_context, arg, builder, local_sym_tab, - map_sym_tab, - struct_sym_tab=None, expected_type=None, ): """Extract or create pointer from the call arguments.""" @@ -102,16 +85,22 @@ def get_or_create_ptr_from_arg( sz = None if isinstance(arg, ast.Name): # Stack space is already allocated - ptr = get_var_ptr_from_name(arg.id, local_sym_tab) + if arg.id in local_sym_tab: + ptr = local_sym_tab[arg.id].var + else: + raise ValueError(f"Variable '{arg.id}' not found") elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): int_width = 64 # Default to i64 if expected_type and isinstance(expected_type, ir.IntType): int_width = expected_type.width - ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width) + ptr = create_int_constant_ptr( + arg.value, builder, compilation_context, local_sym_tab, int_width + ) elif isinstance(arg, ast.Attribute): # A struct field struct_name = arg.value.id field_name = arg.attr + struct_sym_tab = compilation_context.structs_sym_tab if not local_sym_tab or struct_name not in local_sym_tab: raise ValueError(f"Struct '{struct_name}' not found") @@ -136,7 +125,7 @@ def get_or_create_ptr_from_arg( and field_type.element.width == 8 ): ptr, sz = get_char_array_ptr_and_size( - arg, builder, local_sym_tab, struct_sym_tab, func + arg, builder, local_sym_tab, compilation_context, func ) if not ptr: raise ValueError("Failed to get char array pointer from struct field") @@ -146,13 +135,15 @@ def get_or_create_ptr_from_arg( else: # NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop # Evaluate the expression and store the result in a temp variable - val = get_operand_value( - func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab - ) + val = eval_expr(func, compilation_context, builder, arg, local_sym_tab) + if val: + val = val[0] if val is None: raise ValueError("Failed to evaluate expression for helper arg.") - ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type) + ptr, temp_name = compilation_context.scratch_pool.get_next_temp( + local_sym_tab, expected_type + ) logger.info(f"Using temp variable '{temp_name}' for expression result") if ( isinstance(val.type, ir.IntType) @@ -188,8 +179,9 @@ def get_flags_val(arg, builder, local_sym_tab): ) -def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab): +def get_data_ptr_and_size(data_arg, local_sym_tab, compilation_context): """Extract data pointer and size information for perf event output.""" + struct_sym_tab = compilation_context.structs_sym_tab if isinstance(data_arg, ast.Name): data_name = data_arg.id if local_sym_tab and data_name in local_sym_tab: @@ -213,8 +205,9 @@ def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab): ) -def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): +def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, compilation_context): """Extract buffer pointer and size from either a struct field or variable.""" + struct_sym_tab = compilation_context.structs_sym_tab # Case 1: Struct field (obj.field) if isinstance(buf_arg, ast.Attribute): @@ -268,9 +261,10 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): def get_char_array_ptr_and_size( - buf_arg, builder, local_sym_tab, struct_sym_tab, func=None + buf_arg, builder, local_sym_tab, compilation_context, func=None ): """Get pointer to char array and its size.""" + struct_sym_tab = compilation_context.structs_sym_tab # Struct field: obj.field if isinstance(buf_arg, ast.Attribute) and isinstance(buf_arg.value, ast.Name): @@ -351,34 +345,10 @@ def _is_char_array(ir_type): ) -def get_ptr_from_arg( - arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab -): - """Evaluate argument and return pointer value""" - - result = eval_expr( - func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab - ) - - if not result: - raise ValueError("Failed to evaluate argument") - - val, val_type = result - - if not isinstance(val_type, ir.PointerType): - raise ValueError(f"Expected pointer type, got {val_type}") - - return val, val_type - - -def get_int_value_from_arg( - arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab -): +def get_int_value_from_arg(arg, func, compilation_context, builder, local_sym_tab): """Evaluate argument and return integer value""" - result = eval_expr( - func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab - ) + result = eval_expr(func, compilation_context, builder, arg, local_sym_tab) if not result: raise ValueError("Failed to evaluate argument") diff --git a/pythonbpf/license_pass.py b/pythonbpf/license_pass.py index c3d3dc0c..425b1979 100644 --- a/pythonbpf/license_pass.py +++ b/pythonbpf/license_pass.py @@ -23,7 +23,7 @@ def emit_license(module: ir.Module, license_str: str): return gvar -def license_processing(tree, module): +def license_processing(tree, compilation_context): """Process the LICENSE function decorated with @bpf and @bpfglobal and return the section name""" count = 0 for node in tree.body: @@ -42,12 +42,14 @@ def license_processing(tree, module): and isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str) ): - emit_license(module, node.body[0].value.value) + emit_license( + compilation_context.module, node.body[0].value.value + ) return "LICENSE" else: - logger.info("ERROR: LICENSE() must return a string literal") - return None + raise SyntaxError( + "ERROR: LICENSE() must return a string literal" + ) else: - logger.info("ERROR: LICENSE already defined") - return None + raise SyntaxError("ERROR: Multiple LICENSE globals defined") return None diff --git a/pythonbpf/maps/maps_pass.py b/pythonbpf/maps/maps_pass.py index 2d0beb9d..e02efb1c 100644 --- a/pythonbpf/maps/maps_pass.py +++ b/pythonbpf/maps/maps_pass.py @@ -12,14 +12,14 @@ logger: Logger = logging.getLogger(__name__) -def maps_proc(tree, module, chunks, structs_sym_tab): +def maps_proc(tree, compilation_context, chunks): """Process all functions decorated with @map to find BPF maps""" - map_sym_tab = {} + map_sym_tab = compilation_context.map_sym_tab for func_node in chunks: if is_map(func_node): logger.info(f"Found BPF map: {func_node.name}") map_sym_tab[func_node.name] = process_bpf_map( - func_node, module, structs_sym_tab + func_node, compilation_context ) return map_sym_tab @@ -51,11 +51,11 @@ def create_bpf_map(module, map_name, map_params): return MapSymbol(type=map_params["type"], sym=map_global, params=map_params) -def _parse_map_params(rval, expected_args=None): +def _parse_map_params(rval, compilation_context, expected_args=None): """Parse map parameters from call arguments and keywords.""" params = {} - handler = VmlinuxHandlerRegistry.get_handler() + handler = compilation_context.vmlinux_handler # Parse positional arguments if expected_args: for i, arg_name in enumerate(expected_args): @@ -83,12 +83,23 @@ def _get_vmlinux_enum(handler, name): if handler and handler.is_vmlinux_enum(name): return handler.get_vmlinux_enum_value(name) + # Fallback to VmlinuxHandlerRegistry if handler invalid + # This is for backward compatibility or if refactoring isn't complete + if ( + VmlinuxHandlerRegistry.get_handler() + and VmlinuxHandlerRegistry.get_handler().is_vmlinux_enum(name) + ): + return VmlinuxHandlerRegistry.get_handler().get_vmlinux_enum_value(name) + return None + @MapProcessorRegistry.register("RingBuffer") -def process_ringbuf_map(map_name, rval, module, structs_sym_tab): +def process_ringbuf_map(map_name, rval, compilation_context): """Process a BPF_RINGBUF map declaration""" logger.info(f"Processing Ringbuf: {map_name}") - map_params = _parse_map_params(rval, expected_args=["max_entries"]) + map_params = _parse_map_params( + rval, compilation_context, expected_args=["max_entries"] + ) map_params["type"] = BPFMapType.RINGBUF # NOTE: constraints borrowed from https://docs.ebpf.io/linux/map-type/BPF_MAP_TYPE_RINGBUF/ @@ -104,42 +115,62 @@ def process_ringbuf_map(map_name, rval, module, structs_sym_tab): logger.info(f"Ringbuf map parameters: {map_params}") - map_global = create_bpf_map(module, map_name, map_params) + map_global = create_bpf_map(compilation_context.module, map_name, map_params) create_ringbuf_debug_info( - module, map_global.sym, map_name, map_params, structs_sym_tab + compilation_context.module, + map_global.sym, + map_name, + map_params, + compilation_context.structs_sym_tab, ) return map_global @MapProcessorRegistry.register("HashMap") -def process_hash_map(map_name, rval, module, structs_sym_tab): +def process_hash_map(map_name, rval, compilation_context): """Process a BPF_HASH map declaration""" logger.info(f"Processing HashMap: {map_name}") - map_params = _parse_map_params(rval, expected_args=["key", "value", "max_entries"]) + map_params = _parse_map_params( + rval, compilation_context, expected_args=["key", "value", "max_entries"] + ) map_params["type"] = BPFMapType.HASH logger.info(f"Map parameters: {map_params}") - map_global = create_bpf_map(module, map_name, map_params) + map_global = create_bpf_map(compilation_context.module, map_name, map_params) # Generate debug info for BTF - create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab) + create_map_debug_info( + compilation_context.module, + map_global.sym, + map_name, + map_params, + compilation_context.structs_sym_tab, + ) return map_global @MapProcessorRegistry.register("PerfEventArray") -def process_perf_event_map(map_name, rval, module, structs_sym_tab): +def process_perf_event_map(map_name, rval, compilation_context): """Process a BPF_PERF_EVENT_ARRAY map declaration""" logger.info(f"Processing PerfEventArray: {map_name}") - map_params = _parse_map_params(rval, expected_args=["key_size", "value_size"]) + map_params = _parse_map_params( + rval, compilation_context, expected_args=["key_size", "value_size"] + ) map_params["type"] = BPFMapType.PERF_EVENT_ARRAY logger.info(f"Map parameters: {map_params}") - map_global = create_bpf_map(module, map_name, map_params) + map_global = create_bpf_map(compilation_context.module, map_name, map_params) # Generate debug info for BTF - create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab) + create_map_debug_info( + compilation_context.module, + map_global.sym, + map_name, + map_params, + compilation_context.structs_sym_tab, + ) return map_global -def process_bpf_map(func_node, module, structs_sym_tab): +def process_bpf_map(func_node, compilation_context): """Process a BPF map (a function decorated with @map)""" map_name = func_node.name logger.info(f"Processing BPF map: {map_name}") @@ -158,9 +189,9 @@ def process_bpf_map(func_node, module, structs_sym_tab): if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name): handler = MapProcessorRegistry.get_processor(rval.func.id) if handler: - return handler(map_name, rval, module, structs_sym_tab) + return handler(map_name, rval, compilation_context) else: logger.warning(f"Unknown map type {rval.func.id}, defaulting to HashMap") - return process_hash_map(map_name, rval, module) + return process_hash_map(map_name, rval, compilation_context) else: raise ValueError("Function under @map must return a map") diff --git a/pythonbpf/structs/structs_pass.py b/pythonbpf/structs/structs_pass.py index d79fe0e8..fe960ee7 100644 --- a/pythonbpf/structs/structs_pass.py +++ b/pythonbpf/structs/structs_pass.py @@ -14,14 +14,17 @@ # Shall we just int64, int32 and uint32 similarly? -def structs_proc(tree, module, chunks): +def structs_proc(tree, compilation_context, chunks): """Process all class definitions to find BPF structs""" - structs_sym_tab = {} + # Use the context's symbol table + structs_sym_tab = compilation_context.structs_sym_tab + for cls_node in chunks: if is_bpf_struct(cls_node): logger.info(f"Found BPF struct: {cls_node.name}") - struct_info = process_bpf_struct(cls_node, module) + struct_info = process_bpf_struct(cls_node, compilation_context) structs_sym_tab[cls_node.name] = struct_info + return structs_sym_tab @@ -32,7 +35,7 @@ def is_bpf_struct(cls_node): ) -def process_bpf_struct(cls_node, module): +def process_bpf_struct(cls_node, compilation_context): """Process a single BPF struct definition""" fields = parse_struct_fields(cls_node)