From 26a717c800c343dcd6d1e81ccb986f1b6d37baec Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 12 Mar 2025 10:08:16 +0100 Subject: [PATCH 01/34] chore: add proto files --- src/rpc_common/README.md | 1 + .../rpc_common/proto/collection.proto | 26 +++ .../rpc_common/proto/collection_pb2.py | 43 ++++ .../rpc_common/proto/collection_pb2.pyi | 27 +++ .../rpc_common/proto/collection_pb2_grpc.py | 143 ++++++++++++++ .../rpc_common/proto/datasource.proto | 30 +++ .../rpc_common/proto/datasource_pb2.py | 49 +++++ .../rpc_common/proto/datasource_pb2.pyi | 42 ++++ .../rpc_common/proto/datasource_pb2_grpc.py | 186 ++++++++++++++++++ .../forestadmin/rpc_common/proto/forest.proto | 49 +++++ .../rpc_common/proto/forest_pb2.py | 56 ++++++ .../rpc_common/proto/forest_pb2.pyi | 82 ++++++++ .../rpc_common/proto/forest_pb2_grpc.py | 24 +++ src/rpc_common/pyproject.toml | 24 +++ 14 files changed, 782 insertions(+) create mode 100644 src/rpc_common/README.md create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection.proto create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource.proto create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest.proto create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi create mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py create mode 100644 src/rpc_common/pyproject.toml diff --git a/src/rpc_common/README.md b/src/rpc_common/README.md new file mode 100644 index 00000000..0b4778bf --- /dev/null +++ b/src/rpc_common/README.md @@ -0,0 +1 @@ +python -m grpc_tools.protoc -I./ --python_out=. --pyi_out=. --grpc_python_out=. ./proto/*.proto \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection.proto b/src/rpc_common/forestadmin/rpc_common/proto/collection.proto new file mode 100644 index 00000000..5dc03bd6 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/collection.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package datasource_rpc; +import "google/protobuf/any.proto"; +import "google/protobuf/struct.proto"; +import "forestadmin/rpc_common/proto/forest.proto"; + + + +service Collection { + // Define your RPC methods here + rpc RenderChart (CollectionChartRequest) returns (google.protobuf.Any); + // crud + rpc Create (CreateRequest) returns (google.protobuf.Any); +} + +message CollectionChartRequest{ + Caller caller = 1; + string ChartName = 2; + google.protobuf.Any RecordId = 3; +} + +message CreateRequest{ + Caller Caller = 1; + repeated google.protobuf.Struct RecordData = 2; +} \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py new file mode 100644 index 00000000..ddd04935 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: forestadmin/rpc_common/proto/collection.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'forestadmin/rpc_common/proto/collection.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 +from forestadmin.rpc_common.proto import forest_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_forest__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n-forestadmin/rpc_common/proto/collection.proto\x12\x0e\x64\x61tasource_rpc\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a)forestadmin/rpc_common/proto/forest.proto\"{\n\x16\x43ollectionChartRequest\x12&\n\x06\x63\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12\x11\n\tChartName\x18\x02 \x01(\t\x12&\n\x08RecordId\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\"d\n\rCreateRequest\x12&\n\x06\x43\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12+\n\nRecordData\x18\x02 \x03(\x0b\x32\x17.google.protobuf.Struct2\x98\x01\n\nCollection\x12K\n\x0bRenderChart\x12&.datasource_rpc.CollectionChartRequest\x1a\x14.google.protobuf.Any\x12=\n\x06\x43reate\x12\x1d.datasource_rpc.CreateRequest\x1a\x14.google.protobuf.Anyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'forestadmin.rpc_common.proto.collection_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_COLLECTIONCHARTREQUEST']._serialized_start=165 + _globals['_COLLECTIONCHARTREQUEST']._serialized_end=288 + _globals['_CREATEREQUEST']._serialized_start=290 + _globals['_CREATEREQUEST']._serialized_end=390 + _globals['_COLLECTION']._serialized_start=393 + _globals['_COLLECTION']._serialized_end=545 +# @@protoc_insertion_point(module_scope) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi new file mode 100644 index 00000000..93b6d5b3 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi @@ -0,0 +1,27 @@ +from google.protobuf import any_pb2 as _any_pb2 +from google.protobuf import struct_pb2 as _struct_pb2 +from forestadmin.rpc_common.proto import forest_pb2 as _forest_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class CollectionChartRequest(_message.Message): + __slots__ = ("caller", "ChartName", "RecordId") + CALLER_FIELD_NUMBER: _ClassVar[int] + CHARTNAME_FIELD_NUMBER: _ClassVar[int] + RECORDID_FIELD_NUMBER: _ClassVar[int] + caller: _forest_pb2.Caller + ChartName: str + RecordId: _any_pb2.Any + def __init__(self, caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., ChartName: _Optional[str] = ..., RecordId: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...) -> None: ... + +class CreateRequest(_message.Message): + __slots__ = ("Caller", "RecordData") + CALLER_FIELD_NUMBER: _ClassVar[int] + RECORDDATA_FIELD_NUMBER: _ClassVar[int] + Caller: _forest_pb2.Caller + RecordData: _containers.RepeatedCompositeFieldContainer[_struct_pb2.Struct] + def __init__(self, Caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., RecordData: _Optional[_Iterable[_Union[_struct_pb2.Struct, _Mapping]]] = ...) -> None: ... diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py new file mode 100644 index 00000000..a2c8f626 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py @@ -0,0 +1,143 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from forestadmin.rpc_common.proto import collection_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + +GRPC_GENERATED_VERSION = '1.70.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in forestadmin/rpc_common/proto/collection_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class CollectionStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.RenderChart = channel.unary_unary( + '/datasource_rpc.Collection/RenderChart', + request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CollectionChartRequest.SerializeToString, + response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, + _registered_method=True) + self.Create = channel.unary_unary( + '/datasource_rpc.Collection/Create', + request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CreateRequest.SerializeToString, + response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, + _registered_method=True) + + +class CollectionServicer(object): + """Missing associated documentation comment in .proto file.""" + + def RenderChart(self, request, context): + """Define your RPC methods here + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Create(self, request, context): + """crud + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_CollectionServicer_to_server(servicer, server): + rpc_method_handlers = { + 'RenderChart': grpc.unary_unary_rpc_method_handler( + servicer.RenderChart, + request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CollectionChartRequest.FromString, + response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, + ), + 'Create': grpc.unary_unary_rpc_method_handler( + servicer.Create, + request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CreateRequest.FromString, + response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'datasource_rpc.Collection', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('datasource_rpc.Collection', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class Collection(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def RenderChart(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/datasource_rpc.Collection/RenderChart', + forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CollectionChartRequest.SerializeToString, + google_dot_protobuf_dot_any__pb2.Any.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Create(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/datasource_rpc.Collection/Create', + forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CreateRequest.SerializeToString, + google_dot_protobuf_dot_any__pb2.Any.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource.proto b/src/rpc_common/forestadmin/rpc_common/proto/datasource.proto new file mode 100644 index 00000000..a10a0fdb --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/datasource.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package datasource_rpc; +import "google/protobuf/empty.proto"; +import "google/protobuf/any.proto"; +import "forestadmin/rpc_common/proto/forest.proto"; + +service DataSource { + // Define your RPC methods here + rpc Schema (google.protobuf.Empty) returns (SchemaResponse); + rpc RenderChart (DatasourceChartRequest) returns (google.protobuf.Any); + rpc ExecuteNativeQuery (NativeQueryRequest) returns (google.protobuf.Any); +} + +message SchemaResponse{ + DataSourceSchema Schema = 1; + repeated CollectionSchema Collections = 2; +} + +message DatasourceChartRequest{ + Caller caller = 1; + string ChartName = 2; +} + + +message NativeQueryRequest{ + Caller Caller = 1; + string Query = 2; + map Parameters = 3; +} \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py new file mode 100644 index 00000000..96cf9d3f --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: forestadmin/rpc_common/proto/datasource.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'forestadmin/rpc_common/proto/datasource.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +from forestadmin.rpc_common.proto import forest_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_forest__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n-forestadmin/rpc_common/proto/datasource.proto\x12\x0e\x64\x61tasource_rpc\x1a\x1bgoogle/protobuf/empty.proto\x1a\x19google/protobuf/any.proto\x1a)forestadmin/rpc_common/proto/forest.proto\"y\n\x0eSchemaResponse\x12\x30\n\x06Schema\x18\x01 \x01(\x0b\x32 .datasource_rpc.DataSourceSchema\x12\x35\n\x0b\x43ollections\x18\x02 \x03(\x0b\x32 .datasource_rpc.CollectionSchema\"S\n\x16\x44\x61tasourceChartRequest\x12&\n\x06\x63\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12\x11\n\tChartName\x18\x02 \x01(\t\"\xdc\x01\n\x12NativeQueryRequest\x12&\n\x06\x43\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12\r\n\x05Query\x18\x02 \x01(\t\x12\x46\n\nParameters\x18\x03 \x03(\x0b\x32\x32.datasource_rpc.NativeQueryRequest.ParametersEntry\x1aG\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any:\x02\x38\x01\x32\xeb\x01\n\nDataSource\x12@\n\x06Schema\x12\x16.google.protobuf.Empty\x1a\x1e.datasource_rpc.SchemaResponse\x12K\n\x0bRenderChart\x12&.datasource_rpc.DatasourceChartRequest\x1a\x14.google.protobuf.Any\x12N\n\x12\x45xecuteNativeQuery\x12\".datasource_rpc.NativeQueryRequest\x1a\x14.google.protobuf.Anyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'forestadmin.rpc_common.proto.datasource_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._loaded_options = None + _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._serialized_options = b'8\001' + _globals['_SCHEMARESPONSE']._serialized_start=164 + _globals['_SCHEMARESPONSE']._serialized_end=285 + _globals['_DATASOURCECHARTREQUEST']._serialized_start=287 + _globals['_DATASOURCECHARTREQUEST']._serialized_end=370 + _globals['_NATIVEQUERYREQUEST']._serialized_start=373 + _globals['_NATIVEQUERYREQUEST']._serialized_end=593 + _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._serialized_start=522 + _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._serialized_end=593 + _globals['_DATASOURCE']._serialized_start=596 + _globals['_DATASOURCE']._serialized_end=831 +# @@protoc_insertion_point(module_scope) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi new file mode 100644 index 00000000..ee33ec9e --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi @@ -0,0 +1,42 @@ +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf import any_pb2 as _any_pb2 +from forestadmin.rpc_common.proto import forest_pb2 as _forest_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class SchemaResponse(_message.Message): + __slots__ = ("Schema", "Collections") + SCHEMA_FIELD_NUMBER: _ClassVar[int] + COLLECTIONS_FIELD_NUMBER: _ClassVar[int] + Schema: _forest_pb2.DataSourceSchema + Collections: _containers.RepeatedCompositeFieldContainer[_forest_pb2.CollectionSchema] + def __init__(self, Schema: _Optional[_Union[_forest_pb2.DataSourceSchema, _Mapping]] = ..., Collections: _Optional[_Iterable[_Union[_forest_pb2.CollectionSchema, _Mapping]]] = ...) -> None: ... + +class DatasourceChartRequest(_message.Message): + __slots__ = ("caller", "ChartName") + CALLER_FIELD_NUMBER: _ClassVar[int] + CHARTNAME_FIELD_NUMBER: _ClassVar[int] + caller: _forest_pb2.Caller + ChartName: str + def __init__(self, caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., ChartName: _Optional[str] = ...) -> None: ... + +class NativeQueryRequest(_message.Message): + __slots__ = ("Caller", "Query", "Parameters") + class ParametersEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: _any_pb2.Any + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...) -> None: ... + CALLER_FIELD_NUMBER: _ClassVar[int] + QUERY_FIELD_NUMBER: _ClassVar[int] + PARAMETERS_FIELD_NUMBER: _ClassVar[int] + Caller: _forest_pb2.Caller + Query: str + Parameters: _containers.MessageMap[str, _any_pb2.Any] + def __init__(self, Caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., Query: _Optional[str] = ..., Parameters: _Optional[_Mapping[str, _any_pb2.Any]] = ...) -> None: ... diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py new file mode 100644 index 00000000..f2ce6b02 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py @@ -0,0 +1,186 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from forestadmin.rpc_common.proto import datasource_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 + +GRPC_GENERATED_VERSION = '1.70.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in forestadmin/rpc_common/proto/datasource_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class DataSourceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Schema = channel.unary_unary( + '/datasource_rpc.DataSource/Schema', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.SchemaResponse.FromString, + _registered_method=True) + self.RenderChart = channel.unary_unary( + '/datasource_rpc.DataSource/RenderChart', + request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.DatasourceChartRequest.SerializeToString, + response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, + _registered_method=True) + self.ExecuteNativeQuery = channel.unary_unary( + '/datasource_rpc.DataSource/ExecuteNativeQuery', + request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.NativeQueryRequest.SerializeToString, + response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, + _registered_method=True) + + +class DataSourceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Schema(self, request, context): + """Define your RPC methods here + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RenderChart(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ExecuteNativeQuery(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_DataSourceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Schema': grpc.unary_unary_rpc_method_handler( + servicer.Schema, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.SchemaResponse.SerializeToString, + ), + 'RenderChart': grpc.unary_unary_rpc_method_handler( + servicer.RenderChart, + request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.DatasourceChartRequest.FromString, + response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, + ), + 'ExecuteNativeQuery': grpc.unary_unary_rpc_method_handler( + servicer.ExecuteNativeQuery, + request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.NativeQueryRequest.FromString, + response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'datasource_rpc.DataSource', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('datasource_rpc.DataSource', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class DataSource(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Schema(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/datasource_rpc.DataSource/Schema', + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.SchemaResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def RenderChart(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/datasource_rpc.DataSource/RenderChart', + forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.DatasourceChartRequest.SerializeToString, + google_dot_protobuf_dot_any__pb2.Any.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ExecuteNativeQuery(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/datasource_rpc.DataSource/ExecuteNativeQuery', + forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.NativeQueryRequest.SerializeToString, + google_dot_protobuf_dot_any__pb2.Any.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest.proto b/src/rpc_common/forestadmin/rpc_common/proto/forest.proto new file mode 100644 index 00000000..497415f2 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/forest.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package datasource_rpc; +import "google/protobuf/any.proto"; +import "google/protobuf/struct.proto"; + +message Caller { + int32 RendererId = 1; + int32 UserId = 2; + map Tags = 3; + string Email = 4; + string FirstName = 5; + string LastName = 6; + string Team = 7; + string Timezone = 8; + message Request{ + string Ip = 1; + } +} + + + +enum ActionScope{ + GLOBAL = 0; + SINGLE = 1; + BULK = 2; +} +// message ActionFormSchema{} +message ActionSchema{ + ActionScope scope = 1; + bool generateFile = 2; + optional string description = 3; + optional string submitButtonLabel = 3; +} + +message CollectionSchema{ + string name = 1; + map actions = 2; + map fields = 3; + repeated string charts = 4; + bool searchable = 5; + bool countable = 6; + repeated string segments = 7; +} +message DataSourceSchema { + repeated string Charts = 1; + repeated string ConnectionNames = 2; + repeated CollectionSchema Collections = 3; +} \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py new file mode 100644 index 00000000..6e6eee75 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: forestadmin/rpc_common/proto/forest.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'forestadmin/rpc_common/proto/forest.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)forestadmin/rpc_common/proto/forest.proto\x12\x0e\x64\x61tasource_rpc\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xf4\x01\n\x06\x43\x61ller\x12\x12\n\nRendererId\x18\x01 \x01(\x05\x12\x0e\n\x06UserId\x18\x02 \x01(\x05\x12.\n\x04Tags\x18\x03 \x03(\x0b\x32 .datasource_rpc.Caller.TagsEntry\x12\r\n\x05\x45mail\x18\x04 \x01(\t\x12\x11\n\tFirstName\x18\x05 \x01(\t\x12\x10\n\x08LastName\x18\x06 \x01(\t\x12\x0c\n\x04Team\x18\x07 \x01(\t\x12\x10\n\x08Timezone\x18\x08 \x01(\t\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x15\n\x07Request\x12\n\n\x02Ip\x18\x01 \x01(\t\"r\n\x10\x44\x61taSourceSchema\x12\x0e\n\x06\x43harts\x18\x01 \x03(\t\x12\x17\n\x0f\x43onnectionNames\x18\x02 \x03(\t\x12\x35\n\x0b\x43ollections\x18\x03 \x03(\x0b\x32 .datasource_rpc.CollectionSchema\"\xf8\x02\n\x10\x43ollectionSchema\x12\x0c\n\x04name\x18\x01 \x01(\t\x12>\n\x07\x61\x63tions\x18\x02 \x03(\x0b\x32-.datasource_rpc.CollectionSchema.ActionsEntry\x12<\n\x06\x66ields\x18\x03 \x03(\x0b\x32,.datasource_rpc.CollectionSchema.FieldsEntry\x12\x0e\n\x06\x63harts\x18\x04 \x03(\t\x12\x12\n\nsearchable\x18\x05 \x01(\x08\x12\x11\n\tcountable\x18\x06 \x01(\x08\x12\x10\n\x08segments\x18\x07 \x03(\t\x1aG\n\x0c\x41\x63tionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct:\x02\x38\x01\x1a\x46\n\x0b\x46ieldsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct:\x02\x38\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'forestadmin.rpc_common.proto.forest_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_CALLER_TAGSENTRY']._loaded_options = None + _globals['_CALLER_TAGSENTRY']._serialized_options = b'8\001' + _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._loaded_options = None + _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._serialized_options = b'8\001' + _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._loaded_options = None + _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._serialized_options = b'8\001' + _globals['_CALLER']._serialized_start=119 + _globals['_CALLER']._serialized_end=363 + _globals['_CALLER_TAGSENTRY']._serialized_start=297 + _globals['_CALLER_TAGSENTRY']._serialized_end=340 + _globals['_CALLER_REQUEST']._serialized_start=342 + _globals['_CALLER_REQUEST']._serialized_end=363 + _globals['_DATASOURCESCHEMA']._serialized_start=365 + _globals['_DATASOURCESCHEMA']._serialized_end=479 + _globals['_COLLECTIONSCHEMA']._serialized_start=482 + _globals['_COLLECTIONSCHEMA']._serialized_end=858 + _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._serialized_start=715 + _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._serialized_end=786 + _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._serialized_start=788 + _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._serialized_end=858 +# @@protoc_insertion_point(module_scope) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi new file mode 100644 index 00000000..0dcea8fa --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi @@ -0,0 +1,82 @@ +from google.protobuf import any_pb2 as _any_pb2 +from google.protobuf import struct_pb2 as _struct_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Caller(_message.Message): + __slots__ = ("RendererId", "UserId", "Tags", "Email", "FirstName", "LastName", "Team", "Timezone") + class TagsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class Request(_message.Message): + __slots__ = ("Ip",) + IP_FIELD_NUMBER: _ClassVar[int] + Ip: str + def __init__(self, Ip: _Optional[str] = ...) -> None: ... + RENDERERID_FIELD_NUMBER: _ClassVar[int] + USERID_FIELD_NUMBER: _ClassVar[int] + TAGS_FIELD_NUMBER: _ClassVar[int] + EMAIL_FIELD_NUMBER: _ClassVar[int] + FIRSTNAME_FIELD_NUMBER: _ClassVar[int] + LASTNAME_FIELD_NUMBER: _ClassVar[int] + TEAM_FIELD_NUMBER: _ClassVar[int] + TIMEZONE_FIELD_NUMBER: _ClassVar[int] + RendererId: int + UserId: int + Tags: _containers.ScalarMap[str, str] + Email: str + FirstName: str + LastName: str + Team: str + Timezone: str + def __init__(self, RendererId: _Optional[int] = ..., UserId: _Optional[int] = ..., Tags: _Optional[_Mapping[str, str]] = ..., Email: _Optional[str] = ..., FirstName: _Optional[str] = ..., LastName: _Optional[str] = ..., Team: _Optional[str] = ..., Timezone: _Optional[str] = ...) -> None: ... + +class DataSourceSchema(_message.Message): + __slots__ = ("Charts", "ConnectionNames", "Collections") + CHARTS_FIELD_NUMBER: _ClassVar[int] + CONNECTIONNAMES_FIELD_NUMBER: _ClassVar[int] + COLLECTIONS_FIELD_NUMBER: _ClassVar[int] + Charts: _containers.RepeatedScalarFieldContainer[str] + ConnectionNames: _containers.RepeatedScalarFieldContainer[str] + Collections: _containers.RepeatedCompositeFieldContainer[CollectionSchema] + def __init__(self, Charts: _Optional[_Iterable[str]] = ..., ConnectionNames: _Optional[_Iterable[str]] = ..., Collections: _Optional[_Iterable[_Union[CollectionSchema, _Mapping]]] = ...) -> None: ... + +class CollectionSchema(_message.Message): + __slots__ = ("name", "actions", "fields", "charts", "searchable", "countable", "segments") + class ActionsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: _struct_pb2.Struct + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + class FieldsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: _struct_pb2.Struct + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + ACTIONS_FIELD_NUMBER: _ClassVar[int] + FIELDS_FIELD_NUMBER: _ClassVar[int] + CHARTS_FIELD_NUMBER: _ClassVar[int] + SEARCHABLE_FIELD_NUMBER: _ClassVar[int] + COUNTABLE_FIELD_NUMBER: _ClassVar[int] + SEGMENTS_FIELD_NUMBER: _ClassVar[int] + name: str + actions: _containers.MessageMap[str, _struct_pb2.Struct] + fields: _containers.MessageMap[str, _struct_pb2.Struct] + charts: _containers.RepeatedScalarFieldContainer[str] + searchable: bool + countable: bool + segments: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, name: _Optional[str] = ..., actions: _Optional[_Mapping[str, _struct_pb2.Struct]] = ..., fields: _Optional[_Mapping[str, _struct_pb2.Struct]] = ..., charts: _Optional[_Iterable[str]] = ..., searchable: bool = ..., countable: bool = ..., segments: _Optional[_Iterable[str]] = ...) -> None: ... diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py new file mode 100644 index 00000000..1b1116d8 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py @@ -0,0 +1,24 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + + +GRPC_GENERATED_VERSION = '1.70.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in forestadmin/rpc_common/proto/forest_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) diff --git a/src/rpc_common/pyproject.toml b/src/rpc_common/pyproject.toml new file mode 100644 index 00000000..2313e73d --- /dev/null +++ b/src/rpc_common/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = [ "poetry-core",] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "forestadmin-rpc-common" +description = "common files for rpc server and client" +version = "1.23.0" +authors = [ "Julien Barreau ",] +readme = "README.md" +repository = "https://github.com/ForestAdmin/agent-python" +documentation = "https://docs.forestadmin.com/developer-guide-agents-python/" +homepage = "https://www.forestadmin.com" +[[tool.poetry.packages]] +include = "forestadmin" + +[tool.poetry.dependencies] +python = ">=3.8,<3.14" +grpcio = ">1.7.0" +grpcio-tools = { version = "*", optional = true } + +[tool.poe.tasks] +cleanup = { shell = "find forestadmin/rpc_common/proto -name '*.py*' -exec rm -f {} \\;" } +build = { shell = "python -m grpc_tools.protoc -I./ --python_out=. --pyi_out=. --grpc_python_out=. forestadmin/rpc_common/proto/*.proto" } \ No newline at end of file From e53caceec9a1afb0ebfb67af6ced86092bb7f45d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 12 Mar 2025 16:35:09 +0100 Subject: [PATCH 02/34] feat(agent): add a reload method to the agent --- .../forestadmin/agent_toolkit/agent.py | 14 +++++ .../datasource_customizer.py | 38 +++++++----- .../decorators/decorator_stack.py | 61 +++++++++++++++++-- 3 files changed, 95 insertions(+), 18 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index 199ab8fa..db0faf59 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -71,6 +71,20 @@ def __del__(self): if hasattr(self, "_sse_thread") and self._sse_thread.is_alive(): self._sse_thread.stop() + async def reload(self): + try: + await self.customizer._reload() + except Exception as exc: + ForestLogger.log("error", f"Error reloading agent: {exc}") + return + + self._resources = None + Agent.__IS_INITIALIZED = False + if hasattr(self, "_sse_thread") and self._sse_thread.is_alive(): + self._sse_thread.stop() + self._sse_thread = SSECacheInvalidation(self._permission_service, self.options) + await self._start() + async def __mk_resources(self): self._resources: Resources = { "capabilities": CapabilitiesResource( diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py index 2904eb16..54b362d7 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py @@ -25,6 +25,17 @@ async def get_datasource(self): await self.stack.apply_queue_customization() return self.stack.datasource + async def _reload(self): + old_composite_ds = self.composite_datasource + + try: + self.composite_datasource = CompositeDatasource() + await self.stack._reload(self.composite_datasource) + except Exception as e: + self.composite_datasource = old_composite_ds + raise e + await self.get_datasource() + @property def collections(self): """Get list of customizable collections""" @@ -38,22 +49,21 @@ def add_datasource(self, datasource: Datasource, options: Optional[DataSourceOpt options (DataSourceOptions, optional): the options """ _options: DataSourceOptions = DataSourceOptions() if options is None else options + _datasource = datasource + if "include" in _options or "exclude" in _options: + publication_decorator = PublicationDataSourceDecorator(_datasource) + publication_decorator.keep_collections_matching(_options.get("include", []), _options.get("exclude", [])) + _datasource = publication_decorator + + if "rename" in _options: + rename_decorator = RenameCollectionDataSourceDecorator(_datasource) + rename_decorator.rename_collections(_options.get("rename", {})) + _datasource = rename_decorator async def _add_datasource(): - nonlocal datasource - if "include" in _options or "exclude" in _options: - publication_decorator = PublicationDataSourceDecorator(datasource) - publication_decorator.keep_collections_matching( - _options.get("include", []), _options.get("exclude", []) - ) - datasource = publication_decorator - - if "rename" in _options: - rename_decorator = RenameCollectionDataSourceDecorator(datasource) - rename_decorator.rename_collections(_options.get("rename", {})) - datasource = rename_decorator - - self.composite_datasource.add_datasource(datasource) + nonlocal _datasource + + self.composite_datasource.add_datasource(_datasource) self.stack.queue_customization(_add_datasource) return self diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py index 5a1d297f..6845aa1c 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py @@ -1,5 +1,6 @@ from typing import Awaitable, Callable, List +from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.datasource_toolkit.decorators.action.collections import ActionCollectionDecorator from forestadmin.datasource_toolkit.decorators.binary.collection import BinaryCollectionDecorator from forestadmin.datasource_toolkit.decorators.chart.chart_datasource_decorator import ChartDataSourceDecorator @@ -28,10 +29,12 @@ class DecoratorStack: def __init__(self, datasource: Datasource) -> None: self._customizations: List = list() - last = datasource + self._applied_customizations: List = list() + self._init_stack(datasource) + def _init_stack(self, datasource: Datasource): # Step 0: Do not query datasource when we know the result with yield an empty set. - last = self.override = DatasourceDecorator(last, OverrideCollectionDecorator) # type: ignore + last = self.override = DatasourceDecorator(datasource, OverrideCollectionDecorator) # type: ignore last = self.empty = DatasourceDecorator(last, EmptyCollectionDecorator) # type: ignore # Step 1: Computed-Relation-Computed sandwich (needed because some emulated relations depend @@ -67,10 +70,58 @@ def __init__(self, datasource: Datasource) -> None: self.datasource = last + def _backup_stack(self): + backup_stack = {} + for decorator_name in [ + "_customizations", + "_applied_customizations", + "override", + "empty", + "early_computed", + "early_op_emulate", + "early_op_equivalence", + "relation", + "lazy_joins", + "late_computed", + "late_op_emulate", + "late_op_equivalence", + "search", + "segment", + "sort_emulate", + "chart", + "action", + "schema", + "write", + "hook", + "validation", + "binary", + "publication", + "rename_field", + ]: + backup_stack[decorator_name] = getattr(self, decorator_name) + return backup_stack + + def _restore_stack(self, backup_stack): + for decorator_name, instance in backup_stack.items(): + setattr(self, decorator_name, instance) + + async def _reload(self, datasource: Datasource): + stack_backup = self._backup_stack() + + self._customizations: List = self._applied_customizations + self._applied_customizations: List = list() + try: + self._init_stack(datasource) + await self.apply_queue_customization() + except Exception as e: + ForestLogger.log("exception", f"Error while reloading customizations: {e}, restoring previous state") + self._restore_stack(stack_backup) + raise e + def queue_customization(self, customization: Callable[[], Awaitable[None]]): self._customizations.append(customization) - async def apply_queue_customization(self): + async def apply_queue_customization(self, store_customizations=True): queued_customization = self._customizations.copy() self._customizations = [] @@ -78,4 +129,6 @@ async def apply_queue_customization(self): while len(queued_customization) > 0: customization = queued_customization.pop() await customization() - await self.apply_queue_customization() + if store_customizations: + self._applied_customizations.append(customization) + await self.apply_queue_customization(False) From 3ea33fbfd38d795448a2f065903a2772cce6bdf7 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 14 Mar 2025 11:06:54 +0100 Subject: [PATCH 03/34] chore: cleanup agent reload code --- .../forestadmin/agent_toolkit/agent.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index db0faf59..92307299 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -79,11 +79,8 @@ async def reload(self): return self._resources = None - Agent.__IS_INITIALIZED = False - if hasattr(self, "_sse_thread") and self._sse_thread.is_alive(): - self._sse_thread.stop() - self._sse_thread = SSECacheInvalidation(self._permission_service, self.options) - await self._start() + await self.__mk_resources() + await self.send_schema() async def __mk_resources(self): self._resources: Resources = { @@ -233,6 +230,15 @@ async def _start(self): return ForestLogger.log("debug", "Starting agent") + await self.send_schema() + + if self.options["instant_cache_refresh"]: + self._sse_thread.start() + + ForestLogger.log("debug", "Agent started") + Agent.__IS_INITIALIZED = True + + async def send_schema(self): if self.options["skip_schema_update"] is False: try: api_map = await SchemaEmitter.get_serialized_schema( From ff934d1809390ee3db7b4dd2f5af0f38a90e5d0b Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 14 Mar 2025 11:07:33 +0100 Subject: [PATCH 04/34] fix(rename_collection): issues while reloading --- .../datasource_customizer.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py index 54b362d7..24b53475 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py @@ -49,19 +49,22 @@ def add_datasource(self, datasource: Datasource, options: Optional[DataSourceOpt options (DataSourceOptions, optional): the options """ _options: DataSourceOptions = DataSourceOptions() if options is None else options - _datasource = datasource - if "include" in _options or "exclude" in _options: - publication_decorator = PublicationDataSourceDecorator(_datasource) - publication_decorator.keep_collections_matching(_options.get("include", []), _options.get("exclude", [])) - _datasource = publication_decorator - - if "rename" in _options: - rename_decorator = RenameCollectionDataSourceDecorator(_datasource) - rename_decorator.rename_collections(_options.get("rename", {})) - _datasource = rename_decorator async def _add_datasource(): - nonlocal _datasource + nonlocal datasource + _datasource = datasource + + if "include" in _options or "exclude" in _options: + publication_decorator = PublicationDataSourceDecorator(_datasource) + publication_decorator.keep_collections_matching( + _options.get("include", []), _options.get("exclude", []) + ) + _datasource = publication_decorator + + if "rename" in _options: + rename_decorator = RenameCollectionDataSourceDecorator(_datasource) + rename_decorator.rename_collections(_options.get("rename", {})) + _datasource = rename_decorator self.composite_datasource.add_datasource(_datasource) From 036a139164e38d8299cd23c85d9d7d8795400386 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 14 Mar 2025 11:08:02 +0100 Subject: [PATCH 05/34] fix(validation): validation rules were duplicated in schema --- .../datasource_toolkit/decorators/validation/collection.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py index 1be261a6..25cbe5d1 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py @@ -52,9 +52,7 @@ async def update(self, caller: User, filter: Optional[Filter], patch: RecordsDat def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: fields_schema = {**sub_schema["fields"]} for name, rules in self.validations.items(): - if fields_schema[name].get("validations") is None: - fields_schema[name]["validations"] = [] - fields_schema[name]["validations"].extend(rules) + fields_schema[name]["validations"] = [*fields_schema[name].get("validations", []), *rules] return {**sub_schema, "fields": fields_schema} From c39355c1813331c2e4f36a460ef6735bcb53d0df Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 14 Mar 2025 11:08:51 +0100 Subject: [PATCH 06/34] fix(live_queries): connections weren't refresh correctly --- .../datasource_toolkit/decorators/datasource_decorator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py index 64ebddda..e1e5ffbb 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py @@ -27,6 +27,9 @@ def schema(self): def collections(self): return [self.get_collection(c.name) for c in self.child_datasource.collections] + def get_native_query_connections(self) -> List[str]: + return self.child_datasource.get_native_query_connections() + def get_collection(self, name: str) -> CollectionDecorator: collection = self.child_datasource.get_collection(name) if collection not in self._decorators: From 82533183d8d6086ffeb804ca5d9e0f462871fdc5 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 14 Mar 2025 11:09:08 +0100 Subject: [PATCH 07/34] chore: improve naming --- .../datasource_toolkit/decorators/decorator_stack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py index 6845aa1c..69f5fc1d 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py @@ -106,7 +106,7 @@ def _restore_stack(self, backup_stack): setattr(self, decorator_name, instance) async def _reload(self, datasource: Datasource): - stack_backup = self._backup_stack() + backup_stack = self._backup_stack() self._customizations: List = self._applied_customizations self._applied_customizations: List = list() @@ -115,7 +115,7 @@ async def _reload(self, datasource: Datasource): await self.apply_queue_customization() except Exception as e: ForestLogger.log("exception", f"Error while reloading customizations: {e}, restoring previous state") - self._restore_stack(stack_backup) + self._restore_stack(backup_stack) raise e def queue_customization(self, customization: Callable[[], Awaitable[None]]): From 185969e39e58a5978820d390b7b911ea847c1d1d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 20 Mar 2025 14:48:04 +0100 Subject: [PATCH 08/34] chore: add common rpc package --- .../rpc_common/serializers/__init__.py | 0 .../serializers/collection/__init__.py | 0 .../serializers/collection/aggregation.py | 63 ++++ .../serializers/collection/filter.py | 148 ++++++++ .../serializers/collection/record.py | 115 ++++++ .../rpc_common/serializers/schema/__init__.py | 0 .../rpc_common/serializers/schema/schema.py | 352 ++++++++++++++++++ .../rpc_common/serializers/utils.py | 127 +++++++ 8 files changed, 805 insertions(+) create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/__init__.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/collection/__init__.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/collection/aggregation.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/collection/record.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/schema/__init__.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/utils.py diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/__init__.py b/src/rpc_common/forestadmin/rpc_common/serializers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/__init__.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/aggregation.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/aggregation.py new file mode 100644 index 00000000..4c81f4e8 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/aggregation.py @@ -0,0 +1,63 @@ +from typing import Dict + +from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation, Aggregator, DateOperation + + +class AggregatorSerializer: + @staticmethod + def serialize(aggregator: Aggregator) -> str: + if isinstance(aggregator, Aggregator): + return aggregator.value + return aggregator + + @staticmethod + def deserialize(aggregator: str) -> Aggregator: + return Aggregator(aggregator) + + +class DateOperationSerializer: + @staticmethod + def serialize(date_operation: DateOperation) -> str: + return date_operation.value + + @staticmethod + def deserialize(date_operation: str) -> DateOperation: + return DateOperation(date_operation) + + +class AggregationSerializer: + @staticmethod + def serialize(aggregation: Aggregation) -> Dict: + return { + "field": aggregation.field, + "operation": AggregatorSerializer.serialize(aggregation.operation), + "groups": [ + { + "field": group["field"], + "operation": ( + DateOperationSerializer.serialize(group["operation"]) if "operation" in group else None + ), + } + for group in aggregation.groups + ], + } + + @staticmethod + def deserialize(aggregation: Dict) -> Aggregation: + groups = [] + for group in aggregation["groups"]: + tmp = { + "field": group["field"], + } + if group["operation"] is not None: + tmp["operation"] = DateOperationSerializer.deserialize(group["operation"]) + + groups.append(tmp) + + return Aggregation( + { + "field": aggregation["field"], + "operation": AggregatorSerializer.deserialize(aggregation["operation"]), + "groups": groups, + } + ) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py new file mode 100644 index 00000000..62824c0f --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py @@ -0,0 +1,148 @@ +from typing import Any, Dict, List, cast + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf +from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter +from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter +from forestadmin.datasource_toolkit.interfaces.query.page import Page +from forestadmin.datasource_toolkit.interfaces.query.projections import Projection +from forestadmin.datasource_toolkit.interfaces.query.sort import Sort +from forestadmin.datasource_toolkit.utils.collections import CollectionUtils +from forestadmin.rpc_common.serializers.collection.record import RecordSerializer +from forestadmin.rpc_common.serializers.utils import OperatorSerializer, TimezoneSerializer, camel_to_snake_case + + +class ConditionTreeSerializer: + @staticmethod + def serialize(condition_tree: ConditionTree, collection: Collection) -> Dict: + if isinstance(condition_tree, ConditionTreeBranch): + return { + "aggregator": condition_tree.aggregator.value, + "conditions": [ + ConditionTreeSerializer.serialize(condition, collection) for condition in condition_tree.conditions + ], + } + elif isinstance(condition_tree, ConditionTreeLeaf): + return { + "operator": OperatorSerializer.serialize(condition_tree.operator), + "field": condition_tree.field, + "value": ConditionTreeSerializer._serialize_value( + condition_tree.value, collection, condition_tree.field + ), + } + raise ValueError(f"Unknown condition tree type: {type(condition_tree)}") + + @staticmethod + def _serialize_value(value: Any, collection: Collection, field_name: str) -> Any: + if value is None: + return None + field_schema = CollectionUtils.get_field_schema(collection, field_name) #  type:ignore + type_ = field_schema["column_type"] + + if type_ == PrimitiveType.DATE: + ret = value.isoformat() + elif type_ == PrimitiveType.DATE_ONLY: + ret = value.isoformat() + elif type_ == PrimitiveType.DATE: + ret = value.isoformat() + elif type_ == PrimitiveType.POINT: + ret = (value[0], value[1]) + elif type_ == PrimitiveType.UUID: + ret = str(value) + else: + ret = value + + return ret + + @staticmethod + def deserialize(condition_tree: Dict, collection: Collection) -> ConditionTree: + if "aggregator" in condition_tree: + return ConditionTreeBranch( + aggregator=condition_tree["aggregator"], + conditions=[ + ConditionTreeSerializer.deserialize(condition, collection) + for condition in condition_tree["conditions"] + ], + ) + elif "operator" in condition_tree: + return ConditionTreeLeaf( + operator=camel_to_snake_case(condition_tree["operator"]), # type:ignore + field=condition_tree["field"], + value=ConditionTreeSerializer._deserialize_value( + condition_tree.get("value"), collection, condition_tree["field"] + ), + ) + raise ValueError(f"Unknown condition tree type: {condition_tree.keys()}") + + @staticmethod + def _deserialize_value(value: Any, collection: Collection, field_name: str) -> Any: + if value is None: + return None + field_schema = CollectionUtils.get_field_schema(collection, field_name) #  type:ignore + return RecordSerializer._deserialize_primitive_type(field_schema["column_type"], value) #  type:ignore + + +class FilterSerializer: + @staticmethod + def serialize(filter_: Filter, collection: Collection) -> Dict: + return { + "conditionTree": ( + ConditionTreeSerializer.serialize(filter_.condition_tree, collection) + if filter_.condition_tree is not None + else None + ), + "search": filter_.search, + "searchExtended": filter_.search_extended, + "segment": filter_.segment, + "timezone": TimezoneSerializer.serialize(filter_.timezone), + } + + @staticmethod + def deserialize(filter_: Dict, collection: Collection) -> Filter: + return Filter( + { + "condition_tree": ( + ConditionTreeSerializer.deserialize(filter_["conditionTree"], collection) + if filter_.get("conditionTree") is not None + else None + ), # type: ignore + "search": filter_["search"], + "search_extended": filter_["searchExtended"], + "segment": filter_["segment"], + "timezone": TimezoneSerializer.deserialize(filter_["timezone"]), + } + ) + + +class PaginatedFilterSerializer: + @staticmethod + def serialize(filter_: PaginatedFilter, collection: Collection) -> Dict: + ret = FilterSerializer.serialize(filter_.to_base_filter(), collection) + ret["sort"] = filter_.sort + ret["page"] = {"skip": filter_.page.skip, "limit": filter_.page.limit} if filter_.page else None + return ret + + @staticmethod + def deserialize(serialized_filter: Dict, collection: Collection) -> PaginatedFilter: + ret = FilterSerializer.deserialize(serialized_filter, collection) + ret = PaginatedFilter.from_base_filter(ret) + ret.sort = Sort(serialized_filter["sort"]) if serialized_filter.get("sort") is not None else None # type: ignore + ret.page = ( + Page(skip=serialized_filter["page"]["skip"], limit=serialized_filter["page"]["limit"]) + if serialized_filter.get("page") is not None + else None + ) #  type: ignore + return ret + + +class ProjectionSerializer: + @staticmethod + def serialize(projection: Projection) -> List[str]: + return [*projection] + + @staticmethod + def deserialize(projection: List[str]) -> Projection: + return Projection(*projection) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/record.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/record.py new file mode 100644 index 00000000..2a6559f3 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/record.py @@ -0,0 +1,115 @@ +from datetime import date, datetime +from typing import Any, Dict, cast +from uuid import UUID + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.interfaces.fields import ( + Column, + PrimitiveType, + RelationAlias, + is_column, + is_many_to_many, + is_many_to_one, + is_one_to_many, + is_one_to_one, + is_polymorphic_many_to_one, + is_polymorphic_one_to_many, + is_polymorphic_one_to_one, +) +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias + + +class RecordSerializer: + @staticmethod + def serialize(record: RecordsDataAlias) -> Dict: + serialized_record = {} + + for key, value in record.items(): + if isinstance(value, dict): + serialized_record[key] = RecordSerializer.serialize(value) + elif isinstance(value, set) or isinstance(value, list): + serialized_record[key] = [RecordSerializer.serialize(v) for v in value] + elif isinstance(value, date): + serialized_record[key] = value.isoformat() + # elif isinstance(value, datetime): + # serialized_record[key] = value.isoformat() + elif isinstance(value, UUID): + serialized_record[key] = str(value) + else: + serialized_record[key] = value + + return serialized_record + + @staticmethod + def deserialize(record: Dict, collection: Collection) -> RecordsDataAlias: + deserialized_record = {} + collection_schema = collection.schema + + for key, value in record.items(): + if is_column(collection_schema["fields"][key]): + deserialized_record[key] = RecordSerializer._deserialize_column(collection, key, value) + else: + deserialized_record[key] = RecordSerializer._deserialize_relation(collection, key, value, record) + return deserialized_record + + @staticmethod + def _deserialize_relation(collection: Collection, field_name: str, value: Any, record: Dict) -> RecordsDataAlias: + schema_field = cast(RelationAlias, collection.schema["fields"][field_name]) + if is_many_to_one(schema_field) or is_one_to_one(schema_field) or is_polymorphic_one_to_one(schema_field): + return RecordSerializer.deserialize( + value, collection.datasource.get_collection(schema_field["foreign_collection"]) + ) + elif is_many_to_many(schema_field) or is_one_to_many(schema_field) or is_polymorphic_one_to_many(schema_field): + return [ + RecordSerializer.deserialize( + v, collection.datasource.get_collection(schema_field["foreign_collection"]) + ) + for v in value + ] + elif is_polymorphic_many_to_one(schema_field): + if schema_field["foreign_key_type_field"] not in record: + raise ValueError("Cannot deserialize polymorphic many to one relation without foreign key type field") + return RecordSerializer.deserialize( + value, collection.datasource.get_collection(record[schema_field["foreign_key_type_field"]]) + ) + else: + raise ValueError(f"Unknown field type: {schema_field}") + + @staticmethod + def _deserialize_column(collection: Collection, field_name: str, value: Any) -> RecordsDataAlias: + schema_field = cast(Column, collection.schema["fields"][field_name]) + + if isinstance(schema_field["column_type"], dict) or isinstance(schema_field["column_type"], list): + return RecordSerializer._deserialize_complex_type(schema_field["column_type"], value) + elif isinstance(schema_field["column_type"], PrimitiveType): + return RecordSerializer._deserialize_primitive_type(schema_field["column_type"], value) + + @staticmethod + def _deserialize_primitive_type(type_: PrimitiveType, value: Any): + if type_ == PrimitiveType.DATE: + return datetime.fromisoformat(value) + elif type_ == PrimitiveType.DATE_ONLY: + return date.fromisoformat(value) + elif type_ == PrimitiveType.DATE: + return datetime.fromisoformat(value) + elif type_ == PrimitiveType.POINT: + return (value[0], value[1]) + elif type_ == PrimitiveType.UUID: + return UUID(value) + elif type_ == PrimitiveType.BINARY: + pass # TODO: binary + elif type_ == PrimitiveType.TIME_ONLY: + pass # TODO: time only + elif isinstance(type_, PrimitiveType): + return value + else: + raise ValueError(f"Unknown primitive type: {type_}") + + @staticmethod + def _deserialize_complex_type(type_, value): + if isinstance(type_, list): + return [RecordSerializer._deserialize_complex_type(type_[0], v) for v in value] + elif isinstance(type_, dict): + return {k: RecordSerializer._deserialize_complex_type(v, value[k]) for k, v in type_.items()} + + return value diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/schema/__init__.py b/src/rpc_common/forestadmin/rpc_common/serializers/schema/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py new file mode 100644 index 00000000..bc03b652 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py @@ -0,0 +1,352 @@ +import json +from enum import Enum +from typing import Any, Union, cast + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.interfaces.actions import ActionFieldType, ActionsScope +from forestadmin.datasource_toolkit.interfaces.fields import ( + Column, + FieldAlias, + FieldType, + PrimitiveType, + RelationAlias, + is_column, + is_many_to_many, + is_many_to_one, + is_one_to_many, + is_one_to_one, + is_polymorphic_many_to_one, + is_polymorphic_one_to_many, + is_polymorphic_one_to_one, +) +from forestadmin.rpc_common.serializers.utils import OperatorSerializer, enum_to_str_or_value, snake_to_camel_case + + +def serialize_column_type(column_type: Union[Enum, Any]) -> Union[str, dict, list]: + if isinstance(column_type, list): + return [serialize_column_type(t) for t in column_type] + + if isinstance(column_type, dict): + return {k: serialize_column_type(v) for k, v in column_type.items()} + + return enum_to_str_or_value(column_type) + + +def deserialize_column_type(column_type: Union[list, dict, str]) -> Union[PrimitiveType, dict, list]: + if isinstance(column_type, list): + return [deserialize_column_type(t) for t in column_type] + + if isinstance(column_type, dict): + return {k: deserialize_column_type(v) for k, v in column_type.items()} + + return PrimitiveType(column_type) + + +class SchemaSerializer: + VERSION = "1.0" + + def __init__(self, datasource: Datasource) -> None: + self.datasource = datasource + + async def serialize(self) -> str: + value = { + "version": self.VERSION, + "data": { + "charts": sorted(self.datasource.schema["charts"].keys()), + "live_query_connections": sorted(self.datasource.get_native_query_connections()), + "collections": { + c.name: await self._serialize_collection(c) + for c in sorted(self.datasource.collections, key=lambda c: c.name) + }, + }, + } + return json.dumps(value) + + async def _serialize_collection(self, collection: Collection) -> dict: + return { + "fields": { + field: await self._serialize_field(collection.schema["fields"][field]) + for field in sorted(collection.schema["fields"].keys()) + }, + "actions": { + action_name: await self._serialize_action(action_name, collection) + for action_name in sorted(collection.schema["actions"].keys()) + }, + "segments": sorted(collection.schema["segments"]), + "countable": collection.schema["countable"], + "searchable": collection.schema["searchable"], + "charts": sorted(collection.schema["charts"].keys()), + } + + async def _serialize_field(self, field: FieldAlias) -> dict: + if is_column(field): + return await self._serialize_column(field) + else: + return await self._serialize_relation(cast(RelationAlias, field)) + + async def _serialize_column(self, column: Column) -> dict: + return { + "type": "Column", + "columnType": (serialize_column_type(column["column_type"])), + "filterOperators": sorted([OperatorSerializer.serialize(op) for op in column.get("filter_operators", [])]), + "defaultValue": column.get("default_value"), + "enumValues": column.get("enum_values"), + "isPrimaryKey": column.get("is_primary_key", False), + "isReadOnly": column.get("is_read_only", False), + "isSortable": column.get("is_sortable", False), + "validations": [ + { + "operator": OperatorSerializer.serialize(v["operator"]), + "value": v.get("value"), + } + for v in sorted( + column.get("validations", []), + key=lambda v: f'{OperatorSerializer.serialize(v["operator"])}_{str(v.get("value"))}', + ) + ], + } + + async def _serialize_relation(self, relation: RelationAlias) -> dict: + serialized_field = {} + if is_polymorphic_many_to_one(relation): + serialized_field.update( + { + "type": "PolymorphicManyToOne", + "foreignCollections": relation["foreign_collections"], + "foreignKey": relation["foreign_key"], + "foreignKeyTypeField": relation["foreign_key_type_field"], + "foreignKeyTargets": relation["foreign_key_targets"], + } + ) + elif is_many_to_one(relation): + serialized_field.update( + { + "type": "ManyToOne", + "foreignCollection": relation["foreign_collection"], + "foreignKey": relation["foreign_key"], + "foreignKeyTarget": relation["foreign_key_target"], + } + ) + elif is_one_to_one(relation) or is_one_to_many(relation): + serialized_field.update( + { + "type": "OneToMany" if is_one_to_many(relation) else "OneToOne", + "foreignCollection": relation["foreign_collection"], + "originKey": relation["origin_key"], + "originKeyTarget": relation["origin_key_target"], + } + ) + elif is_polymorphic_one_to_one(relation) or is_polymorphic_one_to_many(relation): + serialized_field.update( + { + "type": "PolymorphicOneToMany" if is_polymorphic_one_to_many(relation) else "PolymorphicOneToOne", + "foreignCollection": relation["foreign_collection"], + "originKey": relation["origin_key"], + "originKeyTarget": relation["origin_key_target"], + "originTypeField": relation["origin_type_field"], + "originTypeValue": relation["origin_type_value"], + } + ) + elif is_many_to_many(relation): + serialized_field.update( + { + "type": "ManyToMany", + "foreignCollections": relation["foreign_collection"], + "foreignKey": relation["foreign_key"], + "foreignKeyTargets": relation["foreign_key_target"], + "originKey": relation["origin_key"], + "originKeyTarget": relation["origin_key_target"], + "throughCollection": relation["through_collection"], + } + ) + return serialized_field + + async def _serialize_action(self, action_name: str, collection: Collection) -> dict: + action = collection.schema["actions"][action_name] + if not action.static_form: + form = None + else: + form = await collection.get_form(None, action_name, None, None) # type:ignore + # fields, layout = SchemaActionGenerator.extract_fields_and_layout(form) + # fields = [ + # await SchemaActionGenerator.build_field_schema(collection.datasource, field) for field in fields + # ] + + return { + "scope": action.scope.value, + "generateFile": action.generate_file or False, + "staticForm": action.static_form or False, + "description": action.description, + "submitButtonLabel": action.submit_button_label, + "form": await self._serialize_action_form(form) if form is not None else None, + } + + async def _serialize_action_form(self, form) -> list[dict]: + serialized_form = [] + + for field in form: + if field["type"] == ActionFieldType.LAYOUT: + if field["component"] == "Page": + serialized_form.append( + {**field, "type": "Layout", "elements": await self._serialize_action_form(field["elements"])} + ) + + if field["component"] == "Row": + serialized_form.append( + {**field, "type": "Layout", "fields": await self._serialize_action_form(field["fields"])} + ) + else: + serialized_form.append( + { + **{snake_to_camel_case(k): v for k, v in field.items()}, + "type": enum_to_str_or_value(field["type"]), + } + ) + + return serialized_form + + +class SchemaDeserializer: + VERSION = "1.0" + + def deserialize(self, json_schema) -> dict: + schema = json.loads(json_schema) + if schema["version"] != self.VERSION: + raise ValueError(f"Unsupported schema version {schema['version']}") + + return { + "charts": schema["data"]["charts"], + "live_query_connections": schema["data"]["live_query_connections"], + "collections": { + name: self._deserialize_collection(collection) + for name, collection in schema["data"]["collections"].items() + }, + } + + def _deserialize_collection(self, collection: dict) -> dict: + return { + "fields": {field: self._deserialize_field(collection["fields"][field]) for field in collection["fields"]}, + "actions": { + action_name: self._deserialize_action(action) for action_name, action in collection["actions"].items() + }, + "segments": collection["segments"], + "countable": collection["countable"], + "searchable": collection["searchable"], + "charts": collection["charts"], + } + + def _deserialize_field(self, field: dict) -> dict: + if field["type"] == "Column": + return self._deserialize_column(field) + else: + return self._deserialize_relation(field) + + def _deserialize_column(self, column: dict) -> dict: + return { + "type": FieldType.COLUMN, + "column_type": deserialize_column_type(column["columnType"]), + "filter_operators": [OperatorSerializer.deserialize(op) for op in column["filterOperators"]], + "default_value": column.get("defaultValue"), + "enum_values": column.get("enumValues"), + "is_primary_key": column.get("isPrimaryKey", False), + "is_read_only": column.get("isReadOnly", False), + "is_sortable": column.get("isSortable", False), + "validations": [ + { + "operator": OperatorSerializer.deserialize(op["operator"]), + "value": op.get("value"), + } + for op in column.get("validations", []) + ], + } + + def _deserialize_relation(self, relation: dict) -> dict: + if relation["type"] == "PolymorphicManyToOne": + return { + "type": FieldType("PolymorphicManyToOne"), + "foreign_collections": relation["foreignCollections"], + "foreign_key": relation["foreignKey"], + "foreign_key_type_field": relation["foreignKeyTypeField"], + "foreign_key_targets": relation["foreignKeyTargets"], + } + elif relation["type"] == "ManyToOne": + return { + "type": FieldType("ManyToOne"), + "foreign_collection": relation["foreignCollection"], + "foreign_key": relation["foreignKey"], + "foreign_key_target": relation["foreignKeyTarget"], + } + elif relation["type"] in ["OneToMany", "OneToOne"]: + return { + "type": FieldType("OneToOne") if relation["type"] == "OneToOne" else FieldType("OneToMany"), + "foreign_collection": relation["foreignCollection"], + "origin_key": relation["originKey"], + "origin_key_target": relation["originKeyTarget"], + } + elif relation["type"] in ["PolymorphicOneToMany", "PolymorphicOneToOne"]: + return { + "type": ( + FieldType("PolymorphicOneToOne") + if relation["type"] == "PolymorphicOneToOne" + else FieldType("PolymorphicOneToMany") + ), + "foreign_collection": relation["foreignCollection"], + "origin_key": relation["originKey"], + "origin_key_target": relation["originKeyTarget"], + "origin_type_field": relation["originTypeField"], + "origin_type_value": relation["originTypeValue"], + } + elif relation["type"] == "ManyToMany": + return { + "type": FieldType("ManyToMany"), + "foreign_collection": relation["foreignCollections"], + "foreign_key": relation["foreignKey"], + "foreign_key_target": relation["foreignKeyTargets"], + "origin_key": relation["originKey"], + "origin_key_target": relation["originKeyTarget"], + "through_collection": relation["throughCollection"], + } + raise ValueError(f"Unsupported relation type {relation['type']}") + + def _deserialize_action(self, action: dict) -> dict: + return { + "scope": ActionsScope(action["scope"]), + "generate_file": action["generateFile"], + "static_form": action["staticForm"], + "description": action["description"], + "submit_button_label": action["submitButtonLabel"], + "form": self._deserialize_action_form(action["form"]) if action["form"] is not None else None, + } + + def _deserialize_action_form(self, form: list) -> list[dict]: + deserialized_form = [] + + for field in form: + if field["type"] == "Layout": + if field["component"] == "Page": + deserialized_form.append( + { + **field, + "type": ActionFieldType("Layout"), + "elements": self._deserialize_action_form(field["elements"]), + } + ) + + if field["component"] == "Row": + deserialized_form.append( + { + **field, + "type": ActionFieldType("Layout"), + "fields": self._deserialize_action_form(field["fields"]), + } + ) + else: + deserialized_form.append( + { + **{snake_to_camel_case(k): v for k, v in field.items()}, + "type": ActionFieldType(field["type"]), + } + ) + + return deserialized_form diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/utils.py b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py new file mode 100644 index 00000000..b59fbfc5 --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py @@ -0,0 +1,127 @@ +import sys +from enum import Enum +from typing import Any, Dict, Optional, Union + +if sys.version_info >= (3, 9): + import zoneinfo +else: + from backports import zoneinfo + +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.interfaces.fields import Operator + + +def snake_to_camel_case(snake_str: str) -> str: + ret = "".join(word.title() for word in snake_str.split("_")) + ret = f"{ret[0].lower()}{ret[1:]}" + return ret + + +def camel_to_snake_case(camel_str: str) -> str: + ret = "".join(f"_{c.lower()}" if c.isupper() else c for c in camel_str) + return ret + + +def enum_to_str_or_value(value: Union[Enum, Any]): + return value.value if isinstance(value, Enum) else value + + +class OperatorSerializer: + # OPERATOR_MAPPING: Dict[str, str] = { + # "present": "present", + # "blank": "blank", + # "missing": "missing", + # "equal": "equal", + # "not_equal": "notEqual", + # "less_than": "lessThan", + # "greater_than": "greaterThan", + # "in": "in", + # "not_in": "notIn", + # "like": "like", + # "starts_with": "startsWith", + # "ends_with": "endsWith", + # "contains": "contains", + # "match": "match", + # "not_contains": "notContains", + # "longer_than": "longerThan", + # "shorter_than": "shorterThan", + # "before": "before", + # "after": "after", + # "after_x_hours_ago": "afterXHoursAgo", + # "before_x_hours_ago": "beforeXHoursAgo", + # "future": "future", + # "past": "past", + # "previous_month_to_date": "previousMonthToDate", + # "previous_month": "previousMonth", + # "previous_quarter_to_date": "previousQuarterToDate", + # "previous_quarter": "previousQuarter", + # "previous_week_to_date": "previousWeekToDate", + # "previous_week": "previousWeek", + # "previous_x_days_to_date": "previousXDaysToDate", + # "previous_x_days": "previousXDays", + # "previous_year_to_date": "previousYearToDate", + # "previous_year": "previousYear", + # "today": "today", + # "yesterday": "yesterday", + # "includes_all": "includesAll", + # } + + @staticmethod + def deserialize(operator: str) -> Operator: + return Operator(camel_to_snake_case(operator)) + # for value, serialized in OperatorSerializer.OPERATOR_MAPPING.items(): + # if serialized == operator: + # return Operator(value) + # raise ValueError(f"Unknown operator: {operator}") + + @staticmethod + def serialize(operator: Union[Operator, str]) -> str: + value = operator + if isinstance(value, Enum): + value = value.value + return snake_to_camel_case(value) + # return OperatorSerializer.OPERATOR_MAPPING[value] + + +class TimezoneSerializer: + @staticmethod + def serialize(timezone: Optional[zoneinfo.ZoneInfo]) -> Optional[str]: + if timezone is None: + return None + return str(timezone) + + @staticmethod + def deserialize(timezone: Optional[str]) -> Optional[zoneinfo.ZoneInfo]: + if timezone is None: + return None + return zoneinfo.ZoneInfo(timezone) + + +class CallerSerializer: + @staticmethod + def serialize(caller: User) -> Dict: + return { + "renderingId": caller.rendering_id, + "userId": caller.user_id, + "tags": caller.tags, + "email": caller.email, + "firstName": caller.first_name, + "lastName": caller.last_name, + "team": caller.team, + "timezone": TimezoneSerializer.serialize(caller.timezone), + "request": {"ip": caller.request["ip"]}, + } + + @staticmethod + def deserialize(caller: Dict) -> User: + return User( + rendering_id=caller["renderingId"], + user_id=caller["userId"], + tags=caller["tags"], + email=caller["email"], + first_name=caller["firstName"], + last_name=caller["lastName"], + team=caller["team"], + timezone=TimezoneSerializer.deserialize(caller["timezone"]), # type:ignore + request={"ip": caller["request"]["ip"]}, + ) From 8c62df01f3c43c0ff3782a7e56726ade20ea74b9 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 20 Mar 2025 14:49:32 +0100 Subject: [PATCH 09/34] chore: remove forgotten code --- src/agent_toolkit/forestadmin/agent_toolkit/agent.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index 92307299..55220e89 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -253,9 +253,3 @@ async def send_schema(self): ForestLogger.log("warning", "Cannot send the apimap to Forest. Are you online?") else: ForestLogger.log("warning", 'Schema update was skipped (caused by options["skip_schema_update"]=True)') - - if self.options["instant_cache_refresh"]: - self._sse_thread.start() - - ForestLogger.log("debug", "Agent started") - Agent.__IS_INITIALIZED = True From 678ec10592f4db6525a8a5b8aa0f30a7f390cac2 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 20 Mar 2025 14:50:06 +0100 Subject: [PATCH 10/34] fix: add match operator forgotten from literal operators --- .../forestadmin/datasource_toolkit/interfaces/fields.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py index 04afa827..0d39c911 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/fields.py @@ -59,6 +59,7 @@ class Operator(enum.Enum): "starts_with", "ends_with", "contains", + "match", "not_contains", "longer_than", "shorter_than", From f4afd088908b3b4bfa1f4fe0a5364e072a925549 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 20 Mar 2025 14:54:31 +0100 Subject: [PATCH 11/34] chore: add piece of datasource and rpc agent --- src/agent_rpc/README.md | 0 .../forestadmin/agent_rpc/__init__.py | 0 src/agent_rpc/forestadmin/agent_rpc/agent.py | 153 ++++++++++++++++++ .../forestadmin/agent_rpc/options.py | 5 + .../agent_rpc/services/__init__.py | 0 .../agent_rpc/services/datasource.py | 67 ++++++++ src/agent_rpc/main.py | 28 ++++ src/agent_rpc/pyproject.toml | 77 +++++++++ src/agent_rpc/sqlalchemy_models.py | 95 +++++++++++ src/datasource_rpc/README.md | 1 + .../forestadmin/datasource_rpc/__init__.py | 0 .../forestadmin/datasource_rpc/collection.py | 77 +++++++++ .../forestadmin/datasource_rpc/datasource.py | 107 ++++++++++++ .../forestadmin/datasource_rpc/main.py | 19 +++ .../reload_datasource_decorator.py | 63 ++++++++ .../datasource_rpc/reloadable_datasource.py | 3 + src/datasource_rpc/pyproject.toml | 72 +++++++++ 17 files changed, 767 insertions(+) create mode 100644 src/agent_rpc/README.md create mode 100644 src/agent_rpc/forestadmin/agent_rpc/__init__.py create mode 100644 src/agent_rpc/forestadmin/agent_rpc/agent.py create mode 100644 src/agent_rpc/forestadmin/agent_rpc/options.py create mode 100644 src/agent_rpc/forestadmin/agent_rpc/services/__init__.py create mode 100644 src/agent_rpc/forestadmin/agent_rpc/services/datasource.py create mode 100644 src/agent_rpc/main.py create mode 100644 src/agent_rpc/pyproject.toml create mode 100644 src/agent_rpc/sqlalchemy_models.py create mode 100644 src/datasource_rpc/README.md create mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/__init__.py create mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/collection.py create mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/datasource.py create mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/main.py create mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py create mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py create mode 100644 src/datasource_rpc/pyproject.toml diff --git a/src/agent_rpc/README.md b/src/agent_rpc/README.md new file mode 100644 index 00000000..e69de29b diff --git a/src/agent_rpc/forestadmin/agent_rpc/__init__.py b/src/agent_rpc/forestadmin/agent_rpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agent_rpc/forestadmin/agent_rpc/agent.py b/src/agent_rpc/forestadmin/agent_rpc/agent.py new file mode 100644 index 00000000..eb673c9c --- /dev/null +++ b/src/agent_rpc/forestadmin/agent_rpc/agent.py @@ -0,0 +1,153 @@ +import asyncio +import json +import os +import pprint +import signal +import sys +import threading +import time +from enum import Enum + +from aiohttp import web +from aiohttp_sse import sse_response + +# import grpc +from forestadmin.agent_rpc.options import RpcOptions + +# from forestadmin.agent_rpc.services.datasource import DatasourceService +from forestadmin.agent_toolkit.agent import Agent +from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.datasource_toolkit.interfaces.fields import is_column +from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer +from forestadmin.rpc_common.serializers.collection.filter import ( + FilterSerializer, + PaginatedFilterSerializer, + ProjectionSerializer, +) +from forestadmin.rpc_common.serializers.collection.record import RecordSerializer +from forestadmin.rpc_common.serializers.schema.schema import SchemaSerializer +from forestadmin.rpc_common.serializers.utils import CallerSerializer + +# from concurrent import futures + + +# from forestadmin.rpc_common.proto import datasource_pb2_grpc + + +class RcpJsonEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Enum): + return o.value + if isinstance(o, set): + return list(sorted(o, key=lambda x: x.value if isinstance(x, Enum) else str(x))) + if isinstance(o, set): + return list(sorted(o, key=lambda x: x.value if isinstance(x, Enum) else str(x))) + + try: + return super().default(o) + except Exception as exc: + print(f"error on seriliaze {o}, {type(o)}: {exc}") + + +class RpcAgent(Agent): + # TODO: options to add: + # * listen addr + def __init__(self, options: RpcOptions): + # self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + # self.server = grpc.aio.server() + self.listen_addr, self.listen_port = options["listen_addr"].rsplit(":", 1) + self.app = web.Application() + # self.server.add_insecure_port(options["listen_addr"]) + options["skip_schema_update"] = True + options["env_secret"] = "f" * 64 + options["server_url"] = "http://fake" + options["auth_secret"] = "fake" + options["schema_path"] = "./.forestadmin-schema.json" + + super().__init__(options) + self._server_stop = False + self.setup_routes() + # signal.signal(signal.SIGUSR1, self.stop_handler) + + def setup_routes(self): + self.app.router.add_route("GET", "/sse", self.sse_handler) + self.app.router.add_route("GET", "/schema", self.schema) + self.app.router.add_route("POST", "/collection/list", self.collection_list) + self.app.router.add_route("POST", "/collection/create", self.collection_create) + self.app.router.add_route("POST", "/collection/update", self.collection_update) + self.app.router.add_route("POST", "/collection/delete", self.collection_delete) + self.app.router.add_route("POST", "/collection/aggregate", self.collection_aggregate) + self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) + + async def sse_handler(self, request: web.Request) -> web.StreamResponse: + async with sse_response(request) as resp: + while resp.is_connected() and not self._server_stop: + await resp.send("", event="heartbeat") + await asyncio.sleep(1) + data = json.dumps({"event": "RpcServerStop"}) + await resp.send(data, event="RpcServerStop") + return resp + + async def schema(self, request): + await self.customizer.get_datasource() + + return web.Response(text=await SchemaSerializer(await self.customizer.get_datasource()).serialize()) + + async def collection_list(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection) + projection = ProjectionSerializer.deserialize(body_params["projection"]) + + records = await collection.list(caller, filter_, projection) + records = [RecordSerializer.serialize(record) for record in records] + return web.json_response(records) + + async def collection_create(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] + + records = await collection.create(caller, data) + records = [RecordSerializer.serialize(record) for record in records] + return web.json_response(records) + + async def collection_update(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) + patch = RecordSerializer.deserialize(body_params["patch"], collection) + + await collection.update(caller, filter_, patch) + return web.Response(text="OK") + + async def collection_delete(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) + + await collection.delete(caller, filter_) + return web.Response(text="OK") + + async def collection_aggregate(self, request: web.Request): + body_params = await request.json() + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + caller = CallerSerializer.deserialize(body_params["caller"]) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) + aggregation = AggregationSerializer.deserialize(body_params["aggregation"]) + + records = await collection.aggregate(caller, filter_, aggregation) + # records = [RecordSerializer.serialize(record) for record in records] + return web.json_response(records) + + def start(self): + web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port)) diff --git a/src/agent_rpc/forestadmin/agent_rpc/options.py b/src/agent_rpc/forestadmin/agent_rpc/options.py new file mode 100644 index 00000000..c67c921b --- /dev/null +++ b/src/agent_rpc/forestadmin/agent_rpc/options.py @@ -0,0 +1,5 @@ +from typing_extensions import TypedDict + + +class RpcOptions(TypedDict): + listen_addr: str diff --git a/src/agent_rpc/forestadmin/agent_rpc/services/__init__.py b/src/agent_rpc/forestadmin/agent_rpc/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agent_rpc/forestadmin/agent_rpc/services/datasource.py b/src/agent_rpc/forestadmin/agent_rpc/services/datasource.py new file mode 100644 index 00000000..3fc898d9 --- /dev/null +++ b/src/agent_rpc/forestadmin/agent_rpc/services/datasource.py @@ -0,0 +1,67 @@ +import json +from enum import Enum +from typing import Any, List, Mapping + +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.interfaces.fields import is_column +from forestadmin.rpc_common.proto import datasource_pb2, datasource_pb2_grpc, forest_pb2 + + +class DatasourceService(datasource_pb2_grpc.DataSourceServicer): + def __init__(self, datasource: Datasource): + self.datasource = datasource + super().__init__() + + def Schema(self, request, context) -> datasource_pb2.SchemaResponse: + collections = [] + for collection in self.datasource.collections: + fields: List[Mapping[str, Any]] = [] + for field_name, field_schema in collection.schema["fields"].items(): + field: Mapping[str, Any] = {**field_schema} + + field["type"] = field["type"].value if isinstance(field["type"], Enum) else field["type"] + + if is_column(field_schema): + field["column_type"] = ( + field["column_type"].value if isinstance(field["column_type"], Enum) else field["column_type"] + ) + field["filter_operators"] = [ + op.value if isinstance(op, Enum) else op for op in field["filter_operators"] + ] + field["validations"] = [ + { + **validation, + "operator": ( + validation["operator"].value + if isinstance(validation["operator"], Enum) + else validation["operator"] + ), + } + for validation in field["validations"] + ] + else: + continue + fields.append({field_name: json.dumps(field).encode("utf-8")}) + + try: + collections.append( + forest_pb2.CollectionSchema( + name=collection.name, + searchable=collection.schema["searchable"], + segments=collection.schema["segments"], + countable=collection.schema["countable"], + charts=collection.schema["charts"], + # fields=fields, + # actions=actions, + ) + ) + except Exception as e: + print(e) + raise e + return datasource_pb2.SchemaResponse( + Collections=collections, + Schema=forest_pb2.DataSourceSchema( + Charts=self.datasource.schema["charts"], + ConnectionNames=self.datasource.get_native_query_connections(), + ), + ) diff --git a/src/agent_rpc/main.py b/src/agent_rpc/main.py new file mode 100644 index 00000000..29681b38 --- /dev/null +++ b/src/agent_rpc/main.py @@ -0,0 +1,28 @@ +import asyncio +import logging + +from forestadmin.agent_rpc.agent import RpcAgent +from forestadmin.datasource_sqlalchemy.datasource import SqlAlchemyDatasource +from sqlalchemy_models import DB_URI, Base + + +def main() -> None: + agent = RpcAgent({"listen_addr": "0.0.0.0:50051"}) + agent.add_datasource( + SqlAlchemyDatasource(Base, DB_URI), {"rename": lambda collection_name: f"FROMRPCAGENT_{collection_name}"} + ) + agent.customize_collection("FROMRPCAGENT_address").add_field( + "new_fieeeld", + { + "column_type": "String", + "dependencies": ["pk"], + "get_values": lambda records, ctx: ["v" for r in records], + }, + ) + agent.start() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + # asyncio.run(main()) + main() diff --git a/src/agent_rpc/pyproject.toml b/src/agent_rpc/pyproject.toml new file mode 100644 index 00000000..7f0996bf --- /dev/null +++ b/src/agent_rpc/pyproject.toml @@ -0,0 +1,77 @@ +[build-system] +requires = [ "poetry-core",] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "forestadmin-agent-rpc" +description = "server for datasource rpc" +version = "1.23.0" +authors = [ "Julien Barreau ",] +readme = "README.md" +repository = "https://github.com/ForestAdmin/agent-python" +documentation = "https://docs.forestadmin.com/developer-guide-agents-python/" +homepage = "https://www.forestadmin.com" +[[tool.poetry.packages]] +include = "forestadmin" + +[tool.poetry.dependencies] +python = ">=3.8,<3.14" +typing-extensions = "~=4.2" +forestadmin-agent-toolkit = "1.23.0" +forestadmin-datasource-toolkit = "1.23.0" +forestadmin-rpc-common = "1.23.0" + +[tool.poetry.dependencies."backports.zoneinfo"] +version = "~=0.2.1" +python = "<3.9" +extras = [ "tzdata",] + +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.linter] +optional = true + +[tool.poetry.group.formatter] +optional = true + +[tool.poetry.group.sorter] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "~=7.1" +pytest-asyncio = "~=0.18" +coverage = "~=6.5" +pytest-cov = "^4.0.0" +sqlalchemy = ">=1.4.0" + +[tool.poetry.group.linter.dependencies] +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=5.0" +python = "<3.8.1" + +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=6.0" +python = ">=3.8.1" + +[tool.poetry.group.formatter.dependencies] +black = "~=22.10" + +[tool.poetry.group.sorter.dependencies] +isort = "~=3.6" + +[tool.poetry.group.test.dependencies.forestadmin-datasource-toolkit] +path = "../datasource_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-datasource-sqlalchemy] +path = "../datasource_sqlalchemy" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-agent-toolkit] +path = "../agent_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-rpc-common] +path = "../rpc_common" +develop = true diff --git a/src/agent_rpc/sqlalchemy_models.py b/src/agent_rpc/sqlalchemy_models.py new file mode 100644 index 00000000..261c78e3 --- /dev/null +++ b/src/agent_rpc/sqlalchemy_models.py @@ -0,0 +1,95 @@ +import enum +import os + +import sqlalchemy # type: ignore +from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, Integer, LargeBinary, String, func +from sqlalchemy.orm import relationship + +sqlite_path = os.path.abspath( + os.path.join(__file__, "..", "..", "_example", "django", "django_demo", "db_sqlalchemy.sql") +) +print(sqlite_path) +print(os.path.exists(sqlite_path)) +DB_URI = f"sqlite:///{sqlite_path}" +# DB_URI = "postgresql://flask:flask@127.0.0.1:5432/flask_scratch" + + +use_sqlalchemy_2 = sqlalchemy.__version__.split(".")[0] == "2" +if use_sqlalchemy_2: + from sqlalchemy import Uuid, create_engine + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass + +else: + from sqlalchemy import create_engine + from sqlalchemy.dialects.postgresql import UUID as Uuid + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + +engine = create_engine(DB_URI, echo=False) + + +class ORDER_STATUS(enum.Enum): + PENDING = "Pending" + DISPATCHED = "Dispatched" + DELIVERED = "Delivered" + REJECTED = "Rejected" + + +class Address(Base): + __tablename__ = "address" + pk = Column(Integer, primary_key=True) + street = Column(String(254), nullable=False) + street_number = Column(String(254), nullable=True) + city = Column(String(254), nullable=False) + country = Column(String(254), default="France", nullable=False) + zip_code = Column(String(5), nullable=False, default="75009") + customers = relationship("Customer", secondary="customers_addresses", back_populates="addresses") + + +class Customer(Base): + __tablename__ = "customer" + pk = Column(Uuid, primary_key=True) + first_name = Column(String(254), nullable=False) + last_name = Column(String(254), nullable=False) + age = Column(Integer, nullable=True) + birthday_date = Column(DateTime(timezone=True), default=func.now()) + addresses = relationship("Address", secondary="customers_addresses", back_populates="customers") + is_vip = Column(Boolean, default=False) + avatar = Column(LargeBinary, nullable=True) + + +class Order(Base): + __tablename__ = "order" + pk = Column(Integer, primary_key=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + amount = Column(Integer, nullable=False) + customer = relationship("Customer", backref="orders") + customer_id = Column(Uuid, ForeignKey("customer.pk")) + billing_address_id = Column(Integer, ForeignKey("address.pk")) + billing_address = relationship("Address", foreign_keys=[billing_address_id], backref="billing_orders") + delivering_address_id = Column(Integer, ForeignKey("address.pk")) + delivering_address = relationship("Address", foreign_keys=[delivering_address_id], backref="delivering_orders") + status = Column(Enum(ORDER_STATUS)) + + cart = relationship("Cart", uselist=False, back_populates="order") + + +class Cart(Base): + __tablename__ = "cart" + pk = Column(Integer, primary_key=True) + + name = Column(String(254), nullable=False) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + order_id = Column(Integer, ForeignKey("order.pk"), nullable=True) + order = relationship("Order", back_populates="cart") + + +class CustomersAddresses(Base): + __tablename__ = "customers_addresses" + customer_id = Column(Uuid, ForeignKey("customer.pk"), primary_key=True) + address_id = Column(Integer, ForeignKey("address.pk"), primary_key=True) diff --git a/src/datasource_rpc/README.md b/src/datasource_rpc/README.md new file mode 100644 index 00000000..8fb870ba --- /dev/null +++ b/src/datasource_rpc/README.md @@ -0,0 +1 @@ +python -m grpc_tools.protoc -I./ --python_out=. --pyi_out=. --grpc_python_out=. ./*.proto \ No newline at end of file diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/__init__.py b/src/datasource_rpc/forestadmin/datasource_rpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py new file mode 100644 index 00000000..678be0fe --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py @@ -0,0 +1,77 @@ +import json +from typing import Any, Dict, List, Optional + +import urllib3 +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation +from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter +from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter +from forestadmin.datasource_toolkit.interfaces.query.projections import Projection +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias +from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer +from forestadmin.rpc_common.serializers.collection.filter import ( + FilterSerializer, + PaginatedFilterSerializer, + ProjectionSerializer, +) +from forestadmin.rpc_common.serializers.collection.record import RecordSerializer +from forestadmin.rpc_common.serializers.utils import CallerSerializer + + +class RPCCollection(Collection): + def __init__(self, name: str, datasource: Datasource, connection_uri: str): + super().__init__(name, datasource) + self.connection_uri = connection_uri + self.http = urllib3.PoolManager() + + async def list(self, caller: User, filter_: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]: + body = { + "caller": CallerSerializer.serialize(caller), + "filter": PaginatedFilterSerializer.serialize(filter_, self), + "projection": ProjectionSerializer.serialize(projection), + "collectionName": self.name, + } + response = self.http.request("POST", f"http://{self.connection_uri}/collection/list", body=json.dumps(body)) + return [RecordSerializer.deserialize(record, self) for record in json.loads(response.data.decode("utf-8"))] + + async def create(self, caller: User, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + body = { + "caller": CallerSerializer.serialize(caller), + "data": [RecordSerializer.serialize(r) for r in data], + "collectionName": self.name, + } + response = self.http.request("POST", f"http://{self.connection_uri}/collection/create", body=json.dumps(body)) + return [RecordSerializer.deserialize(record, self) for record in json.loads(response.data.decode("utf-8"))] + + async def update(self, caller: User, filter_: Optional[Filter], patch: Dict[str, Any]) -> None: + body = { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), + "patch": RecordSerializer.serialize(patch), + "collectionName": self.name, + } + self.http.request("POST", f"http://{self.connection_uri}/collection/update", body=json.dumps(body)) + + async def delete(self, caller: User, filter_: Filter | None) -> None: + body = { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), + "collectionName": self.name, + } + self.http.request("POST", f"http://{self.connection_uri}/collection/delete", body=json.dumps(body)) + + async def aggregate( + self, caller: User, filter_: Filter | None, aggregation: Aggregation, limit: int | None = None + ) -> List[AggregateResult]: + body = { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), + "aggregation": AggregationSerializer.serialize(aggregation), + "collectionName": self.name, + } + response = self.http.request( + "POST", f"http://{self.connection_uri}/collection/aggregate", body=json.dumps(body) + ) + return json.loads(response.data.decode("utf-8")) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py new file mode 100644 index 00000000..bfa71558 --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -0,0 +1,107 @@ +import asyncio +import json +import os +import signal +import time +from threading import Thread, Timer + +import grpc +import urllib3 +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_rpc.collection import RPCCollection +from forestadmin.datasource_rpc.reloadable_datasource import ReloadableDatasource +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.rpc_common.serializers.schema.schema import SchemaDeserializer +from sseclient import SSEClient + +# from forestadmin.rpc_common.proto import datasource_pb2_grpc +# from google.protobuf import empty_pb2 + + +class RPCDatasource(Datasource, ReloadableDatasource): + def __init__(self, connection_uri: str): + super().__init__([]) + self.connection_uri = connection_uri + # res = asyncio.run(self.connect_sse()) + # Timer(5, self.internal_reload).start() + # self.trigger_reload() + self.http = urllib3.PoolManager() + self.wait_for_reconnect() + self.introspect() + self.thread = Thread(target=self.run, name="RPCDatasourceSSEThread", daemon=True) + self.thread.start() + + def introspect(self): + self._collections = {} + self._schema = {"charts": {}} + response = self.http.request("GET", f"http://{self.connection_uri}/schema") + # schema_data = json.loads(response.data.decode("utf-8")) + schema_data = SchemaDeserializer().deserialize(response.data.decode("utf-8")) + for collection_name, collection in schema_data["collections"].items(): + self.create_collection(collection_name, collection) + + self._schema["charts"] = {name: None for name in schema_data["charts"]} + self._live_query_connections = schema_data["live_query_connections"] + + def create_collection(self, collection_name, collection_schema): + collection = RPCCollection(collection_name, self, self.connection_uri) + for name, field in collection_schema["fields"].items(): + collection.add_field(name, field) + self.add_collection(collection) + + def run(self): + self.wait_for_reconnect() + self.sse_client = SSEClient( + self.http.request("GET", f"http://{self.connection_uri}/sse", preload_content=False) + ) + try: + for msg in self.sse_client.events(): + if msg.event == "heartbeat": + continue + + if msg.event == "RpcServerStop": + print("RpcServerStop") + break + except: + pass + print("rpc connection to server closed") + self.wait_for_reconnect() + self.internal_reload() + + self.thread = Thread(target=self.run, name="RPCDatasourceSSEThread", daemon=True) + self.thread.start() + + # self.channel = grpc.aio.insecure_channel(self.connection_uri) + # async with grpc.aio.insecure_channel(self.connection_uri) as channel: + # stub = datasource_pb2_grpc.DataSourceStub(self.channel) + # response = await stub.Schema(empty_pb2.Empty()) + # self.channel + + return + + def wait_for_reconnect(self): + while True: + try: + self.ping_connect() + break + except: + time.sleep(1) + print("reconntected") + + def connect(self): + self.ping_connect() + + def ping_connect(self): + self.http.request("GET", f"http://{self.connection_uri}/") + + def internal_reload(self): + # Timer(5, self.internal_reload).start() + print("trigger reload") + self.introspect() + self.trigger_reload() + + # def reload_agent(self): + # os.kill(os.getpid(), signal.SIGUSR1) + async def render_chart(self, caller: User, name: str) -> Chart: + raise Exception("Not implemented") diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/main.py b/src/datasource_rpc/forestadmin/datasource_rpc/main.py new file mode 100644 index 00000000..f8cafde2 --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/main.py @@ -0,0 +1,19 @@ +import asyncio +import logging + +import grpc +from forestadmin.rpc_common.proto import datasource_pb2, datasource_pb2_grpc +from google.protobuf import empty_pb2 + + +async def run() -> None: + async with grpc.aio.insecure_channel("localhost:50051") as channel: + stub = datasource_pb2_grpc.DataSourceStub(channel) + response = await stub.Schema(empty_pb2.Empty()) + print(response) + print("datasource client received: " + str(response.Collections[0].searchable)) + + +if __name__ == "__main__": + logging.basicConfig() + asyncio.run(run()) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py b/src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py new file mode 100644 index 00000000..687598ef --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py @@ -0,0 +1,63 @@ +import asyncio +import hashlib +import json + +# from dictdiffer import diff as differ +from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.datasource_rpc.reloadable_datasource import ReloadableDatasource +from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator +from forestadmin.datasource_toolkit.decorators.datasource_decorator import DatasourceDecorator + + +class ReloadDatasourceDecorator(DatasourceDecorator): + def __init__(self, reload_method, datasource): + self.reload_method = reload_method + super().__init__(datasource, CollectionDecorator) + self.current_schema_hash = self.hash_current_schema() + self.current_schema = { + "charts": self.child_datasource.schema["charts"], + "live_query_connections": self.child_datasource.get_native_query_connections(), + "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], + } + self.current_schema_str = json.dumps( + { + "charts": self.child_datasource.schema["charts"], + "live_query_connections": self.child_datasource.get_native_query_connections(), + "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], + }, + default=str, + ) + if not isinstance(self.child_datasource, ReloadableDatasource): + raise ValueError("The child datasource must be a ReloadableDatasource") + self.child_datasource.trigger_reload = self.reload + + def reload(self): + # diffs = differ( + # self.current_schema, + # { + # "charts": self.child_datasource.schema["charts"], + # "live_query_connections": self.child_datasource.get_native_query_connections(), + # "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], + # }, + # ) + if self.current_schema_hash == self.hash_current_schema(): + ForestLogger.log("info", "Schema has not changed, skipping reload") + return + + if asyncio.iscoroutinefunction(self.reload_method): + asyncio.run(self.reload_method()) + else: + self.reload_method() + + def hash_current_schema(self): + h = hashlib.shake_256( + json.dumps( + { + "charts": self.child_datasource.schema["charts"], + "live_query_connections": self.child_datasource.get_native_query_connections(), + "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], + }, + default=str, + ).encode("utf-8") + ) + return h.hexdigest(20) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py new file mode 100644 index 00000000..5ec1e6cf --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py @@ -0,0 +1,3 @@ +class ReloadableDatasource: + def trigger_reload(self): + pass diff --git a/src/datasource_rpc/pyproject.toml b/src/datasource_rpc/pyproject.toml new file mode 100644 index 00000000..17371fe8 --- /dev/null +++ b/src/datasource_rpc/pyproject.toml @@ -0,0 +1,72 @@ +[build-system] +requires = [ "poetry-core",] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "forestadmin-datasource-rpc" +description = "client for datasource rpc" +version = "1.23.0" +authors = [ "Julien Barreau ",] +readme = "README.md" +repository = "https://github.com/ForestAdmin/agent-python" +documentation = "https://docs.forestadmin.com/developer-guide-agents-python/" +homepage = "https://www.forestadmin.com" +[[tool.poetry.packages]] +include = "forestadmin" + +[tool.poetry.dependencies] +python = ">=3.8,<3.14" +typing-extensions = "~=4.2" +forestadmin-agent-toolkit = "1.23.0" +forestadmin-datasource-toolkit = "1.23.0" +forestadmin-rpc-common = "1.23.0" + +[tool.poetry.dependencies."backports.zoneinfo"] +version = "~=0.2.1" +python = "<3.9" +extras = [ "tzdata",] + +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.linter] +optional = true + +[tool.poetry.group.formatter] +optional = true + +[tool.poetry.group.sorter] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "~=7.1" +pytest-asyncio = "~=0.18" +coverage = "~=6.5" +pytest-cov = "^4.0.0" + +[tool.poetry.group.linter.dependencies] +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=5.0" +python = "<3.8.1" + +[[tool.poetry.group.linter.dependencies.flake8]] +version = "~=6.0" +python = ">=3.8.1" + +[tool.poetry.group.formatter.dependencies] +black = "~=22.10" + +[tool.poetry.group.sorter.dependencies] +isort = "~=3.6" + +[tool.poetry.group.test.dependencies.forestadmin-datasource-toolkit] +path = "../datasource_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-agent-toolkit] +path = "../agent_toolkit" +develop = true + +[tool.poetry.group.test.dependencies.forestadmin-rpc-common] +path = "../rpc_common" +develop = true From d9ca0b900c13d44f1445d0474aade21e3bf9ae2d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 25 Mar 2025 11:10:24 +0100 Subject: [PATCH 12/34] chore: all methods now exists in rpc datasource/agent --- .gitignore | 3 + src/agent_rpc/forestadmin/agent_rpc/agent.py | 149 +++++++++- .../forestadmin/agent_rpc/options.py | 2 + .../forestadmin/datasource_rpc/collection.py | 258 +++++++++++++++--- .../forestadmin/datasource_rpc/datasource.py | 67 ++++- src/rpc_common/forestadmin/rpc_common/hmac.py | 11 + .../rpc_common/serializers/actions.py | 213 +++++++++++++++ .../forestadmin/rpc_common/serializers/aes.py | 22 ++ .../rpc_common/serializers/schema/schema.py | 95 +------ 9 files changed, 676 insertions(+), 144 deletions(-) create mode 100644 src/rpc_common/forestadmin/rpc_common/hmac.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/actions.py create mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/aes.py diff --git a/.gitignore b/.gitignore index b8ff2039..fb2c422f 100644 --- a/.gitignore +++ b/.gitignore @@ -177,6 +177,9 @@ src/datasource_toolkit/poetry.lock src/flask_agent/poetry.lock src/django_agent/poetry.lock src/datasource_django/poetry.lock +src/datasource_rpc/poetry.lock +src/agent_rpc/poetry.lock +src/rpc_common/poetry.lock # generate file during tests src/django_agent/tests/test_project_agent/.forestadmin-schema.json \ No newline at end of file diff --git a/src/agent_rpc/forestadmin/agent_rpc/agent.py b/src/agent_rpc/forestadmin/agent_rpc/agent.py index eb673c9c..4257d1ce 100644 --- a/src/agent_rpc/forestadmin/agent_rpc/agent.py +++ b/src/agent_rpc/forestadmin/agent_rpc/agent.py @@ -1,12 +1,7 @@ import asyncio import json -import os -import pprint -import signal -import sys -import threading -import time from enum import Enum +from uuid import UUID from aiohttp import web from aiohttp_sse import sse_response @@ -16,8 +11,11 @@ # from forestadmin.agent_rpc.services.datasource import DatasourceService from forestadmin.agent_toolkit.agent import Agent -from forestadmin.agent_toolkit.forest_logger import ForestLogger -from forestadmin.datasource_toolkit.interfaces.fields import is_column +from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType +from forestadmin.datasource_toolkit.utils.schema import SchemaUtils +from forestadmin.rpc_common.hmac import is_valid_hmac +from forestadmin.rpc_common.serializers.actions import ActionFormSerializer, ActionResultSerializer +from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer from forestadmin.rpc_common.serializers.collection.filter import ( FilterSerializer, @@ -56,20 +54,33 @@ def __init__(self, options: RpcOptions): # self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) # self.server = grpc.aio.server() self.listen_addr, self.listen_port = options["listen_addr"].rsplit(":", 1) - self.app = web.Application() + self.app = web.Application(middlewares=[self.hmac_middleware]) # self.server.add_insecure_port(options["listen_addr"]) options["skip_schema_update"] = True options["env_secret"] = "f" * 64 options["server_url"] = "http://fake" - options["auth_secret"] = "fake" + # options["auth_secret"] = "f48186505a3c5d62c27743126d6a76c1dd8b3e2d8897de19" options["schema_path"] = "./.forestadmin-schema.json" - super().__init__(options) + + self.aes_key = self.options["auth_secret"][:16].encode() + self.aes_iv = self.options["auth_secret"][-16:].encode() self._server_stop = False self.setup_routes() # signal.signal(signal.SIGUSR1, self.stop_handler) + @web.middleware + async def hmac_middleware(self, request: web.Request, handler): + if request.method == "POST": + body = await request.read() + if not is_valid_hmac( + self.options["auth_secret"].encode(), body, request.headers.get("X-FOREST-HMAC", "").encode("utf-8") + ): + return web.Response(status=401) + return await handler(request) + def setup_routes(self): + # self.app.middlewares.append(self.hmac_middleware) self.app.router.add_route("GET", "/sse", self.sse_handler) self.app.router.add_route("GET", "/schema", self.schema) self.app.router.add_route("POST", "/collection/list", self.collection_list) @@ -77,6 +88,12 @@ def setup_routes(self): self.app.router.add_route("POST", "/collection/update", self.collection_update) self.app.router.add_route("POST", "/collection/delete", self.collection_delete) self.app.router.add_route("POST", "/collection/aggregate", self.collection_aggregate) + self.app.router.add_route("POST", "/collection/get-form", self.collection_get_form) + self.app.router.add_route("POST", "/collection/execute", self.collection_execute) + self.app.router.add_route("POST", "/collection/render-chart", self.collection_render_chart) + + self.app.router.add_route("POST", "/execute-native-query", self.native_query) + self.app.router.add_route("POST", "/render-chart", self.render_chart) self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) async def sse_handler(self, request: web.Request) -> web.StreamResponse: @@ -103,11 +120,14 @@ async def collection_list(self, request: web.Request): records = await collection.list(caller, filter_, projection) records = [RecordSerializer.serialize(record) for record in records] - return web.json_response(records) + return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv)) async def collection_create(self, request: web.Request): - body_params = await request.json() + body_params = await request.text() + body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) + body_params = json.loads(body_params) ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) caller = CallerSerializer.deserialize(body_params["caller"]) data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] @@ -117,7 +137,10 @@ async def collection_create(self, request: web.Request): return web.json_response(records) async def collection_update(self, request: web.Request): - body_params = await request.json() + body_params = await request.text() + body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) + body_params = json.loads(body_params) + ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) caller = CallerSerializer.deserialize(body_params["caller"]) @@ -147,7 +170,103 @@ async def collection_aggregate(self, request: web.Request): records = await collection.aggregate(caller, filter_, aggregation) # records = [RecordSerializer.serialize(record) for record in records] - return web.json_response(records) + return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv)) + + async def collection_get_form(self, request: web.Request): + body_params = await request.text() + body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + + caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None + action_name = body_params["actionName"] + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None + data = body_params["data"] + meta = body_params["meta"] + + form = await collection.get_form(caller, action_name, data, filter_, meta) + return web.Response( + text=aes_encrypt(json.dumps(ActionFormSerializer.serialize(form)), self.aes_key, self.aes_iv) + ) + + async def collection_execute(self, request: web.Request): + body_params = await request.text() + body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + + caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None + action_name = body_params["actionName"] + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None + data = body_params["data"] + + result = await collection.execute(caller, action_name, data, filter_) + return web.Response( + text=aes_encrypt(json.dumps(ActionResultSerializer.serialize(result)), self.aes_key, self.aes_iv) + ) + + async def collection_render_chart(self, request: web.Request): + body_params = await request.text() + body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + collection = ds.get_collection(body_params["collectionName"]) + + caller = CallerSerializer.deserialize(body_params["caller"]) + name = body_params["name"] + record_id = body_params["recordId"] + ret = [] + for i, value in enumerate(record_id): + type_record_id = collection.schema["fields"][SchemaUtils.get_primary_keys(collection.schema)[i]][ + "column_type" + ] + + if type_record_id == PrimitiveType.DATE: + ret.append(value.fromisoformat()) + elif type_record_id == PrimitiveType.DATE_ONLY: + ret.append(value.fromisoformat()) + elif type_record_id == PrimitiveType.DATE: + ret.append(value.fromisoformat()) + elif type_record_id == PrimitiveType.POINT: + ret.append((value[0], value[1])) + elif type_record_id == PrimitiveType.UUID: + ret.append(UUID(value)) + else: + ret.append(value) + + result = await collection.render_chart(caller, name, record_id) + return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv)) + + async def render_chart(self, request: web.Request): + body_params = await request.text() + body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + + caller = CallerSerializer.deserialize(body_params["caller"]) + name = body_params["name"] + + result = await ds.render_chart(caller, name) + return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv)) + + async def native_query(self, request: web.Request): + body_params = await request.text() + body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) + body_params = json.loads(body_params) + + ds = await self.customizer.get_datasource() + connection_name = body_params["connectionName"] + native_query = body_params["nativeQuery"] + parameters = body_params["parameters"] + + result = await ds.execute_native_query(connection_name, native_query, parameters) + return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv)) def start(self): web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port)) diff --git a/src/agent_rpc/forestadmin/agent_rpc/options.py b/src/agent_rpc/forestadmin/agent_rpc/options.py index c67c921b..f94dbbf5 100644 --- a/src/agent_rpc/forestadmin/agent_rpc/options.py +++ b/src/agent_rpc/forestadmin/agent_rpc/options.py @@ -3,3 +3,5 @@ class RpcOptions(TypedDict): listen_addr: str + aes_key: bytes + aes_iv: bytes diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py index 678be0fe..421ea00a 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py @@ -2,14 +2,22 @@ from typing import Any, Dict, List, Optional import urllib3 +from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.interfaces.actions import Action, ActionFormElement, ActionResult, ActionsScope +from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter from forestadmin.datasource_toolkit.interfaces.query.projections import Projection from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias +from forestadmin.datasource_toolkit.utils.schema import SchemaUtils +from forestadmin.rpc_common.hmac import generate_hmac +from forestadmin.rpc_common.serializers.actions import ActionFormSerializer, ActionResultSerializer +from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer from forestadmin.rpc_common.serializers.collection.filter import ( FilterSerializer, @@ -21,57 +29,235 @@ class RPCCollection(Collection): - def __init__(self, name: str, datasource: Datasource, connection_uri: str): + def __init__(self, name: str, datasource: Datasource, connection_uri: str, secret_key: str): super().__init__(name, datasource) self.connection_uri = connection_uri + self.secret_key = secret_key + self.aes_key = secret_key[:16].encode() + self.aes_iv = secret_key[-16:].encode() self.http = urllib3.PoolManager() + self._rpc_actions = {} + + def add_rpc_action(self, name: str, action: Dict[str, Any]) -> None: + if name in self._schema["actions"]: + raise ValueError(f"Action {name} already exists in collection {self.name}") + self._schema["actions"][name] = Action( + scope=ActionsScope(action["scope"]), + description=action.get("description"), + submit_button_label=action.get("submit_button_label"), + generate_file=action.get("generate_file", False), + # form=action.get("form"), + static_form=action["static_form"], + ) + self._rpc_actions[name] = {"form": action["form"]} async def list(self, caller: User, filter_: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]: - body = { - "caller": CallerSerializer.serialize(caller), - "filter": PaginatedFilterSerializer.serialize(filter_, self), - "projection": ProjectionSerializer.serialize(projection), - "collectionName": self.name, - } - response = self.http.request("POST", f"http://{self.connection_uri}/collection/list", body=json.dumps(body)) - return [RecordSerializer.deserialize(record, self) for record in json.loads(response.data.decode("utf-8"))] + body = json.dumps( + { + "caller": CallerSerializer.serialize(caller), + "filter": PaginatedFilterSerializer.serialize(filter_, self), + "projection": ProjectionSerializer.serialize(projection), + "collectionName": self.name, + } + ) + response = self.http.request( + "POST", + f"http://{self.connection_uri}/collection/list", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) + ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) + + return [RecordSerializer.deserialize(record, self) for record in json.loads(ret)] async def create(self, caller: User, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - body = { - "caller": CallerSerializer.serialize(caller), - "data": [RecordSerializer.serialize(r) for r in data], - "collectionName": self.name, - } - response = self.http.request("POST", f"http://{self.connection_uri}/collection/create", body=json.dumps(body)) + body = json.dumps( + { + "caller": CallerSerializer.serialize(caller), + "data": [RecordSerializer.serialize(r) for r in data], + "collectionName": self.name, + } + ) + body = aes_encrypt(body, self.aes_key, self.aes_iv) + response = self.http.request( + "POST", + f"http://{self.connection_uri}/collection/create", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) return [RecordSerializer.deserialize(record, self) for record in json.loads(response.data.decode("utf-8"))] async def update(self, caller: User, filter_: Optional[Filter], patch: Dict[str, Any]) -> None: - body = { - "caller": CallerSerializer.serialize(caller), - "filter": FilterSerializer.serialize(filter_, self), - "patch": RecordSerializer.serialize(patch), - "collectionName": self.name, - } - self.http.request("POST", f"http://{self.connection_uri}/collection/update", body=json.dumps(body)) + body = json.dumps( + { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), + "patch": RecordSerializer.serialize(patch), + "collectionName": self.name, + } + ) + body = aes_encrypt(body, self.aes_key, self.aes_iv) + self.http.request( + "POST", + f"http://{self.connection_uri}/collection/update", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) async def delete(self, caller: User, filter_: Filter | None) -> None: - body = { - "caller": CallerSerializer.serialize(caller), - "filter": FilterSerializer.serialize(filter_, self), - "collectionName": self.name, - } - self.http.request("POST", f"http://{self.connection_uri}/collection/delete", body=json.dumps(body)) + body = json.dumps( + { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), + "collectionName": self.name, + } + ) + self.http.request( + "POST", + f"http://{self.connection_uri}/collection/delete", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) async def aggregate( self, caller: User, filter_: Filter | None, aggregation: Aggregation, limit: int | None = None ) -> List[AggregateResult]: - body = { - "caller": CallerSerializer.serialize(caller), - "filter": FilterSerializer.serialize(filter_, self), - "aggregation": AggregationSerializer.serialize(aggregation), - "collectionName": self.name, - } + body = json.dumps( + { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), + "aggregation": AggregationSerializer.serialize(aggregation), + "collectionName": self.name, + } + ) response = self.http.request( - "POST", f"http://{self.connection_uri}/collection/aggregate", body=json.dumps(body) + "POST", + f"http://{self.connection_uri}/collection/aggregate", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) + ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) + return json.loads(ret) + + async def get_form( + self, + caller: User, + name: str, + data: Optional[RecordsDataAlias] = None, + filter_: Optional[Filter] = None, + meta: Optional[Dict[str, Any]] = None, + ) -> List[ActionFormElement]: + if name not in self._schema["actions"]: + raise ValueError(f"Action {name} does not exist in collection {self.name}") + + if self._schema["actions"][name].static_form: + return self._rpc_actions[name]["form"] + + body = aes_encrypt( + json.dumps( + { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, + "data": data, + "meta": meta, + "collectionName": self.name, + "actionName": name, + } + ), + self.aes_key, + self.aes_iv, + ) + response = self.http.request( + "POST", + f"http://{self.connection_uri}/collection/get-form", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) + ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) + ret = ActionFormSerializer.deserialize(json.loads(ret)) + return ret + + async def execute( + self, + caller: User, + name: str, + data: RecordsDataAlias, + filter_: Optional[Filter], + ) -> ActionResult: + if name not in self._schema["actions"]: + raise ValueError(f"Action {name} does not exist in collection {self.name}") + + body = aes_encrypt( + json.dumps( + { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, + "data": data, + "collectionName": self.name, + "actionName": name, + } + ), + self.aes_key, + self.aes_iv, + ) + response = self.http.request( + "POST", + f"http://{self.connection_uri}/collection/execute", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) + ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) + ret = json.loads(ret) + ret = ActionResultSerializer.deserialize(ret) + return ret + + async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: + if name not in self._schema["charts"].keys(): + raise ValueError(f"Chart {name} does not exist in collection {self.name}") + + ret = [] + for i, value in enumerate(record_id): + type_record_id = self.schema["fields"][SchemaUtils.get_primary_keys(self.schema)[i]]["column_type"] + + if type_record_id == PrimitiveType.DATE: + ret.append(value.isoformat()) + elif type_record_id == PrimitiveType.DATE_ONLY: + ret.append(value.isoformat()) + elif type_record_id == PrimitiveType.DATE: + ret.append(value.isoformat()) + elif type_record_id == PrimitiveType.POINT: + ret.append((value[0], value[1])) + elif type_record_id == PrimitiveType.UUID: + ret.append(str(value)) + else: + ret.append(value) + + body = aes_encrypt( + json.dumps( + { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "name": name, + "collectionName": self.name, + "recordId": ret, + } + ), + self.aes_key, + self.aes_iv, + ) + response = self.http.request( + "POST", + f"http://{self.connection_uri}/collection/render-chart", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) + ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) + ret = json.loads(ret) + return ret + + def get_native_driver(self): + ForestLogger.log( + "error", + "Get_native_driver is not available on RPCCollection. " + "Please use execute_native_query on the rpc datasource instead", ) - return json.loads(response.data.decode("utf-8")) + pass diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index bfa71558..11a0bce5 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -4,6 +4,7 @@ import signal import time from threading import Thread, Timer +from typing import Any, Dict import grpc import urllib3 @@ -11,8 +12,13 @@ from forestadmin.datasource_rpc.collection import RPCCollection from forestadmin.datasource_rpc.reloadable_datasource import ReloadableDatasource from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.decorators.action.collections import ActionCollectionDecorator +from forestadmin.datasource_toolkit.interfaces.actions import Action, ActionsScope from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.rpc_common.hmac import generate_hmac +from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.schema.schema import SchemaDeserializer +from forestadmin.rpc_common.serializers.utils import CallerSerializer from sseclient import SSEClient # from forestadmin.rpc_common.proto import datasource_pb2_grpc @@ -20,9 +26,12 @@ class RPCDatasource(Datasource, ReloadableDatasource): - def __init__(self, connection_uri: str): + def __init__(self, connection_uri: str, secret_key: str): super().__init__([]) self.connection_uri = connection_uri + self.secret_key = secret_key + self.aes_key = secret_key[:16].encode() + self.aes_iv = secret_key[-16:].encode() # res = asyncio.run(self.connect_sse()) # Timer(5, self.internal_reload).start() # self.trigger_reload() @@ -45,9 +54,17 @@ def introspect(self): self._live_query_connections = schema_data["live_query_connections"] def create_collection(self, collection_name, collection_schema): - collection = RPCCollection(collection_name, self, self.connection_uri) + collection = RPCCollection(collection_name, self, self.connection_uri, self.secret_key) for name, field in collection_schema["fields"].items(): collection.add_field(name, field) + + if len(collection_schema["actions"]) > 0: + # collection = ActionCollectionDecorator(collection, self) + for action_name, action_schema in collection_schema["actions"].items(): + collection.add_rpc_action(action_name, action_schema) + + collection._schema["charts"] = {name: None for name in collection_schema["charts"]} + self.add_collection(collection) def run(self): @@ -103,5 +120,49 @@ def internal_reload(self): # def reload_agent(self): # os.kill(os.getpid(), signal.SIGUSR1) + + async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: + body = aes_encrypt( + json.dumps( + { + "connectionName": connection_name, + "nativeQuery": native_query, + "parameters": parameters, + } + ), + self.aes_key, + self.aes_iv, + ) + response = self.http.request( + "POST", + f"http://{self.connection_uri}/execute-native-query", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) + ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) + ret = json.loads(ret) + return ret + async def render_chart(self, caller: User, name: str) -> Chart: - raise Exception("Not implemented") + if name not in self._schema["charts"].keys(): + raise ValueError(f"Chart {name} does not exist in this datasource") + + body = aes_encrypt( + json.dumps( + { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "name": name, + } + ), + self.aes_key, + self.aes_iv, + ) + response = self.http.request( + "POST", + f"http://{self.connection_uri}/render-chart", + body=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) + ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) + ret = json.loads(ret) + return ret diff --git a/src/rpc_common/forestadmin/rpc_common/hmac.py b/src/rpc_common/forestadmin/rpc_common/hmac.py new file mode 100644 index 00000000..7348b5eb --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/hmac.py @@ -0,0 +1,11 @@ +import hashlib +import hmac + + +def generate_hmac(key: bytes, message: bytes) -> str: + return hmac.new(key, message, hashlib.sha256).hexdigest() + + +def is_valid_hmac(key: bytes, message: bytes, sign: bytes) -> bool: + expected_hmac = generate_hmac(key, message).encode("utf-8") + return hmac.compare_digest(expected_hmac, sign) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py new file mode 100644 index 00000000..4919b1fb --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py @@ -0,0 +1,213 @@ +from base64 import b64decode, b64encode +from io import IOBase + +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.interfaces.actions import ( + ActionFieldType, + ActionResult, + ActionsScope, + ErrorResult, + FileResult, + RedirectResult, + SuccessResult, + WebHookResult, +) +from forestadmin.rpc_common.serializers.utils import camel_to_snake_case, enum_to_str_or_value, snake_to_camel_case + + +class ActionSerializer: + @staticmethod + async def serialize(action_name: str, collection: Collection) -> dict: + action = collection.schema["actions"][action_name] + if not action.static_form: + form = None + else: + # TODO: a dummy field here ?? + form = await collection.get_form(None, action_name, None, None) # type:ignore + # fields, layout = SchemaActionGenerator.extract_fields_and_layout(form) + # fields = [ + # await SchemaActionGenerator.build_field_schema(collection.datasource, field) for field in fields + # ] + + return { + "scope": action.scope.value, + "generateFile": action.generate_file or False, + "staticForm": action.static_form or False, + "description": action.description, + "submitButtonLabel": action.submit_button_label, + "form": ActionFormSerializer.serialize(form) if form is not None else [], + } + + @staticmethod + def deserialize(action: dict) -> dict: + return { + "scope": ActionsScope(action["scope"]), + "generate_file": action["generateFile"], + "static_form": action["staticForm"], + "description": action["description"], + "submit_button_label": action["submitButtonLabel"], + "form": ActionFormSerializer.deserialize(action["form"]) if action["form"] is not None else [], + } + + +class ActionFormSerializer: + @staticmethod + def serialize(form) -> list[dict]: + serialized_form = [] + + for field in form: + if field["type"] == ActionFieldType.LAYOUT: + if field["component"] == "Page": + serialized_form.append( + { + **field, + "type": "Layout", + "elements": ActionFormSerializer.serialize(field["elements"]), + } + ) + + if field["component"] == "Row": + serialized_form.append( + { + **field, + "type": "Layout", + "fields": ActionFormSerializer.serialize(field["fields"]), + } + ) + else: + serialized_form.append( + { + **{snake_to_camel_case(k): v for k, v in field.items()}, + "type": enum_to_str_or_value(field["type"]), + } + ) + + return serialized_form + + @staticmethod + def deserialize(form: list) -> list[dict]: + deserialized_form = [] + + for field in form: + if field["type"] == "Layout": + if field["component"] == "Page": + deserialized_form.append( + { + **field, + "type": ActionFieldType("Layout"), + "elements": ActionFormSerializer.deserialize(field["elements"]), + } + ) + + if field["component"] == "Row": + deserialized_form.append( + { + **field, + "type": ActionFieldType("Layout"), + "fields": ActionFormSerializer.deserialize(field["fields"]), + } + ) + else: + deserialized_form.append( + { + **{camel_to_snake_case(k): v for k, v in field.items()}, + "type": ActionFieldType(field["type"]), + } + ) + + return deserialized_form + + +class ActionResultSerializer: + @staticmethod + def serialize(action_result: dict) -> dict: + ret = { + "type": action_result["type"], + "response_headers": action_result["response_headers"], + } + match action_result["type"]: + case "Success": + ret.update( + { + "message": action_result["message"], + "format": action_result["format"], + "invalidated": list(action_result["invalidated"]), + } + ) + case "Error": + ret.update({"message": action_result["message"], "format": action_result["format"]}) + case "Webhook": + ret.update( + { + "url": action_result["url"], + "method": action_result["method"], + "headers": action_result["headers"], + "body": action_result["body"], + } + ) + case "File": + b64_encoded = False + content = action_result["stream"] + if isinstance(content, IOBase): + content = content.read() + if isinstance(content, bytes): + content = b64encode(content).decode("utf-8") + b64_encoded = True + ret.update( + { + "mimeType": action_result["mimeType"], + "name": action_result["name"], + "stream": content, + "b64Encoded": b64_encoded, + } + ) + case "Redirect": + ret.update({"path": action_result["path"]}) + return ret + + @staticmethod + def deserialize(action_result: dict) -> ActionResult: + match action_result["type"]: + case "Success": + return SuccessResult( + type="Success", + message=action_result["message"], + format=action_result["format"], + invalidated=set(action_result["invalidated"]), + response_headers=action_result["response_headers"], + ) + case "Error": + return ErrorResult( + type="Error", + message=action_result["message"], + format=action_result["format"], + response_headers=action_result["response_headers"], + ) + case "Webhook": + return WebHookResult( + type="Webhook", + url=action_result["url"], + method=action_result["method"], + headers=action_result["headers"], + body=action_result["body"], + response_headers=action_result["response_headers"], + ) + case "File": + stream = action_result["stream"] + if action_result["b64Encoded"]: + stream = b64decode(stream.encode("utf-8")) + return FileResult( + type="File", + mimeType=action_result["mimeType"], + name=action_result["name"], + stream=stream, + response_headers=action_result["response_headers"], + ) + case "Redirect": + return RedirectResult( + type="Redirect", + path=action_result["path"], + response_headers=action_result["response_headers"], + ) + case _: + raise ValueError(f"Invalid action result type: {action_result['type']}") diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/aes.py b/src/rpc_common/forestadmin/rpc_common/serializers/aes.py new file mode 100644 index 00000000..94916d8d --- /dev/null +++ b/src/rpc_common/forestadmin/rpc_common/serializers/aes.py @@ -0,0 +1,22 @@ +import base64 + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + +def aes_encrypt(message, aes_key: bytes, aes_iv: bytes): + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # Pad message to 16-byte block size + padded_message = message.ljust(16 * ((len(message) // 16) + 1), "\0") + + encrypted = encryptor.update(padded_message.encode()) + encryptor.finalize() + return base64.b64encode(encrypted).decode() + + +def aes_decrypt(encrypted_message, aes_key: bytes, aes_iv: bytes): + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) + decryptor = cipher.decryptor() + decrypted_padded = decryptor.update(base64.b64decode(encrypted_message)) + decryptor.finalize() + return decrypted_padded.rstrip(b"\0").decode() diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py index bc03b652..1534204c 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py @@ -4,7 +4,6 @@ from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.datasources import Datasource -from forestadmin.datasource_toolkit.interfaces.actions import ActionFieldType, ActionsScope from forestadmin.datasource_toolkit.interfaces.fields import ( Column, FieldAlias, @@ -20,7 +19,8 @@ is_polymorphic_one_to_many, is_polymorphic_one_to_one, ) -from forestadmin.rpc_common.serializers.utils import OperatorSerializer, enum_to_str_or_value, snake_to_camel_case +from forestadmin.rpc_common.serializers.actions import ActionSerializer +from forestadmin.rpc_common.serializers.utils import OperatorSerializer, enum_to_str_or_value def serialize_column_type(column_type: Union[Enum, Any]) -> Union[str, dict, list]: @@ -70,7 +70,7 @@ async def _serialize_collection(self, collection: Collection) -> dict: for field in sorted(collection.schema["fields"].keys()) }, "actions": { - action_name: await self._serialize_action(action_name, collection) + action_name: await ActionSerializer.serialize(action_name, collection) for action_name in sorted(collection.schema["actions"].keys()) }, "segments": sorted(collection.schema["segments"]), @@ -162,50 +162,6 @@ async def _serialize_relation(self, relation: RelationAlias) -> dict: ) return serialized_field - async def _serialize_action(self, action_name: str, collection: Collection) -> dict: - action = collection.schema["actions"][action_name] - if not action.static_form: - form = None - else: - form = await collection.get_form(None, action_name, None, None) # type:ignore - # fields, layout = SchemaActionGenerator.extract_fields_and_layout(form) - # fields = [ - # await SchemaActionGenerator.build_field_schema(collection.datasource, field) for field in fields - # ] - - return { - "scope": action.scope.value, - "generateFile": action.generate_file or False, - "staticForm": action.static_form or False, - "description": action.description, - "submitButtonLabel": action.submit_button_label, - "form": await self._serialize_action_form(form) if form is not None else None, - } - - async def _serialize_action_form(self, form) -> list[dict]: - serialized_form = [] - - for field in form: - if field["type"] == ActionFieldType.LAYOUT: - if field["component"] == "Page": - serialized_form.append( - {**field, "type": "Layout", "elements": await self._serialize_action_form(field["elements"])} - ) - - if field["component"] == "Row": - serialized_form.append( - {**field, "type": "Layout", "fields": await self._serialize_action_form(field["fields"])} - ) - else: - serialized_form.append( - { - **{snake_to_camel_case(k): v for k, v in field.items()}, - "type": enum_to_str_or_value(field["type"]), - } - ) - - return serialized_form - class SchemaDeserializer: VERSION = "1.0" @@ -228,7 +184,8 @@ def _deserialize_collection(self, collection: dict) -> dict: return { "fields": {field: self._deserialize_field(collection["fields"][field]) for field in collection["fields"]}, "actions": { - action_name: self._deserialize_action(action) for action_name, action in collection["actions"].items() + action_name: ActionSerializer.deserialize(action) + for action_name, action in collection["actions"].items() }, "segments": collection["segments"], "countable": collection["countable"], @@ -308,45 +265,3 @@ def _deserialize_relation(self, relation: dict) -> dict: "through_collection": relation["throughCollection"], } raise ValueError(f"Unsupported relation type {relation['type']}") - - def _deserialize_action(self, action: dict) -> dict: - return { - "scope": ActionsScope(action["scope"]), - "generate_file": action["generateFile"], - "static_form": action["staticForm"], - "description": action["description"], - "submit_button_label": action["submitButtonLabel"], - "form": self._deserialize_action_form(action["form"]) if action["form"] is not None else None, - } - - def _deserialize_action_form(self, form: list) -> list[dict]: - deserialized_form = [] - - for field in form: - if field["type"] == "Layout": - if field["component"] == "Page": - deserialized_form.append( - { - **field, - "type": ActionFieldType("Layout"), - "elements": self._deserialize_action_form(field["elements"]), - } - ) - - if field["component"] == "Row": - deserialized_form.append( - { - **field, - "type": ActionFieldType("Layout"), - "fields": self._deserialize_action_form(field["fields"]), - } - ) - else: - deserialized_form.append( - { - **{snake_to_camel_case(k): v for k, v in field.items()}, - "type": ActionFieldType(field["type"]), - } - ) - - return deserialized_form From e2856c386cc55bd27f7da5d9319ffe3d8c678820 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 25 Mar 2025 11:29:27 +0100 Subject: [PATCH 13/34] chore: more tolerant with aggregation typing --- .../datasource_toolkit/interfaces/query/aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py index 80424a27..b410f1d2 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py @@ -65,7 +65,7 @@ class AggregationGroup(TypedDict): class PlainAggregation(TypedDict): field: NotRequired[Optional[str]] - operation: PlainAggregator + operation: Union[PlainAggregator, Aggregator] groups: NotRequired[List[PlainAggregationGroup]] From bb2bcc516b37b19f5241d046e83f66df3cfc0835 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 25 Mar 2025 11:37:29 +0100 Subject: [PATCH 14/34] fix(django): don't create agent in the case it won't start --- .../forestadmin/django_agent/apps.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/django_agent/forestadmin/django_agent/apps.py b/src/django_agent/forestadmin/django_agent/apps.py index 0450fabb..e3509338 100644 --- a/src/django_agent/forestadmin/django_agent/apps.py +++ b/src/django_agent/forestadmin/django_agent/apps.py @@ -8,6 +8,7 @@ from django.conf import settings from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.datasource_django.datasource import DjangoDatasource +from forestadmin.datasource_toolkit.utils.user_callable import call_user_function from forestadmin.django_agent.agent import DjangoAgent, create_agent @@ -24,17 +25,16 @@ def is_launch_as_server() -> bool: def init_app_agent() -> Optional[DjangoAgent]: if not is_launch_as_server(): return None - agent = create_agent() if not hasattr(settings, "FOREST_AUTO_ADD_DJANGO_DATASOURCE") or settings.FOREST_AUTO_ADD_DJANGO_DATASOURCE: - agent.add_datasource(DjangoDatasource()) + DjangoAgentApp._DJANGO_AGENT.add_datasource(DjangoDatasource()) customize_fn = getattr(settings, "FOREST_CUSTOMIZE_FUNCTION", None) if customize_fn: - agent = _call_user_customize_function(customize_fn, agent) + _call_user_customize_function(customize_fn, DjangoAgentApp._DJANGO_AGENT) - if agent: - agent.start() - return agent + if DjangoAgentApp._DJANGO_AGENT: + DjangoAgentApp._DJANGO_AGENT.start() + return DjangoAgentApp._DJANGO_AGENT def _call_user_customize_function(customize_fn: Union[str, Callable[[DjangoAgent], None]], agent: DjangoAgent): @@ -44,21 +44,16 @@ def _call_user_customize_function(customize_fn: Union[str, Callable[[DjangoAgent module = importlib.import_module(module_name) customize_fn = getattr(module, fn_name) except Exception as exc: - ForestLogger.log("error", f"cannot import {customize_fn} : {exc}. Quitting forest.") + ForestLogger.log("exception", f"cannot import {customize_fn} : {exc}. Quitting forest.") return - - if callable(customize_fn): try: - ret = customize_fn(agent) - if asyncio.iscoroutine(ret): - agent.loop.run_until_complete(ret) + agent.loop.run_until_complete(call_user_function(customize_fn, agent)) except Exception as exc: ForestLogger.log( - "error", + "exception", f'error executing "FOREST_CUSTOMIZE_FUNCTION" ({customize_fn}): {exc}. Quitting forest.', ) return - return agent class DjangoAgentApp(AppConfig): @@ -81,6 +76,9 @@ def get_agent(cls) -> DjangoAgent: def ready(self): # we need to wait for other apps to be ready, for this forest app must be ready # that's why we need another thread waiting for every app to be ready + if not is_launch_as_server(): + return + DjangoAgentApp._DJANGO_AGENT = create_agent() t = threading.Thread(name="forest.wait_and_launch_agent", target=self._wait_for_all_apps_ready_and_launch_agent) t.start() From dadea2db630e7762cc1f268af192ffeecbd96603 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 25 Mar 2025 15:37:18 +0100 Subject: [PATCH 15/34] chore: fix linting --- src/django_agent/forestadmin/django_agent/apps.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/django_agent/forestadmin/django_agent/apps.py b/src/django_agent/forestadmin/django_agent/apps.py index e3509338..18ca4968 100644 --- a/src/django_agent/forestadmin/django_agent/apps.py +++ b/src/django_agent/forestadmin/django_agent/apps.py @@ -1,4 +1,3 @@ -import asyncio import importlib import sys import threading From e104ee34d3475d808aff70995fe5cb06973c742b Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 26 Mar 2025 16:44:24 +0100 Subject: [PATCH 16/34] chore: add serializer for files in actions --- .../forestadmin/datasource_rpc/collection.py | 10 +- .../rpc_common/serializers/actions.py | 112 ++++++++++++++++-- 2 files changed, 107 insertions(+), 15 deletions(-) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py index 421ea00a..96c9bc14 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py @@ -16,7 +16,11 @@ from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias from forestadmin.datasource_toolkit.utils.schema import SchemaUtils from forestadmin.rpc_common.hmac import generate_hmac -from forestadmin.rpc_common.serializers.actions import ActionFormSerializer, ActionResultSerializer +from forestadmin.rpc_common.serializers.actions import ( + ActionFormSerializer, + ActionFormValuesSerializer, + ActionResultSerializer, +) from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer from forestadmin.rpc_common.serializers.collection.filter import ( @@ -158,7 +162,7 @@ async def get_form( { "caller": CallerSerializer.serialize(caller) if caller is not None else None, "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, - "data": data, + "data": ActionFormValuesSerializer.serialize(data), "meta": meta, "collectionName": self.name, "actionName": name, @@ -192,7 +196,7 @@ async def execute( { "caller": CallerSerializer.serialize(caller) if caller is not None else None, "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, - "data": data, + "data": ActionFormValuesSerializer.serialize(data), "collectionName": self.name, "actionName": name, } diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py index 4919b1fb..748b89b0 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py @@ -7,6 +7,7 @@ ActionResult, ActionsScope, ErrorResult, + File, FileResult, RedirectResult, SuccessResult, @@ -75,15 +76,73 @@ def serialize(form) -> list[dict]: } ) else: - serialized_form.append( - { - **{snake_to_camel_case(k): v for k, v in field.items()}, - "type": enum_to_str_or_value(field["type"]), - } - ) + tmp_field = { + **{snake_to_camel_case(k): v for k, v in field.items()}, + "type": enum_to_str_or_value(field["type"]), + } + if field["type"] == ActionFieldType.FILE: + tmp_field.update(ActionFormSerializer._serialize_file_field(tmp_field)) + print(tmp_field.keys()) + if field["type"] == ActionFieldType.FILE_LIST: + tmp_field.update(ActionFormSerializer._serialize_file_list_field(tmp_field)) + serialized_form.append(tmp_field) return serialized_form + @staticmethod + def _serialize_file_list_field(field: dict) -> dict: + ret = {} + if field.get("defaultValue"): + ret["defaultValue"] = [ActionFormSerializer._serialize_file_obj(v) for v in field["defaultValue"]] + if field.get("value"): + ret["value"] = [ActionFormSerializer._serialize_file_obj(v) for v in field["value"]] + return ret + + @staticmethod + def _serialize_file_field(field: dict) -> dict: + ret = {} + if field.get("defaultValue"): + ret["defaultValue"] = ActionFormSerializer._serialize_file_obj(field["defaultValue"]) + if field.get("value"): + ret["value"] = ActionFormSerializer._serialize_file_obj(field["value"]) + return ret + + @staticmethod + def _serialize_file_obj(f: File) -> dict: + return { + "mimeType": f.mime_type, + "name": f.name, + "stream": b64encode(f.buffer).decode("utf-8"), + "charset": f.charset, + } + + @staticmethod + def _deserialize_file_obj(f: dict) -> File: + return File( + mime_type=f["mimeType"], + name=f["name"], + buffer=b64decode(f["stream"].encode("utf-8")), + charset=f["charset"], + ) + + @staticmethod + def _deserialize_file_field(field: dict) -> dict: + ret = {} + if field.get("default_value"): + ret["default_value"] = ActionFormSerializer._deserialize_file_obj(field["default_value"]) + if field.get("value"): + ret["value"] = ActionFormSerializer._deserialize_file_obj(field["value"]) + return ret + + @staticmethod + def _deserialize_file_list_field(field: dict) -> dict: + ret = {} + if field.get("default_value"): + ret["default_value"] = [ActionFormSerializer._deserialize_file_obj(v) for v in field["default_value"]] + if field.get("value"): + ret["value"] = [ActionFormSerializer._deserialize_file_obj(v) for v in field["value"]] + return ret + @staticmethod def deserialize(form: list) -> list[dict]: deserialized_form = [] @@ -108,16 +167,45 @@ def deserialize(form: list) -> list[dict]: } ) else: - deserialized_form.append( - { - **{camel_to_snake_case(k): v for k, v in field.items()}, - "type": ActionFieldType(field["type"]), - } - ) + tmp_field = { + **{camel_to_snake_case(k): v for k, v in field.items()}, + "type": ActionFieldType(field["type"]), + } + if tmp_field["type"] == ActionFieldType.FILE: + tmp_field.update(ActionFormSerializer._deserialize_file_field(tmp_field)) + if tmp_field["type"] == ActionFieldType.FILE_LIST: + tmp_field.update(ActionFormSerializer._deserialize_file_list_field(tmp_field)) + deserialized_form.append(tmp_field) return deserialized_form +class ActionFormValuesSerializer: + @staticmethod + def serialize(form_values: dict) -> dict: + ret = {} + for key, value in form_values.items(): + if isinstance(value, File): + ret[key] = ActionFormSerializer._serialize_file_obj(value) + elif isinstance(value, list) and all(isinstance(v, File) for v in value): + ret[key] = [ActionFormSerializer._serialize_file_obj(v) for v in value] + else: + ret[key] = value + return ret + + @staticmethod + def deserialize(form_values: dict) -> dict: + ret = {} + for key, value in form_values.items(): + if isinstance(value, dict) and "mimeType" in value: + ret[key] = ActionFormSerializer._deserialize_file_obj(value) + elif isinstance(value, list) and all(isinstance(v, dict) and "mimeType" in v for v in value): + ret[key] = [ActionFormSerializer._deserialize_file_obj(v) for v in value] + else: + ret[key] = value + return ret + + class ActionResultSerializer: @staticmethod def serialize(action_result: dict) -> dict: From 696837534a57bd6924d544838b84111319078a99 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 26 Mar 2025 16:44:53 +0100 Subject: [PATCH 17/34] chore: simplify the reloading --- .../forestadmin/datasource_rpc/datasource.py | 109 ++++++++---------- .../reload_datasource_decorator.py | 63 ---------- .../datasource_rpc/reloadable_datasource.py | 3 - 3 files changed, 47 insertions(+), 128 deletions(-) delete mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py delete mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index 11a0bce5..0c849484 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import json import os import signal @@ -8,50 +9,58 @@ import grpc import urllib3 +from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_rpc.collection import RPCCollection -from forestadmin.datasource_rpc.reloadable_datasource import ReloadableDatasource from forestadmin.datasource_toolkit.datasources import Datasource -from forestadmin.datasource_toolkit.decorators.action.collections import ActionCollectionDecorator -from forestadmin.datasource_toolkit.interfaces.actions import Action, ActionsScope from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.datasource_toolkit.utils.user_callable import call_user_function from forestadmin.rpc_common.hmac import generate_hmac from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.schema.schema import SchemaDeserializer from forestadmin.rpc_common.serializers.utils import CallerSerializer from sseclient import SSEClient -# from forestadmin.rpc_common.proto import datasource_pb2_grpc -# from google.protobuf import empty_pb2 +def hash_schema(schema) -> str: + return hashlib.sha256(json.dumps(schema, sort_keys=True).encode()).hexdigest() -class RPCDatasource(Datasource, ReloadableDatasource): - def __init__(self, connection_uri: str, secret_key: str): + +class RPCDatasource(Datasource): + def __init__(self, connection_uri: str, secret_key: str, reload_method=None): super().__init__([]) self.connection_uri = connection_uri + self.reload_method = reload_method self.secret_key = secret_key self.aes_key = secret_key[:16].encode() self.aes_iv = secret_key[-16:].encode() - # res = asyncio.run(self.connect_sse()) - # Timer(5, self.internal_reload).start() - # self.trigger_reload() + self.last_schema_hash = "" + self.http = urllib3.PoolManager() - self.wait_for_reconnect() + self.wait_for_connection() self.introspect() - self.thread = Thread(target=self.run, name="RPCDatasourceSSEThread", daemon=True) + + self.thread = Thread(target=asyncio.run, args=(self.run(),), name="RPCDatasourceSSEThread", daemon=True) self.thread.start() - def introspect(self): + def introspect(self) -> bool: + """return true if schema has changed""" + response = self.http.request("GET", f"http://{self.connection_uri}/schema") + schema_data = json.loads(response.data.decode("utf-8")) + if self.last_schema_hash == hash_schema(schema_data): + ForestLogger.log("debug", "[RPCDatasource] Schema has not changed") + return False + self.last_schema_hash = hash_schema(schema_data) + self._collections = {} self._schema = {"charts": {}} - response = self.http.request("GET", f"http://{self.connection_uri}/schema") - # schema_data = json.loads(response.data.decode("utf-8")) - schema_data = SchemaDeserializer().deserialize(response.data.decode("utf-8")) + schema_data = SchemaDeserializer().deserialize(schema_data) for collection_name, collection in schema_data["collections"].items(): self.create_collection(collection_name, collection) self._schema["charts"] = {name: None for name in schema_data["charts"]} self._live_query_connections = schema_data["live_query_connections"] + return True def create_collection(self, collection_name, collection_schema): collection = RPCCollection(collection_name, self, self.connection_uri, self.secret_key) @@ -59,7 +68,6 @@ def create_collection(self, collection_name, collection_schema): collection.add_field(name, field) if len(collection_schema["actions"]) > 0: - # collection = ActionCollectionDecorator(collection, self) for action_name, action_schema in collection_schema["actions"].items(): collection.add_rpc_action(action_name, action_schema) @@ -67,60 +75,37 @@ def create_collection(self, collection_name, collection_schema): self.add_collection(collection) - def run(self): - self.wait_for_reconnect() + def wait_for_connection(self): + while True: + try: + self.http.request("GET", f"http://{self.connection_uri}/") + break + except Exception: + time.sleep(1) + ForestLogger.log("debug", "Connection to RPC datasource established") + + async def internal_reload(self): + has_changed = self.introspect() + if has_changed and self.reload_method is not None: + await call_user_function(self.reload_method) + + async def run(self): + self.wait_for_connection() self.sse_client = SSEClient( self.http.request("GET", f"http://{self.connection_uri}/sse", preload_content=False) ) try: for msg in self.sse_client.events(): - if msg.event == "heartbeat": - continue - - if msg.event == "RpcServerStop": - print("RpcServerStop") - break - except: + pass + except Exception: pass - print("rpc connection to server closed") - self.wait_for_reconnect() - self.internal_reload() + ForestLogger.log("info", "rpc connection to server closed") + self.wait_for_connection() + await self.internal_reload() - self.thread = Thread(target=self.run, name="RPCDatasourceSSEThread", daemon=True) + self.thread = Thread(target=asyncio.run, args=(self.run(),), name="RPCDatasourceSSEThread", daemon=True) self.thread.start() - # self.channel = grpc.aio.insecure_channel(self.connection_uri) - # async with grpc.aio.insecure_channel(self.connection_uri) as channel: - # stub = datasource_pb2_grpc.DataSourceStub(self.channel) - # response = await stub.Schema(empty_pb2.Empty()) - # self.channel - - return - - def wait_for_reconnect(self): - while True: - try: - self.ping_connect() - break - except: - time.sleep(1) - print("reconntected") - - def connect(self): - self.ping_connect() - - def ping_connect(self): - self.http.request("GET", f"http://{self.connection_uri}/") - - def internal_reload(self): - # Timer(5, self.internal_reload).start() - print("trigger reload") - self.introspect() - self.trigger_reload() - - # def reload_agent(self): - # os.kill(os.getpid(), signal.SIGUSR1) - async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: body = aes_encrypt( json.dumps( diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py b/src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py deleted file mode 100644 index 687598ef..00000000 --- a/src/datasource_rpc/forestadmin/datasource_rpc/reload_datasource_decorator.py +++ /dev/null @@ -1,63 +0,0 @@ -import asyncio -import hashlib -import json - -# from dictdiffer import diff as differ -from forestadmin.agent_toolkit.forest_logger import ForestLogger -from forestadmin.datasource_rpc.reloadable_datasource import ReloadableDatasource -from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator -from forestadmin.datasource_toolkit.decorators.datasource_decorator import DatasourceDecorator - - -class ReloadDatasourceDecorator(DatasourceDecorator): - def __init__(self, reload_method, datasource): - self.reload_method = reload_method - super().__init__(datasource, CollectionDecorator) - self.current_schema_hash = self.hash_current_schema() - self.current_schema = { - "charts": self.child_datasource.schema["charts"], - "live_query_connections": self.child_datasource.get_native_query_connections(), - "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], - } - self.current_schema_str = json.dumps( - { - "charts": self.child_datasource.schema["charts"], - "live_query_connections": self.child_datasource.get_native_query_connections(), - "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], - }, - default=str, - ) - if not isinstance(self.child_datasource, ReloadableDatasource): - raise ValueError("The child datasource must be a ReloadableDatasource") - self.child_datasource.trigger_reload = self.reload - - def reload(self): - # diffs = differ( - # self.current_schema, - # { - # "charts": self.child_datasource.schema["charts"], - # "live_query_connections": self.child_datasource.get_native_query_connections(), - # "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], - # }, - # ) - if self.current_schema_hash == self.hash_current_schema(): - ForestLogger.log("info", "Schema has not changed, skipping reload") - return - - if asyncio.iscoroutinefunction(self.reload_method): - asyncio.run(self.reload_method()) - else: - self.reload_method() - - def hash_current_schema(self): - h = hashlib.shake_256( - json.dumps( - { - "charts": self.child_datasource.schema["charts"], - "live_query_connections": self.child_datasource.get_native_query_connections(), - "collections": [c.schema for c in sorted(self.child_datasource.collections, key=lambda c: c.name)], - }, - default=str, - ).encode("utf-8") - ) - return h.hexdigest(20) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py deleted file mode 100644 index 5ec1e6cf..00000000 --- a/src/datasource_rpc/forestadmin/datasource_rpc/reloadable_datasource.py +++ /dev/null @@ -1,3 +0,0 @@ -class ReloadableDatasource: - def trigger_reload(self): - pass From aedaf69444c6f62ab29f656a0094413fe1af40ac Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 26 Mar 2025 16:45:15 +0100 Subject: [PATCH 18/34] chore: cleanup old grpc files --- .../rpc_common/proto/collection.proto | 26 --- .../rpc_common/proto/collection_pb2.py | 43 ---- .../rpc_common/proto/collection_pb2.pyi | 27 --- .../rpc_common/proto/collection_pb2_grpc.py | 143 -------------- .../rpc_common/proto/datasource.proto | 30 --- .../rpc_common/proto/datasource_pb2.py | 49 ----- .../rpc_common/proto/datasource_pb2.pyi | 42 ---- .../rpc_common/proto/datasource_pb2_grpc.py | 186 ------------------ .../forestadmin/rpc_common/proto/forest.proto | 49 ----- .../rpc_common/proto/forest_pb2.py | 56 ------ .../rpc_common/proto/forest_pb2.pyi | 82 -------- .../rpc_common/proto/forest_pb2_grpc.py | 24 --- src/rpc_common/pyproject.toml | 6 - 13 files changed, 763 deletions(-) delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection.proto delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource.proto delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest.proto delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi delete mode 100644 src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection.proto b/src/rpc_common/forestadmin/rpc_common/proto/collection.proto deleted file mode 100644 index 5dc03bd6..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/collection.proto +++ /dev/null @@ -1,26 +0,0 @@ -syntax = "proto3"; - -package datasource_rpc; -import "google/protobuf/any.proto"; -import "google/protobuf/struct.proto"; -import "forestadmin/rpc_common/proto/forest.proto"; - - - -service Collection { - // Define your RPC methods here - rpc RenderChart (CollectionChartRequest) returns (google.protobuf.Any); - // crud - rpc Create (CreateRequest) returns (google.protobuf.Any); -} - -message CollectionChartRequest{ - Caller caller = 1; - string ChartName = 2; - google.protobuf.Any RecordId = 3; -} - -message CreateRequest{ - Caller Caller = 1; - repeated google.protobuf.Struct RecordData = 2; -} \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py deleted file mode 100644 index ddd04935..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: forestadmin/rpc_common/proto/collection.proto -# Protobuf Python Version: 5.29.0 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 0, - '', - 'forestadmin/rpc_common/proto/collection.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -from forestadmin.rpc_common.proto import forest_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_forest__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n-forestadmin/rpc_common/proto/collection.proto\x12\x0e\x64\x61tasource_rpc\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a)forestadmin/rpc_common/proto/forest.proto\"{\n\x16\x43ollectionChartRequest\x12&\n\x06\x63\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12\x11\n\tChartName\x18\x02 \x01(\t\x12&\n\x08RecordId\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\"d\n\rCreateRequest\x12&\n\x06\x43\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12+\n\nRecordData\x18\x02 \x03(\x0b\x32\x17.google.protobuf.Struct2\x98\x01\n\nCollection\x12K\n\x0bRenderChart\x12&.datasource_rpc.CollectionChartRequest\x1a\x14.google.protobuf.Any\x12=\n\x06\x43reate\x12\x1d.datasource_rpc.CreateRequest\x1a\x14.google.protobuf.Anyb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'forestadmin.rpc_common.proto.collection_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_COLLECTIONCHARTREQUEST']._serialized_start=165 - _globals['_COLLECTIONCHARTREQUEST']._serialized_end=288 - _globals['_CREATEREQUEST']._serialized_start=290 - _globals['_CREATEREQUEST']._serialized_end=390 - _globals['_COLLECTION']._serialized_start=393 - _globals['_COLLECTION']._serialized_end=545 -# @@protoc_insertion_point(module_scope) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi deleted file mode 100644 index 93b6d5b3..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2.pyi +++ /dev/null @@ -1,27 +0,0 @@ -from google.protobuf import any_pb2 as _any_pb2 -from google.protobuf import struct_pb2 as _struct_pb2 -from forestadmin.rpc_common.proto import forest_pb2 as _forest_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class CollectionChartRequest(_message.Message): - __slots__ = ("caller", "ChartName", "RecordId") - CALLER_FIELD_NUMBER: _ClassVar[int] - CHARTNAME_FIELD_NUMBER: _ClassVar[int] - RECORDID_FIELD_NUMBER: _ClassVar[int] - caller: _forest_pb2.Caller - ChartName: str - RecordId: _any_pb2.Any - def __init__(self, caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., ChartName: _Optional[str] = ..., RecordId: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...) -> None: ... - -class CreateRequest(_message.Message): - __slots__ = ("Caller", "RecordData") - CALLER_FIELD_NUMBER: _ClassVar[int] - RECORDDATA_FIELD_NUMBER: _ClassVar[int] - Caller: _forest_pb2.Caller - RecordData: _containers.RepeatedCompositeFieldContainer[_struct_pb2.Struct] - def __init__(self, Caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., RecordData: _Optional[_Iterable[_Union[_struct_pb2.Struct, _Mapping]]] = ...) -> None: ... diff --git a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py b/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py deleted file mode 100644 index a2c8f626..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/collection_pb2_grpc.py +++ /dev/null @@ -1,143 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from forestadmin.rpc_common.proto import collection_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2 -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 - -GRPC_GENERATED_VERSION = '1.70.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in forestadmin/rpc_common/proto/collection_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class CollectionStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.RenderChart = channel.unary_unary( - '/datasource_rpc.Collection/RenderChart', - request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CollectionChartRequest.SerializeToString, - response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, - _registered_method=True) - self.Create = channel.unary_unary( - '/datasource_rpc.Collection/Create', - request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CreateRequest.SerializeToString, - response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, - _registered_method=True) - - -class CollectionServicer(object): - """Missing associated documentation comment in .proto file.""" - - def RenderChart(self, request, context): - """Define your RPC methods here - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Create(self, request, context): - """crud - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_CollectionServicer_to_server(servicer, server): - rpc_method_handlers = { - 'RenderChart': grpc.unary_unary_rpc_method_handler( - servicer.RenderChart, - request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CollectionChartRequest.FromString, - response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, - ), - 'Create': grpc.unary_unary_rpc_method_handler( - servicer.Create, - request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CreateRequest.FromString, - response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'datasource_rpc.Collection', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('datasource_rpc.Collection', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class Collection(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def RenderChart(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/datasource_rpc.Collection/RenderChart', - forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CollectionChartRequest.SerializeToString, - google_dot_protobuf_dot_any__pb2.Any.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Create(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/datasource_rpc.Collection/Create', - forestadmin_dot_rpc__common_dot_proto_dot_collection__pb2.CreateRequest.SerializeToString, - google_dot_protobuf_dot_any__pb2.Any.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource.proto b/src/rpc_common/forestadmin/rpc_common/proto/datasource.proto deleted file mode 100644 index a10a0fdb..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/datasource.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -package datasource_rpc; -import "google/protobuf/empty.proto"; -import "google/protobuf/any.proto"; -import "forestadmin/rpc_common/proto/forest.proto"; - -service DataSource { - // Define your RPC methods here - rpc Schema (google.protobuf.Empty) returns (SchemaResponse); - rpc RenderChart (DatasourceChartRequest) returns (google.protobuf.Any); - rpc ExecuteNativeQuery (NativeQueryRequest) returns (google.protobuf.Any); -} - -message SchemaResponse{ - DataSourceSchema Schema = 1; - repeated CollectionSchema Collections = 2; -} - -message DatasourceChartRequest{ - Caller caller = 1; - string ChartName = 2; -} - - -message NativeQueryRequest{ - Caller Caller = 1; - string Query = 2; - map Parameters = 3; -} \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py deleted file mode 100644 index 96cf9d3f..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: forestadmin/rpc_common/proto/datasource.proto -# Protobuf Python Version: 5.29.0 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 0, - '', - 'forestadmin/rpc_common/proto/datasource.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -from forestadmin.rpc_common.proto import forest_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_forest__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n-forestadmin/rpc_common/proto/datasource.proto\x12\x0e\x64\x61tasource_rpc\x1a\x1bgoogle/protobuf/empty.proto\x1a\x19google/protobuf/any.proto\x1a)forestadmin/rpc_common/proto/forest.proto\"y\n\x0eSchemaResponse\x12\x30\n\x06Schema\x18\x01 \x01(\x0b\x32 .datasource_rpc.DataSourceSchema\x12\x35\n\x0b\x43ollections\x18\x02 \x03(\x0b\x32 .datasource_rpc.CollectionSchema\"S\n\x16\x44\x61tasourceChartRequest\x12&\n\x06\x63\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12\x11\n\tChartName\x18\x02 \x01(\t\"\xdc\x01\n\x12NativeQueryRequest\x12&\n\x06\x43\x61ller\x18\x01 \x01(\x0b\x32\x16.datasource_rpc.Caller\x12\r\n\x05Query\x18\x02 \x01(\t\x12\x46\n\nParameters\x18\x03 \x03(\x0b\x32\x32.datasource_rpc.NativeQueryRequest.ParametersEntry\x1aG\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any:\x02\x38\x01\x32\xeb\x01\n\nDataSource\x12@\n\x06Schema\x12\x16.google.protobuf.Empty\x1a\x1e.datasource_rpc.SchemaResponse\x12K\n\x0bRenderChart\x12&.datasource_rpc.DatasourceChartRequest\x1a\x14.google.protobuf.Any\x12N\n\x12\x45xecuteNativeQuery\x12\".datasource_rpc.NativeQueryRequest\x1a\x14.google.protobuf.Anyb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'forestadmin.rpc_common.proto.datasource_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._loaded_options = None - _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._serialized_options = b'8\001' - _globals['_SCHEMARESPONSE']._serialized_start=164 - _globals['_SCHEMARESPONSE']._serialized_end=285 - _globals['_DATASOURCECHARTREQUEST']._serialized_start=287 - _globals['_DATASOURCECHARTREQUEST']._serialized_end=370 - _globals['_NATIVEQUERYREQUEST']._serialized_start=373 - _globals['_NATIVEQUERYREQUEST']._serialized_end=593 - _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._serialized_start=522 - _globals['_NATIVEQUERYREQUEST_PARAMETERSENTRY']._serialized_end=593 - _globals['_DATASOURCE']._serialized_start=596 - _globals['_DATASOURCE']._serialized_end=831 -# @@protoc_insertion_point(module_scope) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi deleted file mode 100644 index ee33ec9e..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2.pyi +++ /dev/null @@ -1,42 +0,0 @@ -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import any_pb2 as _any_pb2 -from forestadmin.rpc_common.proto import forest_pb2 as _forest_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class SchemaResponse(_message.Message): - __slots__ = ("Schema", "Collections") - SCHEMA_FIELD_NUMBER: _ClassVar[int] - COLLECTIONS_FIELD_NUMBER: _ClassVar[int] - Schema: _forest_pb2.DataSourceSchema - Collections: _containers.RepeatedCompositeFieldContainer[_forest_pb2.CollectionSchema] - def __init__(self, Schema: _Optional[_Union[_forest_pb2.DataSourceSchema, _Mapping]] = ..., Collections: _Optional[_Iterable[_Union[_forest_pb2.CollectionSchema, _Mapping]]] = ...) -> None: ... - -class DatasourceChartRequest(_message.Message): - __slots__ = ("caller", "ChartName") - CALLER_FIELD_NUMBER: _ClassVar[int] - CHARTNAME_FIELD_NUMBER: _ClassVar[int] - caller: _forest_pb2.Caller - ChartName: str - def __init__(self, caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., ChartName: _Optional[str] = ...) -> None: ... - -class NativeQueryRequest(_message.Message): - __slots__ = ("Caller", "Query", "Parameters") - class ParametersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: _any_pb2.Any - def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...) -> None: ... - CALLER_FIELD_NUMBER: _ClassVar[int] - QUERY_FIELD_NUMBER: _ClassVar[int] - PARAMETERS_FIELD_NUMBER: _ClassVar[int] - Caller: _forest_pb2.Caller - Query: str - Parameters: _containers.MessageMap[str, _any_pb2.Any] - def __init__(self, Caller: _Optional[_Union[_forest_pb2.Caller, _Mapping]] = ..., Query: _Optional[str] = ..., Parameters: _Optional[_Mapping[str, _any_pb2.Any]] = ...) -> None: ... diff --git a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py b/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py deleted file mode 100644 index f2ce6b02..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/datasource_pb2_grpc.py +++ /dev/null @@ -1,186 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from forestadmin.rpc_common.proto import datasource_pb2 as forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2 -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - -GRPC_GENERATED_VERSION = '1.70.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in forestadmin/rpc_common/proto/datasource_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class DataSourceStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Schema = channel.unary_unary( - '/datasource_rpc.DataSource/Schema', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.SchemaResponse.FromString, - _registered_method=True) - self.RenderChart = channel.unary_unary( - '/datasource_rpc.DataSource/RenderChart', - request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.DatasourceChartRequest.SerializeToString, - response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, - _registered_method=True) - self.ExecuteNativeQuery = channel.unary_unary( - '/datasource_rpc.DataSource/ExecuteNativeQuery', - request_serializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.NativeQueryRequest.SerializeToString, - response_deserializer=google_dot_protobuf_dot_any__pb2.Any.FromString, - _registered_method=True) - - -class DataSourceServicer(object): - """Missing associated documentation comment in .proto file.""" - - def Schema(self, request, context): - """Define your RPC methods here - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def RenderChart(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ExecuteNativeQuery(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_DataSourceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Schema': grpc.unary_unary_rpc_method_handler( - servicer.Schema, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.SchemaResponse.SerializeToString, - ), - 'RenderChart': grpc.unary_unary_rpc_method_handler( - servicer.RenderChart, - request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.DatasourceChartRequest.FromString, - response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, - ), - 'ExecuteNativeQuery': grpc.unary_unary_rpc_method_handler( - servicer.ExecuteNativeQuery, - request_deserializer=forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.NativeQueryRequest.FromString, - response_serializer=google_dot_protobuf_dot_any__pb2.Any.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'datasource_rpc.DataSource', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('datasource_rpc.DataSource', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class DataSource(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def Schema(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/datasource_rpc.DataSource/Schema', - google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.SchemaResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def RenderChart(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/datasource_rpc.DataSource/RenderChart', - forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.DatasourceChartRequest.SerializeToString, - google_dot_protobuf_dot_any__pb2.Any.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def ExecuteNativeQuery(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/datasource_rpc.DataSource/ExecuteNativeQuery', - forestadmin_dot_rpc__common_dot_proto_dot_datasource__pb2.NativeQueryRequest.SerializeToString, - google_dot_protobuf_dot_any__pb2.Any.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest.proto b/src/rpc_common/forestadmin/rpc_common/proto/forest.proto deleted file mode 100644 index 497415f2..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/forest.proto +++ /dev/null @@ -1,49 +0,0 @@ -syntax = "proto3"; - -package datasource_rpc; -import "google/protobuf/any.proto"; -import "google/protobuf/struct.proto"; - -message Caller { - int32 RendererId = 1; - int32 UserId = 2; - map Tags = 3; - string Email = 4; - string FirstName = 5; - string LastName = 6; - string Team = 7; - string Timezone = 8; - message Request{ - string Ip = 1; - } -} - - - -enum ActionScope{ - GLOBAL = 0; - SINGLE = 1; - BULK = 2; -} -// message ActionFormSchema{} -message ActionSchema{ - ActionScope scope = 1; - bool generateFile = 2; - optional string description = 3; - optional string submitButtonLabel = 3; -} - -message CollectionSchema{ - string name = 1; - map actions = 2; - map fields = 3; - repeated string charts = 4; - bool searchable = 5; - bool countable = 6; - repeated string segments = 7; -} -message DataSourceSchema { - repeated string Charts = 1; - repeated string ConnectionNames = 2; - repeated CollectionSchema Collections = 3; -} \ No newline at end of file diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py deleted file mode 100644 index 6e6eee75..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: forestadmin/rpc_common/proto/forest.proto -# Protobuf Python Version: 5.29.0 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 0, - '', - 'forestadmin/rpc_common/proto/forest.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)forestadmin/rpc_common/proto/forest.proto\x12\x0e\x64\x61tasource_rpc\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xf4\x01\n\x06\x43\x61ller\x12\x12\n\nRendererId\x18\x01 \x01(\x05\x12\x0e\n\x06UserId\x18\x02 \x01(\x05\x12.\n\x04Tags\x18\x03 \x03(\x0b\x32 .datasource_rpc.Caller.TagsEntry\x12\r\n\x05\x45mail\x18\x04 \x01(\t\x12\x11\n\tFirstName\x18\x05 \x01(\t\x12\x10\n\x08LastName\x18\x06 \x01(\t\x12\x0c\n\x04Team\x18\x07 \x01(\t\x12\x10\n\x08Timezone\x18\x08 \x01(\t\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x15\n\x07Request\x12\n\n\x02Ip\x18\x01 \x01(\t\"r\n\x10\x44\x61taSourceSchema\x12\x0e\n\x06\x43harts\x18\x01 \x03(\t\x12\x17\n\x0f\x43onnectionNames\x18\x02 \x03(\t\x12\x35\n\x0b\x43ollections\x18\x03 \x03(\x0b\x32 .datasource_rpc.CollectionSchema\"\xf8\x02\n\x10\x43ollectionSchema\x12\x0c\n\x04name\x18\x01 \x01(\t\x12>\n\x07\x61\x63tions\x18\x02 \x03(\x0b\x32-.datasource_rpc.CollectionSchema.ActionsEntry\x12<\n\x06\x66ields\x18\x03 \x03(\x0b\x32,.datasource_rpc.CollectionSchema.FieldsEntry\x12\x0e\n\x06\x63harts\x18\x04 \x03(\t\x12\x12\n\nsearchable\x18\x05 \x01(\x08\x12\x11\n\tcountable\x18\x06 \x01(\x08\x12\x10\n\x08segments\x18\x07 \x03(\t\x1aG\n\x0c\x41\x63tionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct:\x02\x38\x01\x1a\x46\n\x0b\x46ieldsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct:\x02\x38\x01\x62\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'forestadmin.rpc_common.proto.forest_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_CALLER_TAGSENTRY']._loaded_options = None - _globals['_CALLER_TAGSENTRY']._serialized_options = b'8\001' - _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._loaded_options = None - _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._serialized_options = b'8\001' - _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._loaded_options = None - _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._serialized_options = b'8\001' - _globals['_CALLER']._serialized_start=119 - _globals['_CALLER']._serialized_end=363 - _globals['_CALLER_TAGSENTRY']._serialized_start=297 - _globals['_CALLER_TAGSENTRY']._serialized_end=340 - _globals['_CALLER_REQUEST']._serialized_start=342 - _globals['_CALLER_REQUEST']._serialized_end=363 - _globals['_DATASOURCESCHEMA']._serialized_start=365 - _globals['_DATASOURCESCHEMA']._serialized_end=479 - _globals['_COLLECTIONSCHEMA']._serialized_start=482 - _globals['_COLLECTIONSCHEMA']._serialized_end=858 - _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._serialized_start=715 - _globals['_COLLECTIONSCHEMA_ACTIONSENTRY']._serialized_end=786 - _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._serialized_start=788 - _globals['_COLLECTIONSCHEMA_FIELDSENTRY']._serialized_end=858 -# @@protoc_insertion_point(module_scope) diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi deleted file mode 100644 index 0dcea8fa..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2.pyi +++ /dev/null @@ -1,82 +0,0 @@ -from google.protobuf import any_pb2 as _any_pb2 -from google.protobuf import struct_pb2 as _struct_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class Caller(_message.Message): - __slots__ = ("RendererId", "UserId", "Tags", "Email", "FirstName", "LastName", "Team", "Timezone") - class TagsEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - class Request(_message.Message): - __slots__ = ("Ip",) - IP_FIELD_NUMBER: _ClassVar[int] - Ip: str - def __init__(self, Ip: _Optional[str] = ...) -> None: ... - RENDERERID_FIELD_NUMBER: _ClassVar[int] - USERID_FIELD_NUMBER: _ClassVar[int] - TAGS_FIELD_NUMBER: _ClassVar[int] - EMAIL_FIELD_NUMBER: _ClassVar[int] - FIRSTNAME_FIELD_NUMBER: _ClassVar[int] - LASTNAME_FIELD_NUMBER: _ClassVar[int] - TEAM_FIELD_NUMBER: _ClassVar[int] - TIMEZONE_FIELD_NUMBER: _ClassVar[int] - RendererId: int - UserId: int - Tags: _containers.ScalarMap[str, str] - Email: str - FirstName: str - LastName: str - Team: str - Timezone: str - def __init__(self, RendererId: _Optional[int] = ..., UserId: _Optional[int] = ..., Tags: _Optional[_Mapping[str, str]] = ..., Email: _Optional[str] = ..., FirstName: _Optional[str] = ..., LastName: _Optional[str] = ..., Team: _Optional[str] = ..., Timezone: _Optional[str] = ...) -> None: ... - -class DataSourceSchema(_message.Message): - __slots__ = ("Charts", "ConnectionNames", "Collections") - CHARTS_FIELD_NUMBER: _ClassVar[int] - CONNECTIONNAMES_FIELD_NUMBER: _ClassVar[int] - COLLECTIONS_FIELD_NUMBER: _ClassVar[int] - Charts: _containers.RepeatedScalarFieldContainer[str] - ConnectionNames: _containers.RepeatedScalarFieldContainer[str] - Collections: _containers.RepeatedCompositeFieldContainer[CollectionSchema] - def __init__(self, Charts: _Optional[_Iterable[str]] = ..., ConnectionNames: _Optional[_Iterable[str]] = ..., Collections: _Optional[_Iterable[_Union[CollectionSchema, _Mapping]]] = ...) -> None: ... - -class CollectionSchema(_message.Message): - __slots__ = ("name", "actions", "fields", "charts", "searchable", "countable", "segments") - class ActionsEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: _struct_pb2.Struct - def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... - class FieldsEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: _struct_pb2.Struct - def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... - NAME_FIELD_NUMBER: _ClassVar[int] - ACTIONS_FIELD_NUMBER: _ClassVar[int] - FIELDS_FIELD_NUMBER: _ClassVar[int] - CHARTS_FIELD_NUMBER: _ClassVar[int] - SEARCHABLE_FIELD_NUMBER: _ClassVar[int] - COUNTABLE_FIELD_NUMBER: _ClassVar[int] - SEGMENTS_FIELD_NUMBER: _ClassVar[int] - name: str - actions: _containers.MessageMap[str, _struct_pb2.Struct] - fields: _containers.MessageMap[str, _struct_pb2.Struct] - charts: _containers.RepeatedScalarFieldContainer[str] - searchable: bool - countable: bool - segments: _containers.RepeatedScalarFieldContainer[str] - def __init__(self, name: _Optional[str] = ..., actions: _Optional[_Mapping[str, _struct_pb2.Struct]] = ..., fields: _Optional[_Mapping[str, _struct_pb2.Struct]] = ..., charts: _Optional[_Iterable[str]] = ..., searchable: bool = ..., countable: bool = ..., segments: _Optional[_Iterable[str]] = ...) -> None: ... diff --git a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py b/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py deleted file mode 100644 index 1b1116d8..00000000 --- a/src/rpc_common/forestadmin/rpc_common/proto/forest_pb2_grpc.py +++ /dev/null @@ -1,24 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - - -GRPC_GENERATED_VERSION = '1.70.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in forestadmin/rpc_common/proto/forest_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) diff --git a/src/rpc_common/pyproject.toml b/src/rpc_common/pyproject.toml index 2313e73d..2e292c68 100644 --- a/src/rpc_common/pyproject.toml +++ b/src/rpc_common/pyproject.toml @@ -16,9 +16,3 @@ include = "forestadmin" [tool.poetry.dependencies] python = ">=3.8,<3.14" -grpcio = ">1.7.0" -grpcio-tools = { version = "*", optional = true } - -[tool.poe.tasks] -cleanup = { shell = "find forestadmin/rpc_common/proto -name '*.py*' -exec rm -f {} \\;" } -build = { shell = "python -m grpc_tools.protoc -I./ --python_out=. --pyi_out=. --grpc_python_out=. forestadmin/rpc_common/proto/*.proto" } \ No newline at end of file From e9ba2e9cfe1e2f7b74faed28392e21807d84ee5a Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 26 Mar 2025 17:10:08 +0100 Subject: [PATCH 19/34] chore: remove json dump and load from serializer --- .../forestadmin/rpc_common/serializers/schema/schema.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py index 1534204c..8d644a7d 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py @@ -49,7 +49,7 @@ class SchemaSerializer: def __init__(self, datasource: Datasource) -> None: self.datasource = datasource - async def serialize(self) -> str: + async def serialize(self) -> dict: value = { "version": self.VERSION, "data": { @@ -61,7 +61,7 @@ async def serialize(self) -> str: }, }, } - return json.dumps(value) + return value async def _serialize_collection(self, collection: Collection) -> dict: return { @@ -166,8 +166,7 @@ async def _serialize_relation(self, relation: RelationAlias) -> dict: class SchemaDeserializer: VERSION = "1.0" - def deserialize(self, json_schema) -> dict: - schema = json.loads(json_schema) + def deserialize(self, schema) -> dict: if schema["version"] != self.VERSION: raise ValueError(f"Unsupported schema version {schema['version']}") From ab1b218b6c1849a4c987ccd919f2e470efcdaf6d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 26 Mar 2025 17:10:42 +0100 Subject: [PATCH 20/34] chore: continue cleanup --- src/datasource_rpc/forestadmin/datasource_rpc/datasource.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index 0c849484..0213a071 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -1,14 +1,12 @@ import asyncio import hashlib import json -import os -import signal import time -from threading import Thread, Timer +from threading import Thread from typing import Any, Dict -import grpc import urllib3 +from forestadmin.agent_toolkit.agent import Agent from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_rpc.collection import RPCCollection From 031d6d5c5e6eb69d2a7ca9ffe1f1b02a26a926b0 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 26 Mar 2025 17:28:54 +0100 Subject: [PATCH 21/34] chore: forgot to commit this --- src/agent_rpc/forestadmin/agent_rpc/agent.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/agent_rpc/forestadmin/agent_rpc/agent.py b/src/agent_rpc/forestadmin/agent_rpc/agent.py index 4257d1ce..b415fa25 100644 --- a/src/agent_rpc/forestadmin/agent_rpc/agent.py +++ b/src/agent_rpc/forestadmin/agent_rpc/agent.py @@ -14,7 +14,11 @@ from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType from forestadmin.datasource_toolkit.utils.schema import SchemaUtils from forestadmin.rpc_common.hmac import is_valid_hmac -from forestadmin.rpc_common.serializers.actions import ActionFormSerializer, ActionResultSerializer +from forestadmin.rpc_common.serializers.actions import ( + ActionFormSerializer, + ActionFormValuesSerializer, + ActionResultSerializer, +) from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer from forestadmin.rpc_common.serializers.collection.filter import ( @@ -108,7 +112,7 @@ async def sse_handler(self, request: web.Request) -> web.StreamResponse: async def schema(self, request): await self.customizer.get_datasource() - return web.Response(text=await SchemaSerializer(await self.customizer.get_datasource()).serialize()) + return web.Response(text=json.dumps(await SchemaSerializer(await self.customizer.get_datasource()).serialize())) async def collection_list(self, request: web.Request): body_params = await request.json() @@ -183,7 +187,7 @@ async def collection_get_form(self, request: web.Request): caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None action_name = body_params["actionName"] filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None - data = body_params["data"] + data = ActionFormValuesSerializer.deserialize(body_params["data"]) meta = body_params["meta"] form = await collection.get_form(caller, action_name, data, filter_, meta) @@ -202,7 +206,7 @@ async def collection_execute(self, request: web.Request): caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None action_name = body_params["actionName"] filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None - data = body_params["data"] + data = ActionFormValuesSerializer.deserialize(body_params["data"]) result = await collection.execute(caller, action_name, data, filter_) return web.Response( From 96cf7af13bede377ca44c6de51ba1d3abc25bd03 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 26 Mar 2025 17:35:36 +0100 Subject: [PATCH 22/34] chore: add segments from rpc datasource --- src/datasource_rpc/forestadmin/datasource_rpc/datasource.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index 0213a071..ff57d082 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -70,6 +70,7 @@ def create_collection(self, collection_name, collection_schema): collection.add_rpc_action(action_name, action_schema) collection._schema["charts"] = {name: None for name in collection_schema["charts"]} + collection.add_segments(collection_schema["segments"]) self.add_collection(collection) From bb1cd9d954ad2c755f7fba2246d9a3ca408345bd Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 27 Mar 2025 10:54:06 +0100 Subject: [PATCH 23/34] chore: handle searchable --- .../forestadmin/datasource_rpc/datasource.py | 2 ++ .../datasource_customizer/collection_customizer.py | 7 +++++++ .../datasource_toolkit/decorators/search/collections.py | 9 +++++++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index ff57d082..259fa752 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -70,6 +70,8 @@ def create_collection(self, collection_name, collection_schema): collection.add_rpc_action(action_name, action_schema) collection._schema["charts"] = {name: None for name in collection_schema["charts"]} + collection._schema["searchable"] = collection_schema["searchable"] + collection._schema["countable"] = collection_schema["countable"] collection.add_segments(collection_schema["segments"]) self.add_collection(collection) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py index e0da61b7..50d6d46d 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py @@ -550,6 +550,13 @@ async def _disable_count(): self.stack.queue_customization(_disable_count) return self + def bypass_search(self) -> Self: + async def _bypass_search(): + cast(SearchCollectionDecorator, self.stack.search.get_collection(self.collection_name)).bypass_search() + + self.stack.queue_customization(_bypass_search) + return self + def replace_search(self, definition: SearchDefinition) -> Self: """Replace the behavior of the search bar Args: diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py index 673c5803..d4abe308 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py @@ -33,11 +33,16 @@ def __init__(self, collection: Collection, datasource: Datasource[BoundCollectio super().__init__(collection, datasource) self._replacer: SearchDefinition = None self._searchable = len(self._get_searchable_fields(self.child_collection, False)) > 0 + self._bypass_searchable = False def disable_search(self): self._searchable = False self.mark_schema_as_dirty() + def bypass_search(self): + self._bypass_searchable = True + self.mark_schema_as_dirty() + def replace_search(self, replacer: SearchDefinition): self._replacer = replacer self._searchable = True @@ -46,7 +51,7 @@ def replace_search(self, replacer: SearchDefinition): def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: return { **sub_schema, - "searchable": self._searchable, + "searchable": self._searchable if self._bypass_searchable is False else sub_schema["searchable"], } def _default_replacer(self, search: str, extended: bool) -> ConditionTree: @@ -58,7 +63,7 @@ def _default_replacer(self, search: str, extended: bool) -> ConditionTree: async def _refine_filter( self, caller: User, _filter: Union[Optional[PaginatedFilter], Optional[Filter]] ) -> Optional[Union[PaginatedFilter, Filter]]: - if _filter is None: + if _filter is None or self._bypass_searchable: return _filter # Search string is not significant From 6e10e3e7be70372662a38ef7228443c9c991abc9 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 27 Mar 2025 10:57:02 +0100 Subject: [PATCH 24/34] chore: second solution of search --- .../datasource_toolkit/decorators/search/collections.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py index d4abe308..40acfe65 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py @@ -40,6 +40,7 @@ def disable_search(self): self.mark_schema_as_dirty() def bypass_search(self): + # only used when a disable search was done on server RPC self._bypass_searchable = True self.mark_schema_as_dirty() @@ -49,9 +50,11 @@ def replace_search(self, replacer: SearchDefinition): self.mark_schema_as_dirty() def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: + if sub_schema["searchable"] is True: + self._bypass_searchable = True return { **sub_schema, - "searchable": self._searchable if self._bypass_searchable is False else sub_schema["searchable"], + "searchable": self._searchable if self._bypass_searchable is False else True, } def _default_replacer(self, search: str, extended: bool) -> ConditionTree: From 8eec33d77a21a72530a56b5e91913d4a9658f00b Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 27 Mar 2025 11:51:29 +0100 Subject: [PATCH 25/34] chore: finalize search behavior --- .../forestadmin/datasource_rpc/datasource.py | 12 ++++++++---- .../collection_customizer.py | 7 ------- .../decorators/search/collections.py | 15 ++------------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index 259fa752..6df48649 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -70,10 +70,12 @@ def create_collection(self, collection_name, collection_schema): collection.add_rpc_action(action_name, action_schema) collection._schema["charts"] = {name: None for name in collection_schema["charts"]} - collection._schema["searchable"] = collection_schema["searchable"] - collection._schema["countable"] = collection_schema["countable"] collection.add_segments(collection_schema["segments"]) + if collection_schema["countable"]: + collection.enable_count() + collection.enable_search() + self.add_collection(collection) def wait_for_connection(self): @@ -87,8 +89,10 @@ def wait_for_connection(self): async def internal_reload(self): has_changed = self.introspect() - if has_changed and self.reload_method is not None: - await call_user_function(self.reload_method) + if has_changed: + await Agent.get_instance().reload() + # if has_changed and self.reload_method is not None: + # await call_user_function(self.reload_method) async def run(self): self.wait_for_connection() diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py index 50d6d46d..e0da61b7 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/collection_customizer.py @@ -550,13 +550,6 @@ async def _disable_count(): self.stack.queue_customization(_disable_count) return self - def bypass_search(self) -> Self: - async def _bypass_search(): - cast(SearchCollectionDecorator, self.stack.search.get_collection(self.collection_name)).bypass_search() - - self.stack.queue_customization(_bypass_search) - return self - def replace_search(self, definition: SearchDefinition) -> Self: """Replace the behavior of the search bar Args: diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py index 40acfe65..857317ed 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py @@ -33,29 +33,18 @@ def __init__(self, collection: Collection, datasource: Datasource[BoundCollectio super().__init__(collection, datasource) self._replacer: SearchDefinition = None self._searchable = len(self._get_searchable_fields(self.child_collection, False)) > 0 - self._bypass_searchable = False def disable_search(self): self._searchable = False self.mark_schema_as_dirty() - def bypass_search(self): - # only used when a disable search was done on server RPC - self._bypass_searchable = True - self.mark_schema_as_dirty() - def replace_search(self, replacer: SearchDefinition): self._replacer = replacer self._searchable = True self.mark_schema_as_dirty() def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: - if sub_schema["searchable"] is True: - self._bypass_searchable = True - return { - **sub_schema, - "searchable": self._searchable if self._bypass_searchable is False else True, - } + return {**sub_schema, "searchable": self._searchable} def _default_replacer(self, search: str, extended: bool) -> ConditionTree: searchable_fields = self._get_searchable_fields(self.child_collection, extended) @@ -66,7 +55,7 @@ def _default_replacer(self, search: str, extended: bool) -> ConditionTree: async def _refine_filter( self, caller: User, _filter: Union[Optional[PaginatedFilter], Optional[Filter]] ) -> Optional[Union[PaginatedFilter, Filter]]: - if _filter is None or self._bypass_searchable: + if _filter is None: return _filter # Search string is not significant From 76e4c19f07a397b2f1ecf48921567efd3da1bee7 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 2 Apr 2025 10:18:12 +0200 Subject: [PATCH 26/34] chore: cleanup and refactor --- src/agent_rpc/forestadmin/agent_rpc/agent.py | 110 +++------ .../forestadmin/datasource_rpc/collection.py | 220 ++++++------------ .../forestadmin/datasource_rpc/datasource.py | 138 +++++------ .../forestadmin/datasource_rpc/main.py | 19 -- .../forestadmin/datasource_rpc/requester.py | 186 +++++++++++++++ src/datasource_rpc/pyproject.toml | 1 + .../rpc_common/serializers/actions.py | 2 + .../forestadmin/rpc_common/serializers/aes.py | 22 -- .../rpc_common/serializers/utils.py | 2 + 9 files changed, 349 insertions(+), 351 deletions(-) delete mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/main.py create mode 100644 src/datasource_rpc/forestadmin/datasource_rpc/requester.py delete mode 100644 src/rpc_common/forestadmin/rpc_common/serializers/aes.py diff --git a/src/agent_rpc/forestadmin/agent_rpc/agent.py b/src/agent_rpc/forestadmin/agent_rpc/agent.py index b415fa25..d6d8d40a 100644 --- a/src/agent_rpc/forestadmin/agent_rpc/agent.py +++ b/src/agent_rpc/forestadmin/agent_rpc/agent.py @@ -1,16 +1,12 @@ import asyncio import json -from enum import Enum from uuid import UUID from aiohttp import web from aiohttp_sse import sse_response - -# import grpc from forestadmin.agent_rpc.options import RpcOptions - -# from forestadmin.agent_rpc.services.datasource import DatasourceService from forestadmin.agent_toolkit.agent import Agent +from forestadmin.agent_toolkit.options import Options from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType from forestadmin.datasource_toolkit.utils.schema import SchemaUtils from forestadmin.rpc_common.hmac import is_valid_hmac @@ -30,48 +26,19 @@ from forestadmin.rpc_common.serializers.schema.schema import SchemaSerializer from forestadmin.rpc_common.serializers.utils import CallerSerializer -# from concurrent import futures - - -# from forestadmin.rpc_common.proto import datasource_pb2_grpc - - -class RcpJsonEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, Enum): - return o.value - if isinstance(o, set): - return list(sorted(o, key=lambda x: x.value if isinstance(x, Enum) else str(x))) - if isinstance(o, set): - return list(sorted(o, key=lambda x: x.value if isinstance(x, Enum) else str(x))) - - try: - return super().default(o) - except Exception as exc: - print(f"error on seriliaze {o}, {type(o)}: {exc}") - class RpcAgent(Agent): - # TODO: options to add: - # * listen addr def __init__(self, options: RpcOptions): - # self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) - # self.server = grpc.aio.server() self.listen_addr, self.listen_port = options["listen_addr"].rsplit(":", 1) + agent_options: Options = {**options} # type:ignore + agent_options["skip_schema_update"] = True + agent_options["env_secret"] = "f" * 64 + agent_options["server_url"] = "http://fake" + agent_options["schema_path"] = "./.forestadmin-schema.json" + super().__init__(agent_options) + self.app = web.Application(middlewares=[self.hmac_middleware]) - # self.server.add_insecure_port(options["listen_addr"]) - options["skip_schema_update"] = True - options["env_secret"] = "f" * 64 - options["server_url"] = "http://fake" - # options["auth_secret"] = "f48186505a3c5d62c27743126d6a76c1dd8b3e2d8897de19" - options["schema_path"] = "./.forestadmin-schema.json" - super().__init__(options) - - self.aes_key = self.options["auth_secret"][:16].encode() - self.aes_iv = self.options["auth_secret"][-16:].encode() - self._server_stop = False self.setup_routes() - # signal.signal(signal.SIGUSR1, self.stop_handler) @web.middleware async def hmac_middleware(self, request: web.Request, handler): @@ -80,11 +47,10 @@ async def hmac_middleware(self, request: web.Request, handler): if not is_valid_hmac( self.options["auth_secret"].encode(), body, request.headers.get("X-FOREST-HMAC", "").encode("utf-8") ): - return web.Response(status=401) + return web.Response(status=401, text="Unauthorized from HMAC verification") return await handler(request) def setup_routes(self): - # self.app.middlewares.append(self.hmac_middleware) self.app.router.add_route("GET", "/sse", self.sse_handler) self.app.router.add_route("GET", "/schema", self.schema) self.app.router.add_route("POST", "/collection/list", self.collection_list) @@ -98,11 +64,11 @@ def setup_routes(self): self.app.router.add_route("POST", "/execute-native-query", self.native_query) self.app.router.add_route("POST", "/render-chart", self.render_chart) - self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) + self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) # type: ignore async def sse_handler(self, request: web.Request) -> web.StreamResponse: async with sse_response(request) as resp: - while resp.is_connected() and not self._server_stop: + while resp.is_connected(): await resp.send("", event="heartbeat") await asyncio.sleep(1) data = json.dumps({"event": "RpcServerStop"}) @@ -112,29 +78,28 @@ async def sse_handler(self, request: web.Request) -> web.StreamResponse: async def schema(self, request): await self.customizer.get_datasource() - return web.Response(text=json.dumps(await SchemaSerializer(await self.customizer.get_datasource()).serialize())) + return web.json_response(await SchemaSerializer(await self.customizer.get_datasource()).serialize()) async def collection_list(self, request: web.Request): body_params = await request.json() ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) caller = CallerSerializer.deserialize(body_params["caller"]) - filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection) + filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection) # type:ignore projection = ProjectionSerializer.deserialize(body_params["projection"]) records = await collection.list(caller, filter_, projection) records = [RecordSerializer.serialize(record) for record in records] - return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv)) + return web.json_response(records) async def collection_create(self, request: web.Request): body_params = await request.text() - body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) body_params = json.loads(body_params) ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) caller = CallerSerializer.deserialize(body_params["caller"]) - data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] + data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] # type:ignore records = await collection.create(caller, data) records = [RecordSerializer.serialize(record) for record in records] @@ -142,14 +107,13 @@ async def collection_create(self, request: web.Request): async def collection_update(self, request: web.Request): body_params = await request.text() - body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) body_params = json.loads(body_params) ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) caller = CallerSerializer.deserialize(body_params["caller"]) - filter_ = FilterSerializer.deserialize(body_params["filter"], collection) - patch = RecordSerializer.deserialize(body_params["patch"], collection) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + patch = RecordSerializer.deserialize(body_params["patch"], collection) # type:ignore await collection.update(caller, filter_, patch) return web.Response(text="OK") @@ -159,7 +123,7 @@ async def collection_delete(self, request: web.Request): ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) caller = CallerSerializer.deserialize(body_params["caller"]) - filter_ = FilterSerializer.deserialize(body_params["filter"], collection) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore await collection.delete(caller, filter_) return web.Response(text="OK") @@ -169,53 +133,51 @@ async def collection_aggregate(self, request: web.Request): ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) caller = CallerSerializer.deserialize(body_params["caller"]) - filter_ = FilterSerializer.deserialize(body_params["filter"], collection) + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore aggregation = AggregationSerializer.deserialize(body_params["aggregation"]) records = await collection.aggregate(caller, filter_, aggregation) - # records = [RecordSerializer.serialize(record) for record in records] - return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv)) + return web.json_response(records) async def collection_get_form(self, request: web.Request): body_params = await request.text() - body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) body_params = json.loads(body_params) ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) - caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None + caller = CallerSerializer.deserialize(body_params["caller"]) action_name = body_params["actionName"] - filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None + if body_params["filter"]: + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + else: + filter_ = None data = ActionFormValuesSerializer.deserialize(body_params["data"]) meta = body_params["meta"] form = await collection.get_form(caller, action_name, data, filter_, meta) - return web.Response( - text=aes_encrypt(json.dumps(ActionFormSerializer.serialize(form)), self.aes_key, self.aes_iv) - ) + return web.json_response(ActionFormSerializer.serialize(form)) async def collection_execute(self, request: web.Request): body_params = await request.text() - body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) body_params = json.loads(body_params) ds = await self.customizer.get_datasource() collection = ds.get_collection(body_params["collectionName"]) - caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None + caller = CallerSerializer.deserialize(body_params["caller"]) action_name = body_params["actionName"] - filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None + if body_params["filter"]: + filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore + else: + filter_ = None data = ActionFormValuesSerializer.deserialize(body_params["data"]) result = await collection.execute(caller, action_name, data, filter_) - return web.Response( - text=aes_encrypt(json.dumps(ActionResultSerializer.serialize(result)), self.aes_key, self.aes_iv) - ) + return web.json_response(ActionResultSerializer.serialize(result)) # type:ignore async def collection_render_chart(self, request: web.Request): body_params = await request.text() - body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) body_params = json.loads(body_params) ds = await self.customizer.get_datasource() @@ -244,11 +206,10 @@ async def collection_render_chart(self, request: web.Request): ret.append(value) result = await collection.render_chart(caller, name, record_id) - return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv)) + return web.json_response(result) async def render_chart(self, request: web.Request): body_params = await request.text() - body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) body_params = json.loads(body_params) ds = await self.customizer.get_datasource() @@ -257,11 +218,10 @@ async def render_chart(self, request: web.Request): name = body_params["name"] result = await ds.render_chart(caller, name) - return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv)) + return web.json_response(result) async def native_query(self, request: web.Request): body_params = await request.text() - body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv) body_params = json.loads(body_params) ds = await self.customizer.get_datasource() @@ -270,7 +230,7 @@ async def native_query(self, request: web.Request): parameters = body_params["parameters"] result = await ds.execute_native_query(connection_name, native_query, parameters) - return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv)) + return web.json_response(result) def start(self): web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port)) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py index 96c9bc14..fed9725a 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py @@ -1,11 +1,9 @@ import json -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional -import urllib3 from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.collections import Collection -from forestadmin.datasource_toolkit.datasources import Datasource from forestadmin.datasource_toolkit.interfaces.actions import Action, ActionFormElement, ActionResult, ActionsScope from forestadmin.datasource_toolkit.interfaces.chart import Chart from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType @@ -15,13 +13,11 @@ from forestadmin.datasource_toolkit.interfaces.query.projections import Projection from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias from forestadmin.datasource_toolkit.utils.schema import SchemaUtils -from forestadmin.rpc_common.hmac import generate_hmac from forestadmin.rpc_common.serializers.actions import ( ActionFormSerializer, ActionFormValuesSerializer, ActionResultSerializer, ) -from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer from forestadmin.rpc_common.serializers.collection.filter import ( FilterSerializer, @@ -31,16 +27,19 @@ from forestadmin.rpc_common.serializers.collection.record import RecordSerializer from forestadmin.rpc_common.serializers.utils import CallerSerializer +if TYPE_CHECKING: + from forestadmin.datasource_rpc.datasource import RPCDatasource + class RPCCollection(Collection): - def __init__(self, name: str, datasource: Datasource, connection_uri: str, secret_key: str): + def __init__(self, name: str, datasource: "RPCDatasource"): super().__init__(name, datasource) - self.connection_uri = connection_uri - self.secret_key = secret_key - self.aes_key = secret_key[:16].encode() - self.aes_iv = secret_key[-16:].encode() - self.http = urllib3.PoolManager() self._rpc_actions = {} + self._datasource = datasource + + @property + def datasource(self) -> "RPCDatasource": + return self._datasource def add_rpc_action(self, name: str, action: Dict[str, Any]) -> None: if name in self._schema["actions"]: @@ -50,98 +49,59 @@ def add_rpc_action(self, name: str, action: Dict[str, Any]) -> None: description=action.get("description"), submit_button_label=action.get("submit_button_label"), generate_file=action.get("generate_file", False), - # form=action.get("form"), static_form=action["static_form"], ) self._rpc_actions[name] = {"form": action["form"]} async def list(self, caller: User, filter_: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]: - body = json.dumps( - { - "caller": CallerSerializer.serialize(caller), - "filter": PaginatedFilterSerializer.serialize(filter_, self), - "projection": ProjectionSerializer.serialize(projection), - "collectionName": self.name, - } - ) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/collection/list", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) - ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) - - return [RecordSerializer.deserialize(record, self) for record in json.loads(ret)] + body = { + "caller": CallerSerializer.serialize(caller), + "filter": PaginatedFilterSerializer.serialize(filter_, self), + "projection": ProjectionSerializer.serialize(projection), + "collectionName": self.name, + } + ret = await self.datasource.requester.list(body=body) + return [RecordSerializer.deserialize(record, self) for record in ret] async def create(self, caller: User, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - body = json.dumps( - { - "caller": CallerSerializer.serialize(caller), - "data": [RecordSerializer.serialize(r) for r in data], - "collectionName": self.name, - } - ) - body = aes_encrypt(body, self.aes_key, self.aes_iv) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/collection/create", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) - return [RecordSerializer.deserialize(record, self) for record in json.loads(response.data.decode("utf-8"))] + body = { + "caller": CallerSerializer.serialize(caller), + "data": [RecordSerializer.serialize(r) for r in data], + "collectionName": self.name, + } + response = await self.datasource.requester.create(body) + return [RecordSerializer.deserialize(record, self) for record in response] async def update(self, caller: User, filter_: Optional[Filter], patch: Dict[str, Any]) -> None: - body = json.dumps( - { - "caller": CallerSerializer.serialize(caller), - "filter": FilterSerializer.serialize(filter_, self), - "patch": RecordSerializer.serialize(patch), - "collectionName": self.name, - } - ) - body = aes_encrypt(body, self.aes_key, self.aes_iv) - self.http.request( - "POST", - f"http://{self.connection_uri}/collection/update", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) + body = { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), # type: ignore + "patch": RecordSerializer.serialize(patch), + "collectionName": self.name, + } + await self.datasource.requester.update(body) async def delete(self, caller: User, filter_: Filter | None) -> None: body = json.dumps( { "caller": CallerSerializer.serialize(caller), - "filter": FilterSerializer.serialize(filter_, self), + "filter": FilterSerializer.serialize(filter_, self), # type: ignore "collectionName": self.name, } ) - self.http.request( - "POST", - f"http://{self.connection_uri}/collection/delete", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) + await self.datasource.requester.delete(body) async def aggregate( self, caller: User, filter_: Filter | None, aggregation: Aggregation, limit: int | None = None ) -> List[AggregateResult]: - body = json.dumps( - { - "caller": CallerSerializer.serialize(caller), - "filter": FilterSerializer.serialize(filter_, self), - "aggregation": AggregationSerializer.serialize(aggregation), - "collectionName": self.name, - } - ) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/collection/aggregate", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) - ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) - return json.loads(ret) + body = { + "caller": CallerSerializer.serialize(caller), + "filter": FilterSerializer.serialize(filter_, self), # type: ignore + "aggregation": AggregationSerializer.serialize(aggregation), + "collectionName": self.name, + } + response = await self.datasource.requester.aggregate(body) + return response async def get_form( self, @@ -157,29 +117,16 @@ async def get_form( if self._schema["actions"][name].static_form: return self._rpc_actions[name]["form"] - body = aes_encrypt( - json.dumps( - { - "caller": CallerSerializer.serialize(caller) if caller is not None else None, - "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, - "data": ActionFormValuesSerializer.serialize(data), - "meta": meta, - "collectionName": self.name, - "actionName": name, - } - ), - self.aes_key, - self.aes_iv, - ) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/collection/get-form", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) - ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) - ret = ActionFormSerializer.deserialize(json.loads(ret)) - return ret + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, + "data": ActionFormValuesSerializer.serialize(data), # type: ignore + "meta": meta, + "collectionName": self.name, + "actionName": name, + } + response = await self.datasource.requester.get_form(body) + return ActionFormSerializer.deserialize(response) # type: ignore async def execute( self, @@ -191,29 +138,15 @@ async def execute( if name not in self._schema["actions"]: raise ValueError(f"Action {name} does not exist in collection {self.name}") - body = aes_encrypt( - json.dumps( - { - "caller": CallerSerializer.serialize(caller) if caller is not None else None, - "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, - "data": ActionFormValuesSerializer.serialize(data), - "collectionName": self.name, - "actionName": name, - } - ), - self.aes_key, - self.aes_iv, - ) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/collection/execute", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) - ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) - ret = json.loads(ret) - ret = ActionResultSerializer.deserialize(ret) - return ret + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "filter": FilterSerializer.serialize(filter_, self) if filter_ is not None else None, + "data": ActionFormValuesSerializer.serialize(data), + "collectionName": self.name, + "actionName": name, + } + response = await self.datasource.requester.execute(body) + return ActionResultSerializer.deserialize(response) async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: if name not in self._schema["charts"].keys(): @@ -221,7 +154,8 @@ async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: ret = [] for i, value in enumerate(record_id): - type_record_id = self.schema["fields"][SchemaUtils.get_primary_keys(self.schema)[i]]["column_type"] + pk_field = SchemaUtils.get_primary_keys(self.schema)[i] + type_record_id = self.schema["fields"][pk_field]["column_type"] # type: ignore if type_record_id == PrimitiveType.DATE: ret.append(value.isoformat()) @@ -236,27 +170,13 @@ async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: else: ret.append(value) - body = aes_encrypt( - json.dumps( - { - "caller": CallerSerializer.serialize(caller) if caller is not None else None, - "name": name, - "collectionName": self.name, - "recordId": ret, - } - ), - self.aes_key, - self.aes_iv, - ) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/collection/render-chart", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) - ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) - ret = json.loads(ret) - return ret + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "name": name, + "collectionName": self.name, + "recordId": ret, + } + return await self.datasource.requester.collection_render_chart(body) def get_native_driver(self): ForestLogger.log( diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index 6df48649..b69dcc66 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -1,54 +1,58 @@ import asyncio import hashlib import json -import time from threading import Thread from typing import Any, Dict -import urllib3 -from forestadmin.agent_toolkit.agent import Agent from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_rpc.collection import RPCCollection +from forestadmin.datasource_rpc.requester import RPCRequester from forestadmin.datasource_toolkit.datasources import Datasource from forestadmin.datasource_toolkit.interfaces.chart import Chart from forestadmin.datasource_toolkit.utils.user_callable import call_user_function -from forestadmin.rpc_common.hmac import generate_hmac -from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.schema.schema import SchemaDeserializer from forestadmin.rpc_common.serializers.utils import CallerSerializer -from sseclient import SSEClient -def hash_schema(schema) -> str: +def _hash_schema(schema) -> str: return hashlib.sha256(json.dumps(schema, sort_keys=True).encode()).hexdigest() +async def create_rpc_datasource(connection_uri: str, secret_key: str, reload_method=None) -> "RPCDatasource": + """Create a new RPC datasource and wait for the connection to be established.""" + datasource = RPCDatasource(connection_uri, secret_key, reload_method) + await datasource.wait_for_connection() + await datasource.introspect() + return datasource + + class RPCDatasource(Datasource): def __init__(self, connection_uri: str, secret_key: str, reload_method=None): super().__init__([]) self.connection_uri = connection_uri self.reload_method = reload_method self.secret_key = secret_key - self.aes_key = secret_key[:16].encode() - self.aes_iv = secret_key[-16:].encode() self.last_schema_hash = "" - self.http = urllib3.PoolManager() - self.wait_for_connection() - self.introspect() + self.requester = RPCRequester(connection_uri, secret_key) - self.thread = Thread(target=asyncio.run, args=(self.run(),), name="RPCDatasourceSSEThread", daemon=True) + # self.thread = Thread(target=asyncio.run, args=(self.run(),), name="RPCDatasourceSSEThread", daemon=True) + self.thread = Thread( + target=asyncio.run, + args=(self.requester.sse_connect(self.sse_callback),), + name="RPCDatasourceSSEThread", + daemon=True, + ) self.thread.start() - def introspect(self) -> bool: + async def introspect(self) -> bool: """return true if schema has changed""" - response = self.http.request("GET", f"http://{self.connection_uri}/schema") - schema_data = json.loads(response.data.decode("utf-8")) - if self.last_schema_hash == hash_schema(schema_data): + schema_data = await self.requester.schema() + if self.last_schema_hash == _hash_schema(schema_data): ForestLogger.log("debug", "[RPCDatasource] Schema has not changed") return False - self.last_schema_hash = hash_schema(schema_data) + self.last_schema_hash = _hash_schema(schema_data) self._collections = {} self._schema = {"charts": {}} @@ -56,12 +60,12 @@ def introspect(self) -> bool: for collection_name, collection in schema_data["collections"].items(): self.create_collection(collection_name, collection) - self._schema["charts"] = {name: None for name in schema_data["charts"]} + self._schema["charts"] = {name: None for name in schema_data["charts"]} # type: ignore self._live_query_connections = schema_data["live_query_connections"] return True def create_collection(self, collection_name, collection_schema): - collection = RPCCollection(collection_name, self, self.connection_uri, self.secret_key) + collection = RPCCollection(collection_name, self) for name, field in collection_schema["fields"].items(): collection.add_field(name, field) @@ -69,7 +73,7 @@ def create_collection(self, collection_name, collection_schema): for action_name, action_schema in collection_schema["actions"].items(): collection.add_rpc_action(action_name, action_schema) - collection._schema["charts"] = {name: None for name in collection_schema["charts"]} + collection._schema["charts"] = {name: None for name in collection_schema["charts"]} # type: ignore collection.add_segments(collection_schema["segments"]) if collection_schema["countable"]: @@ -78,81 +82,45 @@ def create_collection(self, collection_name, collection_schema): self.add_collection(collection) - def wait_for_connection(self): - while True: - try: - self.http.request("GET", f"http://{self.connection_uri}/") - break - except Exception: - time.sleep(1) + async def wait_for_connection(self): + """Wait for the connection to be established.""" + await self.requester.wait_for_connection() ForestLogger.log("debug", "Connection to RPC datasource established") async def internal_reload(self): - has_changed = self.introspect() - if has_changed: - await Agent.get_instance().reload() - # if has_changed and self.reload_method is not None: - # await call_user_function(self.reload_method) - - async def run(self): - self.wait_for_connection() - self.sse_client = SSEClient( - self.http.request("GET", f"http://{self.connection_uri}/sse", preload_content=False) - ) - try: - for msg in self.sse_client.events(): - pass - except Exception: - pass - ForestLogger.log("info", "rpc connection to server closed") - self.wait_for_connection() + has_changed = await self.introspect() + if has_changed and self.reload_method is not None: + await call_user_function(self.reload_method) + + async def sse_callback(self): + await self.wait_for_connection() await self.internal_reload() - self.thread = Thread(target=asyncio.run, args=(self.run(),), name="RPCDatasourceSSEThread", daemon=True) + self.thread = Thread( + target=asyncio.run, + args=(self.requester.sse_connect(self.sse_callback),), + name="RPCDatasourceSSEThread", + daemon=True, + ) self.thread.start() async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: - body = aes_encrypt( - json.dumps( - { - "connectionName": connection_name, - "nativeQuery": native_query, - "parameters": parameters, - } - ), - self.aes_key, - self.aes_iv, - ) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/execute-native-query", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + return await self.requester.native_query( + { + "connectionName": connection_name, + "nativeQuery": native_query, + "parameters": parameters, + } ) - ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) - ret = json.loads(ret) - return ret async def render_chart(self, caller: User, name: str) -> Chart: if name not in self._schema["charts"].keys(): raise ValueError(f"Chart {name} does not exist in this datasource") - body = aes_encrypt( - json.dumps( - { - "caller": CallerSerializer.serialize(caller) if caller is not None else None, - "name": name, - } - ), - self.aes_key, - self.aes_iv, - ) - response = self.http.request( - "POST", - f"http://{self.connection_uri}/render-chart", - body=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) - ret = aes_decrypt(response.data.decode("utf-8"), self.aes_key, self.aes_iv) - ret = json.loads(ret) - return ret + body = { + "caller": CallerSerializer.serialize(caller) if caller is not None else None, + "name": name, + } + + response = self.requester.collection_render_chart(body) + return response diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/main.py b/src/datasource_rpc/forestadmin/datasource_rpc/main.py deleted file mode 100644 index f8cafde2..00000000 --- a/src/datasource_rpc/forestadmin/datasource_rpc/main.py +++ /dev/null @@ -1,19 +0,0 @@ -import asyncio -import logging - -import grpc -from forestadmin.rpc_common.proto import datasource_pb2, datasource_pb2_grpc -from google.protobuf import empty_pb2 - - -async def run() -> None: - async with grpc.aio.insecure_channel("localhost:50051") as channel: - stub = datasource_pb2_grpc.DataSourceStub(channel) - response = await stub.Schema(empty_pb2.Empty()) - print(response) - print("datasource client received: " + str(response.Collections[0].searchable)) - - -if __name__ == "__main__": - logging.basicConfig() - asyncio.run(run()) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py new file mode 100644 index 00000000..3102b44d --- /dev/null +++ b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py @@ -0,0 +1,186 @@ +import asyncio +from typing import List + +import aiohttp +from aiohttp_sse_client import client as sse_client +from forestadmin.rpc_common.hmac import generate_hmac + + +class RPCRequester: + def __init__(self, connection_uri: str, secret_key: str): + self.connection_uri = connection_uri + self.secret_key = secret_key + self.aes_key = secret_key[:16].encode() + self.aes_iv = secret_key[-16:].encode() + + async def sse_connect(self, callback): + """Connect to the SSE stream.""" + await self.wait_for_connection() + + async with sse_client.EventSource( + f"http://{self.connection_uri}/sse", + ) as event_source: + try: + async for event in event_source: + # print(event) + pass + except Exception: + pass + + await callback() + + async def wait_for_connection(self): + """Wait for the connection to be established.""" + async with aiohttp.ClientSession() as session: + while True: + try: + async with session.get(f"http://{self.connection_uri}/") as response: + if response.status == 200: + return True + else: + raise Exception(f"Failed to connect: {response.status}") + except Exception: + await asyncio.sleep(1) + + # methods for datasource + + async def schema(self) -> dict: + """Get the schema of the datasource.""" + async with aiohttp.ClientSession() as session: + async with session.get(f"http://{self.connection_uri}/schema") as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to get schema: {response.status}") + + async def native_query(self, body): + """Execute a native query.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/execute-native-query", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to execute native query: {response.status}") + + async def datasource_render_chart(self, body): + """Render a chart.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/render-chart", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to render chart: {response.status}") + + # methods for collection + + async def list(self, body) -> list[dict]: + """List records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/list", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to list query: {response.status}") + + async def create(self, body) -> List[dict]: + """Create records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/create", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to create query: {response.status}") + + async def update(self, body): + """Update records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/update", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return + else: + raise Exception(f"Failed to update query: {response.status}") + + async def delete(self, body): + """Delete records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/delete", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return + else: + raise Exception(f"Failed to delete query: {response.status}") + + async def aggregate(self, body): + """Aggregate records in a collection.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/aggregate", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to aggregate query: {response.status}") + + async def get_form(self, body): + """Get the form for an action.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/get-form", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to aggregate query: {response.status}") + + async def execute(self, body): + """Execute an action.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/collection/execute", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to execute query: {response.status}") + + async def collection_render_chart(self, body): + """Render a chart.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"http://{self.connection_uri}/render-chart", + json=body, + headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + ) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"Failed to render chart: {response.status}") diff --git a/src/datasource_rpc/pyproject.toml b/src/datasource_rpc/pyproject.toml index 17371fe8..fb5b92bb 100644 --- a/src/datasource_rpc/pyproject.toml +++ b/src/datasource_rpc/pyproject.toml @@ -20,6 +20,7 @@ typing-extensions = "~=4.2" forestadmin-agent-toolkit = "1.23.0" forestadmin-datasource-toolkit = "1.23.0" forestadmin-rpc-common = "1.23.0" +aiohttp-sse-client = ">=0.2.1" [tool.poetry.dependencies."backports.zoneinfo"] version = "~=0.2.1" diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py index 748b89b0..6f0bdfb7 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py @@ -183,6 +183,8 @@ def deserialize(form: list) -> list[dict]: class ActionFormValuesSerializer: @staticmethod def serialize(form_values: dict) -> dict: + if form_values is None: + return {} ret = {} for key, value in form_values.items(): if isinstance(value, File): diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/aes.py b/src/rpc_common/forestadmin/rpc_common/serializers/aes.py deleted file mode 100644 index 94916d8d..00000000 --- a/src/rpc_common/forestadmin/rpc_common/serializers/aes.py +++ /dev/null @@ -1,22 +0,0 @@ -import base64 - -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - - -def aes_encrypt(message, aes_key: bytes, aes_iv: bytes): - cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) - encryptor = cipher.encryptor() - - # Pad message to 16-byte block size - padded_message = message.ljust(16 * ((len(message) // 16) + 1), "\0") - - encrypted = encryptor.update(padded_message.encode()) + encryptor.finalize() - return base64.b64encode(encrypted).decode() - - -def aes_decrypt(encrypted_message, aes_key: bytes, aes_iv: bytes): - cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) - decryptor = cipher.decryptor() - decrypted_padded = decryptor.update(base64.b64decode(encrypted_message)) + decryptor.finalize() - return decrypted_padded.rstrip(b"\0").decode() diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/utils.py b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py index b59fbfc5..409dc9f2 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/utils.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py @@ -114,6 +114,8 @@ def serialize(caller: User) -> Dict: @staticmethod def deserialize(caller: Dict) -> User: + if caller is None: + return None return User( rendering_id=caller["renderingId"], user_id=caller["userId"], From 51f963d4138013b5bff694f9d9e1c0846caa79c0 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 2 Apr 2025 11:53:43 +0200 Subject: [PATCH 27/34] chore: rpc route are the same as ruby --- src/agent_rpc/forestadmin/agent_rpc/agent.py | 62 ++++++++++--------- .../forestadmin/datasource_rpc/collection.py | 16 ++--- .../forestadmin/datasource_rpc/datasource.py | 17 ++--- .../forestadmin/datasource_rpc/requester.py | 56 +++++++++-------- src/rpc_common/forestadmin/rpc_common/hmac.py | 1 + .../serializers/collection/filter.py | 2 +- 6 files changed, 81 insertions(+), 73 deletions(-) diff --git a/src/agent_rpc/forestadmin/agent_rpc/agent.py b/src/agent_rpc/forestadmin/agent_rpc/agent.py index d6d8d40a..765b473a 100644 --- a/src/agent_rpc/forestadmin/agent_rpc/agent.py +++ b/src/agent_rpc/forestadmin/agent_rpc/agent.py @@ -15,7 +15,6 @@ ActionFormValuesSerializer, ActionResultSerializer, ) -from forestadmin.rpc_common.serializers.aes import aes_decrypt, aes_encrypt from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer from forestadmin.rpc_common.serializers.collection.filter import ( FilterSerializer, @@ -42,6 +41,7 @@ def __init__(self, options: RpcOptions): @web.middleware async def hmac_middleware(self, request: web.Request, handler): + # TODO: handle HMAC like ruby agent if request.method == "POST": body = await request.read() if not is_valid_hmac( @@ -51,20 +51,21 @@ async def hmac_middleware(self, request: web.Request, handler): return await handler(request) def setup_routes(self): + self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) # type: ignore self.app.router.add_route("GET", "/sse", self.sse_handler) self.app.router.add_route("GET", "/schema", self.schema) - self.app.router.add_route("POST", "/collection/list", self.collection_list) - self.app.router.add_route("POST", "/collection/create", self.collection_create) - self.app.router.add_route("POST", "/collection/update", self.collection_update) - self.app.router.add_route("POST", "/collection/delete", self.collection_delete) - self.app.router.add_route("POST", "/collection/aggregate", self.collection_aggregate) - self.app.router.add_route("POST", "/collection/get-form", self.collection_get_form) - self.app.router.add_route("POST", "/collection/execute", self.collection_execute) - self.app.router.add_route("POST", "/collection/render-chart", self.collection_render_chart) - - self.app.router.add_route("POST", "/execute-native-query", self.native_query) - self.app.router.add_route("POST", "/render-chart", self.render_chart) - self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) # type: ignore + + # self.app.router.add_route("POST", "/execute-native-query", self.native_query) + self.app.router.add_route("POST", "/forest/rpc/datasource-chart", self.render_chart) + + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/list", self.collection_list) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/create", self.collection_create) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/update", self.collection_update) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/delete", self.collection_delete) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/aggregate", self.collection_aggregate) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/action-form", self.collection_get_form) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/action-execute", self.collection_execute) + self.app.router.add_route("POST", "/forest/rpc/{collection_name}/chart", self.collection_render_chart) async def sse_handler(self, request: web.Request) -> web.StreamResponse: async with sse_response(request) as resp: @@ -83,7 +84,7 @@ async def schema(self, request): async def collection_list(self, request: web.Request): body_params = await request.json() ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection) # type:ignore projection = ProjectionSerializer.deserialize(body_params["projection"]) @@ -97,7 +98,7 @@ async def collection_create(self, request: web.Request): body_params = json.loads(body_params) ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] # type:ignore @@ -110,7 +111,7 @@ async def collection_update(self, request: web.Request): body_params = json.loads(body_params) ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore patch = RecordSerializer.deserialize(body_params["patch"], collection) # type:ignore @@ -121,7 +122,7 @@ async def collection_update(self, request: web.Request): async def collection_delete(self, request: web.Request): body_params = await request.json() ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore @@ -131,7 +132,7 @@ async def collection_delete(self, request: web.Request): async def collection_aggregate(self, request: web.Request): body_params = await request.json() ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore aggregation = AggregationSerializer.deserialize(body_params["aggregation"]) @@ -144,7 +145,7 @@ async def collection_get_form(self, request: web.Request): body_params = json.loads(body_params) ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) action_name = body_params["actionName"] @@ -163,7 +164,7 @@ async def collection_execute(self, request: web.Request): body_params = json.loads(body_params) ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) action_name = body_params["actionName"] @@ -181,7 +182,7 @@ async def collection_render_chart(self, request: web.Request): body_params = json.loads(body_params) ds = await self.customizer.get_datasource() - collection = ds.get_collection(body_params["collectionName"]) + collection = ds.get_collection(request.match_info["collection_name"]) caller = CallerSerializer.deserialize(body_params["caller"]) name = body_params["name"] @@ -220,17 +221,18 @@ async def render_chart(self, request: web.Request): result = await ds.render_chart(caller, name) return web.json_response(result) - async def native_query(self, request: web.Request): - body_params = await request.text() - body_params = json.loads(body_params) + # TODO: speak about; it's currently not implemented in ruby + # async def native_query(self, request: web.Request): + # body_params = await request.text() + # body_params = json.loads(body_params) - ds = await self.customizer.get_datasource() - connection_name = body_params["connectionName"] - native_query = body_params["nativeQuery"] - parameters = body_params["parameters"] + # ds = await self.customizer.get_datasource() + # connection_name = body_params["connectionName"] + # native_query = body_params["nativeQuery"] + # parameters = body_params["parameters"] - result = await ds.execute_native_query(connection_name, native_query, parameters) - return web.json_response(result) + # result = await ds.execute_native_query(connection_name, native_query, parameters) + # return web.json_response(result) def start(self): web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port)) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py index fed9725a..d1210037 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/collection.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/collection.py @@ -60,7 +60,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project "projection": ProjectionSerializer.serialize(projection), "collectionName": self.name, } - ret = await self.datasource.requester.list(body=body) + ret = await self.datasource.requester.list(self.name, body) return [RecordSerializer.deserialize(record, self) for record in ret] async def create(self, caller: User, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @@ -69,7 +69,7 @@ async def create(self, caller: User, data: List[Dict[str, Any]]) -> List[Dict[st "data": [RecordSerializer.serialize(r) for r in data], "collectionName": self.name, } - response = await self.datasource.requester.create(body) + response = await self.datasource.requester.create(self.name, body) return [RecordSerializer.deserialize(record, self) for record in response] async def update(self, caller: User, filter_: Optional[Filter], patch: Dict[str, Any]) -> None: @@ -79,7 +79,7 @@ async def update(self, caller: User, filter_: Optional[Filter], patch: Dict[str, "patch": RecordSerializer.serialize(patch), "collectionName": self.name, } - await self.datasource.requester.update(body) + await self.datasource.requester.update(self.name, body) async def delete(self, caller: User, filter_: Filter | None) -> None: body = json.dumps( @@ -89,7 +89,7 @@ async def delete(self, caller: User, filter_: Filter | None) -> None: "collectionName": self.name, } ) - await self.datasource.requester.delete(body) + await self.datasource.requester.delete(self.name, body) async def aggregate( self, caller: User, filter_: Filter | None, aggregation: Aggregation, limit: int | None = None @@ -100,7 +100,7 @@ async def aggregate( "aggregation": AggregationSerializer.serialize(aggregation), "collectionName": self.name, } - response = await self.datasource.requester.aggregate(body) + response = await self.datasource.requester.aggregate(self.name, body) return response async def get_form( @@ -125,7 +125,7 @@ async def get_form( "collectionName": self.name, "actionName": name, } - response = await self.datasource.requester.get_form(body) + response = await self.datasource.requester.get_form(self.name, body) return ActionFormSerializer.deserialize(response) # type: ignore async def execute( @@ -145,7 +145,7 @@ async def execute( "collectionName": self.name, "actionName": name, } - response = await self.datasource.requester.execute(body) + response = await self.datasource.requester.execute(self.name, body) return ActionResultSerializer.deserialize(response) async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: @@ -176,7 +176,7 @@ async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: "collectionName": self.name, "recordId": ret, } - return await self.datasource.requester.collection_render_chart(body) + return await self.datasource.requester.collection_render_chart(self.name, body) def get_native_driver(self): ForestLogger.log( diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index b69dcc66..9fbcd818 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -104,14 +104,17 @@ async def sse_callback(self): ) self.thread.start() + # TODO: speak about; it's currently not implemented in ruby + # async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: + # return await self.requester.native_query( + # { + # "connectionName": connection_name, + # "nativeQuery": native_query, + # "parameters": parameters, + # } + # ) async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: - return await self.requester.native_query( - { - "connectionName": connection_name, - "nativeQuery": native_query, - "parameters": parameters, - } - ) + raise NotImplementedError async def render_chart(self, caller: User, name: str) -> Chart: if name not in self._schema["charts"].keys(): diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py index 3102b44d..a474f8c1 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py @@ -53,25 +53,27 @@ async def schema(self) -> dict: else: raise Exception(f"Failed to get schema: {response.status}") - async def native_query(self, body): - """Execute a native query.""" - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://{self.connection_uri}/execute-native-query", - json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, - ) as response: - if response.status == 200: - return await response.json() - else: - raise Exception(f"Failed to execute native query: {response.status}") + # TODO: speak about; it's currently not implemented in ruby + # async def native_query(self, body): + # """Execute a native query.""" + # async with aiohttp.ClientSession() as session: + # async with session.post( + # f"http://{self.connection_uri}/execute-native-query", + # json=body, + # headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + # ) as response: + # if response.status == 200: + # return await response.json() + # else: + # raise Exception(f"Failed to execute native query: {response.status}") async def datasource_render_chart(self, body): """Render a chart.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/render-chart", + f"http://{self.connection_uri}/forest/rpc/datasource-chart", json=body, + # TODO: handle HMAC like ruby agent headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: if response.status == 200: @@ -81,11 +83,11 @@ async def datasource_render_chart(self, body): # methods for collection - async def list(self, body) -> list[dict]: + async def list(self, collection_name, body) -> list[dict]: """List records in a collection.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/collection/list", + f"http://{self.connection_uri}/forest/rpc/{collection_name}/list", json=body, headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: @@ -107,11 +109,11 @@ async def create(self, body) -> List[dict]: else: raise Exception(f"Failed to create query: {response.status}") - async def update(self, body): + async def update(self, collection_name, body): """Update records in a collection.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/collection/update", + f"http://{self.connection_uri}/forest/rpc/{collection_name}/update", json=body, headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: @@ -120,11 +122,11 @@ async def update(self, body): else: raise Exception(f"Failed to update query: {response.status}") - async def delete(self, body): + async def delete(self, collection_name, body): """Delete records in a collection.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/collection/delete", + f"http://{self.connection_uri}/forest/rpc/{collection_name}/delete", json=body, headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: @@ -133,11 +135,11 @@ async def delete(self, body): else: raise Exception(f"Failed to delete query: {response.status}") - async def aggregate(self, body): + async def aggregate(self, collection_name, body): """Aggregate records in a collection.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/collection/aggregate", + f"http://{self.connection_uri}/forest/rpc/{collection_name}/aggregate", json=body, headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: @@ -146,11 +148,11 @@ async def aggregate(self, body): else: raise Exception(f"Failed to aggregate query: {response.status}") - async def get_form(self, body): + async def get_form(self, collection_name, body): """Get the form for an action.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/collection/get-form", + f"http://{self.connection_uri}/forest/rpc/{collection_name}/action-form", json=body, headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: @@ -159,11 +161,11 @@ async def get_form(self, body): else: raise Exception(f"Failed to aggregate query: {response.status}") - async def execute(self, body): + async def execute(self, collection_name, body): """Execute an action.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/collection/execute", + f"http://{self.connection_uri}/forest/rpc/{collection_name}/action-execute", json=body, headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: @@ -172,11 +174,11 @@ async def execute(self, body): else: raise Exception(f"Failed to execute query: {response.status}") - async def collection_render_chart(self, body): + async def collection_render_chart(self, collection_name, body): """Render a chart.""" async with aiohttp.ClientSession() as session: async with session.post( - f"http://{self.connection_uri}/render-chart", + f"http://{self.connection_uri}/forest/rpc/{collection_name}/chart", json=body, headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, ) as response: diff --git a/src/rpc_common/forestadmin/rpc_common/hmac.py b/src/rpc_common/forestadmin/rpc_common/hmac.py index 7348b5eb..bde00b43 100644 --- a/src/rpc_common/forestadmin/rpc_common/hmac.py +++ b/src/rpc_common/forestadmin/rpc_common/hmac.py @@ -3,6 +3,7 @@ def generate_hmac(key: bytes, message: bytes) -> str: + # TODO: handle HMAC like ruby agent return hmac.new(key, message, hashlib.sha256).hexdigest() diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py index 62824c0f..4b7e73f2 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, cast +from typing import Any, Dict, List from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType From 5171eec72750ee16b3649e17eba76b9a50aa4c14 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 7 Apr 2025 11:25:15 +0200 Subject: [PATCH 28/34] chore: simplify PR diff --- .../datasource_toolkit/decorators/search/collections.py | 5 ++++- .../datasource_toolkit/decorators/validation/collection.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py index 857317ed..673c5803 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/search/collections.py @@ -44,7 +44,10 @@ def replace_search(self, replacer: SearchDefinition): self.mark_schema_as_dirty() def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: - return {**sub_schema, "searchable": self._searchable} + return { + **sub_schema, + "searchable": self._searchable, + } def _default_replacer(self, search: str, extended: bool) -> ConditionTree: searchable_fields = self._get_searchable_fields(self.child_collection, extended) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py index 25cbe5d1..1be261a6 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/validation/collection.py @@ -52,7 +52,9 @@ async def update(self, caller: User, filter: Optional[Filter], patch: RecordsDat def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: fields_schema = {**sub_schema["fields"]} for name, rules in self.validations.items(): - fields_schema[name]["validations"] = [*fields_schema[name].get("validations", []), *rules] + if fields_schema[name].get("validations") is None: + fields_schema[name]["validations"] = [] + fields_schema[name]["validations"].extend(rules) return {**sub_schema, "fields": fields_schema} From 4018b8333a7303099978933cd7c278d8094b4ff8 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 8 Apr 2025 14:44:45 +0200 Subject: [PATCH 29/34] chore: same hmac and urls as ruby --- src/agent_rpc/forestadmin/agent_rpc/agent.py | 24 ++++---- .../forestadmin/agent_rpc/hmac_middleware.py | 56 +++++++++++++++++++ .../forestadmin/agent_rpc/options.py | 3 +- .../forestadmin/datasource_rpc/requester.py | 36 +++++++----- 4 files changed, 94 insertions(+), 25 deletions(-) create mode 100644 src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py diff --git a/src/agent_rpc/forestadmin/agent_rpc/agent.py b/src/agent_rpc/forestadmin/agent_rpc/agent.py index 765b473a..e8646736 100644 --- a/src/agent_rpc/forestadmin/agent_rpc/agent.py +++ b/src/agent_rpc/forestadmin/agent_rpc/agent.py @@ -4,12 +4,12 @@ from aiohttp import web from aiohttp_sse import sse_response +from forestadmin.agent_rpc.hmac_middleware import HmacValidationError, HmacValidator from forestadmin.agent_rpc.options import RpcOptions from forestadmin.agent_toolkit.agent import Agent from forestadmin.agent_toolkit.options import Options from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType from forestadmin.datasource_toolkit.utils.schema import SchemaUtils -from forestadmin.rpc_common.hmac import is_valid_hmac from forestadmin.rpc_common.serializers.actions import ( ActionFormSerializer, ActionFormValuesSerializer, @@ -35,25 +35,29 @@ def __init__(self, options: RpcOptions): agent_options["server_url"] = "http://fake" agent_options["schema_path"] = "./.forestadmin-schema.json" super().__init__(agent_options) + self.hmac_validator = HmacValidator(options["auth_secret"]) self.app = web.Application(middlewares=[self.hmac_middleware]) self.setup_routes() @web.middleware async def hmac_middleware(self, request: web.Request, handler): - # TODO: handle HMAC like ruby agent - if request.method == "POST": - body = await request.read() - if not is_valid_hmac( - self.options["auth_secret"].encode(), body, request.headers.get("X-FOREST-HMAC", "").encode("utf-8") - ): - return web.Response(status=401, text="Unauthorized from HMAC verification") + # TODO: hmc on SSE ? + if request.url.path in ["/", "/forest/rpc/sse"]: + return await handler(request) + + header_sign = request.headers.get("X_SIGNATURE") + header_timestamp = request.headers.get("X_TIMESTAMP") + try: + self.hmac_validator.validate_hmac(header_sign, header_timestamp) + except HmacValidationError: + return web.Response(status=401, text="Unauthorized from HMAC verification") return await handler(request) def setup_routes(self): self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) # type: ignore - self.app.router.add_route("GET", "/sse", self.sse_handler) - self.app.router.add_route("GET", "/schema", self.schema) + self.app.router.add_route("GET", "/forest/rpc/sse", self.sse_handler) + self.app.router.add_route("GET", "/forest/rpc-schema", self.schema) # self.app.router.add_route("POST", "/execute-native-query", self.native_query) self.app.router.add_route("POST", "/forest/rpc/datasource-chart", self.render_chart) diff --git a/src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py b/src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py new file mode 100644 index 00000000..46bc0808 --- /dev/null +++ b/src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py @@ -0,0 +1,56 @@ +import datetime + +from forestadmin.datasource_toolkit.exceptions import ForestException +from forestadmin.rpc_common.hmac import generate_hmac, is_valid_hmac + + +class HmacValidationError(ForestException): + pass + + +class HmacValidator: + ALLOWED_TIME_DIFF = 300 + SIGNATURE_REUSE_WINDOW = 5 + + def __init__(self, secret_key: str) -> None: + self.secret_key = secret_key + self.used_signatures = dict() + + def validate_hmac(self, sign, timestamp): + """Validate the HMAC signature.""" + if not sign or not timestamp: + raise HmacValidationError("Missing HMAC signature or timestamp") + + self.validate_timestamp(timestamp) + + expected_sign = generate_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8")) + if not is_valid_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8"), expected_sign.encode("utf-8")): + raise HmacValidationError("Invalid HMAC signature") + + if sign in self.used_signatures.keys(): + last_used = self.used_signatures[sign] + if (datetime.datetime.now(datetime.timezone.utc) - last_used).total_seconds() > self.SIGNATURE_REUSE_WINDOW: + raise HmacValidationError("HMAC signature has already been used") + + self.used_signatures[sign] = datetime.datetime.now(datetime.timezone.utc) + self._cleanup_old_signs() + return True + + def validate_timestamp(self, timestamp): + try: + current_time = datetime.datetime.fromisoformat(timestamp) + except Exception: + raise HmacValidationError("Invalid timestamp format") + + if (datetime.datetime.now(datetime.timezone.utc) - current_time).total_seconds() > self.ALLOWED_TIME_DIFF: + raise HmacValidationError("Timestamp is too old or in the future") + + def _cleanup_old_signs(self): + now = datetime.datetime.now(datetime.timezone.utc) + to_rm = [] + for sign, last_used in self.used_signatures.items(): + if (now - last_used).total_seconds() > self.ALLOWED_TIME_DIFF: + to_rm.append(sign) + + for sign in to_rm: + del self.used_signatures[sign] diff --git a/src/agent_rpc/forestadmin/agent_rpc/options.py b/src/agent_rpc/forestadmin/agent_rpc/options.py index f94dbbf5..d3eda88e 100644 --- a/src/agent_rpc/forestadmin/agent_rpc/options.py +++ b/src/agent_rpc/forestadmin/agent_rpc/options.py @@ -3,5 +3,4 @@ class RpcOptions(TypedDict): listen_addr: str - aes_key: bytes - aes_iv: bytes + auth_secret: str diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py index a474f8c1..102ec9c5 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py @@ -1,4 +1,5 @@ import asyncio +import datetime from typing import List import aiohttp @@ -13,12 +14,19 @@ def __init__(self, connection_uri: str, secret_key: str): self.aes_key = secret_key[:16].encode() self.aes_iv = secret_key[-16:].encode() + def mk_hmac_headers(self): + timestamp = datetime.datetime.now(datetime.UTC).isoformat() + return { + "X_SIGNATURE": generate_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8")), + "X_TIMESTAMP": timestamp, + } + async def sse_connect(self, callback): """Connect to the SSE stream.""" await self.wait_for_connection() async with sse_client.EventSource( - f"http://{self.connection_uri}/sse", + f"http://{self.connection_uri}/forest/rpc/sse", ) as event_source: try: async for event in event_source: @@ -47,7 +55,10 @@ async def wait_for_connection(self): async def schema(self) -> dict: """Get the schema of the datasource.""" async with aiohttp.ClientSession() as session: - async with session.get(f"http://{self.connection_uri}/schema") as response: + async with session.get( + f"http://{self.connection_uri}/forest/rpc-schema", + headers=self.mk_hmac_headers(), + ) as response: if response.status == 200: return await response.json() else: @@ -60,7 +71,7 @@ async def schema(self) -> dict: # async with session.post( # f"http://{self.connection_uri}/execute-native-query", # json=body, - # headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + # headers=self.mk_hmac_headers(), # ) as response: # if response.status == 200: # return await response.json() @@ -73,8 +84,7 @@ async def datasource_render_chart(self, body): async with session.post( f"http://{self.connection_uri}/forest/rpc/datasource-chart", json=body, - # TODO: handle HMAC like ruby agent - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return await response.json() @@ -89,7 +99,7 @@ async def list(self, collection_name, body) -> list[dict]: async with session.post( f"http://{self.connection_uri}/forest/rpc/{collection_name}/list", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return await response.json() @@ -102,7 +112,7 @@ async def create(self, body) -> List[dict]: async with session.post( f"http://{self.connection_uri}/collection/create", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return await response.json() @@ -115,7 +125,7 @@ async def update(self, collection_name, body): async with session.post( f"http://{self.connection_uri}/forest/rpc/{collection_name}/update", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return @@ -128,7 +138,7 @@ async def delete(self, collection_name, body): async with session.post( f"http://{self.connection_uri}/forest/rpc/{collection_name}/delete", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return @@ -141,7 +151,7 @@ async def aggregate(self, collection_name, body): async with session.post( f"http://{self.connection_uri}/forest/rpc/{collection_name}/aggregate", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return await response.json() @@ -154,7 +164,7 @@ async def get_form(self, collection_name, body): async with session.post( f"http://{self.connection_uri}/forest/rpc/{collection_name}/action-form", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return await response.json() @@ -167,7 +177,7 @@ async def execute(self, collection_name, body): async with session.post( f"http://{self.connection_uri}/forest/rpc/{collection_name}/action-execute", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return await response.json() @@ -180,7 +190,7 @@ async def collection_render_chart(self, collection_name, body): async with session.post( f"http://{self.connection_uri}/forest/rpc/{collection_name}/chart", json=body, - headers={"X-FOREST-HMAC": generate_hmac(self.secret_key.encode("utf-8"), body.encode("utf-8"))}, + headers=self.mk_hmac_headers(), ) as response: if response.status == 200: return await response.json() From 93f2872ea2fc16ca46d742e3a4a8029719f2f51d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 8 Apr 2025 14:44:58 +0200 Subject: [PATCH 30/34] chore: remove old files --- .../agent_rpc/services/__init__.py | 0 .../agent_rpc/services/datasource.py | 67 ------------------- 2 files changed, 67 deletions(-) delete mode 100644 src/agent_rpc/forestadmin/agent_rpc/services/__init__.py delete mode 100644 src/agent_rpc/forestadmin/agent_rpc/services/datasource.py diff --git a/src/agent_rpc/forestadmin/agent_rpc/services/__init__.py b/src/agent_rpc/forestadmin/agent_rpc/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/agent_rpc/forestadmin/agent_rpc/services/datasource.py b/src/agent_rpc/forestadmin/agent_rpc/services/datasource.py deleted file mode 100644 index 3fc898d9..00000000 --- a/src/agent_rpc/forestadmin/agent_rpc/services/datasource.py +++ /dev/null @@ -1,67 +0,0 @@ -import json -from enum import Enum -from typing import Any, List, Mapping - -from forestadmin.datasource_toolkit.datasources import Datasource -from forestadmin.datasource_toolkit.interfaces.fields import is_column -from forestadmin.rpc_common.proto import datasource_pb2, datasource_pb2_grpc, forest_pb2 - - -class DatasourceService(datasource_pb2_grpc.DataSourceServicer): - def __init__(self, datasource: Datasource): - self.datasource = datasource - super().__init__() - - def Schema(self, request, context) -> datasource_pb2.SchemaResponse: - collections = [] - for collection in self.datasource.collections: - fields: List[Mapping[str, Any]] = [] - for field_name, field_schema in collection.schema["fields"].items(): - field: Mapping[str, Any] = {**field_schema} - - field["type"] = field["type"].value if isinstance(field["type"], Enum) else field["type"] - - if is_column(field_schema): - field["column_type"] = ( - field["column_type"].value if isinstance(field["column_type"], Enum) else field["column_type"] - ) - field["filter_operators"] = [ - op.value if isinstance(op, Enum) else op for op in field["filter_operators"] - ] - field["validations"] = [ - { - **validation, - "operator": ( - validation["operator"].value - if isinstance(validation["operator"], Enum) - else validation["operator"] - ), - } - for validation in field["validations"] - ] - else: - continue - fields.append({field_name: json.dumps(field).encode("utf-8")}) - - try: - collections.append( - forest_pb2.CollectionSchema( - name=collection.name, - searchable=collection.schema["searchable"], - segments=collection.schema["segments"], - countable=collection.schema["countable"], - charts=collection.schema["charts"], - # fields=fields, - # actions=actions, - ) - ) - except Exception as e: - print(e) - raise e - return datasource_pb2.SchemaResponse( - Collections=collections, - Schema=forest_pb2.DataSourceSchema( - Charts=self.datasource.schema["charts"], - ConnectionNames=self.datasource.get_native_query_connections(), - ), - ) From 42d5474a39dae728a2d878894e6b89b392b104fc Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 8 Apr 2025 14:45:50 +0200 Subject: [PATCH 31/34] chore: be concilient if no live query connections --- src/datasource_rpc/forestadmin/datasource_rpc/datasource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py index 9fbcd818..fee0b978 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/datasource.py @@ -61,7 +61,7 @@ async def introspect(self) -> bool: self.create_collection(collection_name, collection) self._schema["charts"] = {name: None for name in schema_data["charts"]} # type: ignore - self._live_query_connections = schema_data["live_query_connections"] + self._live_query_connections = schema_data.get("live_query_connections", []) return True def create_collection(self, collection_name, collection_schema): From efa4124291133b29f5a5565d103bd718717ae1b3 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 8 Apr 2025 14:46:12 +0200 Subject: [PATCH 32/34] chore: missing attribute on backup stack --- .../forestadmin/datasource_toolkit/decorators/decorator_stack.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py index 69f5fc1d..d0f29517 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py @@ -97,6 +97,7 @@ def _backup_stack(self): "binary", "publication", "rename_field", + "datasource", ]: backup_stack[decorator_name] = getattr(self, decorator_name) return backup_stack From fd7687ff659e4302bb67eefeb3d7b6f2358a8c29 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 8 Apr 2025 14:46:38 +0200 Subject: [PATCH 33/34] chore(rpc_serilization): use same case as ruby --- .../rpc_common/serializers/actions.py | 87 +++++------ .../serializers/collection/filter.py | 10 +- .../rpc_common/serializers/schema/schema.py | 136 +++++++++--------- .../rpc_common/serializers/utils.py | 25 ++-- 4 files changed, 130 insertions(+), 128 deletions(-) diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py index 6f0bdfb7..585c6107 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/actions.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/actions.py @@ -13,7 +13,7 @@ SuccessResult, WebHookResult, ) -from forestadmin.rpc_common.serializers.utils import camel_to_snake_case, enum_to_str_or_value, snake_to_camel_case +from forestadmin.rpc_common.serializers.utils import enum_to_str_or_value class ActionSerializer: @@ -31,22 +31,23 @@ async def serialize(action_name: str, collection: Collection) -> dict: # ] return { - "scope": action.scope.value, - "generateFile": action.generate_file or False, - "staticForm": action.static_form or False, + "scope": action.scope.value.lower(), + "is_generate_file": action.generate_file or False, + "static_form": action.static_form or False, "description": action.description, - "submitButtonLabel": action.submit_button_label, + "submit_button_label": action.submit_button_label, "form": ActionFormSerializer.serialize(form) if form is not None else [], + "execute": {}, # TODO: I'm pretty sure this is not needed } @staticmethod def deserialize(action: dict) -> dict: return { - "scope": ActionsScope(action["scope"]), - "generate_file": action["generateFile"], - "static_form": action["staticForm"], + "scope": ActionsScope(action["scope"].capitalize()), + "generate_file": action["is_generate_file"], + "static_form": action["static_form"], "description": action["description"], - "submit_button_label": action["submitButtonLabel"], + "submit_button_label": action["submit_button_label"], "form": ActionFormSerializer.deserialize(action["form"]) if action["form"] is not None else [], } @@ -57,35 +58,35 @@ def serialize(form) -> list[dict]: serialized_form = [] for field in form: + tmp_field = {} if field["type"] == ActionFieldType.LAYOUT: if field["component"] == "Page": - serialized_form.append( - { - **field, - "type": "Layout", - "elements": ActionFormSerializer.serialize(field["elements"]), - } - ) + tmp_field = { + **field, + "type": "Layout", + "elements": ActionFormSerializer.serialize(field["elements"]), + } if field["component"] == "Row": - serialized_form.append( - { - **field, - "type": "Layout", - "fields": ActionFormSerializer.serialize(field["fields"]), - } - ) + tmp_field = { + **field, + "type": "Layout", + "fields": ActionFormSerializer.serialize(field["fields"]), + } else: tmp_field = { - **{snake_to_camel_case(k): v for k, v in field.items()}, + **{k: v for k, v in field.items()}, "type": enum_to_str_or_value(field["type"]), } if field["type"] == ActionFieldType.FILE: tmp_field.update(ActionFormSerializer._serialize_file_field(tmp_field)) - print(tmp_field.keys()) if field["type"] == ActionFieldType.FILE_LIST: tmp_field.update(ActionFormSerializer._serialize_file_list_field(tmp_field)) - serialized_form.append(tmp_field) + + if "if_" in tmp_field: + tmp_field["if_condition"] = tmp_field["if_"] + del tmp_field["if_"] + serialized_form.append(tmp_field) return serialized_form @@ -105,7 +106,6 @@ def _serialize_file_field(field: dict) -> dict: ret["defaultValue"] = ActionFormSerializer._serialize_file_obj(field["defaultValue"]) if field.get("value"): ret["value"] = ActionFormSerializer._serialize_file_obj(field["value"]) - return ret @staticmethod def _serialize_file_obj(f: File) -> dict: @@ -148,34 +148,35 @@ def deserialize(form: list) -> list[dict]: deserialized_form = [] for field in form: + tmp_field = {} if field["type"] == "Layout": if field["component"] == "Page": - deserialized_form.append( - { - **field, - "type": ActionFieldType("Layout"), - "elements": ActionFormSerializer.deserialize(field["elements"]), - } - ) + tmp_field = { + **field, + "type": ActionFieldType("Layout"), + "elements": ActionFormSerializer.deserialize(field["elements"]), + } if field["component"] == "Row": - deserialized_form.append( - { - **field, - "type": ActionFieldType("Layout"), - "fields": ActionFormSerializer.deserialize(field["fields"]), - } - ) + tmp_field = { + **field, + "type": ActionFieldType("Layout"), + "fields": ActionFormSerializer.deserialize(field["fields"]), + } else: tmp_field = { - **{camel_to_snake_case(k): v for k, v in field.items()}, + **{k: v for k, v in field.items()}, "type": ActionFieldType(field["type"]), } if tmp_field["type"] == ActionFieldType.FILE: tmp_field.update(ActionFormSerializer._deserialize_file_field(tmp_field)) if tmp_field["type"] == ActionFieldType.FILE_LIST: tmp_field.update(ActionFormSerializer._deserialize_file_list_field(tmp_field)) - deserialized_form.append(tmp_field) + + if "if_condition" in tmp_field: + tmp_field["if_"] = tmp_field["if_condition"] + del tmp_field["if_condition"] + deserialized_form.append(tmp_field) return deserialized_form diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py index 4b7e73f2..29ffc7d9 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/collection/filter.py @@ -89,13 +89,13 @@ class FilterSerializer: @staticmethod def serialize(filter_: Filter, collection: Collection) -> Dict: return { - "conditionTree": ( + "condition_tree": ( ConditionTreeSerializer.serialize(filter_.condition_tree, collection) if filter_.condition_tree is not None else None ), "search": filter_.search, - "searchExtended": filter_.search_extended, + "search_extended": filter_.search_extended, "segment": filter_.segment, "timezone": TimezoneSerializer.serialize(filter_.timezone), } @@ -105,12 +105,12 @@ def deserialize(filter_: Dict, collection: Collection) -> Filter: return Filter( { "condition_tree": ( - ConditionTreeSerializer.deserialize(filter_["conditionTree"], collection) - if filter_.get("conditionTree") is not None + ConditionTreeSerializer.deserialize(filter_["condition_tree"], collection) + if filter_.get("condition_tree") is not None else None ), # type: ignore "search": filter_["search"], - "search_extended": filter_["searchExtended"], + "search_extended": filter_["search_extended"], "segment": filter_["segment"], "timezone": TimezoneSerializer.deserialize(filter_["timezone"]), } diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py index 8d644a7d..e5806cf0 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/schema/schema.py @@ -50,21 +50,21 @@ def __init__(self, datasource: Datasource) -> None: self.datasource = datasource async def serialize(self) -> dict: - value = { - "version": self.VERSION, - "data": { + value = ( + { "charts": sorted(self.datasource.schema["charts"].keys()), "live_query_connections": sorted(self.datasource.get_native_query_connections()), - "collections": { - c.name: await self._serialize_collection(c) + "collections": [ + await self._serialize_collection(c) for c in sorted(self.datasource.collections, key=lambda c: c.name) - }, + ], }, - } + ) return value async def _serialize_collection(self, collection: Collection) -> dict: return { + "name": collection.name, "fields": { field: await self._serialize_field(collection.schema["fields"][field]) for field in sorted(collection.schema["fields"].keys()) @@ -88,13 +88,13 @@ async def _serialize_field(self, field: FieldAlias) -> dict: async def _serialize_column(self, column: Column) -> dict: return { "type": "Column", - "columnType": (serialize_column_type(column["column_type"])), - "filterOperators": sorted([OperatorSerializer.serialize(op) for op in column.get("filter_operators", [])]), - "defaultValue": column.get("default_value"), - "enumValues": column.get("enum_values"), - "isPrimaryKey": column.get("is_primary_key", False), - "isReadOnly": column.get("is_read_only", False), - "isSortable": column.get("is_sortable", False), + "column_type": (serialize_column_type(column["column_type"])), + "filter_operators": sorted([OperatorSerializer.serialize(op) for op in column.get("filter_operators", [])]), + "default_value": column.get("default_value"), + "enum_values": column.get("enum_values"), + "is_primary_key": column.get("is_primary_key", False), + "is_read_only": column.get("is_read_only", False), + "is_sortable": column.get("is_sortable", False), "validations": [ { "operator": OperatorSerializer.serialize(v["operator"]), @@ -113,51 +113,51 @@ async def _serialize_relation(self, relation: RelationAlias) -> dict: serialized_field.update( { "type": "PolymorphicManyToOne", - "foreignCollections": relation["foreign_collections"], - "foreignKey": relation["foreign_key"], - "foreignKeyTypeField": relation["foreign_key_type_field"], - "foreignKeyTargets": relation["foreign_key_targets"], + "foreign_collections": relation["foreign_collections"], + "foreign_key": relation["foreign_key"], + "foreign_key_type_field": relation["foreign_key_type_field"], + "foreign_key_targets": relation["foreign_key_targets"], } ) elif is_many_to_one(relation): serialized_field.update( { "type": "ManyToOne", - "foreignCollection": relation["foreign_collection"], - "foreignKey": relation["foreign_key"], - "foreignKeyTarget": relation["foreign_key_target"], + "foreign_collection": relation["foreign_collection"], + "foreign_key": relation["foreign_key"], + "foreign_key_target": relation["foreign_key_target"], } ) elif is_one_to_one(relation) or is_one_to_many(relation): serialized_field.update( { "type": "OneToMany" if is_one_to_many(relation) else "OneToOne", - "foreignCollection": relation["foreign_collection"], - "originKey": relation["origin_key"], - "originKeyTarget": relation["origin_key_target"], + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], } ) elif is_polymorphic_one_to_one(relation) or is_polymorphic_one_to_many(relation): serialized_field.update( { "type": "PolymorphicOneToMany" if is_polymorphic_one_to_many(relation) else "PolymorphicOneToOne", - "foreignCollection": relation["foreign_collection"], - "originKey": relation["origin_key"], - "originKeyTarget": relation["origin_key_target"], - "originTypeField": relation["origin_type_field"], - "originTypeValue": relation["origin_type_value"], + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "origin_type_field": relation["origin_type_field"], + "origin_type_value": relation["origin_type_value"], } ) elif is_many_to_many(relation): serialized_field.update( { "type": "ManyToMany", - "foreignCollections": relation["foreign_collection"], - "foreignKey": relation["foreign_key"], - "foreignKeyTargets": relation["foreign_key_target"], - "originKey": relation["origin_key"], - "originKeyTarget": relation["origin_key_target"], - "throughCollection": relation["through_collection"], + "foreign_collections": relation["foreign_collection"], + "foreign_key": relation["foreign_key"], + "foreign_key_targets": relation["foreign_key_target"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "through_collection": relation["through_collection"], } ) return serialized_field @@ -167,15 +167,11 @@ class SchemaDeserializer: VERSION = "1.0" def deserialize(self, schema) -> dict: - if schema["version"] != self.VERSION: - raise ValueError(f"Unsupported schema version {schema['version']}") - return { - "charts": schema["data"]["charts"], - "live_query_connections": schema["data"]["live_query_connections"], + "charts": schema["charts"], + # "live_query_connections": schema["data"]["live_query_connections"], "collections": { - name: self._deserialize_collection(collection) - for name, collection in schema["data"]["collections"].items() + collection["name"]: self._deserialize_collection(collection) for collection in schema["collections"] }, } @@ -201,13 +197,13 @@ def _deserialize_field(self, field: dict) -> dict: def _deserialize_column(self, column: dict) -> dict: return { "type": FieldType.COLUMN, - "column_type": deserialize_column_type(column["columnType"]), - "filter_operators": [OperatorSerializer.deserialize(op) for op in column["filterOperators"]], - "default_value": column.get("defaultValue"), - "enum_values": column.get("enumValues"), - "is_primary_key": column.get("isPrimaryKey", False), - "is_read_only": column.get("isReadOnly", False), - "is_sortable": column.get("isSortable", False), + "column_type": deserialize_column_type(column["column_type"]), + "filter_operators": [OperatorSerializer.deserialize(op) for op in column["filter_operators"]], + "default_value": column.get("default_value"), + "enum_values": column.get("enum_values"), + "is_primary_key": column.get("is_primary_key", False), + "is_read_only": column.get("is_read_only", False), + "is_sortable": column.get("is_sortable", False), "validations": [ { "operator": OperatorSerializer.deserialize(op["operator"]), @@ -221,24 +217,24 @@ def _deserialize_relation(self, relation: dict) -> dict: if relation["type"] == "PolymorphicManyToOne": return { "type": FieldType("PolymorphicManyToOne"), - "foreign_collections": relation["foreignCollections"], - "foreign_key": relation["foreignKey"], - "foreign_key_type_field": relation["foreignKeyTypeField"], - "foreign_key_targets": relation["foreignKeyTargets"], + "foreign_collections": relation["foreign_collections"], + "foreign_key": relation["foreign_key"], + "foreign_key_type_field": relation["foreign_key_type_field"], + "foreign_key_targets": relation["foreign_key_targets"], } elif relation["type"] == "ManyToOne": return { "type": FieldType("ManyToOne"), - "foreign_collection": relation["foreignCollection"], - "foreign_key": relation["foreignKey"], - "foreign_key_target": relation["foreignKeyTarget"], + "foreign_collection": relation["foreign_collection"], + "foreign_key": relation["foreign_key"], + "foreign_key_target": relation["foreign_key_target"], } elif relation["type"] in ["OneToMany", "OneToOne"]: return { "type": FieldType("OneToOne") if relation["type"] == "OneToOne" else FieldType("OneToMany"), - "foreign_collection": relation["foreignCollection"], - "origin_key": relation["originKey"], - "origin_key_target": relation["originKeyTarget"], + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], } elif relation["type"] in ["PolymorphicOneToMany", "PolymorphicOneToOne"]: return { @@ -247,20 +243,20 @@ def _deserialize_relation(self, relation: dict) -> dict: if relation["type"] == "PolymorphicOneToOne" else FieldType("PolymorphicOneToMany") ), - "foreign_collection": relation["foreignCollection"], - "origin_key": relation["originKey"], - "origin_key_target": relation["originKeyTarget"], - "origin_type_field": relation["originTypeField"], - "origin_type_value": relation["originTypeValue"], + "foreign_collection": relation["foreign_collection"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "origin_type_field": relation["origin_type_field"], + "origin_type_value": relation["origin_type_value"], } elif relation["type"] == "ManyToMany": return { "type": FieldType("ManyToMany"), - "foreign_collection": relation["foreignCollections"], - "foreign_key": relation["foreignKey"], - "foreign_key_target": relation["foreignKeyTargets"], - "origin_key": relation["originKey"], - "origin_key_target": relation["originKeyTarget"], - "through_collection": relation["throughCollection"], + "foreign_collection": relation["foreign_collections"], + "foreign_key": relation["foreign_key"], + "foreign_key_target": relation["foreign_key_targets"], + "origin_key": relation["origin_key"], + "origin_key_target": relation["origin_key_target"], + "through_collection": relation["through_collection"], } raise ValueError(f"Unsupported relation type {relation['type']}") diff --git a/src/rpc_common/forestadmin/rpc_common/serializers/utils.py b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py index 409dc9f2..ee3a1a63 100644 --- a/src/rpc_common/forestadmin/rpc_common/serializers/utils.py +++ b/src/rpc_common/forestadmin/rpc_common/serializers/utils.py @@ -68,7 +68,11 @@ class OperatorSerializer: @staticmethod def deserialize(operator: str) -> Operator: - return Operator(camel_to_snake_case(operator)) + _operator = operator + if operator.startswith("i_"): + _operator = operator[2:] + return Operator(_operator) + # return Operator(camel_to_snake_case(operator)) # for value, serialized in OperatorSerializer.OPERATOR_MAPPING.items(): # if serialized == operator: # return Operator(value) @@ -79,7 +83,8 @@ def serialize(operator: Union[Operator, str]) -> str: value = operator if isinstance(value, Enum): value = value.value - return snake_to_camel_case(value) + return value + # return snake_to_camel_case(value) # return OperatorSerializer.OPERATOR_MAPPING[value] @@ -101,12 +106,12 @@ class CallerSerializer: @staticmethod def serialize(caller: User) -> Dict: return { - "renderingId": caller.rendering_id, - "userId": caller.user_id, + "rendering_id": caller.rendering_id, + "user_id": caller.user_id, "tags": caller.tags, "email": caller.email, - "firstName": caller.first_name, - "lastName": caller.last_name, + "first_name": caller.first_name, + "last_name": caller.last_name, "team": caller.team, "timezone": TimezoneSerializer.serialize(caller.timezone), "request": {"ip": caller.request["ip"]}, @@ -117,12 +122,12 @@ def deserialize(caller: Dict) -> User: if caller is None: return None return User( - rendering_id=caller["renderingId"], - user_id=caller["userId"], + rendering_id=caller["rendering_id"], + user_id=caller["user_id"], tags=caller["tags"], email=caller["email"], - first_name=caller["firstName"], - last_name=caller["lastName"], + first_name=caller["first_name"], + last_name=caller["last_name"], team=caller["team"], timezone=TimezoneSerializer.deserialize(caller["timezone"]), # type:ignore request={"ip": caller["request"]["ip"]}, From 818d283f4759f4535048c8d04c4e9f29266282be Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 9 Apr 2025 14:11:41 +0200 Subject: [PATCH 34/34] chore: cleanup old code --- src/datasource_rpc/forestadmin/datasource_rpc/requester.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py index 102ec9c5..de2989ff 100644 --- a/src/datasource_rpc/forestadmin/datasource_rpc/requester.py +++ b/src/datasource_rpc/forestadmin/datasource_rpc/requester.py @@ -11,8 +11,6 @@ class RPCRequester: def __init__(self, connection_uri: str, secret_key: str): self.connection_uri = connection_uri self.secret_key = secret_key - self.aes_key = secret_key[:16].encode() - self.aes_iv = secret_key[-16:].encode() def mk_hmac_headers(self): timestamp = datetime.datetime.now(datetime.UTC).isoformat()