Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions pythonbpf/allocation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def create_targets_and_rvals(stmt):
return stmt.targets, [stmt.value]


def handle_assign_allocation(
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
):
def handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab):
"""Handle memory allocation for assignment statements."""

logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}")
Expand Down Expand Up @@ -59,7 +57,7 @@ def handle_assign_allocation(
# Determine type and allocate based on rval
if isinstance(rval, ast.Call):
_allocate_for_call(
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
builder, var_name, rval, local_sym_tab, compilation_context
)
elif isinstance(rval, ast.Constant):
_allocate_for_constant(builder, var_name, rval, local_sym_tab)
Expand All @@ -71,18 +69,17 @@ def handle_assign_allocation(
elif isinstance(rval, ast.Attribute):
# Struct field-to-variable assignment (a = dat.fld)
_allocate_for_attribute(
builder, var_name, rval, local_sym_tab, structs_sym_tab
builder, var_name, rval, local_sym_tab, compilation_context
)
else:
logger.warning(
f"Unsupported assignment value type for {var_name}: {type(rval).__name__}"
)


def _allocate_for_call(
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
def _allocate_for_call(builder, var_name, rval, local_sym_tab, compilation_context):
"""Allocate memory for variable assigned from a call."""
structs_sym_tab = compilation_context.structs_sym_tab

if isinstance(rval.func, ast.Name):
call_type = rval.func.id
Expand Down Expand Up @@ -149,17 +146,19 @@ def _allocate_for_call(
elif isinstance(rval.func, ast.Attribute):
# Map method calls - need double allocation for ptr handling
_allocate_for_map_method(
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
builder, var_name, rval, local_sym_tab, compilation_context
)

else:
logger.warning(f"Unsupported call function type for {var_name}")


def _allocate_for_map_method(
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
builder, var_name, rval, local_sym_tab, compilation_context
):
"""Allocate memory for variable assigned from map method (double alloc)."""
map_sym_tab = compilation_context.map_sym_tab
structs_sym_tab = compilation_context.structs_sym_tab

map_name = rval.func.value.id
method_name = rval.func.attr
Expand Down Expand Up @@ -299,6 +298,15 @@ def allocate_temp_pool(builder, max_temps, local_sym_tab):
logger.debug(f"Allocated temp variable: {temp_name}")


def _get_alignment(tmp_type):
"""Return alignment for a given type."""
if isinstance(tmp_type, ir.PointerType):
return 8
elif isinstance(tmp_type, ir.IntType):
return tmp_type.width // 8
return 8


def _allocate_for_name(builder, var_name, rval, local_sym_tab):
"""Allocate memory for variable-to-variable assignment (b = a)."""
source_var = rval.id
Expand All @@ -321,8 +329,22 @@ def _allocate_for_name(builder, var_name, rval, local_sym_tab):
)


def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_tab):
def _allocate_with_type(builder, var_name, ir_type):
"""Allocate memory for a variable with a specific type."""
var = builder.alloca(ir_type, name=var_name)
if isinstance(ir_type, ir.IntType):
var.align = ir_type.width // 8
elif isinstance(ir_type, ir.PointerType):
var.align = 8
return var


def _allocate_for_attribute(
builder, var_name, rval, local_sym_tab, compilation_context
):
"""Allocate memory for struct field-to-variable assignment (a = dat.fld)."""
structs_sym_tab = compilation_context.structs_sym_tab

if not isinstance(rval.value, ast.Name):
logger.warning(f"Complex attribute access not supported for {var_name}")
return
Expand Down Expand Up @@ -455,20 +477,3 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_
logger.info(
f"Pre-allocated {var_name} from {struct_var}.{field_name} with type {alloc_type}"
)


def _allocate_with_type(builder, var_name, ir_type):
"""Allocate variable with appropriate alignment for type."""
var = builder.alloca(ir_type, name=var_name)
var.align = _get_alignment(ir_type)
return var


def _get_alignment(ir_type):
"""Get appropriate alignment for IR type."""
if isinstance(ir_type, ir.IntType):
return ir_type.width // 8
elif isinstance(ir_type, ir.ArrayType) and isinstance(ir_type.element, ir.IntType):
return ir_type.element.width // 8
else:
return 8 # Default: pointer size
24 changes: 9 additions & 15 deletions pythonbpf/assign_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def handle_struct_field_assignment(
func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, builder, target, rval, local_sym_tab
):
"""Handle struct field assignment (obj.field = value)."""

Expand All @@ -24,7 +24,7 @@ def handle_struct_field_assignment(
return

struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type]
struct_info = compilation_context.structs_sym_tab[struct_type]

if field_name not in struct_info.fields:
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
Expand All @@ -33,9 +33,7 @@ def handle_struct_field_assignment(
# Get field pointer and evaluate value
field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name)
field_type = struct_info.field_type(field_name)
val_result = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab)

if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}.{field_name}")
Expand All @@ -47,14 +45,12 @@ def handle_struct_field_assignment(
if _is_char_array(field_type) and _is_i8_ptr(val_type):
_copy_string_to_char_array(
func,
module,
compilation_context,
builder,
val,
field_ptr,
field_type,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
logger.info(f"Copied string to char array {var_name}.{field_name}")
return
Expand All @@ -66,14 +62,12 @@ def handle_struct_field_assignment(

def _copy_string_to_char_array(
func,
module,
compilation_context,
builder,
src_ptr,
dst_ptr,
array_type,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
):
"""Copy string (i8*) to char array ([N x i8]) using bpf_probe_read_kernel_str"""

Expand Down Expand Up @@ -109,7 +103,7 @@ def _is_i8_ptr(ir_type):


def handle_variable_assignment(
func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, builder, var_name, rval, local_sym_tab
):
"""Handle single named variable assignment."""

Expand All @@ -120,6 +114,8 @@ def handle_variable_assignment(
var_ptr = local_sym_tab[var_name].var
var_type = local_sym_tab[var_name].ir_type

structs_sym_tab = compilation_context.structs_sym_tab

# NOTE: Special case for struct initialization
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
struct_name = rval.func.id
Expand All @@ -142,9 +138,7 @@ def handle_variable_assignment(
logger.info(f"Assigned char array pointer to {var_name}")
return True

val_result = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab)
if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}")
return False
Expand Down
25 changes: 16 additions & 9 deletions pythonbpf/codegen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from llvmlite import ir
from .context import CompilationContext
from .license_pass import license_processing
from .functions import func_proc
from .maps import maps_proc
Expand Down Expand Up @@ -67,9 +68,10 @@ def find_bpf_chunks(tree):
return bpf_functions


def processor(source_code, filename, module):
def processor(source_code, filename, compilation_context):
tree = ast.parse(source_code, filename)
logger.debug(ast.dump(tree, indent=4))
module = compilation_context.module

bpf_chunks = find_bpf_chunks(tree)
for func_node in bpf_chunks:
Expand All @@ -81,15 +83,18 @@ def processor(source_code, filename, module):
if vmlinux_symtab:
handler = VmlinuxHandler.initialize(vmlinux_symtab)
VmlinuxHandlerRegistry.set_handler(handler)
compilation_context.vmlinux_handler = handler

populate_global_symbol_table(tree, module)
license_processing(tree, module)
globals_processing(tree, module)
structs_sym_tab = structs_proc(tree, module, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks, structs_sym_tab)
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)
populate_global_symbol_table(tree, compilation_context)
license_processing(tree, compilation_context)
globals_processing(tree, compilation_context)
structs_sym_tab = structs_proc(tree, compilation_context, bpf_chunks)

globals_list_creation(tree, module)
map_sym_tab = maps_proc(tree, compilation_context, bpf_chunks)

func_proc(tree, compilation_context, bpf_chunks)

globals_list_creation(tree, compilation_context)
return structs_sym_tab, map_sym_tab


Expand All @@ -104,6 +109,8 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
module.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128"
module.triple = "bpf"

compilation_context = CompilationContext(module)

if not hasattr(module, "_debug_compile_unit"):
debug_generator = DebugInfoGenerator(module)
debug_generator.generate_file_metadata(filename, os.path.dirname(filename))
Expand All @@ -116,7 +123,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
True,
)

structs_sym_tab, maps_sym_tab = processor(source, filename, module)
structs_sym_tab, maps_sym_tab = processor(source, filename, compilation_context)

wchar_size = module.add_metadata(
[
Expand Down
82 changes: 82 additions & 0 deletions pythonbpf/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from llvmlite import ir
import logging
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pythonbpf.structs.struct_type import StructType
from pythonbpf.maps.maps_utils import MapSymbol

logger = logging.getLogger(__name__)


class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab"""

def __init__(self):
self._counters = {}

@property
def counter(self):
return sum(self._counters.values())

def reset(self):
self._counters.clear()
logger.debug("Scratch pool counter reset to 0")

def _get_type_name(self, ir_type):
if isinstance(ir_type, ir.PointerType):
return "ptr"
elif isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")

def get_next_temp(self, local_sym_tab, expected_type=None):
# Default to i64 if no expected type provided
type_name = self._get_type_name(expected_type) if expected_type else "i64"
if type_name not in self._counters:
self._counters[type_name] = 0

counter = self._counters[type_name]
temp_name = f"__helper_temp_{type_name}_{counter}"
self._counters[type_name] += 1

if temp_name not in local_sym_tab:
raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Type: {type_name} Counter: {counter}"
)

logger.debug(f"Using {temp_name} for type {type_name}")
return local_sym_tab[temp_name].var, temp_name


class CompilationContext:
"""
Holds the state for a single compilation run.
This replaces global mutable state modules.
"""

def __init__(self, module: ir.Module):
self.module = module

# Symbol tables
self.global_sym_tab: list[ir.GlobalVariable] = []
self.structs_sym_tab: dict[str, "StructType"] = {}
self.map_sym_tab: dict[str, "MapSymbol"] = {}

# Helper management
self.scratch_pool = ScratchPoolManager()

# Vmlinux handling (optional, specialized)
self.vmlinux_handler = None # Can be VmlinuxHandler instance

# Current function context (optional, if needed globally during function processing)
self.current_func = None

def reset(self):
"""Reset state between functions if necessary, though new context per compile is preferred."""
self.scratch_pool.reset()
self.current_func = None
8 changes: 2 additions & 6 deletions pythonbpf/expr/call_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@ def set_handler(cls, handler):
cls._handler = handler

@classmethod
def handle_call(
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
):
def handle_call(cls, call, compilation_context, builder, func, local_sym_tab):
"""Handle a call using the registered handler"""
if cls._handler is None:
return None
return cls._handler(
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
)
return cls._handler(call, compilation_context, builder, func, local_sym_tab)
Loading