diff --git a/examples/eql/result_quantifiers.md b/examples/eql/result_quantifiers.md index 5040392..903f0ae 100644 --- a/examples/eql/result_quantifiers.md +++ b/examples/eql/result_quantifiers.md @@ -28,6 +28,7 @@ from dataclasses import dataclass from typing_extensions import List from krrood.entity_query_language.entity import entity, let, the, Symbol, an +from krrood.entity_query_language.result_quantification_constraint import AtLeast, AtMost, Exactly, Range from krrood.entity_query_language.failures import MultipleSolutionFound, LessThanExpectedNumberOfSolutions, GreaterThanExpectedNumberOfSolutions @@ -94,8 +95,7 @@ You can also bound the number of results within a range using both `at_least` an query = an( entity(body := let(Body, domain=world.bodies)), - at_least=1, - at_most=3, + quantification=Range(AtLeast(1), AtMost(3)) ) print(len(list(query.evaluate()))) # -> 2 @@ -107,7 +107,7 @@ If you want an exact number of results, use `exactly`: query = an( entity(body := let(Body, domain=world.bodies)), - exactly=2, + quantification=Exactly(2), ) print(len(list(query.evaluate()))) # -> 2 @@ -126,7 +126,7 @@ The result count constraints will raise informative exceptions when the number o query = an( entity(body := let(Body, domain=world.bodies)), - at_least=3, + quantification=AtLeast(3), ) try: list(query.evaluate()) @@ -137,7 +137,7 @@ except LessThanExpectedNumberOfSolutions as e: query = an( entity(body := let(Body, domain=world.bodies)), - at_most=1, + quantification=AtMost(1), ) try: list(query.evaluate()) @@ -148,7 +148,7 @@ except GreaterThanExpectedNumberOfSolutions as e: query = an( entity(body := let(Body, domain=world.bodies)), - exactly=1, + quantification=Exactly(1), ) try: list(query.evaluate()) diff --git a/examples/eql/writing_rule_trees.md b/examples/eql/writing_rule_trees.md index d4300b9..b5da65b 100644 --- a/examples/eql/writing_rule_trees.md +++ b/examples/eql/writing_rule_trees.md @@ -121,7 +121,7 @@ from krrood.entity_query_language.predicate import HasType # Declare the variables fixed_connection = let(type_=FixedConnection, domain=world.connections) revolute_connection = let(type_=RevoluteConnection, domain=world.connections) -views = let(type_=View, domain=None) +views = inference(View)() # Define aliases for convenience handle = fixed_connection.child diff --git a/src/krrood/entity_query_language/entity.py b/src/krrood/entity_query_language/entity.py index 5aedee9..4e1d82a 100644 --- a/src/krrood/entity_query_language/entity.py +++ b/src/krrood/entity_query_language/entity.py @@ -1,6 +1,7 @@ from __future__ import annotations from enum import Enum +from typing import Callable from .symbol_graph import SymbolGraph from .utils import is_iterable @@ -32,7 +33,6 @@ Comparator, chained_logic, CanBehaveLikeAVariable, - ResultQuantifier, From, Variable, optimize_or, @@ -41,7 +41,7 @@ Exists, Literal, ) -from .conclusion import Infer +from .result_quantification_constraint import ResultQuantificationConstraint from .predicate import ( Predicate, @@ -63,23 +63,17 @@ def an( entity_: EntityType, - at_least: Optional[int] = None, - at_most: Optional[int] = None, - exactly: Optional[int] = None, + quantification: Optional[ResultQuantificationConstraint] = None, ) -> Union[An[T], T, SymbolicExpression[T]]: """ Select a single element satisfying the given entity description. :param entity_: An entity or a set expression to quantify over. - :param at_least: Optional minimum number of results. - :param at_most: Optional maximum number of results. - :param exactly: Optional exact number of results. + :param quantification: Optional quantification constraint. :return: A quantifier representing "an" element. :rtype: An[T] """ - return select_one_or_select_many_or_an( - An, entity_, _at_least_=at_least, _at_most_=at_most, _exactly_=exactly - ) + return An(entity_, _quantification_constraint_=quantification) a = an @@ -98,37 +92,7 @@ def the( :return: A quantifier representing "an" element. :rtype: The[T] """ - return select_one_or_select_many_or_an(The, entity_) - - -def select_one_or_select_many_or_an( - quantifier: Type[ResultQuantifier], - entity_: EntityType, - **kwargs, -) -> ResultQuantifier[T]: - """ - Selects one or many entities or infers the result based on the provided quantifier - and entity type. This function facilitates creating or managing quantified results - depending on the entity type and additional keyword arguments. - - :param quantifier: A type of ResultQuantifier used to quantify the entity. - :param entity_: The entity or quantifier to be selected or converted to a quantifier. - :param kwargs: Additional keyword arguments for quantifier initialization. - :return: A result quantifier of the provided type, inferred type, or directly the - one provided. - :raises ValueError: If the provided entity is invalid. - """ - if isinstance(entity_, ResultQuantifier): - if isinstance(entity_, quantifier): - return entity_ - - entity_._child_._parent_ = None - return quantifier(entity_._child_, **kwargs) - - if isinstance(entity_, (Entity, SetOf)): - return quantifier(entity_, **kwargs) - - raise ValueError(f"Invalid entity: {entity_}") + return The(entity_) def entity( @@ -191,11 +155,7 @@ def _extract_variables_and_expression( return selected_variables, expression -class DomainKind(Enum): - INFERRED = 1 - - -DomainType = Union[Iterable, TypingLiteral[DomainKind.INFERRED], None] +DomainType = Union[Iterable, None] def let( @@ -214,7 +174,7 @@ def let( which may contain unnecessarily many elements. :param type_: The type of variable. - :param domain: Iterable of potential values for the variable or an INFERRED sentinel (for rules) or None. + :param domain: Iterable of potential values for the variable or None. If None, the domain will be inferred from the SymbolGraph for Symbol types, else should not be evaluated by EQL but by another evaluator (e.g., EQL To SQL converter in Ormatic). :param name: The variable name, only required for pretty printing. @@ -229,7 +189,6 @@ def let( _type_=type_, _domain_source_=domain_source, _name__=name, - _is_inferred_=domain is DomainKind.INFERRED, ) return result @@ -245,9 +204,7 @@ def _get_domain_source_from_domain_and_type_values( :param type_: The type of the variable. :return: The domain source as a From object. """ - if domain is DomainKind.INFERRED: - domain = None - elif is_iterable(domain): + if is_iterable(domain): domain = filter(lambda x: isinstance(x, type_), domain) elif domain is None and issubclass(type_, Symbol): domain = SymbolGraph().get_instances_of_type(type_) @@ -352,7 +309,9 @@ def exists( return Exists(universal_variable, condition) -def inference(type_: Type[T]) -> Union[Variable[T], Type[T]]: +def inference( + type_: Type[T], +) -> Union[Type[T], Callable[[Any], Variable[T]]]: """ This returns a factory function that creates a new variable of the given type and takes keyword arguments for the type constructor. diff --git a/src/krrood/entity_query_language/failures.py b/src/krrood/entity_query_language/failures.py index 1eff6bf..f4b8e3f 100644 --- a/src/krrood/entity_query_language/failures.py +++ b/src/krrood/entity_query_language/failures.py @@ -1,84 +1,101 @@ +""" +This module defines some custom exception types used by the entity_query_language package. +""" + from __future__ import annotations from abc import ABC +from dataclasses import dataclass -""" -Custom exception types used by entity_query_language. -""" from typing_extensions import TYPE_CHECKING, Type +from ..utils import DataclassException + if TYPE_CHECKING: - from .symbolic import SymbolicExpression + from .symbolic import SymbolicExpression, ResultQuantifier -class QuantificationError(Exception, ABC): +@dataclass +class QuantificationNotSatisfiedError(DataclassException, ABC): """ - Represents a custom exception specific to quantification errors. + Represents a custom exception where the quantification constraints are not satisfied. This exception is used to indicate errors related to the quantification of the query results. """ + expression: ResultQuantifier + """ + The result quantifier expression where the error occurred. + """ + expected_number: int + """ + Expected number of solutions (i.e, quantification constraint value). + """ -class GreaterThanExpectedNumberOfSolutions(QuantificationError): + +@dataclass +class GreaterThanExpectedNumberOfSolutions(QuantificationNotSatisfiedError): """ Represents an error when the number of solutions exceeds the expected threshold. """ - def __init__(self, expression: SymbolicExpression, expected_number: int): - super(GreaterThanExpectedNumberOfSolutions, self).__init__( - f"More than {expected_number} solutions found for the expression {expression}." - ) + def __post_init__(self): + self.message = f"More than {self.expected_number} solutions found for the expression {self.expression}." + super().__post_init__() -class LessThanExpectedNumberOfSolutions(QuantificationError): +@dataclass +class LessThanExpectedNumberOfSolutions(QuantificationNotSatisfiedError): """ Represents an error that occurs when the number of solutions found is lower than the expected number. """ - def __init__( - self, expression: SymbolicExpression, expected_number: int, found_number: int - ): - super(LessThanExpectedNumberOfSolutions, self).__init__( - f"Found {found_number} solutions which is less than the expected {expected_number} solutions for" - f" the expression {expression}." + found_number: int + """ + The number of solutions found. + """ + + def __post_init__(self): + self.message = ( + f"Found {self.found_number} solutions which is less than the expected {self.expected_number} " + f"solutions for the expression {self.expression}." ) + super().__post_init__() +@dataclass class MultipleSolutionFound(GreaterThanExpectedNumberOfSolutions): """ Raised when a query unexpectedly yields more than one solution where a single result was expected. """ - def __init__(self, expression: SymbolicExpression): - super(MultipleSolutionFound, self).__init__(expression, 1) + expected_number: int = 1 +@dataclass class NoSolutionFound(LessThanExpectedNumberOfSolutions): """ Raised when a query does not yield any solution. """ - def __init__(self, expression: SymbolicExpression, expected_number: int = 1): - super(NoSolutionFound, self).__init__( - expression, - expected_number, - 0, - ) + expected_number: int = 1 + found_number: int = 0 -class UsageError(Exception): +@dataclass +class UsageError(DataclassException): """ Raised when there is an incorrect usage of the entity query language API. """ - def __init__(self, message: str): - super(UsageError, self).__init__(message) + ... +@dataclass class UnsupportedOperation(UsageError): """ Raised when an operation is not supported by the entity query language API. @@ -87,39 +104,68 @@ class UnsupportedOperation(UsageError): ... +@dataclass class UnsupportedNegation(UnsupportedOperation): """ Raised when negating quantifiers. """ - def __init__(self, operation_type: Type[SymbolicExpression]): - super().__init__( - f"Symbolic NOT operations on {operation_type} types" + operation_type: Type[SymbolicExpression] + """ + The type of the operation that is being negated. + """ + + def __post_init__(self): + self.message = ( + f"Symbolic NOT operations on {self.operation_type} types" f" operands are not allowed, you can negate the conditions instead," f" as negating them is most likely not what you want" f" because it is ambiguous and can be very expensive to compute." f"To Negate Conditions do:" f" `not_(condition)` instead of `not_(an(entity(..., condition)))`." ) + super().__post_init__() -class CardinalitySpecificationError(UsageError): +@dataclass +class QuantificationSpecificationError(UsageError): """ - Raised when the cardinality constraints specified on the query results are invalid or inconsistent. + Raised when the quantification constraints specified on the query results are invalid or inconsistent. """ -class CardinalityConsistencyError(CardinalitySpecificationError): +@dataclass +class QuantificationConsistencyError(QuantificationSpecificationError): """ - Raised when the cardinality constraints specified on the query results are inconsistent. + Raised when the quantification constraints specified on the query results are inconsistent. """ ... -class CardinalityValueError(CardinalityConsistencyError): +@dataclass +class NegativeQuantificationError(QuantificationConsistencyError): """ - Raised when the cardinality constraints specified on the query results are invalid. + Raised when the quantification constraints specified on the query results have a negative value. """ - ... + message: str = f"ResultQuantificationConstraint must be a non-negative integer." + + +@dataclass +class InvalidEntityType(UsageError): + """ + Raised when an invalid entity type is given to the quantification operation. + """ + + invalid_entity_type: Type + """ + The invalid entity type. + """ + + def __post_init__(self): + self.message = ( + f"The entity type {self.invalid_entity_type} is not valid. It must be a subclass of QueryObjectDescriptor class." + f"e.g. Entity, or SetOf" + ) + super().__post_init__() diff --git a/src/krrood/entity_query_language/result_quantification_constraint.py b/src/krrood/entity_query_language/result_quantification_constraint.py new file mode 100644 index 0000000..3f4f865 --- /dev/null +++ b/src/krrood/entity_query_language/result_quantification_constraint.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing_extensions import TYPE_CHECKING + +from .failures import ( + NegativeQuantificationError, + QuantificationConsistencyError, + GreaterThanExpectedNumberOfSolutions, + LessThanExpectedNumberOfSolutions, +) + +if TYPE_CHECKING: + from .symbolic import An, ResultQuantifier + + +@dataclass +class ResultQuantificationConstraint(ABC): + """ + A base class that represents a constraint for quantification. + """ + + @abstractmethod + def assert_satisfaction( + self, number_of_solutions: int, quantifier: ResultQuantifier, done: bool + ) -> None: + """ + Check if the constraint is satisfied, if not, raise a QuantificationNotSatisfiedError exception. + + :param number_of_solutions: The current number of solutions. + :param quantifier: The quantifier expression of the query. + :param done: Whether all results have been found. + :raises: QuantificationNotSatisfiedError: If the constraint is not satisfied. + """ + ... + + @abstractmethod + def __repr__(self): ... + + +@dataclass +class SingleValueQuantificationConstraint(ResultQuantificationConstraint, ABC): + """ + A class that represents a single value constraint on the result quantification. + """ + + value: int + """ + The exact value of the constraint. + """ + + def __post_init__(self): + if self.value < 0: + raise NegativeQuantificationError() + + +@dataclass +class Exactly(SingleValueQuantificationConstraint): + """ + A class that represents an exact constraint on the result quantification. + """ + + def __repr__(self): + return f"n=={self.value}" + + def assert_satisfaction( + self, number_of_solutions: int, quantifier: ResultQuantifier, done: bool + ) -> None: + if number_of_solutions > self.value: + raise GreaterThanExpectedNumberOfSolutions(quantifier, self.value) + elif done and number_of_solutions < self.value: + raise LessThanExpectedNumberOfSolutions( + quantifier, self.value, number_of_solutions + ) + + +@dataclass +class AtLeast(SingleValueQuantificationConstraint): + """ + A class that specifies a minimum number of results as a quantification constraint. + """ + + def __repr__(self): + return f"n>={self.value}" + + def assert_satisfaction( + self, number_of_solutions: int, quantifier: ResultQuantifier, done: bool + ) -> None: + if done and number_of_solutions < self.value: + raise LessThanExpectedNumberOfSolutions( + quantifier, self.value, number_of_solutions + ) + + +@dataclass +class AtMost(SingleValueQuantificationConstraint): + """ + A class that specifies a maximum number of results as a quantification constraint. + """ + + def __repr__(self): + return f"n<={self.value}" + + def assert_satisfaction( + self, number_of_solutions: int, quantifier: ResultQuantifier, done: bool + ) -> None: + if number_of_solutions > self.value: + raise GreaterThanExpectedNumberOfSolutions(quantifier, self.value) + + +@dataclass +class Range(ResultQuantificationConstraint): + """ + A class that represents a range constraint on the result quantification. + """ + + at_least: AtLeast + """ + The minimum value of the range. + """ + at_most: AtMost + """ + The maximum value of the range. + """ + + def __post_init__(self): + """ + Validate quantification constraints are consistent. + """ + if self.at_most.value < self.at_least.value: + raise QuantificationConsistencyError( + message=f"at_most {self.at_most} cannot be less than at_least {self.at_least}." + ) + + def assert_satisfaction( + self, number_of_solutions: int, quantifier: ResultQuantifier, done: bool + ) -> None: + self.at_least.assert_satisfaction(number_of_solutions, quantifier, done) + self.at_most.assert_satisfaction(number_of_solutions, quantifier, done) + + def __repr__(self): + return f"{self.at_least}<=n<={self.at_most}" diff --git a/src/krrood/entity_query_language/symbolic.py b/src/krrood/entity_query_language/symbolic.py index bee3106..e3b534c 100644 --- a/src/krrood/entity_query_language/symbolic.py +++ b/src/krrood/entity_query_language/symbolic.py @@ -45,10 +45,16 @@ UnsupportedNegation, GreaterThanExpectedNumberOfSolutions, LessThanExpectedNumberOfSolutions, - CardinalityConsistencyError, - CardinalityValueError, + InvalidEntityType, ) from .hashed_data import HashedValue, HashedIterable, T +from .result_quantification_constraint import ( + ResultQuantificationConstraint, + Exactly, + AtLeast, + AtMost, + Range, +) from .rxnode import RWXNode, ColorLegend from .symbol_graph import SymbolGraph from .utils import IDGenerator, is_iterable, generate_combinations @@ -119,6 +125,16 @@ def __getitem__(self, item): def __setitem__(self, key, value): self.bindings[key] = value + def __hash__(self): + return id(self) + + def __eq__(self, other): + return ( + self.bindings == other.bindings + and self.is_true == other.is_true + and self.operand == other.operand + ) + @dataclass(eq=False) class SymbolicExpression(Generic[T], ABC): @@ -460,11 +476,11 @@ class ResultQuantifier(CanBehaveLikeAVariable[T], ABC): """ _child_: QueryObjectDescriptor[T] - _at_least_: Optional[int] = None - _at_most_: Optional[int] = None - _exactly_: Optional[int] = None + _quantification_constraint_: Optional[ResultQuantificationConstraint] = None def __post_init__(self): + if not isinstance(self._child_, QueryObjectDescriptor): + raise InvalidEntityType(type(self._child_)) super().__post_init__() self._var_ = ( self._child_._var_ @@ -472,7 +488,6 @@ def __post_init__(self): else None ) self._node_.wrap_subtree = True - self._validate_cardinality_constraints_() @cached_property def _type_(self): @@ -495,92 +510,36 @@ def evaluate( result_count = 0 for result in map(self._process_result_, filter(lambda r: r.is_true, results)): result_count += 1 - self._assert_less_than_upper_limit_(result_count) + self._assert_satisfaction_of_quantification_constraints_( + result_count, done=False + ) yield result - self._assert_more_than_lower_limit_(result_count) + self._assert_satisfaction_of_quantification_constraints_( + result_count, done=True + ) self._reset_cache_() - def _validate_cardinality_constraints_(self): - """ - Validate cardinality constraints are consistent and non-negative. + def _assert_satisfaction_of_quantification_constraints_( + self, result_count: int, done: bool + ): """ - if self._exactly_ and (self._at_least_ or self._at_most_): - raise CardinalityConsistencyError( - f"exactly is specified, but either at_least or at_most is also specified," - f"cannot specify both." - ) - if ( - (self._at_least_ and self._at_least_ < 0) - or (self._at_most_ and self._at_most_ < 0) - or (self._exactly_ and self._exactly_ < 0) - ): - raise CardinalityValueError( - f"at_least, at_most, and exactly must be non-negative integers." - ) - if self._at_most_ and self._at_least_ and self._at_most_ < self._at_least_: - raise CardinalityValueError( - f"at_most {self._at_most_} cannot be less than at_least {self._at_least_}." - ) + Assert the satisfaction of quantification constraints. - @cached_property - def _upper_limit_(self) -> Optional[int]: + :param result_count: The current count of results + :param done: Whether all results have been processed + :raises QuantificationNotSatisfiedError: If the quantification constraints are not satisfied. """ - :return: The upper limit of the number of results if exists. - """ - if self._exactly_: - return self._exactly_ - elif self._at_most_: - return self._at_most_ - else: - return None - - @cached_property - def _lower_limit_(self) -> Optional[int]: - """ - :return: The lower limit of the number of results if exists. - """ - if self._exactly_: - return self._exactly_ - elif self._at_least_: - return self._at_least_ - else: - return None + if self._quantification_constraint_: + self._quantification_constraint_.assert_satisfaction( + result_count, self, done + ) def __repr__(self): name = f"{self.__class__.__name__}" - if self._at_least_ or self._at_most_ or self._exactly_: - name += "(" - if self._at_least_ and not self._at_most_: - name += f"n>={self._at_least_})" - elif self._at_most_ and not self._at_least_: - name += f"n<={self._at_most_})" - elif self._at_least_ and self._at_most_: - name += f"{self._at_least_}<=n<={self._at_most_})" - elif self._exactly_: - name += f"n={self._exactly_})" + if self._quantification_constraint_: + name += f"({self._quantification_constraint_})" return name - def _assert_less_than_upper_limit_(self, count: int): - """ - Assert that the count is less than the upper limit. - - :param count: - :raises GreaterThanExpectedNumberOfSolutions: If the count exceeds the upper limit. - """ - if self._upper_limit_ and count > self._upper_limit_: - raise GreaterThanExpectedNumberOfSolutions(self, self._upper_limit_) - - def _assert_more_than_lower_limit_(self, count: int): - """ - Assert that the count is more than the lower limit. - - :param count: The current count. - :raises LessThanExpectedNumberOfSolutions: If the count is less than the lower limit. - :raises NoSolutionFound: If no solution is found. - """ - if self._lower_limit_ and count < self._lower_limit_: - raise LessThanExpectedNumberOfSolutions(self, self._lower_limit_, count) - def _evaluate__( self, sources: Optional[Dict[int, HashedValue]] = None, @@ -702,9 +661,9 @@ class The(ResultQuantifier[T]): Quantifier that expects exactly one result; raises MultipleSolutionFound if more. """ - _exactly_: int = field(init=False, default=1) - _at_least_: int = field(init=False, default=None) - _at_most_: int = field(init=False, default=None) + _quantification_constraint_: ResultQuantificationConstraint = field( + init=False, default_factory=lambda: Exactly(1) + ) def evaluate( self, @@ -764,32 +723,68 @@ def _evaluate__( yield OperationResult(sources, self._is_false_, self) for values in self.get_constrained_values(sources): values = self.update_data_from_child(values) - if self.any_selected_inferred_vars_are_unbound(values): + if self.any_selected_variable_is_inferred_and_unbound(values): continue - self._warn_on_unbound_variables_(values.bindings, self.selected_variables) - if self.any_selected_not_inferred_vars_are_unbound(values): - for binding in self.generate_combinations_with_unbound_variables( - values.bindings - ): - yield OperationResult(binding, self._is_false_, self) + if self.any_selected_variable_is_unbound(values): + yield from self.evaluate_selected_variables(values.bindings) else: yield values - def any_selected_inferred_vars_are_unbound(self, values: OperationResult) -> bool: - return any( - var._id_ not in values and (isinstance(var, Variable) and var._is_inferred_) - for var in self.selected_variables - ) + def any_selected_variable_is_unbound(self, values: OperationResult) -> bool: + """ + Check if any of the selected variables is unbound. + + :param values: The current result with the current bindings. + :return: True if any of the selected variables is unbound, otherwise False. + """ + return any(var._id_ not in values for var in self.selected_variables) + + @staticmethod + def variable_is_inferred(var: CanBehaveLikeAVariable[T]) -> bool: + """ + Whether the variable is inferred or not. - def any_selected_not_inferred_vars_are_unbound( + :param var: The variable. + :return: True if the variable is inferred, otherwise False. + """ + return isinstance(var, Variable) and var._is_inferred_ + + def any_selected_variable_is_inferred_and_unbound( self, values: OperationResult ) -> bool: + """ + Check if any of the selected variables is inferred and is not bound. + + :param values: The current result with the current bindings. + :return: True if any of the selected variables is inferred and is not bound, otherwise False. + """ return any( - var._id_ not in values - and not (isinstance(var, Variable) and var._is_inferred_) + not self.variable_is_bound_or_its_children_are_bound(var, values) for var in self.selected_variables + if self.variable_is_inferred(var) ) + @lru_cache(maxsize=None) + def variable_is_bound_or_its_children_are_bound( + self, var: CanBehaveLikeAVariable[T], result: OperationResult + ) -> bool: + """ + Whether the variable is directly bound or all its children are bound. + + :param var: The variable. + :param result: The current result containing the current bindings. + :return: True if the variable is bound, otherwise False. + """ + if var._id_ in result: + return True + unique_vars = [uv.value for uv in var._unique_variables_ if uv.value is not var] + if unique_vars and all( + self.variable_is_bound_or_its_children_are_bound(uv, result) + for uv in unique_vars + ): + return True + return False + def update_data_from_child(self, child_result: OperationResult): if self._child_: self._is_false_ = child_result.is_false @@ -806,52 +801,33 @@ def update_data_from_child(self, child_result: OperationResult): def get_constrained_values( self, sources: Optional[Dict[int, HashedValue]] ) -> Iterable[OperationResult]: + """ + Evaluate the child (i.e., the conditions that constrain the domain of the selected variables). + + :param sources: The current bindings. + :return: The bindings after applying the constraints of the child. + """ if self._child_: yield from self._child_._evaluate__(sources, parent=self) else: yield from [OperationResult(sources, False, self)] - def generate_combinations_with_unbound_variables( + def evaluate_selected_variables( self, sources: Dict[int, HashedValue] - ): + ) -> Iterable[OperationResult]: + """ + Evaluate the selected variables by generating combinations of values from their evaluation generators. + + :param sources: The current bindings. + :return: An Iterable of OperationResults for each combination of values. + """ var_val_gen = { var: var._evaluate__(copy(sources), parent=self) for var in self.selected_variables } for sol in generate_combinations(var_val_gen): var_val = {var._id_: sol[var][var._id_] for var in self.selected_variables} - yield {**sources, **var_val} - - def _warn_on_unbound_variables_( - self, - sources: Dict[int, HashedValue], - selected_vars: Iterable[CanBehaveLikeAVariable], - ): - """ - Warn the user if there are unbound variables in the query descriptor, because this will result in a cartesian - product join operation. - - :param sources: The bound values after applying the conditions. - :param selected_vars: The variables selected in the query descriptor. - """ - unbound_variables = HashedIterable() - for var in selected_vars: - unbound_variables.update( - var._unique_variables_.difference(HashedIterable(values=sources)) - ) - unbound_variables_with_domain = HashedIterable() - for var in unbound_variables: - if var.value._domain_ and len(var.value._domain_.values) > 20: - if var not in self.warned_vars: - self.warned_vars.add(var) - unbound_variables_with_domain.add(var) - if unbound_variables_with_domain: - logger.warning( - f"\nCartesian Product: " - f"The following variables are not constrained " - f"{unbound_variables_with_domain.unwrapped_values}" - f"\nfor the query descriptor {self._name_}" - ) + yield OperationResult({**sources, **var_val}, self._is_false_, self) @property @lru_cache(maxsize=None) diff --git a/src/krrood/utils.py b/src/krrood/utils.py index 65d1b2e..9fdc9e9 100644 --- a/src/krrood/utils.py +++ b/src/krrood/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +from dataclasses import dataclass, field + from typing_extensions import TypeVar, Type, List T = TypeVar("T") @@ -13,3 +15,17 @@ def recursive_subclasses(cls: Type[T]) -> List[Type[T]]: return cls.__subclasses__() + [ g for s in cls.__subclasses__() for g in recursive_subclasses(s) ] + + +@dataclass +class DataclassException(Exception): + """ + A base exception class for dataclass-based exceptions. + The way this is used is by inheriting from it and setting the `message` field in the __post_init__ method, + then calling the super().__post_init__() method. + """ + + message: str = field(kw_only=True, default=None) + + def __post_init__(self): + super().__init__(self.message) diff --git a/test/test_eql/test_core/test_queries.py b/test/test_eql/test_core/test_queries.py index 5c8c785..49adc55 100644 --- a/test/test_eql/test_core/test_queries.py +++ b/test/test_eql/test_core/test_queries.py @@ -30,6 +30,13 @@ Predicate, ) from krrood.entity_query_language.symbol_graph import SymbolGraph +from krrood.entity_query_language.result_quantification_constraint import ( + ResultQuantificationConstraint, + Exactly, + AtLeast, + AtMost, + Range, +) from ...dataset.semantic_world_like_classes import ( Handle, Body, @@ -917,29 +924,25 @@ def test_unsupported_negation(handles_and_containers_world): def test_quantified_query(handles_and_containers_world): world = handles_and_containers_world - def get_quantified_query( - at_least: int = None, at_most: int = None, exactly: int = None - ): + def get_quantified_query(quantification: ResultQuantificationConstraint): query = an( entity( body := let(type_=Body, domain=world.bodies), contains(body.name, "Handle"), ), - at_least=at_least, - at_most=at_most, - exactly=exactly, + quantification=quantification, ) return query - results = list(get_quantified_query(at_least=3).evaluate()) + results = list(get_quantified_query(AtLeast(3)).evaluate()) assert len(results) == 3 - results = list(get_quantified_query(at_least=2, at_most=4).evaluate()) + results = list(get_quantified_query(Range(AtLeast(2), AtMost(4))).evaluate()) assert len(results) == 3 with pytest.raises(LessThanExpectedNumberOfSolutions): - list(get_quantified_query(at_least=4).evaluate()) + list(get_quantified_query(AtLeast(4)).evaluate()) with pytest.raises(GreaterThanExpectedNumberOfSolutions): - list(get_quantified_query(at_most=2).evaluate()) + list(get_quantified_query(AtMost(2)).evaluate()) with pytest.raises(GreaterThanExpectedNumberOfSolutions): - list(get_quantified_query(exactly=2).evaluate()) + list(get_quantified_query(Exactly(2)).evaluate()) with pytest.raises(LessThanExpectedNumberOfSolutions): - list(get_quantified_query(exactly=4).evaluate()) + list(get_quantified_query(Exactly(4)).evaluate()) diff --git a/test/test_eql/test_core/test_rules.py b/test/test_eql/test_core/test_rules.py index d8da771..d1941da 100644 --- a/test/test_eql/test_core/test_rules.py +++ b/test/test_eql/test_core/test_rules.py @@ -1,12 +1,5 @@ from krrood.entity_query_language.conclusion import Add -from krrood.entity_query_language.entity import ( - let, - an, - entity, - and_, - inference, - DomainKind, -) +from krrood.entity_query_language.entity import let, an, entity, and_, inference from krrood.entity_query_language.predicate import HasType from krrood.entity_query_language.rule import refinement, alternative, next_rule from ...dataset.semantic_world_like_classes import ( @@ -134,7 +127,7 @@ def test_rule_tree_with_multiple_refinements(doors_and_drawers_world): query = an( entity( - views := let(type_=View, domain=DomainKind.INFERRED), + views := inference(View)(), body == fixed_connection.parent, handle == fixed_connection.child, ) @@ -178,7 +171,7 @@ def test_rule_tree_with_an_alternative(doors_and_drawers_world): query = an( entity( - views := let(type_=View, domain=DomainKind.INFERRED), + views := inference(View)(), body == fixed_connection.parent, handle == fixed_connection.child, ) @@ -220,7 +213,7 @@ def test_rule_tree_with_multiple_alternatives(doors_and_drawers_world): query = an( entity( - views := let(type_=View, domain=DomainKind.INFERRED), + views := inference(View)(), body == fixed_connection.parent, handle == fixed_connection.child, body == prismatic_connection.child, @@ -272,7 +265,7 @@ def test_rule_tree_with_multiple_alternatives_optimized(doors_and_drawers_world) query = an( entity( - views := let(type_=View, domain=DomainKind.INFERRED), + views := inference(View)(), HasType(fixed_connection.child, Handle), fixed_connection.parent == prismatic_connection.child, ) @@ -337,7 +330,7 @@ def test_rule_tree_with_multiple_alternatives_better_rule_tree(doors_and_drawers query = an( entity( - views := let(type_=View, domain=DomainKind.INFERRED), + views := inference(View)(), body == fixed_connection.parent, handle == fixed_connection.child, ) @@ -389,7 +382,7 @@ def test_rule_tree_with_multiple_alternatives_better_rule_tree_optimized( query = an( entity( - views := let(type_=View, domain=DomainKind.INFERRED), + views := inference(View)(), HasType(fixed_connection.child, Handle), ) )