Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,24 @@ def build_eval_function(
definition: Union[BinaryStatement, UnaryStatement, StatementGroup],
execution_context: str = "<root>",
) -> Callable[[T], bool]:
if isinstance(definition, BinaryStatement):
# Short-circuit type checks to avoid unnecessary computation/attrlookup.
BinaryStatement_type = BinaryStatement
UnaryStatement_type = UnaryStatement
StatementGroup_type = StatementGroup

if type(definition) is BinaryStatement_type:
return build_binary_statement(definition, execution_context=execution_context)
if isinstance(definition, UnaryStatement):
if type(definition) is UnaryStatement_type:
return build_unary_statement(definition, execution_context=execution_context)
statements_functions = []
for statement_id, statement in enumerate(definition.statements):
statement_execution_context = f"{execution_context}.statements[{statement_id}]"
statements_functions.append(
build_eval_function(
statement, execution_context=statement_execution_context
)
# Only reached for StatementGroup; don't repeat isinstance check for each iteration.
statements = definition.statements
statements_functions = [
build_eval_function(
statement,
execution_context=f"{execution_context}.statements[{statement_id}]",
)
for statement_id, statement in enumerate(statements)
]
return partial(
compound_eval,
statements_functions=statements_functions,
Expand All @@ -87,19 +93,19 @@ def build_binary_statement(
definition: BinaryStatement,
execution_context: str,
) -> Callable[[Dict[str, T]], bool]:
operator = BINARY_OPERATORS[definition.comparator.type]
operator_parameters_names = [
t for t in type(definition.comparator).model_fields if t != TYPE_PARAMETER_NAME
]
operator_parameters = {
a: getattr(definition.comparator, a) for a in operator_parameters_names
}
comparator = definition.comparator
# Avoid repeated lookups and allocations by using list comprehension directly, avoiding `type()` call multiple times
model_fields = type(comparator).model_fields
operator_parameters_names = [t for t in model_fields if t != TYPE_PARAMETER_NAME]
operator_parameters = {a: getattr(comparator, a) for a in operator_parameters_names}
left_operand_builder = create_operand_builder(
definition=definition.left_operand, execution_context=execution_context
)
right_operand_builder = create_operand_builder(
definition=definition.right_operand, execution_context=execution_context
)
operator = BINARY_OPERATORS[comparator.type] # Single dict lookup

return partial(
binary_eval,
left_operand_builder=left_operand_builder,
Expand Down Expand Up @@ -225,16 +231,16 @@ def binary_eval(
def build_unary_statement(
definition: UnaryStatement, execution_context: str
) -> Callable[[Dict[str, T]], bool]:
operator = UNARY_OPERATORS[definition.operator.type]
operator_parameters_names = [
t for t in type(definition.operator).model_fields if t != TYPE_PARAMETER_NAME
]
operator_obj = definition.operator
model_fields = type(operator_obj).model_fields
operator_parameters_names = [t for t in model_fields if t != TYPE_PARAMETER_NAME]
operator_parameters = {
a: getattr(definition.operator, a) for a in operator_parameters_names
a: getattr(operator_obj, a) for a in operator_parameters_names
}
operand_builder = create_operand_builder(
definition=definition.operand, execution_context=execution_context
)
operator = UNARY_OPERATORS[operator_obj.type]
return partial(
unary_eval,
operand_builder=operand_builder,
Expand Down
Loading