diff --git a/.gitignore b/.gitignore index b8ff20397..fb2c422fc 100644 --- a/.gitignore +++ b/.gitignore @@ -177,6 +177,9 @@ src/datasource_toolkit/poetry.lock src/flask_agent/poetry.lock src/django_agent/poetry.lock src/datasource_django/poetry.lock +src/datasource_rpc/poetry.lock +src/agent_rpc/poetry.lock +src/rpc_common/poetry.lock # generate file during tests src/django_agent/tests/test_project_agent/.forestadmin-schema.json \ No newline at end of file diff --git a/src/agent_rpc/README.md b/src/agent_rpc/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/agent_rpc/forestadmin/agent_rpc/__init__.py b/src/agent_rpc/forestadmin/agent_rpc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/agent_rpc/forestadmin/agent_rpc/agent.py b/src/agent_rpc/forestadmin/agent_rpc/agent.py new file mode 100644 index 000000000..e86467366 --- /dev/null +++ b/src/agent_rpc/forestadmin/agent_rpc/agent.py @@ -0,0 +1,242 @@ +import asyncio +import json +from uuid import UUID + +from aiohttp import web +from aiohttp_sse import sse_response +from forestadmin.agent_rpc.hmac_middleware import HmacValidationError, HmacValidator +from forestadmin.agent_rpc.options import RpcOptions +from forestadmin.agent_toolkit.agent import Agent +from forestadmin.agent_toolkit.options import Options +from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType +from forestadmin.datasource_toolkit.utils.schema import SchemaUtils +from forestadmin.rpc_common.serializers.actions import ( + ActionFormSerializer, + ActionFormValuesSerializer, + ActionResultSerializer, +) +from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer +from forestadmin.rpc_common.serializers.collection.filter import ( + FilterSerializer, + PaginatedFilterSerializer, + ProjectionSerializer, +) +from forestadmin.rpc_common.serializers.collection.record import RecordSerializer +from forestadmin.rpc_common.serializers.schema.schema import SchemaSerializer +from forestadmin.rpc_common.serializers.utils import CallerSerializer + + +class RpcAgent(Agent): + def __init__(self, options: RpcOptions): + self.listen_addr, self.listen_port = options["listen_addr"].rsplit(":", 1) + agent_options: Options = {**options} # type:ignore + agent_options["skip_schema_update"] = True + agent_options["env_secret"] = "f" * 64 + agent_options["server_url"] = "http://fake" + agent_options["schema_path"] = "./.forestadmin-schema.json" + super().__init__(agent_options) + self.hmac_validator = HmacValidator(options["auth_secret"]) + + self.app = web.Application(middlewares=[self.hmac_middleware]) + self.setup_routes() + + @web.middleware + async def hmac_middleware(self, request: web.Request, handler): + # TODO: hmc on SSE ? + if request.url.path in ["/", "/forest/rpc/sse"]: + return await handler(request) + + header_sign = request.headers.get("X_SIGNATURE") + header_timestamp = request.headers.get("X_TIMESTAMP") + try: + self.hmac_validator.validate_hmac(header_sign, header_timestamp) + except HmacValidationError: + return web.Response(status=401, text="Unauthorized from HMAC verification") + return await handler(request) + + def setup_routes(self): + self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) # type: ignore + self.app.router.add_route("GET", "/forest/rpc/sse", self.sse_handler) + self.app.router.add_route("GET", "/forest/rpc-schema", self.schema) + + # self.app.router.add_route("POST", "/execute-native-query", self.native_query) + self.app.router.add_route("POST", "/forest/rpc/datasource-chart", self.render_chart) + + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/list", self.collection_list) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/create", self.collection_create) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/update", self.collection_update) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/delete", self.collection_delete) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/aggregate", self.collection_aggregate) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/action-form", self.collection_get_form) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/action-execute", self.collection_execute) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/chart", self.collection_render_chart) + + async def sse_handler(self, request: web.Request) -> web.StreamResponse: + async with sse_response(request) as resp: + while resp.is_connected(): + await resp.send("", event="heartbeat") + await asyncio.sleep(1) + data = json.dumps({"event": "RpcServerStop"}) + await resp.send(data, event="RpcServerStop") + return resp + + async def schema(self, request): + await self.customizer.get_datasource() + + return web.json_response(await SchemaSerializer(await self.customizer.get_datasource()).serialize()) + + async def collection_list(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(request.match_info["collection_name"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + projection = ProjectionSerializer.deserialize(body_params["projection"]) + + records = await collection.list(caller, filter_, projection) + records = [RecordSerializer.serialize(record) for record in records] + return web.json_response(records) + + async def collection_create(self, request: web.Request): + body_params = await request.text() + body_params = json.loads(body_params) + ds = await self.customizer.get_datasource() + + collection = ds.get_collection(request.match_info["collection_name"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] # type:ignore + + records = await collection.create(caller, data) + records = [RecordSerializer.serialize(record) for record in records] + return web.json_response(records) + + async def collection_update(self, request: web.Request): + body_params = await request.text() + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + collection = ds.get_collection(request.match_info["collection_name"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + patch = RecordSerializer.deserialize(body_params["patch"], collection) # type:ignore + + await collection.update(caller, filter_, patch) + return web.Response(text="OK") + + async def collection_delete(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(request.match_info["collection_name"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + + await collection.delete(caller, filter_) + return web.Response(text="OK") + + async def collection_aggregate(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(request.match_info["collection_name"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + aggregation = AggregationSerializer.deserialize(body_params["aggregation"]) + + records = await collection.aggregate(caller, filter_, aggregation) + return web.json_response(records) + + async def collection_get_form(self, request: web.Request): + body_params = await request.text() + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + collection = ds.get_collection(request.match_info["collection_name"]) + + caller = CallerSerializer.deserialize(body_params["caller"]) + action_name = body_params["actionName"] + if body_params["filter"]: + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + else: + filter_ = None + data = ActionFormValuesSerializer.deserialize(body_params["data"]) + meta = body_params["meta"] + + form = await collection.get_form(caller, action_name, data, filter_, meta) + return web.json_response(ActionFormSerializer.serialize(form)) + + async def collection_execute(self, request: web.Request): + body_params = await request.text() + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + collection = ds.get_collection(request.match_info["collection_name"]) + + caller = CallerSerializer.deserialize(body_params["caller"]) + action_name = body_params["actionName"] + if body_params["filter"]: + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + else: + filter_ = None + data = ActionFormValuesSerializer.deserialize(body_params["data"]) + + result = await collection.execute(caller, action_name, data, filter_) + return web.json_response(ActionResultSerializer.serialize(result)) # type:ignore + + async def collection_render_chart(self, request: web.Request): + body_params = await request.text() + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + collection = ds.get_collection(request.match_info["collection_name"]) + + caller = CallerSerializer.deserialize(body_params["caller"]) + name = body_params["name"] + record_id = body_params["recordId"] + ret = [] + for i, value in enumerate(record_id): + type_record_id = collection.schema["fields"][SchemaUtils.get_primary_keys(collection.schema)[i]][ + "column_type" + ] + + if type_record_id == PrimitiveType.DATE: + ret.append(value.fromisoformat()) + elif type_record_id == PrimitiveType.DATE_ONLY: + ret.append(value.fromisoformat()) + elif type_record_id == PrimitiveType.DATE: + ret.append(value.fromisoformat()) + elif type_record_id == PrimitiveType.POINT: + ret.append((value[0], value[1])) + elif type_record_id == PrimitiveType.UUID: + ret.append(UUID(value)) + else: + ret.append(value) + + result = await collection.render_chart(caller, name, record_id) + return web.json_response(result) + + async def render_chart(self, request: web.Request): + body_params = await request.text() + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + + caller = CallerSerializer.deserialize(body_params["caller"]) + name = body_params["name"] + + result = await ds.render_chart(caller, name) + return web.json_response(result) + + # TODO: speak about; it's currently not implemented in ruby + # async def native_query(self, request: web.Request): + # body_params = await request.text() + # body_params = json.loads(body_params) + + # ds = await self.customizer.get_datasource() + # connection_name = body_params["connectionName"] + # native_query = body_params["nativeQuery"] + # parameters = body_params["parameters"] + + # result = await ds.execute_native_query(connection_name, native_query, parameters) + # return web.json_response(result) + + def start(self): + web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port)) diff --git a/src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py b/src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py new file mode 100644 index 000000000..46bc08086 --- /dev/null +++ b/src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py @@ -0,0 +1,56 @@ +import datetime + +from forestadmin.datasource_toolkit.exceptions import ForestException +from forestadmin.rpc_common.hmac import generate_hmac, is_valid_hmac + + +class HmacValidationError(ForestException): + pass + + +class HmacValidator: + ALLOWED_TIME_DIFF = 300 + SIGNATURE_REUSE_WINDOW = 5 + + def __init__(self, secret_key: str) -> None: + self.secret_key = secret_key + self.used_signatures = dict() + + def validate_hmac(self, sign, timestamp): + """Validate the HMAC signature.""" + if not sign or not timestamp: + raise HmacValidationError("Missing HMAC signature or timestamp") + + self.validate_timestamp(timestamp) + + expected_sign = generate_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8")) + if not is_valid_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8"), expected_sign.encode("utf-8")): + raise HmacValidationError("Invalid HMAC signature") + + if sign in self.used_signatures.keys(): + last_used = self.used_signatures[sign] + if (datetime.datetime.now(datetime.timezone.utc) - last_used).total_seconds() > self.SIGNATURE_REUSE_WINDOW: + raise HmacValidationError("HMAC signature has already been used") + + self.used_signatures[sign] = datetime.datetime.now(datetime.timezone.utc) + self._cleanup_old_signs() + return True + + def validate_timestamp(self, timestamp): + try: + current_time = datetime.datetime.fromisoformat(timestamp) + except Exception: + raise HmacValidationError("Invalid timestamp format") + + if (datetime.datetime.now(datetime.timezone.utc) - current_time).total_seconds() > self.ALLOWED_TIME_DIFF: + raise HmacValidationError("Timestamp is too old or in the future") + + def _cleanup_old_signs(self): + now = datetime.datetime.now(datetime.timezone.utc) + to_rm = [] + for sign, last_used in self.used_signatures.items(): + if (now - last_used).total_seconds() > self.ALLOWED_TIME_DIFF: + to_rm.append(sign) + + for sign in to_rm: + del self.used_signatures[sign] diff --git a/src/agent_rpc/forestadmin/agent_rpc/options.py b/src/agent_rpc/forestadmin/agent_rpc/options.py new file mode 100644 index 000000000..d3eda88e9 --- /dev/null +++ b/src/agent_rpc/forestadmin/agent_rpc/options.py @@ -0,0 +1,6 @@ +from typing_extensions import TypedDict + + +class RpcOptions(TypedDict): + listen_addr: str + auth_secret: str diff --git a/src/agent_rpc/main.py b/src/agent_rpc/main.py new file mode 100644 index 000000000..29681b386 --- /dev/null +++ b/src/agent_rpc/main.py @@ -0,0 +1,28 @@ +import asyncio +import logging + +from forestadmin.agent_rpc.agent import RpcAgent +from forestadmin.datasource_sqlalchemy.datasource import SqlAlchemyDatasource +from sqlalchemy_models import DB_URI, Base + + +def main() -> None: + agent = RpcAgent({"listen_addr": "0.0.0.0:50051"}) + agent.add_datasource( + SqlAlchemyDatasource(Base, DB_URI), {"rename": lambda collection_name: f"FROMRPCAGENT_{collection_name}"} + ) + agent.customize_collection("FROMRPCAGENT_address").add_field( + "new_fieeeld", + { + "column_type": "String", + "dependencies": ["pk"], + "get_values": lambda records, ctx: ["v" for r in records], + }, + ) + agent.start() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + # asyncio.run(main()) + main() diff --git a/src/agent_rpc/pyproject.toml b/src/agent_rpc/pyproject.toml new file mode 100644 index 000000000..7f0996bf8 --- /dev/null +++ b/src/agent_rpc/pyproject.toml @@ -0,0 +1,77 @@ +[build-system] +requires = [ "poetry-core",] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "forestadmin-agent-rpc" +description = "server for datasource rpc" +version = "1.23.0" +authors = [ "Julien Barreau ",] +readme = "README.md" +repository = "https://github.com/ForestAdmin/agent-python" +documentation = "https://docs.forestadmin.com/developer-guide-agents-python/" +homepage = "https://www.forestadmin.com" +[[tool.poetry.packages]] +include = "forestadmin" + +[tool.poetry.dependencies] +python = ">=3.8,<3.14" +typing-extensions = "~=4.2" +forestadmin-agent-toolkit = "1.23.0" +forestadmin-datasource-toolkit = "1.23.0" +forestadmin-rpc-common = "1.23.0" + +[tool.poetry.dependencies."backports.zoneinfo"] +version = "~=0.2.1" +python = "<3.9" +extras = [ "tzdata",] + +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.linter] +optional = true + +[tool.poetry.group.formatter] +optional = true + +[tool.poetry.group.sorter] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "~=7.1" +pytest-asyncio = "~=0.18" +coverage = "~=6.5" +pytest-cov = "^4.0.0" +sqlalchemy = ">=1.4.0" + +[tool.poetry.group.linter.dependencies] +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=5.0" +python = "<3.8.1" + +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=6.0" +python = ">=3.8.1" + +[tool.poetry.group.formatter.dependencies] +black = "~=22.10" + +[tool.poetry.group.sorter.dependencies] +isort = "~=3.6" + +[tool.poetry.group.test.dependencies.forestadmin-datasource-toolkit] +path = "../datasource_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-datasource-sqlalchemy] +path = "../datasource_sqlalchemy" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-agent-toolkit] +path = "../agent_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-rpc-common] +path = "../rpc_common" +develop = true diff --git a/src/agent_rpc/sqlalchemy_models.py b/src/agent_rpc/sqlalchemy_models.py new file mode 100644 index 000000000..261c78e39 --- /dev/null +++ b/src/agent_rpc/sqlalchemy_models.py @@ -0,0 +1,95 @@ +import enum +import os + +import sqlalchemy # type: ignore +from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, Integer, LargeBinary, String, func +from sqlalchemy.orm import relationship + +sqlite_path = os.path.abspath( + os.path.join(__file__, "..", "..", "_example", "django", "django_demo", "db_sqlalchemy.sql") +) +print(sqlite_path) +print(os.path.exists(sqlite_path)) +DB_URI = f"sqlite:///{sqlite_path}" +# DB_URI = "postgresql://flask:flask@127.0.0.1:5432/flask_scratch" + + +use_sqlalchemy_2 = sqlalchemy.__version__.split(".")[0] == "2" +if use_sqlalchemy_2: + from sqlalchemy import Uuid, create_engine + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass + +else: + from sqlalchemy import create_engine + from sqlalchemy.dialects.postgresql import UUID as Uuid + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + +engine = create_engine(DB_URI, echo=False) + + +class ORDER_STATUS(enum.Enum): + PENDING = "Pending" + DISPATCHED = "Dispatched" + DELIVERED = "Delivered" + REJECTED = "Rejected" + + +class Address(Base): + __tablename__ = "address" + pk = Column(Integer, primary_key=True) + street = Column(String(254), nullable=False) + street_number = Column(String(254), nullable=True) + city = Column(String(254), nullable=False) + country = Column(String(254), default="France", nullable=False) + zip_code = Column(String(5), nullable=False, default="75009") + customers = relationship("Customer", secondary="customers_addresses", back_populates="addresses") + + +class Customer(Base): + __tablename__ = "customer" + pk = Column(Uuid, primary_key=True) + first_name = Column(String(254), nullable=False) + last_name = Column(String(254), nullable=False) + age = Column(Integer, nullable=True) + birthday_date = Column(DateTime(timezone=True), default=func.now()) + addresses = relationship("Address", secondary="customers_addresses", back_populates="customers") + is_vip = Column(Boolean, default=False) + avatar = Column(LargeBinary, nullable=True) + + +class Order(Base): + __tablename__ = "order" + pk = Column(Integer, primary_key=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + amount = Column(Integer, nullable=False) + customer = relationship("Customer", backref="orders") + customer_id = Column(Uuid, ForeignKey("customer.pk")) + billing_address_id = Column(Integer, ForeignKey("address.pk")) + billing_address = relationship("Address", foreign_keys=[billing_address_id], backref="billing_orders") + delivering_address_id = Column(Integer, ForeignKey("address.pk")) + delivering_address = relationship("Address", foreign_keys=[delivering_address_id], backref="delivering_orders") + status = Column(Enum(ORDER_STATUS)) + + cart = relationship("Cart", uselist=False, back_populates="order") + + +class Cart(Base): + __tablename__ = "cart" + pk = Column(Integer, primary_key=True) + + name = Column(String(254), nullable=False) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + order_id = Column(Integer, ForeignKey("order.pk"), nullable=True) + order = relationship("Order", back_populates="cart") + + +class CustomersAddresses(Base): + __tablename__ = "customers_addresses" + customer_id = Column(Uuid, ForeignKey("customer.pk"), primary_key=True) + address_id = Column(Integer, ForeignKey("address.pk"), primary_key=True) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index 199ab8fa5..55220e893 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -71,6 +71,17 @@ def __del__(self): if hasattr(self, "_sse_thread") and self._sse_thread.is_alive(): self._sse_thread.stop() + async def reload(self): + try: + await self.customizer._reload() + except Exception as exc: + ForestLogger.log("error", f"Error reloading agent: {exc}") + return + + self._resources = None + await self.__mk_resources() + await self.send_schema() + async def __mk_resources(self): self._resources: Resources = { "capabilities": CapabilitiesResource( @@ -219,6 +230,15 @@ async def _start(self): return ForestLogger.log("debug", "Starting agent") + await self.send_schema() + + if self.options["instant_cache_refresh"]: + self._sse_thread.start() + + ForestLogger.log("debug", "Agent started") + Agent.__IS_INITIALIZED = True + + async def send_schema(self): if self.options["skip_schema_update"] is False: try: api_map = await SchemaEmitter.get_serialized_schema( @@ -233,9 +253,3 @@ async def _start(self): ForestLogger.log("warning", "Cannot send the apimap to Forest. Are you online?") else: ForestLogger.log("warning", 'Schema update was skipped (caused by options["skip_schema_update"]=True)') - - if self.options["instant_cache_refresh"]: - self._sse_thread.start() - - ForestLogger.log("debug", "Agent started") - Agent.__IS_INITIALIZED = True diff --git a/src/datasource_rpc/README.md b/src/datasource_rpc/README.md new file mode 100644 index 000000000..8fb870bae --- /dev/null +++ b/src/datasource_rpc/README.md @@ -0,0 +1 @@ +python -m grpc_tools.protoc -I./ --python_out=. --pyi_out=. --grpc_python_out=. ./*.proto \ No newline at end of file diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/__init__.py b/src/datasource_rpc/forestadmin/datasource_rpc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py new file mode 100644 index 000000000..d12100378 --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py @@ -0,0 +1,187 @@ +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.interfaces.actions import Action, ActionFormElement, ActionResult, ActionsScope +from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType +from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation +from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter +from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter +from forestadmin.datasource_toolkit.interfaces.query.projections import Projection +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias +from forestadmin.datasource_toolkit.utils.schema import SchemaUtils +from forestadmin.rpc_common.serializers.actions import ( + ActionFormSerializer, + ActionFormValuesSerializer, + ActionResultSerializer, +) +from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer +from forestadmin.rpc_common.serializers.collection.filter import ( + FilterSerializer, + PaginatedFilterSerializer, + ProjectionSerializer, +) +from forestadmin.rpc_common.serializers.collection.record import RecordSerializer +from forestadmin.rpc_common.serializers.utils import CallerSerializer + +if TYPE_CHECKING: + from forestadmin.datasource_rpc.datasource import RPCDatasource + + +class RPCCollection(Collection): + def __init__(self, name: str, datasource: "RPCDatasource"): + super().__init__(name, datasource) + self._rpc_actions = {} + self._datasource = datasource + + @property + def datasource(self) -> "RPCDatasource": + return self._datasource + + def add_rpc_action(self, name: str, action: Dict[str, Any]) -> None: + if name in self._schema["actions"]: + raise ValueError(f"Action {name} already exists in collection {self.name}") + self._schema["actions"][name] = Action( + scope=ActionsScope(action["scope"]), + description=action.get("description"), + submit_button_label=action.get("submit_button_label"), + generate_file=action.get("generate_file", False), + static_form=action["static_form"], + ) + self._rpc_actions[name] = {"form": action["form"]} + + async def list(self, caller: User, filter_: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]: + body = { + "caller": CallerSerializer.serialize(caller), + "filter": PaginatedFilterSerializer.serialize(filter_, self), + "projection": ProjectionSerializer.serialize(projection), + "collectionName": self.name, + } + ret = await self.datasource.requester.list(self.name, body) + return [RecordSerializer.deserialize(record, self) for record in ret] + + async def create(self, caller: User, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + body = { + "caller": CallerSerializer.serialize(caller), + "data": [RecordSerializer.serialize(r) for r in data], + "collectionName": self.name, + } + response = await self.datasource.requester.create(self.name, body) + return [RecordSerializer.deserialize(record, self) for record in response] + + async def update(self, caller: User, filter_: Optional[Filter], patch: Dict[str, Any]) -> None: + body = { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), # type: ignore + "patch": RecordSerializer.serialize(patch), + "collectionName": self.name, + } + await self.datasource.requester.update(self.name, body) + + async def delete(self, caller: User, filter_: Filter | None) -> None: + body = json.dumps( + { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), # type: ignore + "collectionName": self.name, + } + ) + await self.datasource.requester.delete(self.name, body) + + async def aggregate( + self, caller: User, filter_: Filter | None, aggregation: Aggregation, limit: int | None = None + ) -> List[AggregateResult]: + body = { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), # type: ignore + "aggregation": AggregationSerializer.serialize(aggregation), + "collectionName": self.name, + } + response = await self.datasource.requester.aggregate(self.name, body) + return response + + async def get_form( + self, + caller: User, + name: str, + data: Optional[RecordsDataAlias] = None, + filter_: Optional[Filter] = None, + meta: Optional[Dict[str, Any]] = None, + ) -> List[ActionFormElement]: + if name not in self._schema["actions"]: + raise ValueError(f"Action {name} does not exist in collection {self.name}") + + if self._schema["actions"][name].static_form: + return self._rpc_actions[name]["form"] + + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, + "data": ActionFormValuesSerializer.serialize(data), # type: ignore + "meta": meta, + "collectionName": self.name, + "actionName": name, + } + response = await self.datasource.requester.get_form(self.name, body) + return ActionFormSerializer.deserialize(response) # type: ignore + + async def execute( + self, + caller: User, + name: str, + data: RecordsDataAlias, + filter_: Optional[Filter], + ) -> ActionResult: + if name not in self._schema["actions"]: + raise ValueError(f"Action {name} does not exist in collection {self.name}") + + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, + "data": ActionFormValuesSerializer.serialize(data), + "collectionName": self.name, + "actionName": name, + } + response = await self.datasource.requester.execute(self.name, body) + return ActionResultSerializer.deserialize(response) + + async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: + if name not in self._schema["charts"].keys(): + raise ValueError(f"Chart {name} does not exist in collection {self.name}") + + ret = [] + for i, value in enumerate(record_id): + pk_field = SchemaUtils.get_primary_keys(self.schema)[i] + type_record_id = self.schema["fields"][pk_field]["column_type"] # type: ignore + + if type_record_id == PrimitiveType.DATE: + ret.append(value.isoformat()) + elif type_record_id == PrimitiveType.DATE_ONLY: + ret.append(value.isoformat()) + elif type_record_id == PrimitiveType.DATE: + ret.append(value.isoformat()) + elif type_record_id == PrimitiveType.POINT: + ret.append((value[0], value[1])) + elif type_record_id == PrimitiveType.UUID: + ret.append(str(value)) + else: + ret.append(value) + + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "name": name, + "collectionName": self.name, + "recordId": ret, + } + return await self.datasource.requester.collection_render_chart(self.name, body) + + def get_native_driver(self): + ForestLogger.log( + "error", + "Get_native_driver is not available on RPCCollection. " + "Please use execute_native_query on the rpc datasource instead", + ) + pass diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py new file mode 100644 index 000000000..fee0b9788 --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -0,0 +1,129 @@ +import asyncio +import hashlib +import json +from threading import Thread +from typing import Any, Dict + +from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_rpc.collection import RPCCollection +from forestadmin.datasource_rpc.requester import RPCRequester +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.datasource_toolkit.utils.user_callable import call_user_function +from forestadmin.rpc_common.serializers.schema.schema import SchemaDeserializer +from forestadmin.rpc_common.serializers.utils import CallerSerializer + + +def _hash_schema(schema) -> str: + return hashlib.sha256(json.dumps(schema, sort_keys=True).encode()).hexdigest() + + +async def create_rpc_datasource(connection_uri: str, secret_key: str, reload_method=None) -> "RPCDatasource": + """Create a new RPC datasource and wait for the connection to be established.""" + datasource = RPCDatasource(connection_uri, secret_key, reload_method) + await datasource.wait_for_connection() + await datasource.introspect() + return datasource + + +class RPCDatasource(Datasource): + def __init__(self, connection_uri: str, secret_key: str, reload_method=None): + super().__init__([]) + self.connection_uri = connection_uri + self.reload_method = reload_method + self.secret_key = secret_key + self.last_schema_hash = "" + + self.requester = RPCRequester(connection_uri, secret_key) + + # self.thread = Thread(target=asyncio.run, args=(self.run(),), name="RPCDatasourceSSEThread", daemon=True) + self.thread = Thread( + target=asyncio.run, + args=(self.requester.sse_connect(self.sse_callback),), + name="RPCDatasourceSSEThread", + daemon=True, + ) + self.thread.start() + + async def introspect(self) -> bool: + """return true if schema has changed""" + schema_data = await self.requester.schema() + if self.last_schema_hash == _hash_schema(schema_data): + ForestLogger.log("debug", "[RPCDatasource] Schema has not changed") + return False + self.last_schema_hash = _hash_schema(schema_data) + + self._collections = {} + self._schema = {"charts": {}} + schema_data = SchemaDeserializer().deserialize(schema_data) + for collection_name, collection in schema_data["collections"].items(): + self.create_collection(collection_name, collection) + + self._schema["charts"] = {name: None for name in schema_data["charts"]} # type: ignore + self._live_query_connections = schema_data.get("live_query_connections", []) + return True + + def create_collection(self, collection_name, collection_schema): + collection = RPCCollection(collection_name, self) + for name, field in collection_schema["fields"].items(): + collection.add_field(name, field) + + if len(collection_schema["actions"]) > 0: + for action_name, action_schema in collection_schema["actions"].items(): + collection.add_rpc_action(action_name, action_schema) + + collection._schema["charts"] = {name: None for name in collection_schema["charts"]} # type: ignore + collection.add_segments(collection_schema["segments"]) + + if collection_schema["countable"]: + collection.enable_count() + collection.enable_search() + + self.add_collection(collection) + + async def wait_for_connection(self): + """Wait for the connection to be established.""" + await self.requester.wait_for_connection() + ForestLogger.log("debug", "Connection to RPC datasource established") + + async def internal_reload(self): + has_changed = await self.introspect() + if has_changed and self.reload_method is not None: + await call_user_function(self.reload_method) + + async def sse_callback(self): + await self.wait_for_connection() + await self.internal_reload() + + self.thread = Thread( + target=asyncio.run, + args=(self.requester.sse_connect(self.sse_callback),), + name="RPCDatasourceSSEThread", + daemon=True, + ) + self.thread.start() + + # TODO: speak about; it's currently not implemented in ruby + # async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: + # return await self.requester.native_query( + # { + # "connectionName": connection_name, + # "nativeQuery": native_query, + # "parameters": parameters, + # } + # ) + async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: + raise NotImplementedError + + async def render_chart(self, caller: User, name: str) -> Chart: + if name not in self._schema["charts"].keys(): + raise ValueError(f"Chart {name} does not exist in this datasource") + + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "name": name, + } + + response = self.requester.collection_render_chart(body) + return response diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py new file mode 100644 index 000000000..de2989ff3 --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py @@ -0,0 +1,196 @@ +import asyncio +import datetime +from typing import List + +import aiohttp +from aiohttp_sse_client import client as sse_client +from forestadmin.rpc_common.hmac import generate_hmac + + +class RPCRequester: + def __init__(self, connection_uri: str, secret_key: str): + self.connection_uri = connection_uri + self.secret_key = secret_key + + def mk_hmac_headers(self): + timestamp = datetime.datetime.now(datetime.UTC).isoformat() + return { + "X_SIGNATURE": generate_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8")), + "X_TIMESTAMP": timestamp, + } + + async def sse_connect(self, callback): + """Connect to the SSE stream.""" + await self.wait_for_connection() + + async with sse_client.EventSource( + f"http://{self.connection_uri}/forest/rpc/sse", + ) as event_source: + try: + async for event in event_source: + # print(event) + pass + except Exception: + pass + + await callback() + + async def wait_for_connection(self): + """Wait for the connection to be established.""" + async with aiohttp.ClientSession() as session: + while True: + try: + async with session.get(f"http://{self.connection_uri}/") as response: + if response.status == 200: + return True + else: + raise Exception(f"Failed to connect: {response.status}") + except Exception: + await asyncio.sleep(1) + + # methods for datasource + + async def schema(self) -> dict: + """Get the schema of the datasource.""" + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://{self.connection_uri}/forest/rpc-schema", + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to get schema: {response.status}") + + # TODO: speak about; it's currently not implemented in ruby + # async def native_query(self, body): + # """Execute a native query.""" + # async with aiohttp.ClientSession() as session: + # async with session.post( + # f"http://{self.connection_uri}/execute-native-query", + # json=body, + # headers=self.mk_hmac_headers(), + # ) as response: + # if response.status == 200: + # return await response.json() + # else: + # raise Exception(f"Failed to execute native query: {response.status}") + + async def datasource_render_chart(self, body): + """Render a chart.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/datasource-chart", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to render chart: {response.status}") + + # methods for collection + + async def list(self, collection_name, body) -> list[dict]: + """List records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/{collection_name}/list", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to list query: {response.status}") + + async def create(self, body) -> List[dict]: + """Create records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/create", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to create query: {response.status}") + + async def update(self, collection_name, body): + """Update records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/{collection_name}/update", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return + else: + raise Exception(f"Failed to update query: {response.status}") + + async def delete(self, collection_name, body): + """Delete records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/{collection_name}/delete", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return + else: + raise Exception(f"Failed to delete query: {response.status}") + + async def aggregate(self, collection_name, body): + """Aggregate records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/{collection_name}/aggregate", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to aggregate query: {response.status}") + + async def get_form(self, collection_name, body): + """Get the form for an action.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/{collection_name}/action-form", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to aggregate query: {response.status}") + + async def execute(self, collection_name, body): + """Execute an action.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/{collection_name}/action-execute", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to execute query: {response.status}") + + async def collection_render_chart(self, collection_name, body): + """Render a chart.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/forest/rpc/{collection_name}/chart", + json=body, + headers=self.mk_hmac_headers(), + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to render chart: {response.status}") diff --git a/src/datasource_rpc/pyproject.toml b/src/datasource_rpc/pyproject.toml new file mode 100644 index 000000000..fb5b92bb6 --- /dev/null +++ b/src/datasource_rpc/pyproject.toml @@ -0,0 +1,73 @@ +[build-system] +requires = [ "poetry-core",] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "forestadmin-datasource-rpc" +description = "client for datasource rpc" +version = "1.23.0" +authors = [ "Julien Barreau ",] +readme = "README.md" +repository = "https://github.com/ForestAdmin/agent-python" +documentation = "https://docs.forestadmin.com/developer-guide-agents-python/" +homepage = "https://www.forestadmin.com" +[[tool.poetry.packages]] +include = "forestadmin" + +[tool.poetry.dependencies] +python = ">=3.8,<3.14" +typing-extensions = "~=4.2" +forestadmin-agent-toolkit = "1.23.0" +forestadmin-datasource-toolkit = "1.23.0" +forestadmin-rpc-common = "1.23.0" +aiohttp-sse-client = ">=0.2.1" + +[tool.poetry.dependencies."backports.zoneinfo"] +version = "~=0.2.1" +python = "<3.9" +extras = [ "tzdata",] + +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.linter] +optional = true + +[tool.poetry.group.formatter] +optional = true + +[tool.poetry.group.sorter] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "~=7.1" +pytest-asyncio = "~=0.18" +coverage = "~=6.5" +pytest-cov = "^4.0.0" + +[tool.poetry.group.linter.dependencies] +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=5.0" +python = "<3.8.1" + +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=6.0" +python = ">=3.8.1" + +[tool.poetry.group.formatter.dependencies] +black = "~=22.10" + +[tool.poetry.group.sorter.dependencies] +isort = "~=3.6" + +[tool.poetry.group.test.dependencies.forestadmin-datasource-toolkit] +path = "../datasource_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-agent-toolkit] +path = "../agent_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-rpc-common] +path = "../rpc_common" +develop = true diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py index 2904eb167..24b53475b 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py @@ -25,6 +25,17 @@ async def get_datasource(self): await self.stack.apply_queue_customization() return self.stack.datasource + async def _reload(self): + old_composite_ds = self.composite_datasource + + try: + self.composite_datasource = CompositeDatasource() + await self.stack._reload(self.composite_datasource) + except Exception as e: + self.composite_datasource = old_composite_ds + raise e + await self.get_datasource() + @property def collections(self): """Get list of customizable collections""" @@ -41,19 +52,21 @@ def add_datasource(self, datasource: Datasource, options: Optional[DataSourceOpt async def _add_datasource(): nonlocal datasource + _datasource = datasource + if "include" in _options or "exclude" in _options: - publication_decorator = PublicationDataSourceDecorator(datasource) + publication_decorator = PublicationDataSourceDecorator(_datasource) publication_decorator.keep_collections_matching( _options.get("include", []), _options.get("exclude", []) ) - datasource = publication_decorator + _datasource = publication_decorator if "rename" in _options: - rename_decorator = RenameCollectionDataSourceDecorator(datasource) + rename_decorator = RenameCollectionDataSourceDecorator(_datasource) rename_decorator.rename_collections(_options.get("rename", {})) - datasource = rename_decorator + _datasource = rename_decorator - self.composite_datasource.add_datasource(datasource) + self.composite_datasource.add_datasource(_datasource) self.stack.queue_customization(_add_datasource) return self diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py index 64ebddda3..e1e5ffbb1 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py @@ -27,6 +27,9 @@ def schema(self): def collections(self): return [self.get_collection(c.name) for c in self.child_datasource.collections] + def get_native_query_connections(self) -> List[str]: + return self.child_datasource.get_native_query_connections() + def get_collection(self, name: str) -> CollectionDecorator: collection = self.child_datasource.get_collection(name) if collection not in self._decorators: diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py index 5a1d297fd..d0f295177 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py @@ -1,5 +1,6 @@ from typing import Awaitable, Callable, List +from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.datasource_toolkit.decorators.action.collections import ActionCollectionDecorator from forestadmin.datasource_toolkit.decorators.binary.collection import BinaryCollectionDecorator from forestadmin.datasource_toolkit.decorators.chart.chart_datasource_decorator import ChartDataSourceDecorator @@ -28,10 +29,12 @@ class DecoratorStack: def __init__(self, datasource: Datasource) -> None: self._customizations: List = list() - last = datasource + self._applied_customizations: List = list() + self._init_stack(datasource) + def _init_stack(self, datasource: Datasource): # Step 0: Do not query datasource when we know the result with yield an empty set. - last = self.override = DatasourceDecorator(last, OverrideCollectionDecorator) # type: ignore + last = self.override = DatasourceDecorator(datasource, OverrideCollectionDecorator) # type: ignore last = self.empty = DatasourceDecorator(last, EmptyCollectionDecorator) # type: ignore # Step 1: Computed-Relation-Computed sandwich (needed because some emulated relations depend @@ -67,10 +70,59 @@ def __init__(self, datasource: Datasource) -> None: self.datasource = last + def _backup_stack(self): + backup_stack = {} + for decorator_name in [ + "_customizations", + "_applied_customizations", + "override", + "empty", + "early_computed", + "early_op_emulate", + "early_op_equivalence", + "relation", + "lazy_joins", + "late_computed", + "late_op_emulate", + "late_op_equivalence", + "search", + "segment", + "sort_emulate", + "chart", + "action", + "schema", + "write", + "hook", + "validation", + "binary", + "publication", + "rename_field", + "datasource", + ]: + backup_stack[decorator_name] = getattr(self, decorator_name) + return backup_stack + + def _restore_stack(self, backup_stack): + for decorator_name, instance in backup_stack.items(): + setattr(self, decorator_name, instance) + + async def _reload(self, datasource: Datasource): + backup_stack = self._backup_stack() + + self._customizations: List = self._applied_customizations + self._applied_customizations: List = list() + try: + self._init_stack(datasource) + await self.apply_queue_customization() + except Exception as e: + ForestLogger.log("exception", f"Error while reloading customizations: {e}, restoring previous state") + self._restore_stack(backup_stack) + raise e + def queue_customization(self, customization: Callable[[], Awaitable[None]]): self._customizations.append(customization) - async def apply_queue_customization(self): + async def apply_queue_customization(self, store_customizations=True): queued_customization = self._customizations.copy() self._customizations = [] @@ -78,4 +130,6 @@ async def apply_queue_customization(self): while len(queued_customization) > 0: customization = queued_customization.pop() await customization() - await self.apply_queue_customization() + if store_customizations: + self._applied_customizations.append(customization) + await self.apply_queue_customization(False) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py index 04afa827b..0d39c9116 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py @@ -59,6 +59,7 @@ class Operator(enum.Enum): "starts_with", "ends_with", "contains", + "match", "not_contains", "longer_than", "shorter_than", diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py index 80424a279..b410f1d2b 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py @@ -65,7 +65,7 @@ class AggregationGroup(TypedDict): class PlainAggregation(TypedDict): field: NotRequired[Optional[str]] - operation: PlainAggregator + operation: Union[PlainAggregator, Aggregator] groups: NotRequired[List[PlainAggregationGroup]] diff --git a/src/django_agent/forestadmin/django_agent/apps.py b/src/django_agent/forestadmin/django_agent/apps.py index 0450fabba..18ca49686 100644 --- a/src/django_agent/forestadmin/django_agent/apps.py +++ b/src/django_agent/forestadmin/django_agent/apps.py @@ -1,4 +1,3 @@ -import asyncio import importlib import sys import threading @@ -8,6 +7,7 @@ from django.conf import settings from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.datasource_django.datasource import DjangoDatasource +from forestadmin.datasource_toolkit.utils.user_callable import call_user_function from forestadmin.django_agent.agent import DjangoAgent, create_agent @@ -24,17 +24,16 @@ def is_launch_as_server() -> bool: def init_app_agent() -> Optional[DjangoAgent]: if not is_launch_as_server(): return None - agent = create_agent() if not hasattr(settings, "FOREST_AUTO_ADD_DJANGO_DATASOURCE") or settings.FOREST_AUTO_ADD_DJANGO_DATASOURCE: - agent.add_datasource(DjangoDatasource()) + DjangoAgentApp._DJANGO_AGENT.add_datasource(DjangoDatasource()) customize_fn = getattr(settings, "FOREST_CUSTOMIZE_FUNCTION", None) if customize_fn: - agent = _call_user_customize_function(customize_fn, agent) + _call_user_customize_function(customize_fn, DjangoAgentApp._DJANGO_AGENT) - if agent: - agent.start() - return agent + if DjangoAgentApp._DJANGO_AGENT: + DjangoAgentApp._DJANGO_AGENT.start() + return DjangoAgentApp._DJANGO_AGENT def _call_user_customize_function(customize_fn: Union[str, Callable[[DjangoAgent], None]], agent: DjangoAgent): @@ -44,21 +43,16 @@ def _call_user_customize_function(customize_fn: Union[str, Callable[[DjangoAgent module = importlib.import_module(module_name) customize_fn = getattr(module, fn_name) except Exception as exc: - ForestLogger.log("error", f"cannot import {customize_fn} : {exc}. Quitting forest.") + ForestLogger.log("exception", f"cannot import {customize_fn} : {exc}. Quitting forest.") return - - if callable(customize_fn): try: - ret = customize_fn(agent) - if asyncio.iscoroutine(ret): - agent.loop.run_until_complete(ret) + agent.loop.run_until_complete(call_user_function(customize_fn, agent)) except Exception as exc: ForestLogger.log( - "error", + "exception", f'error executing "FOREST_CUSTOMIZE_FUNCTION" ({customize_fn}): {exc}. Quitting forest.', ) return - return agent class DjangoAgentApp(AppConfig): @@ -81,6 +75,9 @@ def get_agent(cls) -> DjangoAgent: def ready(self): # we need to wait for other apps to be ready, for this forest app must be ready # that's why we need another thread waiting for every app to be ready + if not is_launch_as_server(): + return + DjangoAgentApp._DJANGO_AGENT = create_agent() t = threading.Thread(name="forest.wait_and_launch_agent", target=self._wait_for_all_apps_ready_and_launch_agent) t.start() diff --git a/src/rpc_common/README.md b/src/rpc_common/README.md new file mode 100644 index 000000000..0b4778bff --- /dev/null +++ b/src/rpc_common/README.md @@ -0,0 +1 @@ +python -m grpc_tools.protoc -I./ --python_out=. --pyi_out=. --grpc_python_out=. ./proto/*.proto \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/hmac.py b/src/rpc_common/forestadmin/rpc_common/hmac.py new file mode 100644 index 000000000..bde00b433 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/hmac.py @@ -0,0 +1,12 @@ +import hashlib +import hmac + + +def generate_hmac(key: bytes, message: bytes) -> str: + # TODO: handle HMAC like ruby agent + return hmac.new(key, message, hashlib.sha256).hexdigest() + + +def is_valid_hmac(key: bytes, message: bytes, sign: bytes) -> bool: + expected_hmac = generate_hmac(key, message).encode("utf-8") + return hmac.compare_digest(expected_hmac, sign) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/__init__.py b/src/rpc_common/forestadmin/rpc_common/serializers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py new file mode 100644 index 000000000..585c61071 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py @@ -0,0 +1,304 @@ +from base64 import b64decode, b64encode +from io import IOBase + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.interfaces.actions import ( + ActionFieldType, + ActionResult, + ActionsScope, + ErrorResult, + File, + FileResult, + RedirectResult, + SuccessResult, + WebHookResult, +) +from forestadmin.rpc_common.serializers.utils import enum_to_str_or_value + + +class ActionSerializer: + @staticmethod + async def serialize(action_name: str, collection: Collection) -> dict: + action = collection.schema["actions"][action_name] + if not action.static_form: + form = None + else: + # TODO: a dummy field here ?? + form = await collection.get_form(None, action_name, None, None) # type:ignore + # fields, layout = SchemaActionGenerator.extract_fields_and_layout(form) + # fields = [ + # await SchemaActionGenerator.build_field_schema(collection.datasource, field) for field in fields + # ] + + return { + "scope": action.scope.value.lower(), + "is_generate_file": action.generate_file or False, + "static_form": action.static_form or False, + "description": action.description, + "submit_button_label": action.submit_button_label, + "form": ActionFormSerializer.serialize(form) if form is not None else [], + "execute": {}, # TODO: I'm pretty sure this is not needed + } + + @staticmethod + def deserialize(action: dict) -> dict: + return { + "scope": ActionsScope(action["scope"].capitalize()), + "generate_file": action["is_generate_file"], + "static_form": action["static_form"], + "description": action["description"], + "submit_button_label": action["submit_button_label"], + "form": ActionFormSerializer.deserialize(action["form"]) if action["form"] is not None else [], + } + + +class ActionFormSerializer: + @staticmethod + def serialize(form) -> list[dict]: + serialized_form = [] + + for field in form: + tmp_field = {} + if field["type"] == ActionFieldType.LAYOUT: + if field["component"] == "Page": + tmp_field = { + **field, + "type": "Layout", + "elements": ActionFormSerializer.serialize(field["elements"]), + } + + if field["component"] == "Row": + tmp_field = { + **field, + "type": "Layout", + "fields": ActionFormSerializer.serialize(field["fields"]), + } + else: + tmp_field = { + **{k: v for k, v in field.items()}, + "type": enum_to_str_or_value(field["type"]), + } + if field["type"] == ActionFieldType.FILE: + tmp_field.update(ActionFormSerializer._serialize_file_field(tmp_field)) + if field["type"] == ActionFieldType.FILE_LIST: + tmp_field.update(ActionFormSerializer._serialize_file_list_field(tmp_field)) + + if "if_" in tmp_field: + tmp_field["if_condition"] = tmp_field["if_"] + del tmp_field["if_"] + serialized_form.append(tmp_field) + + return serialized_form + + @staticmethod + def _serialize_file_list_field(field: dict) -> dict: + ret = {} + if field.get("defaultValue"): + ret["defaultValue"] = [ActionFormSerializer._serialize_file_obj(v) for v in field["defaultValue"]] + if field.get("value"): + ret["value"] = [ActionFormSerializer._serialize_file_obj(v) for v in field["value"]] + return ret + + @staticmethod + def _serialize_file_field(field: dict) -> dict: + ret = {} + if field.get("defaultValue"): + ret["defaultValue"] = ActionFormSerializer._serialize_file_obj(field["defaultValue"]) + if field.get("value"): + ret["value"] = ActionFormSerializer._serialize_file_obj(field["value"]) + + @staticmethod + def _serialize_file_obj(f: File) -> dict: + return { + "mimeType": f.mime_type, + "name": f.name, + "stream": b64encode(f.buffer).decode("utf-8"), + "charset": f.charset, + } + + @staticmethod + def _deserialize_file_obj(f: dict) -> File: + return File( + mime_type=f["mimeType"], + name=f["name"], + buffer=b64decode(f["stream"].encode("utf-8")), + charset=f["charset"], + ) + + @staticmethod + def _deserialize_file_field(field: dict) -> dict: + ret = {} + if field.get("default_value"): + ret["default_value"] = ActionFormSerializer._deserialize_file_obj(field["default_value"]) + if field.get("value"): + ret["value"] = ActionFormSerializer._deserialize_file_obj(field["value"]) + return ret + + @staticmethod + def _deserialize_file_list_field(field: dict) -> dict: + ret = {} + if field.get("default_value"): + ret["default_value"] = [ActionFormSerializer._deserialize_file_obj(v) for v in field["default_value"]] + if field.get("value"): + ret["value"] = [ActionFormSerializer._deserialize_file_obj(v) for v in field["value"]] + return ret + + @staticmethod + def deserialize(form: list) -> list[dict]: + deserialized_form = [] + + for field in form: + tmp_field = {} + if field["type"] == "Layout": + if field["component"] == "Page": + tmp_field = { + **field, + "type": ActionFieldType("Layout"), + "elements": ActionFormSerializer.deserialize(field["elements"]), + } + + if field["component"] == "Row": + tmp_field = { + **field, + "type": ActionFieldType("Layout"), + "fields": ActionFormSerializer.deserialize(field["fields"]), + } + else: + tmp_field = { + **{k: v for k, v in field.items()}, + "type": ActionFieldType(field["type"]), + } + if tmp_field["type"] == ActionFieldType.FILE: + tmp_field.update(ActionFormSerializer._deserialize_file_field(tmp_field)) + if tmp_field["type"] == ActionFieldType.FILE_LIST: + tmp_field.update(ActionFormSerializer._deserialize_file_list_field(tmp_field)) + + if "if_condition" in tmp_field: + tmp_field["if_"] = tmp_field["if_condition"] + del tmp_field["if_condition"] + deserialized_form.append(tmp_field) + + return deserialized_form + + +class ActionFormValuesSerializer: + @staticmethod + def serialize(form_values: dict) -> dict: + if form_values is None: + return {} + ret = {} + for key, value in form_values.items(): + if isinstance(value, File): + ret[key] = ActionFormSerializer._serialize_file_obj(value) + elif isinstance(value, list) and all(isinstance(v, File) for v in value): + ret[key] = [ActionFormSerializer._serialize_file_obj(v) for v in value] + else: + ret[key] = value + return ret + + @staticmethod + def deserialize(form_values: dict) -> dict: + ret = {} + for key, value in form_values.items(): + if isinstance(value, dict) and "mimeType" in value: + ret[key] = ActionFormSerializer._deserialize_file_obj(value) + elif isinstance(value, list) and all(isinstance(v, dict) and "mimeType" in v for v in value): + ret[key] = [ActionFormSerializer._deserialize_file_obj(v) for v in value] + else: + ret[key] = value + return ret + + +class ActionResultSerializer: + @staticmethod + def serialize(action_result: dict) -> dict: + ret = { + "type": action_result["type"], + "response_headers": action_result["response_headers"], + } + match action_result["type"]: + case "Success": + ret.update( + { + "message": action_result["message"], + "format": action_result["format"], + "invalidated": list(action_result["invalidated"]), + } + ) + case "Error": + ret.update({"message": action_result["message"], "format": action_result["format"]}) + case "Webhook": + ret.update( + { + "url": action_result["url"], + "method": action_result["method"], + "headers": action_result["headers"], + "body": action_result["body"], + } + ) + case "File": + b64_encoded = False + content = action_result["stream"] + if isinstance(content, IOBase): + content = content.read() + if isinstance(content, bytes): + content = b64encode(content).decode("utf-8") + b64_encoded = True + ret.update( + { + "mimeType": action_result["mimeType"], + "name": action_result["name"], + "stream": content, + "b64Encoded": b64_encoded, + } + ) + case "Redirect": + ret.update({"path": action_result["path"]}) + return ret + + @staticmethod + def deserialize(action_result: dict) -> ActionResult: + match action_result["type"]: + case "Success": + return SuccessResult( + type="Success", + message=action_result["message"], + format=action_result["format"], + invalidated=set(action_result["invalidated"]), + response_headers=action_result["response_headers"], + ) + case "Error": + return ErrorResult( + type="Error", + message=action_result["message"], + format=action_result["format"], + response_headers=action_result["response_headers"], + ) + case "Webhook": + return WebHookResult( + type="Webhook", + url=action_result["url"], + method=action_result["method"], + headers=action_result["headers"], + body=action_result["body"], + response_headers=action_result["response_headers"], + ) + case "File": + stream = action_result["stream"] + if action_result["b64Encoded"]: + stream = b64decode(stream.encode("utf-8")) + return FileResult( + type="File", + mimeType=action_result["mimeType"], + name=action_result["name"], + stream=stream, + response_headers=action_result["response_headers"], + ) + case "Redirect": + return RedirectResult( + type="Redirect", + path=action_result["path"], + response_headers=action_result["response_headers"], + ) + case _: + raise ValueError(f"Invalid action result type: {action_result['type']}") diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/__init__.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/aggregation.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/aggregation.py new file mode 100644 index 000000000..4c81f4e87 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/aggregation.py @@ -0,0 +1,63 @@ +from typing import Dict + +from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation, Aggregator, DateOperation + + +class AggregatorSerializer: + @staticmethod + def serialize(aggregator: Aggregator) -> str: + if isinstance(aggregator, Aggregator): + return aggregator.value + return aggregator + + @staticmethod + def deserialize(aggregator: str) -> Aggregator: + return Aggregator(aggregator) + + +class DateOperationSerializer: + @staticmethod + def serialize(date_operation: DateOperation) -> str: + return date_operation.value + + @staticmethod + def deserialize(date_operation: str) -> DateOperation: + return DateOperation(date_operation) + + +class AggregationSerializer: + @staticmethod + def serialize(aggregation: Aggregation) -> Dict: + return { + "field": aggregation.field, + "operation": AggregatorSerializer.serialize(aggregation.operation), + "groups": [ + { + "field": group["field"], + "operation": ( + DateOperationSerializer.serialize(group["operation"]) if "operation" in group else None + ), + } + for group in aggregation.groups + ], + } + + @staticmethod + def deserialize(aggregation: Dict) -> Aggregation: + groups = [] + for group in aggregation["groups"]: + tmp = { + "field": group["field"], + } + if group["operation"] is not None: + tmp["operation"] = DateOperationSerializer.deserialize(group["operation"]) + + groups.append(tmp) + + return Aggregation( + { + "field": aggregation["field"], + "operation": AggregatorSerializer.deserialize(aggregation["operation"]), + "groups": groups, + } + ) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py new file mode 100644 index 000000000..29ffc7d95 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py @@ -0,0 +1,148 @@ +from typing import Any, Dict, List + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf +from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter +from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter +from forestadmin.datasource_toolkit.interfaces.query.page import Page +from forestadmin.datasource_toolkit.interfaces.query.projections import Projection +from forestadmin.datasource_toolkit.interfaces.query.sort import Sort +from forestadmin.datasource_toolkit.utils.collections import CollectionUtils +from forestadmin.rpc_common.serializers.collection.record import RecordSerializer +from forestadmin.rpc_common.serializers.utils import OperatorSerializer, TimezoneSerializer, camel_to_snake_case + + +class ConditionTreeSerializer: + @staticmethod + def serialize(condition_tree: ConditionTree, collection: Collection) -> Dict: + if isinstance(condition_tree, ConditionTreeBranch): + return { + "aggregator": condition_tree.aggregator.value, + "conditions": [ + ConditionTreeSerializer.serialize(condition, collection) for condition in condition_tree.conditions + ], + } + elif isinstance(condition_tree, ConditionTreeLeaf): + return { + "operator": OperatorSerializer.serialize(condition_tree.operator), + "field": condition_tree.field, + "value": ConditionTreeSerializer._serialize_value( + condition_tree.value, collection, condition_tree.field + ), + } + raise ValueError(f"Unknown condition tree type: {type(condition_tree)}") + + @staticmethod + def _serialize_value(value: Any, collection: Collection, field_name: str) -> Any: + if value is None: + return None + field_schema = CollectionUtils.get_field_schema(collection, field_name) #  type:ignore + type_ = field_schema["column_type"] + + if type_ == PrimitiveType.DATE: + ret = value.isoformat() + elif type_ == PrimitiveType.DATE_ONLY: + ret = value.isoformat() + elif type_ == PrimitiveType.DATE: + ret = value.isoformat() + elif type_ == PrimitiveType.POINT: + ret = (value[0], value[1]) + elif type_ == PrimitiveType.UUID: + ret = str(value) + else: + ret = value + + return ret + + @staticmethod + def deserialize(condition_tree: Dict, collection: Collection) -> ConditionTree: + if "aggregator" in condition_tree: + return ConditionTreeBranch( + aggregator=condition_tree["aggregator"], + conditions=[ + ConditionTreeSerializer.deserialize(condition, collection) + for condition in condition_tree["conditions"] + ], + ) + elif "operator" in condition_tree: + return ConditionTreeLeaf( + operator=camel_to_snake_case(condition_tree["operator"]), # type:ignore + field=condition_tree["field"], + value=ConditionTreeSerializer._deserialize_value( + condition_tree.get("value"), collection, condition_tree["field"] + ), + ) + raise ValueError(f"Unknown condition tree type: {condition_tree.keys()}") + + @staticmethod + def _deserialize_value(value: Any, collection: Collection, field_name: str) -> Any: + if value is None: + return None + field_schema = CollectionUtils.get_field_schema(collection, field_name) #  type:ignore + return RecordSerializer._deserialize_primitive_type(field_schema["column_type"], value) #  type:ignore + + +class FilterSerializer: + @staticmethod + def serialize(filter_: Filter, collection: Collection) -> Dict: + return { + "condition_tree": ( + ConditionTreeSerializer.serialize(filter_.condition_tree, collection) + if filter_.condition_tree is not None + else None + ), + "search": filter_.search, + "search_extended": filter_.search_extended, + "segment": filter_.segment, + "timezone": TimezoneSerializer.serialize(filter_.timezone), + } + + @staticmethod + def deserialize(filter_: Dict, collection: Collection) -> Filter: + return Filter( + { + "condition_tree": ( + ConditionTreeSerializer.deserialize(filter_["condition_tree"], collection) + if filter_.get("condition_tree") is not None + else None + ), # type: ignore + "search": filter_["search"], + "search_extended": filter_["search_extended"], + "segment": filter_["segment"], + "timezone": TimezoneSerializer.deserialize(filter_["timezone"]), + } + ) + + +class PaginatedFilterSerializer: + @staticmethod + def serialize(filter_: PaginatedFilter, collection: Collection) -> Dict: + ret = FilterSerializer.serialize(filter_.to_base_filter(), collection) + ret["sort"] = filter_.sort + ret["page"] = {"skip": filter_.page.skip, "limit": filter_.page.limit} if filter_.page else None + return ret + + @staticmethod + def deserialize(serialized_filter: Dict, collection: Collection) -> PaginatedFilter: + ret = FilterSerializer.deserialize(serialized_filter, collection) + ret = PaginatedFilter.from_base_filter(ret) + ret.sort = Sort(serialized_filter["sort"]) if serialized_filter.get("sort") is not None else None # type: ignore + ret.page = ( + Page(skip=serialized_filter["page"]["skip"], limit=serialized_filter["page"]["limit"]) + if serialized_filter.get("page") is not None + else None + ) #  type: ignore + return ret + + +class ProjectionSerializer: + @staticmethod + def serialize(projection: Projection) -> List[str]: + return [*projection] + + @staticmethod + def deserialize(projection: List[str]) -> Projection: + return Projection(*projection) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/record.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/record.py new file mode 100644 index 000000000..2a6559f35 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/record.py @@ -0,0 +1,115 @@ +from datetime import date, datetime +from typing import Any, Dict, cast +from uuid import UUID + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.interfaces.fields import ( + Column, + PrimitiveType, + RelationAlias, + is_column, + is_many_to_many, + is_many_to_one, + is_one_to_many, + is_one_to_one, + is_polymorphic_many_to_one, + is_polymorphic_one_to_many, + is_polymorphic_one_to_one, +) +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias + + +class RecordSerializer: + @staticmethod + def serialize(record: RecordsDataAlias) -> Dict: + serialized_record = {} + + for key, value in record.items(): + if isinstance(value, dict): + serialized_record[key] = RecordSerializer.serialize(value) + elif isinstance(value, set) or isinstance(value, list): + serialized_record[key] = [RecordSerializer.serialize(v) for v in value] + elif isinstance(value, date): + serialized_record[key] = value.isoformat() + # elif isinstance(value, datetime): + # serialized_record[key] = value.isoformat() + elif isinstance(value, UUID): + serialized_record[key] = str(value) + else: + serialized_record[key] = value + + return serialized_record + + @staticmethod + def deserialize(record: Dict, collection: Collection) -> RecordsDataAlias: + deserialized_record = {} + collection_schema = collection.schema + + for key, value in record.items(): + if is_column(collection_schema["fields"][key]): + deserialized_record[key] = RecordSerializer._deserialize_column(collection, key, value) + else: + deserialized_record[key] = RecordSerializer._deserialize_relation(collection, key, value, record) + return deserialized_record + + @staticmethod + def _deserialize_relation(collection: Collection, field_name: str, value: Any, record: Dict) -> RecordsDataAlias: + schema_field = cast(RelationAlias, collection.schema["fields"][field_name]) + if is_many_to_one(schema_field) or is_one_to_one(schema_field) or is_polymorphic_one_to_one(schema_field): + return RecordSerializer.deserialize( + value, collection.datasource.get_collection(schema_field["foreign_collection"]) + ) + elif is_many_to_many(schema_field) or is_one_to_many(schema_field) or is_polymorphic_one_to_many(schema_field): + return [ + RecordSerializer.deserialize( + v, collection.datasource.get_collection(schema_field["foreign_collection"]) + ) + for v in value + ] + elif is_polymorphic_many_to_one(schema_field): + if schema_field["foreign_key_type_field"] not in record: + raise ValueError("Cannot deserialize polymorphic many to one relation without foreign key type field") + return RecordSerializer.deserialize( + value, collection.datasource.get_collection(record[schema_field["foreign_key_type_field"]]) + ) + else: + raise ValueError(f"Unknown field type: {schema_field}") + + @staticmethod + def _deserialize_column(collection: Collection, field_name: str, value: Any) -> RecordsDataAlias: + schema_field = cast(Column, collection.schema["fields"][field_name]) + + if isinstance(schema_field["column_type"], dict) or isinstance(schema_field["column_type"], list): + return RecordSerializer._deserialize_complex_type(schema_field["column_type"], value) + elif isinstance(schema_field["column_type"], PrimitiveType): + return RecordSerializer._deserialize_primitive_type(schema_field["column_type"], value) + + @staticmethod + def _deserialize_primitive_type(type_: PrimitiveType, value: Any): + if type_ == PrimitiveType.DATE: + return datetime.fromisoformat(value) + elif type_ == PrimitiveType.DATE_ONLY: + return date.fromisoformat(value) + elif type_ == PrimitiveType.DATE: + return datetime.fromisoformat(value) + elif type_ == PrimitiveType.POINT: + return (value[0], value[1]) + elif type_ == PrimitiveType.UUID: + return UUID(value) + elif type_ == PrimitiveType.BINARY: + pass # TODO: binary + elif type_ == PrimitiveType.TIME_ONLY: + pass # TODO: time only + elif isinstance(type_, PrimitiveType): + return value + else: + raise ValueError(f"Unknown primitive type: {type_}") + + @staticmethod + def _deserialize_complex_type(type_, value): + if isinstance(type_, list): + return [RecordSerializer._deserialize_complex_type(type_[0], v) for v in value] + elif isinstance(type_, dict): + return {k: RecordSerializer._deserialize_complex_type(v, value[k]) for k, v in type_.items()} + + return value diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/schema/__init__.py b/src/rpc_common/forestadmin/rpc_common/serializers/schema/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py new file mode 100644 index 000000000..e5806cf05 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py @@ -0,0 +1,262 @@ +import json +from enum import Enum +from typing import Any, Union, cast + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.interfaces.fields import ( + Column, + FieldAlias, + FieldType, + PrimitiveType, + RelationAlias, + is_column, + is_many_to_many, + is_many_to_one, + is_one_to_many, + is_one_to_one, + is_polymorphic_many_to_one, + is_polymorphic_one_to_many, + is_polymorphic_one_to_one, +) +from forestadmin.rpc_common.serializers.actions import ActionSerializer +from forestadmin.rpc_common.serializers.utils import OperatorSerializer, enum_to_str_or_value + + +def serialize_column_type(column_type: Union[Enum, Any]) -> Union[str, dict, list]: + if isinstance(column_type, list): + return [serialize_column_type(t) for t in column_type] + + if isinstance(column_type, dict): + return {k: serialize_column_type(v) for k, v in column_type.items()} + + return enum_to_str_or_value(column_type) + + +def deserialize_column_type(column_type: Union[list, dict, str]) -> Union[PrimitiveType, dict, list]: + if isinstance(column_type, list): + return [deserialize_column_type(t) for t in column_type] + + if isinstance(column_type, dict): + return {k: deserialize_column_type(v) for k, v in column_type.items()} + + return PrimitiveType(column_type) + + +class SchemaSerializer: + VERSION = "1.0" + + def __init__(self, datasource: Datasource) -> None: + self.datasource = datasource + + async def serialize(self) -> dict: + value = ( + { + "charts": sorted(self.datasource.schema["charts"].keys()), + "live_query_connections": sorted(self.datasource.get_native_query_connections()), + "collections": [ + await self._serialize_collection(c) + for c in sorted(self.datasource.collections, key=lambda c: c.name) + ], + }, + ) + return value + + async def _serialize_collection(self, collection: Collection) -> dict: + return { + "name": collection.name, + "fields": { + field: await self._serialize_field(collection.schema["fields"][field]) + for field in sorted(collection.schema["fields"].keys()) + }, + "actions": { + action_name: await ActionSerializer.serialize(action_name, collection) + for action_name in sorted(collection.schema["actions"].keys()) + }, + "segments": sorted(collection.schema["segments"]), + "countable": collection.schema["countable"], + "searchable": collection.schema["searchable"], + "charts": sorted(collection.schema["charts"].keys()), + } + + async def _serialize_field(self, field: FieldAlias) -> dict: + if is_column(field): + return await self._serialize_column(field) + else: + return await self._serialize_relation(cast(RelationAlias, field)) + + async def _serialize_column(self, column: Column) -> dict: + return { + "type": "Column", + "column_type": (serialize_column_type(column["column_type"])), + "filter_operators": sorted([OperatorSerializer.serialize(op) for op in column.get("filter_operators", [])]), + "default_value": column.get("default_value"), + "enum_values": column.get("enum_values"), + "is_primary_key": column.get("is_primary_key", False), + "is_read_only": column.get("is_read_only", False), + "is_sortable": column.get("is_sortable", False), + "validations": [ + { + "operator": OperatorSerializer.serialize(v["operator"]), + "value": v.get("value"), + } + for v in sorted( + column.get("validations", []), + key=lambda v: f'{OperatorSerializer.serialize(v["operator"])}_{str(v.get("value"))}', + ) + ], + } + + async def _serialize_relation(self, relation: RelationAlias) -> dict: + serialized_field = {} + if is_polymorphic_many_to_one(relation): + serialized_field.update( + { + "type": "PolymorphicManyToOne", + "foreign_collections": relation["foreign_collections"], + "foreign_key": relation["foreign_key"], + "foreign_key_type_field": relation["foreign_key_type_field"], + "foreign_key_targets": relation["foreign_key_targets"], + } + ) + elif is_many_to_one(relation): + serialized_field.update( + { + "type": "ManyToOne", + "foreign_collection": relation["foreign_collection"], + "foreign_key": relation["foreign_key"], + "foreign_key_target": relation["foreign_key_target"], + } + ) + elif is_one_to_one(relation) or is_one_to_many(relation): + serialized_field.update( + { + "type": "OneToMany" if is_one_to_many(relation) else "OneToOne", + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + } + ) + elif is_polymorphic_one_to_one(relation) or is_polymorphic_one_to_many(relation): + serialized_field.update( + { + "type": "PolymorphicOneToMany" if is_polymorphic_one_to_many(relation) else "PolymorphicOneToOne", + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "origin_type_field": relation["origin_type_field"], + "origin_type_value": relation["origin_type_value"], + } + ) + elif is_many_to_many(relation): + serialized_field.update( + { + "type": "ManyToMany", + "foreign_collections": relation["foreign_collection"], + "foreign_key": relation["foreign_key"], + "foreign_key_targets": relation["foreign_key_target"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "through_collection": relation["through_collection"], + } + ) + return serialized_field + + +class SchemaDeserializer: + VERSION = "1.0" + + def deserialize(self, schema) -> dict: + return { + "charts": schema["charts"], + # "live_query_connections": schema["data"]["live_query_connections"], + "collections": { + collection["name"]: self._deserialize_collection(collection) for collection in schema["collections"] + }, + } + + def _deserialize_collection(self, collection: dict) -> dict: + return { + "fields": {field: self._deserialize_field(collection["fields"][field]) for field in collection["fields"]}, + "actions": { + action_name: ActionSerializer.deserialize(action) + for action_name, action in collection["actions"].items() + }, + "segments": collection["segments"], + "countable": collection["countable"], + "searchable": collection["searchable"], + "charts": collection["charts"], + } + + def _deserialize_field(self, field: dict) -> dict: + if field["type"] == "Column": + return self._deserialize_column(field) + else: + return self._deserialize_relation(field) + + def _deserialize_column(self, column: dict) -> dict: + return { + "type": FieldType.COLUMN, + "column_type": deserialize_column_type(column["column_type"]), + "filter_operators": [OperatorSerializer.deserialize(op) for op in column["filter_operators"]], + "default_value": column.get("default_value"), + "enum_values": column.get("enum_values"), + "is_primary_key": column.get("is_primary_key", False), + "is_read_only": column.get("is_read_only", False), + "is_sortable": column.get("is_sortable", False), + "validations": [ + { + "operator": OperatorSerializer.deserialize(op["operator"]), + "value": op.get("value"), + } + for op in column.get("validations", []) + ], + } + + def _deserialize_relation(self, relation: dict) -> dict: + if relation["type"] == "PolymorphicManyToOne": + return { + "type": FieldType("PolymorphicManyToOne"), + "foreign_collections": relation["foreign_collections"], + "foreign_key": relation["foreign_key"], + "foreign_key_type_field": relation["foreign_key_type_field"], + "foreign_key_targets": relation["foreign_key_targets"], + } + elif relation["type"] == "ManyToOne": + return { + "type": FieldType("ManyToOne"), + "foreign_collection": relation["foreign_collection"], + "foreign_key": relation["foreign_key"], + "foreign_key_target": relation["foreign_key_target"], + } + elif relation["type"] in ["OneToMany", "OneToOne"]: + return { + "type": FieldType("OneToOne") if relation["type"] == "OneToOne" else FieldType("OneToMany"), + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + } + elif relation["type"] in ["PolymorphicOneToMany", "PolymorphicOneToOne"]: + return { + "type": ( + FieldType("PolymorphicOneToOne") + if relation["type"] == "PolymorphicOneToOne" + else FieldType("PolymorphicOneToMany") + ), + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "origin_type_field": relation["origin_type_field"], + "origin_type_value": relation["origin_type_value"], + } + elif relation["type"] == "ManyToMany": + return { + "type": FieldType("ManyToMany"), + "foreign_collection": relation["foreign_collections"], + "foreign_key": relation["foreign_key"], + "foreign_key_target": relation["foreign_key_targets"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "through_collection": relation["through_collection"], + } + raise ValueError(f"Unsupported relation type {relation['type']}") diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/utils.py b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py new file mode 100644 index 000000000..ee3a1a630 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py @@ -0,0 +1,134 @@ +import sys +from enum import Enum +from typing import Any, Dict, Optional, Union + +if sys.version_info >= (3, 9): + import zoneinfo +else: + from backports import zoneinfo + +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.interfaces.fields import Operator + + +def snake_to_camel_case(snake_str: str) -> str: + ret = "".join(word.title() for word in snake_str.split("_")) + ret = f"{ret[0].lower()}{ret[1:]}" + return ret + + +def camel_to_snake_case(camel_str: str) -> str: + ret = "".join(f"_{c.lower()}" if c.isupper() else c for c in camel_str) + return ret + + +def enum_to_str_or_value(value: Union[Enum, Any]): + return value.value if isinstance(value, Enum) else value + + +class OperatorSerializer: + # OPERATOR_MAPPING: Dict[str, str] = { + # "present": "present", + # "blank": "blank", + # "missing": "missing", + # "equal": "equal", + # "not_equal": "notEqual", + # "less_than": "lessThan", + # "greater_than": "greaterThan", + # "in": "in", + # "not_in": "notIn", + # "like": "like", + # "starts_with": "startsWith", + # "ends_with": "endsWith", + # "contains": "contains", + # "match": "match", + # "not_contains": "notContains", + # "longer_than": "longerThan", + # "shorter_than": "shorterThan", + # "before": "before", + # "after": "after", + # "after_x_hours_ago": "afterXHoursAgo", + # "before_x_hours_ago": "beforeXHoursAgo", + # "future": "future", + # "past": "past", + # "previous_month_to_date": "previousMonthToDate", + # "previous_month": "previousMonth", + # "previous_quarter_to_date": "previousQuarterToDate", + # "previous_quarter": "previousQuarter", + # "previous_week_to_date": "previousWeekToDate", + # "previous_week": "previousWeek", + # "previous_x_days_to_date": "previousXDaysToDate", + # "previous_x_days": "previousXDays", + # "previous_year_to_date": "previousYearToDate", + # "previous_year": "previousYear", + # "today": "today", + # "yesterday": "yesterday", + # "includes_all": "includesAll", + # } + + @staticmethod + def deserialize(operator: str) -> Operator: + _operator = operator + if operator.startswith("i_"): + _operator = operator[2:] + return Operator(_operator) + # return Operator(camel_to_snake_case(operator)) + # for value, serialized in OperatorSerializer.OPERATOR_MAPPING.items(): + # if serialized == operator: + # return Operator(value) + # raise ValueError(f"Unknown operator: {operator}") + + @staticmethod + def serialize(operator: Union[Operator, str]) -> str: + value = operator + if isinstance(value, Enum): + value = value.value + return value + # return snake_to_camel_case(value) + # return OperatorSerializer.OPERATOR_MAPPING[value] + + +class TimezoneSerializer: + @staticmethod + def serialize(timezone: Optional[zoneinfo.ZoneInfo]) -> Optional[str]: + if timezone is None: + return None + return str(timezone) + + @staticmethod + def deserialize(timezone: Optional[str]) -> Optional[zoneinfo.ZoneInfo]: + if timezone is None: + return None + return zoneinfo.ZoneInfo(timezone) + + +class CallerSerializer: + @staticmethod + def serialize(caller: User) -> Dict: + return { + "rendering_id": caller.rendering_id, + "user_id": caller.user_id, + "tags": caller.tags, + "email": caller.email, + "first_name": caller.first_name, + "last_name": caller.last_name, + "team": caller.team, + "timezone": TimezoneSerializer.serialize(caller.timezone), + "request": {"ip": caller.request["ip"]}, + } + + @staticmethod + def deserialize(caller: Dict) -> User: + if caller is None: + return None + return User( + rendering_id=caller["rendering_id"], + user_id=caller["user_id"], + tags=caller["tags"], + email=caller["email"], + first_name=caller["first_name"], + last_name=caller["last_name"], + team=caller["team"], + timezone=TimezoneSerializer.deserialize(caller["timezone"]), # type:ignore + request={"ip": caller["request"]["ip"]}, + ) diff --git a/src/rpc_common/pyproject.toml b/src/rpc_common/pyproject.toml new file mode 100644 index 000000000..2e292c684 --- /dev/null +++ b/src/rpc_common/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = [ "poetry-core",] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "forestadmin-rpc-common" +description = "common files for rpc server and client" +version = "1.23.0" +authors = [ "Julien Barreau ",] +readme = "README.md" +repository = "https://github.com/ForestAdmin/agent-python" +documentation = "https://docs.forestadmin.com/developer-guide-agents-python/" +homepage = "https://www.forestadmin.com" +[[tool.poetry.packages]] +include = "forestadmin" + +[tool.poetry.dependencies] +python = ">=3.8,<3.14"