Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/eql/result_quantifiers.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ from dataclasses import dataclass
from typing_extensions import List

from krrood.entity_query_language.entity import entity, let, the, Symbol, an
from krrood.entity_query_language.result_quantification_constraint import AtLeast, AtMost, Exactly, Range
from krrood.entity_query_language.failures import MultipleSolutionFound, LessThanExpectedNumberOfSolutions, GreaterThanExpectedNumberOfSolutions


Expand Down Expand Up @@ -94,8 +95,7 @@ You can also bound the number of results within a range using both `at_least` an

query = an(
entity(body := let(Body, domain=world.bodies)),
at_least=1,
at_most=3,
quantification=Range(AtLeast(1), AtMost(3))
)

print(len(list(query.evaluate()))) # -> 2
Expand All @@ -107,7 +107,7 @@ If you want an exact number of results, use `exactly`:

query = an(
entity(body := let(Body, domain=world.bodies)),
exactly=2,
quantification=Exactly(2),
)

print(len(list(query.evaluate()))) # -> 2
Expand All @@ -126,7 +126,7 @@ The result count constraints will raise informative exceptions when the number o

query = an(
entity(body := let(Body, domain=world.bodies)),
at_least=3,
quantification=AtLeast(3),
)
try:
list(query.evaluate())
Expand All @@ -137,7 +137,7 @@ except LessThanExpectedNumberOfSolutions as e:

query = an(
entity(body := let(Body, domain=world.bodies)),
at_most=1,
quantification=AtMost(1),
)
try:
list(query.evaluate())
Expand All @@ -148,7 +148,7 @@ except GreaterThanExpectedNumberOfSolutions as e:

query = an(
entity(body := let(Body, domain=world.bodies)),
exactly=1,
quantification=Exactly(1),
)
try:
list(query.evaluate())
Expand Down
2 changes: 1 addition & 1 deletion examples/eql/writing_rule_trees.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ from krrood.entity_query_language.predicate import HasType
# Declare the variables
fixed_connection = let(type_=FixedConnection, domain=world.connections)
revolute_connection = let(type_=RevoluteConnection, domain=world.connections)
views = let(type_=View, domain=None)
views = inference(View)()

# Define aliases for convenience
handle = fixed_connection.child
Expand Down
65 changes: 12 additions & 53 deletions src/krrood/entity_query_language/entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import Callable

from .symbol_graph import SymbolGraph
from .utils import is_iterable
Expand Down Expand Up @@ -32,7 +33,6 @@
Comparator,
chained_logic,
CanBehaveLikeAVariable,
ResultQuantifier,
From,
Variable,
optimize_or,
Expand All @@ -41,7 +41,7 @@
Exists,
Literal,
)
from .conclusion import Infer
from .result_quantification_constraint import ResultQuantificationConstraint

from .predicate import (
Predicate,
Expand All @@ -63,23 +63,17 @@

def an(
entity_: EntityType,
at_least: Optional[int] = None,
at_most: Optional[int] = None,
exactly: Optional[int] = None,
quantification: Optional[ResultQuantificationConstraint] = None,
) -> Union[An[T], T, SymbolicExpression[T]]:
"""
Select a single element satisfying the given entity description.

:param entity_: An entity or a set expression to quantify over.
:param at_least: Optional minimum number of results.
:param at_most: Optional maximum number of results.
:param exactly: Optional exact number of results.
:param quantification: Optional quantification constraint.
:return: A quantifier representing "an" element.
:rtype: An[T]
"""
return select_one_or_select_many_or_an(
An, entity_, _at_least_=at_least, _at_most_=at_most, _exactly_=exactly
)
return An(entity_, _quantification_constraint_=quantification)


a = an
Expand All @@ -98,37 +92,7 @@ def the(
:return: A quantifier representing "an" element.
:rtype: The[T]
"""
return select_one_or_select_many_or_an(The, entity_)


def select_one_or_select_many_or_an(
quantifier: Type[ResultQuantifier],
entity_: EntityType,
**kwargs,
) -> ResultQuantifier[T]:
"""
Selects one or many entities or infers the result based on the provided quantifier
and entity type. This function facilitates creating or managing quantified results
depending on the entity type and additional keyword arguments.

:param quantifier: A type of ResultQuantifier used to quantify the entity.
:param entity_: The entity or quantifier to be selected or converted to a quantifier.
:param kwargs: Additional keyword arguments for quantifier initialization.
:return: A result quantifier of the provided type, inferred type, or directly the
one provided.
:raises ValueError: If the provided entity is invalid.
"""
if isinstance(entity_, ResultQuantifier):
if isinstance(entity_, quantifier):
return entity_

entity_._child_._parent_ = None
return quantifier(entity_._child_, **kwargs)

if isinstance(entity_, (Entity, SetOf)):
return quantifier(entity_, **kwargs)

raise ValueError(f"Invalid entity: {entity_}")
return The(entity_)


def entity(
Expand Down Expand Up @@ -191,11 +155,7 @@ def _extract_variables_and_expression(
return selected_variables, expression


class DomainKind(Enum):
INFERRED = 1


DomainType = Union[Iterable, TypingLiteral[DomainKind.INFERRED], None]
DomainType = Union[Iterable, None]


def let(
Expand All @@ -214,7 +174,7 @@ def let(
which may contain unnecessarily many elements.

:param type_: The type of variable.
:param domain: Iterable of potential values for the variable or an INFERRED sentinel (for rules) or None.
:param domain: Iterable of potential values for the variable or None.
If None, the domain will be inferred from the SymbolGraph for Symbol types, else should not be evaluated by EQL
but by another evaluator (e.g., EQL To SQL converter in Ormatic).
:param name: The variable name, only required for pretty printing.
Expand All @@ -229,7 +189,6 @@ def let(
_type_=type_,
_domain_source_=domain_source,
_name__=name,
_is_inferred_=domain is DomainKind.INFERRED,
)

return result
Expand All @@ -245,9 +204,7 @@ def _get_domain_source_from_domain_and_type_values(
:param type_: The type of the variable.
:return: The domain source as a From object.
"""
if domain is DomainKind.INFERRED:
domain = None
elif is_iterable(domain):
if is_iterable(domain):
domain = filter(lambda x: isinstance(x, type_), domain)
elif domain is None and issubclass(type_, Symbol):
domain = SymbolGraph().get_instances_of_type(type_)
Expand Down Expand Up @@ -352,7 +309,9 @@ def exists(
return Exists(universal_variable, condition)


def inference(type_: Type[T]) -> Union[Variable[T], Type[T]]:
def inference(
type_: Type[T],
) -> Union[Type[T], Callable[[Any], Variable[T]]]:
"""
This returns a factory function that creates a new variable of the given type and takes keyword arguments for the
type constructor.
Expand Down
Loading