diff --git a/.gitignore b/.gitignore index f4c93f5d7..df719ac28 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ dist/ build/ *.egg-info/ -/.dmod_client_config.yml +/.dmod_client_config.json /python/pydevd-pycharm.egg diff --git a/docker/nwm_gui/app_server/Dockerfile b/docker/nwm_gui/app_server/Dockerfile index 10da982a5..343c812f9 100644 --- a/docker/nwm_gui/app_server/Dockerfile +++ b/docker/nwm_gui/app_server/Dockerfile @@ -46,15 +46,36 @@ COPY ./python/gui/MaaS ./MaaS COPY ./docker/nwm_gui/app_server/entrypoint.sh ./ COPY ./docker/nwm_gui/app_server/client_debug_helper.py ./ -RUN echo "request-service:" > .dmod_client_config.yml \ - && echo " hostname: 'request-service'" >> .dmod_client_config.yml \ - && echo " port: 3012" >> .dmod_client_config.yml \ - && echo " ssl-dir: '/usr/maas_portal/ssl'" >> .dmod_client_config.yml +ARG PYCHARM_REMOTE_DEBUG_HOST +ARG PYCHARM_REMOTE_DEBUG_PORT +ARG REQUEST_SERVICE_PORT + +ENV PYCHARM_DEBUG_EGG=/pydevd-pycharm.egg +ENV REQ_SRV_SSL_DIR=${WORKDIR}/request_service_ssl +# TODO: needs to be fixed ... doesn't mesh with configurability of location +COPY ./ssl/request-service ${REQ_SRV_SSL_DIR} + +# TODO: move to heredoc syntax once confirmed it's reasonable to expect all environments run sufficiently recent Docker +RUN echo '{' > .dmod_client_config.json \ + && if [ -n "${PYCHARM_DEBUG_EGG:-}" ] && [ -n "${PYCHARM_REMOTE_DEBUG_PORT:-}" ] && [ -n "${PYCHARM_REMOTE_DEBUG_HOST:-}" ]; then \ + echo ' "remote-debug": {' >> .dmod_client_config.json ; \ + echo " \"egg-path\": \"${PYCHARM_DEBUG_EGG:?}\"," >> .dmod_client_config.json ; \ + echo " \"port\": ${PYCHARM_REMOTE_DEBUG_PORT:?}," >> .dmod_client_config.json ; \ + echo " \"host\": \"${PYCHARM_REMOTE_DEBUG_HOST:?}\"" >> .dmod_client_config.json ; \ + echo ' },' >> .dmod_client_config.json ; \ + fi \ + && echo ' "request-service": {' >> .dmod_client_config.json \ + && echo ' "protocol": "wss",' >> .dmod_client_config.json \ + && echo " \"pem\": \"${REQ_SRV_SSL_DIR}/certificate.pem\"," >> .dmod_client_config.json \ + && echo " \"port\": ${REQUEST_SERVICE_PORT:?}," >> .dmod_client_config.json \ + && echo ' "hostname": "request-service"' >> .dmod_client_config.json \ + && echo ' }' >> .dmod_client_config.json \ + && echo '}' >> .dmod_client_config.json # TODO: when image tagging/versioning is improved, look at keeping this in a "debug" image only # Copy this to have access to debugging pydevd egg COPY --from=sources /dmod /dmod_src -RUN if [ -e /dmod_src/python/pydevd-pycharm.egg ]; then mv /dmod_src/python/pydevd-pycharm.egg /. ; fi \ +RUN if [ -e /dmod_src/python/pydevd-pycharm.egg ]; then mv /dmod_src/python/pydevd-pycharm.egg ${PYCHARM_DEBUG_EGG} ; fi \ && rm -rf /dmod_src # Set the entry point so that it is run every time the container is started diff --git a/docker/nwm_gui/docker-compose.yml b/docker/nwm_gui/docker-compose.yml index 826b41e5d..55d085a34 100644 --- a/docker/nwm_gui/docker-compose.yml +++ b/docker/nwm_gui/docker-compose.yml @@ -34,6 +34,10 @@ services: args: docker_internal_registry: ${DOCKER_INTERNAL_REGISTRY:?Missing DOCKER_INTERNAL_REGISTRY value (see 'Private Docker Registry ' section in example.env)} comms_package_name: ${PYTHON_PACKAGE_DIST_NAME_COMMS:?} + # Necessary to generate the CLI's ClientConfig to support remote debugging + PYCHARM_REMOTE_DEBUG_HOST: ${PYCHARM_REMOTE_DEBUG_SERVER_HOST:-host.docker.internal} + PYCHARM_REMOTE_DEBUG_PORT: ${PYCHARM_REMOTE_DEBUG_SERVER_PORT_GUI:-55875} + REQUEST_SERVICE_PORT: ${DOCKER_REQUESTS_CONTAINER_PORT:-3012} networks: - request-listener-net # Call this when starting the container diff --git a/example.env b/example.env index d657f6287..160e604f6 100644 --- a/example.env +++ b/example.env @@ -168,8 +168,19 @@ PYTHON_PACKAGE_NAME_DATA_SERVICE=dmod.dataservice ## server at the level of individual (supported) services. ## ######################################################################## PYCHARM_REMOTE_DEBUG_VERSION=~=211.7628.24 -## Right now there is only support for the GUI app service +## Flag to indicate whether remote Pycharm debugging should be active +## in the GUI image and service. PYCHARM_REMOTE_DEBUG_GUI_ACTIVE=false +## Flag for whether remote Pycharm debugging is active for data service +PYCHARM_REMOTE_DEBUG_DATA_SERVICE_ACTIVE=false +## Flag for whether remote Pycharm debugging is active for evaluation service +PYCHARM_REMOTE_DEBUG_EVALUATION_SERVICE_ACTIVE=false +## Flag for whether remote Pycharm debugging is active for partitioner service +PYCHARM_REMOTE_DEBUG_PARTITIONER_SERVICE_ACTIVE=false +## Flag for whether remote Pycharm debugging is active for request service +PYCHARM_REMOTE_DEBUG_REQUEST_SERVICE_ACTIVE=false +## Flag for whether remote Pycharm debugging is active for scheduler service +PYCHARM_REMOTE_DEBUG_SCHEDULER_SERVICE_ACTIVE=false ## The debug server host for the debugged Python processes to attach to PYCHARM_REMOTE_DEBUG_SERVER_HOST=host.docker.internal ## The remote debug server port to for debugging request-service diff --git a/python/lib/client/dmod/client/__main__.py b/python/lib/client/dmod/client/__main__.py index da72481dd..91982ae80 100644 --- a/python/lib/client/dmod/client/__main__.py +++ b/python/lib/client/dmod/client/__main__.py @@ -3,14 +3,15 @@ import json from dmod.core.execution import AllocationParadigm from . import name as package_name -from .dmod_client import YamlClientConfig, DmodClient +from .dmod_client import ClientConfig, DmodClient from dmod.communication.client import get_or_create_eventloop from dmod.core.meta_data import ContinuousRestriction, DataCategory, DataDomain, DataFormat, DiscreteRestriction, \ TimeRange from pathlib import Path -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type -DEFAULT_CLIENT_CONFIG_BASENAME = '.dmod_client_config.yml' + +DEFAULT_CLIENT_CONFIG_BASENAME = '.dmod_client_config.json' class DmodCliArgumentError(ValueError): @@ -51,6 +52,7 @@ def _create_ngen_based_exec_parser(subcommand_container: Any, parser_name: str, default=default_alloc_paradigm, help='Specify job resource allocation paradigm to use.') new_parser.add_argument('--catchment-ids', dest='catchments', nargs='+', help='Specify catchment subset.') + new_parser.add_argument('--forcings-data-id', dest='forcings_data_id', help='Specify catchment subset.') date_format = DataDomain.get_datetime_str_format() print_date_format = 'YYYY-mm-dd HH:MM:SS' @@ -59,8 +61,7 @@ def _create_ngen_based_exec_parser(subcommand_container: Any, parser_name: str, help='Model time range ({} to {})'.format(print_date_format, print_date_format)) new_parser.add_argument('hydrofabric_data_id', help='Identifier of dataset of required hydrofabric') new_parser.add_argument('hydrofabric_uid', help='Unique identifier of required hydrofabric') - new_parser.add_argument('config_data_id', help='Identifier of dataset of required realization config') - new_parser.add_argument('bmi_cfg_data_id', help='Identifier of dataset of required BMI init configs') + new_parser.add_argument('config_data_id', help='Identifier of composite config dataset with required configs') new_parser.add_argument('cpu_count', type=int, help='Provide the desired number of processes for the execution') return new_parser @@ -98,10 +99,25 @@ def _handle_exec_command_args(parent_subparsers_container): # A parser for the 'exec' command itself, underneath the parent 'command' subparsers container command_parser = parent_subparsers_container.add_parser('exec') - # Subparser under the exec command's parser for handling the different workflows that might be run - workflow_subparsers = command_parser.add_subparsers(dest='workflow') + # Subparser under the exec command's parser for handling the different job workflows that might be started + workflow_subparsers = command_parser.add_subparsers(dest='workflow_starter') workflow_subparsers.required = True + # Add some parsers to deserialize a request from a JSON string, or ... + parser_from_json = workflow_subparsers.add_parser("from_json") + #parser_from_json.add_argument('--partition-config-data-id', dest='partition_cfg_data_id', default=None, + # help='Provide data_id for desired partition config dataset.') + parser_from_json.add_argument('job_type', choices=['ngen', 'ngen_cal'], + help="Set type of for request object so it is deserialized correctly") + parser_from_json.add_argument('request_json', + help='JSON string for exec request object to use to start a job') + # ... from JSON contained within a file + parser_from_file = workflow_subparsers.add_parser("from_file") + parser_from_file.add_argument('job_type', choices=['ngen', 'ngen_cal'], + help="Set type of for request object so it is deserialized correctly") + parser_from_file.add_argument('request_file', type=Path, + help='Path to file containing JSON exec request object to use to start a job') + # Nested parser for the 'ngen' action parser_ngen = _create_ngen_based_exec_parser(subcommand_container=workflow_subparsers, parser_name='ngen', default_alloc_paradigm=AllocationParadigm.get_default_selection()) @@ -165,18 +181,19 @@ def model_calibration_param(arg_val: str): help='The model calibration strategy used by ngen_cal.') -def _handle_dataset_command_args(parent_subparsers_container): +def _handle_data_service_action_args(parent_subparsers_container): """ - Handle setup of arg parsing for 'dataset' command, which allows for various operations related to datasets. + Handle setup of arg parsing for 'data' command, which allows for various operations related to datasets. Parameters ---------- parent_subparsers_container - The top-level parent container for subparsers of various commands, including the 'dataset' command, to which + The top-level parent container for subparsers of various commands, including the 'data' command, to which some numbers of nested subparser containers and parsers will be added. """ - # A parser for the 'dataset' command itself, underneath the parent 'command' subparsers container - command_parser = parent_subparsers_container.add_parser('dataset') + # A parser for the 'data' command itself, underneath the parent 'command' subparsers container + command_parser = parent_subparsers_container.add_parser('dataset', + description="Perform various dataset-related actions.") # Subparser under the dataset command's parser for handling the different actions that might be done relating to a # dataset (e.g., creation or uploading of data) @@ -187,46 +204,53 @@ def _handle_dataset_command_args(parent_subparsers_container): dataset_formats = [e.name for e in DataFormat] # Nested parser for the 'create' action, with required argument for dataset name, category, and format - parser_create = action_subparsers.add_parser('create') + parser_create = action_subparsers.add_parser('create', description="Create a new dataset.") parser_create.add_argument('name', help='Specify the name of the dataset to create.') - parser_create.add_argument('--paths', dest='upload_paths', nargs='+', help='Specify files/directories to upload.') - json_form = '{"variable": "", ("begin": "", "end": "" | "values": [])}' - restrict_help_str = 'Specify continuous or discrete domain restriction as (simplified) serialized JSON - {}' - parser_create.add_argument('--restriction', dest='domain_restrictions', nargs='*', - help=restrict_help_str.format(json_form)) - parser_create.add_argument('--format', dest='dataset_format', choices=dataset_formats, help='Specify dataset domain format.') - parser_create.add_argument('--domain-json', dest='domain_file', help='Deserialize the dataset domain from a file.') - parser_create.add_argument('category', choices=dataset_categories, help='Specify dataset category.') + parser_create.add_argument('--paths', dest='upload_paths', type=Path, nargs='+', + help='Specify files/directories to upload.') + parser_create.add_argument('--data-root', dest='data_root', type=Path, + help='Relative data root directory, used to adjust the names for uploaded items.') + c_json_form = '{"variable": "", "begin": "", "end": ""}' + d_json_form = '{"variable": "", "values": [, ...]}' + c_restrict_help_str = 'Specify continuous domain restriction as (simplified) serialized JSON - {}' + d_restrict_help_str = 'Specify discrete domain restriction as (simplified) serialized JSON - {}' + # TODO: need to test that this works as expected + parser_create.add_argument('--continuous-restriction', type=lambda s: ContinuousRestriction(**json.loads(s)), + dest='continuous_restrictions', nargs='*', help=c_restrict_help_str.format(c_json_form)) + parser_create.add_argument('--discrete-restriction', type=lambda s: DiscreteRestriction(**json.loads(s)), + dest='discrete_restrictions', nargs='*', help=d_restrict_help_str.format(d_json_form)) + parser_create.add_argument('--format', dest='dataset_format', choices=dataset_formats, type=DataFormat.get_for_name, + help='Specify dataset domain format.') + parser_create.add_argument('--domain-json', dest='domain_file', type=Path, help='Deserialize the dataset domain from a file.') + parser_create.add_argument('category', type=DataCategory.get_for_name, choices=dataset_categories, help='Specify dataset category.') # Nested parser for the 'delete' action, with required argument for dataset name - parser_delete = action_subparsers.add_parser('delete') + parser_delete = action_subparsers.add_parser('delete', description="Delete a specified (entire) dataset.") parser_delete.add_argument('name', help='Specify the name of the dataset to delete.') # Nested parser for the 'upload' action, with required args for dataset name and files to upload - parser_upload = action_subparsers.add_parser('upload') - parser_upload.add_argument('name', help='Specify the name of the desired dataset.') - parser_upload.add_argument('paths', nargs='+', help='Specify files or directories to upload.') - - # Nested parser for the 'upload' action, with required args for dataset name and files to upload - parser_download = action_subparsers.add_parser('download') - parser_download.add_argument('name', help='Specify the name of the desired dataset.') - parser_download.add_argument('--dest', dest='download_dest', default=None, - help='Specify local destination path to save to.') - parser_download.add_argument('path', help='Specify a file/item within dataset to download.') - - # Nested parser for the 'upload' action, with required args for dataset name and files to upload - parser_download_all = action_subparsers.add_parser('download_all') - parser_download_all.add_argument('name', help='Specify the name of the desired dataset.') - parser_download_all.add_argument('--directory', dest='download_dir', default=None, - help='Specify local destination directory to save to (defaults to ./') - - # Nested parser for the 'list' action - parser_list = action_subparsers.add_parser('list') - listing_categories_choices = list(dataset_categories) - listing_categories_choices.append('all') - parser_list.add_argument('category', choices=listing_categories_choices, nargs='?', default='all', + parser_upload = action_subparsers.add_parser('upload', description="Upload local files to a dataset.") + parser_upload.add_argument('--data-root', dest='data_root', type=Path, + help='Relative data root directory, used to adjust the names for uploaded items.') + parser_upload.add_argument('dataset_name', help='Specify the name of the desired dataset.') + parser_upload.add_argument('paths', type=Path, nargs='+', help='Specify files or directories to upload.') + + # Nested parser for the 'download' action, with required args for dataset name and files to upload + parser_download = action_subparsers.add_parser('download', description="Download some or all items from a dataset.") + parser_download.add_argument('--items', dest='item_names', nargs='+', + help='Specify files/items within dataset to download.') + parser_download.add_argument('dataset_name', help='Specify the name of the desired dataset.') + parser_download.add_argument('dest_dir', type=Path, help='Specify local destination directory to save to.') + + # Nested parser for the 'list_datasets' action + parser_list = action_subparsers.add_parser('list', description="List available datasets.") + parser_list.add_argument('--category', dest='category', choices=dataset_categories, type=DataCategory.get_for_name, help='Specify the category of dataset to list') + # Nested parser for the 'list_items' action + parser_list = action_subparsers.add_parser('items', description="List items within a specified dataset.") + parser_list.add_argument('dataset_name', help='Specify the dataset for which to list items') + def _handle_jobs_command_args(parent_subparsers_container): """ @@ -242,28 +266,28 @@ def _handle_jobs_command_args(parent_subparsers_container): command_parser = parent_subparsers_container.add_parser('jobs') # Subparser under the jobs command's parser for handling the different query or control that might be run - subcommand_subparsers = command_parser.add_subparsers(dest='subcommand') - subcommand_subparsers.required = True + job_command_subparsers = command_parser.add_subparsers(dest='job_command') + job_command_subparsers.required = True # Nested parser for the 'list' action - parser_list_jobs = subcommand_subparsers.add_parser('list') + parser_list_jobs = job_command_subparsers.add_parser('list') parser_list_jobs.add_argument('--active', dest='jobs_list_active_only', action='store_true', help='List only jobs with "active" status') # Nested parser for the 'info' action - parser_job_info = subcommand_subparsers.add_parser('info') + parser_job_info = job_command_subparsers.add_parser('info') parser_job_info.add_argument('job_id', help='The id of the job for which to retrieve job state info') # Nested parser for the 'release' action - parser_job_release = subcommand_subparsers.add_parser('release') + parser_job_release = job_command_subparsers.add_parser('release') parser_job_release.add_argument('job_id', help='The id of the job for which to release resources') # Nested parser for the 'status' action - parser_job_status = subcommand_subparsers.add_parser('status') + parser_job_status = job_command_subparsers.add_parser('status') parser_job_status.add_argument('job_id', help='The id of the job for which to retrieve status') # Nested parser for the 'stop' action - parser_job_stop = subcommand_subparsers.add_parser('stop') + parser_job_stop = job_command_subparsers.add_parser('stop') parser_job_stop.add_argument('job_id', help='The id of the job to stop') @@ -271,17 +295,20 @@ def _handle_args(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, prog='dmod.client') parser.add_argument('--client-config', help='Set path to client configuration file', + type=Path, dest='client_config', - default=None) + default=Path('.dmod_client_config.json')) parser.add_argument('--bypass-request-service', '-b', dest='bypass_reqsrv', action='store_true', default=False, help='Attempt to connect directly to the applicable service') + parser.add_argument('--remote-debug', '-D', dest='remote_debug', action='store_true', default=False, + help='Activate remote debugging, according to loaded client configuration.') # Top-level subparsers container, splitting off a variety of handled commands for different behavior # e.g., 'dataset' to do things related to datasets, like creation subparsers = parser.add_subparsers(dest='command') subparsers.required = True # Nested command parsers handling actions of dataset command - _handle_dataset_command_args(parent_subparsers_container=subparsers) + _handle_data_service_action_args(parent_subparsers_container=subparsers) # Nested command parsers handling config actions _handle_config_command_args(parent_subparsers_container=subparsers) # Nested command parsers handling exec actions @@ -324,178 +351,20 @@ def find_client_config(basenames: Optional[List[str]] = None, dirs: Optional[Lis return existing[0] if len(existing) > 0 else None -def _process_uploads(upload_path_str: Optional[List[str]]) -> Tuple[List[Path], List[Path]]: - """ - Process the given list of string representations of paths, returning a tuple of lists of processed and bad paths. - - Process the given list of string representations of paths, converting the initial list to a second list of - ::class:`Path` objects. Then, derive a third list from the second, consisting of any "bad" paths that do not - exist. Return a tuple containing the second and third lists. - - If the param is ``None``, it will be treated as an empty list, resulting in a tuple of two empty lists returned. - - Parameters - ---------- - upload_path_str : Optional[List[str]] - A list of string forms of paths to process. - - Returns - ------- - Tuple[List[Path], List[Path]] - Tuple of two lists of ::class:`Path`, with the first list being all derived from the param, and the second being - a list of any non-existing ::class:`Path` objects within the first list. - """ - # Convert the string form of the given paths to Path objects - upload_paths = [] if upload_path_str is None else [Path(p) for p in upload_path_str] - # Track bad paths, though, so that we can bail if there are any - bad_paths = [p for p in upload_paths if not p.exists()] - return upload_paths, bad_paths - - -def _process_domain_restriction_args(domain_restriction_strs: List[str]) -> Tuple[List[ContinuousRestriction], List[DiscreteRestriction]]: - """ - Process serialized JSON strings to restriction objects. - - Strings are expected to be in either the standard serialized format from each type's ``to_dict`` function, or - for continuous restrictions, in a similar, truncated form that can be converted to the standard format by using - ::method:`ContinuousRestriction.convert_truncated_serial_form`. - - Parameters - ---------- - domain_restriction_strs : List[str] - List of JSON strings, where strings are serialized restriction objects, possibly in a simplified format for - ::class:`ContinuousRestriction` - - Returns - ------- - Tuple[List[ContinuousRestriction], List[DiscreteRestriction]] - A tuple of two lists of restriction objects, with the first being continuous and the second discrete. - """ - discrete_restrictions = [] - continuous_restrictions = [] - for json_str in domain_restriction_strs: - json_obj = json.loads(json_str) - discrete_restrict = DiscreteRestriction.factory_init_from_deserialized_json(json_obj) - if discrete_restrict is not None: - discrete_restrictions.append(discrete_restrict) - continue - continuous_restrict = ContinuousRestriction.factory_init_from_deserialized_json(json_obj) - if continuous_restrict is not None: - continuous_restrictions.append(continuous_restrict) - continue - # Try this as well so continuous restrictions can use simpler format - continuous_restrict = ContinuousRestriction.factory_init_from_deserialized_json( - ContinuousRestriction.convert_truncated_serial_form(json_obj)) - if continuous_restrict is not None: - continuous_restrictions.append(continuous_restrict) - continue - return continuous_restrictions, discrete_restrictions - - -def execute_dataset_command(parsed_args, client: DmodClient): +def execute_dataset_command(args, client: DmodClient): async_loop = get_or_create_eventloop() - if parsed_args.action == 'create': - category = DataCategory.get_for_name(parsed_args.category) - upload_paths, bad_paths = _process_uploads(parsed_args.upload_paths) - if len(bad_paths): - raise RuntimeError('Aborted before dataset {} created; invalid upload paths: {}'.format(parsed_args.name, - bad_paths)) - # Proceed with create, and raising error on failure - - key_args = dict() - - # If we have a domain file, parse it, and use it as the only key args - if parsed_args.domain_file is not None: - domain_file = Path(parsed_args.domain_file) - domain = DataDomain.factory_init_from_deserialized_json(json.load(domain_file.open())) - if domain is None: - raise RuntimeError("Could not deserialize data domain from file {}".format(domain_file)) - key_args['domain'] = domain - else: - # Otherwise, start by processing any serialized restrictions provided on the command line - c_restricts, d_restricts = _process_domain_restriction_args(parsed_args.domain_restrictions) - # With restrictions processed, proceed to generating keyword args for the client's create function - data_format = DataFormat.get_for_name(parsed_args.dataset_format) - if data_format is None: - msg = 'Failed to create dataset {} due to unparseable data format' - raise RuntimeError(msg.format(parsed_args.name, parsed_args.dataset_format)) - else: - key_args['data_format'] = data_format - # Finally, assemble the key args we will use - if d_restricts: - key_args['discrete_restrictions'] = d_restricts - if c_restricts: - key_args['continuous_restrictions'] = c_restricts - - if not async_loop.run_until_complete(client.create_dataset(parsed_args.name, category, **key_args)): - raise RuntimeError('Failed to create dataset {}'.format(parsed_args.name)) - # Display message if create succeeded and there was nothing to upload - elif len(upload_paths) == 0: - print('Dataset {} of category {} created successfully'.format(parsed_args.name, category)) - # Handle uploads if there are some after create succeeded, but if those failed ... - if not async_loop.run_until_complete(client.upload_to_dataset(parsed_args.name, upload_paths)): - raise RuntimeError('Dataset {} created, but upload of data failed from paths {}'.format(parsed_args.name, - upload_paths)) - # Lastly (i.e., if uploading did work) - else: - print('Dataset {} of category {} created successfully, and uploaded {}'.format(parsed_args.name, category, - upload_paths)) - elif parsed_args.action == 'list': - category = None if parsed_args.category == 'all' else DataCategory.get_for_name(parsed_args.category) - dataset_names = async_loop.run_until_complete(client.list_datasets(category)) - if len(dataset_names) == 0: - print('No existing datasets were found.') - else: - for d in dataset_names: - print(d) - - elif parsed_args.action == 'delete': - if not async_loop.run_until_complete(client.delete_dataset(parsed_args.name)): - raise RuntimeError('Failed to delete dataset {}'.format(parsed_args.name)) - - elif parsed_args.action == 'upload': - upload_paths, bad_paths = _process_uploads(parsed_args.paths) - if len(bad_paths): - raise RuntimeError("Can't upload to {} - invalid upload paths: {}".format(parsed_args.name, bad_paths)) - elif not async_loop.run_until_complete(client.upload_to_dataset(parsed_args.name, upload_paths)): - raise RuntimeError('Upload of data to {} failed from paths {}'.format(parsed_args.name, upload_paths)) - else: - print('Upload succeeded.') - - elif parsed_args.action == 'download': - if parsed_args.download_dest is None: - dest = Path('./{}'.format(Path(parsed_args.path).name)) - else: - dest = Path(parsed_args.download_dest) - if not dest.parent.is_dir(): - raise RuntimeError("Cannot download file to {}: parent directory doesn't exist.".format(dest)) - if dest.exists(): - raise RuntimeError("Cannot download file to {}: file already exists".format(dest)) - if not async_loop.run_until_complete(client.download_from_dataset(dataset_name=parsed_args.name, - item_name=parsed_args.path, dest=dest)): - msg = 'Download of {} data to {} failed from locations {}' - raise RuntimeError(msg.format(parsed_args.name, dest, parsed_args.path)) - else: - print('Downloaded {} to local file {}.'.format(parsed_args.path, dest)) - - elif parsed_args.action == 'download_all': - dest_dir = Path(parsed_args.download_dir) if parsed_args.download_dir is not None else Path(parsed_args.name) - if dest_dir.exists(): - if dest_dir.is_dir(): - dest_dir_orig = dest_dir.name - old_backup = dest_dir.rename(dest_dir.parent.joinpath('.{}_old'.format(dest_dir_orig))) - dest_dir = dest_dir.parent.joinpath(dest_dir_orig) - dest_dir.mkdir() - print("Backing up existing {} to {}".format(dest_dir, old_backup)) - else: - RuntimeError("Can't download files to directory named '{}': this is an existing file".format(dest_dir)) - if not async_loop.run_until_complete(client.download_dataset(dataset_name=parsed_args.name, dest_dir=dest_dir)): - msg = 'Download of dataset {} to directory {} failed.' - raise RuntimeError(msg.format(parsed_args.name, dest_dir)) - else: - print('Downloaded {} contents to local directory {}.'.format(parsed_args.name, dest_dir)) - else: - raise RuntimeError("Bad dataset command action '{}'".format(parsed_args.action)) + try: + result = async_loop.run_until_complete(client.data_service_action(**(vars(args)))) + print(result) + except ValueError as e: + print(str(e)) + exit(1) + except NotImplementedError as e: + print(str(e)) + exit(1) + except Exception as e: + print("ERROR: Encountered {} - {}".format(e.__class__.__name__, str(e))) + exit(1) def execute_config_command(parsed_args, client: DmodClient): @@ -507,24 +376,16 @@ def execute_config_command(parsed_args, client: DmodClient): raise RuntimeError("Bad client command action '{}'".format(parsed_args.action)) -def execute_jobs_command(args, client: DmodClient): +def execute_job_command(args, client: DmodClient): async_loop = get_or_create_eventloop() try: - if args.subcommand == 'info': - result = async_loop.run_until_complete(client.request_job_info(**(vars(args)))) - elif args.subcommand == 'list': - result = async_loop.run_until_complete(client.request_jobs_list(**(vars(args)))) - elif args.subcommand == 'release': - result = async_loop.run_until_complete(client.request_job_release(**(vars(args)))) - elif args.subcommand == 'status': - result = async_loop.run_until_complete(client.request_job_status(**(vars(args)))) - elif args.subcommand == 'stop': - result = async_loop.run_until_complete(client.request_job_stop(**(vars(args)))) - else: - raise DmodCliArgumentError() + result = async_loop.run_until_complete(client.job_command(**(vars(args)))) print(result) - except DmodCliArgumentError as e: - print("ERROR: Unsupported jobs subcommand {}".format(args.subcommand)) + except ValueError as e: + print(str(e)) + exit(1) + except NotImplementedError as e: + print(str(e)) exit(1) except Exception as e: print("ERROR: Encountered {} - {}".format(e.__class__.__name__, str(e))) @@ -533,16 +394,43 @@ def execute_jobs_command(args, client: DmodClient): def execute_workflow_command(args, client: DmodClient): async_loop = get_or_create_eventloop() - # TODO: aaraney - if args.workflow == 'ngen': - result = async_loop.run_until_complete(client.submit_ngen_request(**(vars(args)))) - print(result) - elif args.workflow == "ngen_cal": - result = async_loop.run_until_complete(client.submit_ngen_cal_request(**(vars(args)))) + try: + result = async_loop.run_until_complete(client.execute_job(**(vars(args)))) print(result) - else: - print("ERROR: Unsupported execution workflow {}".format(args.workflow)) + except ValueError as e: + print(str(e)) exit(1) + except Exception as e: + print(f"Encounted {e.__class__.__name__}: {str(e)}") + exit(1) + +# TODO: (later) add something to TransportLayerClient to check if it supports multiplexing + + +def _load_debugger_and_settrace(debug_cfg): + """ + Helper function to append the path of Pycharm debug egg to system path, import it, and set the remote debug trace. + + Parameters + ---------- + debug_cfg + + Returns + ------- + + """ + if debug_cfg is None: + return False + import sys + sys.path.append(str(debug_cfg.egg_path)) + import pydevd_pycharm + try: + pydevd_pycharm.settrace(debug_cfg.debug_host, port=debug_cfg.port, stdoutToServer=True, stderrToServer=True) + return True + except Exception as error: + print(f'Warning: could not set debugging trace to {debug_cfg.debug_host} on {debug_cfg.port!s} due to' + f' {error.__class__.__name__} - {error!s}') + return False def main(): @@ -553,7 +441,15 @@ def main(): exit(1) try: - client = DmodClient(client_config=YamlClientConfig(client_config_path), bypass_request_service=args.bypass_reqsrv) + + client_config = ClientConfig.parse_file(client_config_path) + if args.remote_debug and client_config.pycharm_debug_config is not None: + _load_debugger_and_settrace(debug_cfg=client_config.pycharm_debug_config) + elif args.remote_debug: + print("ERROR: received arg to activate remote debugging, but client config lacks debugging parameters.") + exit(1) + + client = DmodClient(client_config=client_config, bypass_request_service=args.bypass_reqsrv) if args.command == 'config': execute_config_command(args, client) @@ -562,12 +458,12 @@ def main(): elif args.command == 'exec': execute_workflow_command(args, client) elif args.command == 'jobs': - execute_jobs_command(args, client) + execute_job_command(args, client) else: raise ValueError("Unsupported command {}".format(args.command)) except Exception as error: - print("ERROR: {}".format(error)) + print(f"ERROR: {error!s}") exit(1) diff --git a/python/lib/client/dmod/client/_version.py b/python/lib/client/dmod/client/_version.py index 0404d8103..abeeedbf5 100644 --- a/python/lib/client/dmod/client/_version.py +++ b/python/lib/client/dmod/client/_version.py @@ -1 +1 @@ -__version__ = '0.3.0' +__version__ = '0.4.0' diff --git a/python/lib/client/dmod/client/client_config.py b/python/lib/client/dmod/client/client_config.py index aa042dc43..7063eb460 100644 --- a/python/lib/client/dmod/client/client_config.py +++ b/python/lib/client/dmod/client/client_config.py @@ -1,139 +1,49 @@ -import yaml -from abc import ABC, abstractmethod +from dmod.core.serializable import Serializable from pathlib import Path from typing import Optional - - -class ClientConfig(ABC): - _CONFIG_KEY_DATA_SERVICE = 'data-service' - _CONFIG_KEY_HOSTNAME = 'hostname' - _CONFIG_KEY_PORT = 'port' - _CONFIG_KEY_REQUEST_SERVICE = 'request-service' - _CONFIG_KEY_SSL_DIR = 'ssl-dir' - - def __init__(self, *args, **kwargs): - super(ClientConfig, self).__init__(*args, **kwargs) - - @property - @abstractmethod - def config_file(self) -> Path: - pass - - @property - @abstractmethod - def dataservice_endpoint_uri(self) -> Optional[str]: - pass - - @property - @abstractmethod - def dataservice_ssl_dir(self) -> Optional[Path]: - pass - - @property - @abstractmethod - def requests_endpoint_uri(self) -> str: - pass - - @property - @abstractmethod - def requests_ssl_dir(self) -> Path: - pass - - @abstractmethod - def print_config(self): - pass - - -class YamlClientConfig(ClientConfig): - """ - A subtype of ::class:`ClientConfig` backed by configuration details loaded from a YAML file. - """ - - @classmethod - def generate_endpoint_uri(cls, hostname: str, port: int): - return 'wss://{}:{}'.format(hostname, port) - - @classmethod - def get_service_ssl_dir(cls, backing_config: dict, service_key: str): - dir_path = Path(backing_config[service_key][cls._CONFIG_KEY_SSL_DIR]) - if dir_path.is_dir(): - return dir_path - else: - raise RuntimeError("Non-existing {} SSL directory configured ({})".format(service_key, dir_path)) - - def __init__(self, client_config_file: Path, *args, **kwargs): - """ - Initialize this instance. - - Parameters - ---------- - client_config_file : Path - The path to the backing YAML configuration file that must be loaded. - args - Ordered args to pass to the superclass init function. - kwargs - Keyword args to pass to the superclass init function. - """ - super(ClientConfig, self).__init__(*args, **kwargs) - self._config_file = client_config_file - - self._backing_config = None - """ A backing config that must be loaded from the given file before other properties are accessible. """ - with self._config_file.open() as file: - self._backing_config = yaml.safe_load(file) - - self._requests_endpoint_uri = self.generate_endpoint_uri(self.requests_hostname, self.requests_port) - self._requests_ssl_dir = self.get_service_ssl_dir(self._backing_config, self._CONFIG_KEY_REQUEST_SERVICE) - - self._dataservice_endpoint_uri = None - self._dataservice_ssl_dir = None - - @property - def config_file(self) -> Path: - return self._config_file - - @property - def dataservice_endpoint_uri(self) -> Optional[str]: - if self._dataservice_endpoint_uri is None and self.dataservice_hostname is not None \ - and self.dataservice_port is not None: - self._dataservice_endpoint_uri = self.generate_endpoint_uri(self.dataservice_hostname, self.dataservice_port) - return self._dataservice_endpoint_uri - - @property - def dataservice_hostname(self) -> Optional[str]: - if self._CONFIG_KEY_DATA_SERVICE in self._backing_config: - return self._backing_config[self._CONFIG_KEY_DATA_SERVICE][self._CONFIG_KEY_HOSTNAME] - else: - return None - - @property - def dataservice_port(self) -> Optional[int]: - if self._CONFIG_KEY_DATA_SERVICE in self._backing_config: - return self._backing_config[self._CONFIG_KEY_DATA_SERVICE][self._CONFIG_KEY_PORT] - else: - return None - - @property - def dataservice_ssl_dir(self) -> Optional[Path]: - if self._dataservice_ssl_dir is None and self._CONFIG_KEY_DATA_SERVICE in self._backing_config: - self._dataservice_ssl_dir = self.get_service_ssl_dir(self._backing_config, self._CONFIG_KEY_DATA_SERVICE) - return self._dataservice_ssl_dir - - @property - def requests_endpoint_uri(self) -> str: - return self._requests_endpoint_uri - - @property - def requests_hostname(self) -> str: - return self._backing_config[self._CONFIG_KEY_REQUEST_SERVICE][self._CONFIG_KEY_HOSTNAME] - - @property - def requests_port(self) -> int: - return self._backing_config[self._CONFIG_KEY_REQUEST_SERVICE][self._CONFIG_KEY_PORT] - - @property - def requests_ssl_dir(self) -> Path: - return self._requests_ssl_dir - - def print_config(self): - print(self._config_file.read_text()) +from pydantic import Field, validator + + +class ConnectionConfig(Serializable): + + active: bool = Field(True, description="Whether this configured connection should be active and initialized") + endpoint_protocol: str = Field(description="The protocol in this config", alias="protocol") + endpoint_host: str = Field(description="The configured hostname", alias="hostname") + endpoint_port: int = Field(description="The configured host port", alias="port") + cafile: Optional[Path] = Field(None, description="Optional path to CA certificates PEM file.", alias="pem") + capath: Optional[Path] = Field(None, description="Optional path to directory containing CA certificates PEM files.", + alias="ssl-dir") + use_default_context: bool = False + + +class PycharmRemoteDebugConfig(Serializable): + egg_path: Path = Field(description="Path to egg for remote Pycharm debugging", alias="egg-path") + debug_host: str = Field("host.docker.internal", description="Debug host to connect back to for remote debugger", + alias="host") + port: int = Field(55875, description="Port to connect back to for remote debugging") + + @validator("egg_path") + def validate_egg_path(cls, value): + if not isinstance(value, Path): + value = Path(value) + if not value.exists(): + raise RuntimeError(f"No file exists at '{value!s}' received by {cls.__name__} for egg path!") + elif not value.is_file(): + raise RuntimeError(f"{cls.__name__} received '{value!s}' for egg path, but this is not a regular file!") + return value + + +class ClientConfig(Serializable): + pycharm_debug_config: Optional[PycharmRemoteDebugConfig] = Field(None, + description="Config for remote Pycharm debugging", + alias="remote-debug") + request_service: ConnectionConfig = Field(description="The config for connecting to the request service", + alias="request-service") + data_service: Optional[ConnectionConfig] = Field(None, description="The config for connecting to the data service", + alias="data-service") + + @validator("request_service") + def validate_request_service_connection_active(cls, value): + if not value.active: + raise RuntimeError(f"{cls.__name__} must have request service config set to 'active'!") + return value diff --git a/python/lib/client/dmod/client/dmod_client.py b/python/lib/client/dmod/client/dmod_client.py index 8fd2134b5..60dba51c3 100644 --- a/python/lib/client/dmod/client/dmod_client.py +++ b/python/lib/client/dmod/client/dmod_client.py @@ -1,277 +1,213 @@ -from dmod.core.execution import AllocationParadigm -from dmod.core.meta_data import DataCategory, DataDomain, DataFormat, DiscreteRestriction -from .request_clients import DatasetClient, DatasetExternalClient, DatasetInternalClient, NgenRequestClient, NgenCalRequestClient -from .client_config import YamlClientConfig -from datetime import datetime +import json + +from dmod.communication import AuthClient, TransportLayerClient, WebSocketClient +from dmod.core.common import get_subclasses +from dmod.core.serializable import ResultIndicator +from dmod.core.meta_data import DataDomain +from .request_clients import DataServiceClient, JobClient +from .client_config import ClientConfig from pathlib import Path -from typing import List, Optional +from typing import Type + + +def determine_transport_client_type(protocol: str, + *prioritized_subtypes: Type[TransportLayerClient]) -> Type[TransportLayerClient]: + """ + Determine the specific subclass type of ::class:`TransportLayerClient` appropriate for a specified URI protocol. + + To allow for control when there are potential multiple subtypes that would support the same protocol, specific + ::class:`TransportLayerClient` subclasses can be given as variable positional arguments. These will be prioritized + and examined first. After that, the order of examined subtypes is subject to the runtime order of the search for + concrete ::class:`TransportLayerClient` subclasses. + + Parameters + ---------- + protocol : str + A URI protocol substring value. + *prioritized_subtypes : Type[TransportLayerClient] + Specific subclass type(s) to prioritize in the event of any duplication of protocol value(s) across subtypes. + + Returns + ------- + Type[TransportLayerClient] + The appropriate type of ::class:`TransportLayerClient`. + """ + if not protocol.strip(): + raise ValueError("Cannot determine transport client type for empty protocol value") + elif any((s for s in prioritized_subtypes if not issubclass(s, TransportLayerClient))): + raise TypeError("Bad values for prioritized types received when attempting to determine transport client type") + + def _get_subclasses(class_val): + return set([s for s in class_val.__subclasses__() if not s.__abstractmethods__]).union( + [s for c in class_val.__subclasses__() for s in get_subclasses(c) if not s.__abstractmethods__]) + + #for subtype in (*prioritized_subtypes, *get_subclasses(TransportLayerClient)): + for subtype in (*prioritized_subtypes, *_get_subclasses(TransportLayerClient)): + if subtype.get_endpoint_protocol_str(True) == protocol or subtype.get_endpoint_protocol_str(False) == protocol: + return subtype + raise RuntimeError(f"No subclass of `{TransportLayerClient.__name__}` found supporting protocol '{protocol}'") class DmodClient: - def __init__(self, client_config: YamlClientConfig, bypass_request_service: bool = False, *args, **kwargs): + def __init__(self, client_config: ClientConfig, bypass_request_service: bool = False, *args, **kwargs): self._client_config = client_config - self._dataset_client = None - self._ngen_client = None - self._ngen_cal_client = None + self._data_service_client = None + self._job_client = None self._bypass_request_service = bypass_request_service - @property - def client_config(self): - return self._client_config - - async def create_dataset(self, dataset_name: str, category: DataCategory, domain: Optional[DataDomain] = None, - **kwargs) -> bool: - """ - Create a dataset from the given parameters. + # TODO: this should (optionally) be a client multiplexer (once that is available) instead of a transport client + # (with a getter to actually get a transport client in either case) + request_t_client_type = determine_transport_client_type(client_config.request_service.endpoint_protocol, + WebSocketClient) + self._request_service_conn: TransportLayerClient = request_t_client_type(**client_config.request_service.dict()) - Note that despite the type hinting, ``domain`` is only semi-optional, as a domain is required to create a - dataset. However, if a ``data_format`` keyword arg provides a ::class:`DataFormat` value, then a minimal - ::class:`DataDomain` object can be generated and used. + self._auth_client: AuthClient = AuthClient(transport_client=self._get_transport_client()) - Additionally, ``continuous_restrictions`` and ``discrete_restrictions`` keyword args are used if present for - creating the domain when necessary. If neither are provided, the generated domain will have a minimal discrete - restriction created for "all values" (i.e., an empty list) of the first index variable of the provided - ::class:`DataFormat`. + def _get_transport_client(self, **kwargs) -> TransportLayerClient: + # TODO: later add support for multiplexing capabilities and spawning wrapper clients + return self._request_service_conn - In the event neither a domain not a data format is provided, a ::class:`ValueError` is raised. + @property + def client_config(self) -> ClientConfig: + return self._client_config - Additionally, keyword arguments are forwarded in the call to the ::attribute:`dataset_client` property's - ::method:`DatasetClient.create_dataset` function. This includes the aforementioned kwargs for a creating a - default ::class:`DataDomain`, but only if they are otherwise ignored because a valid domain arg was provided. + async def data_service_action(self, action: str, **kwargs) -> ResultIndicator: + """ + Perform a supported data service action. Parameters ---------- - dataset_name : str - The name of the dataset. - category : DataCategory - The dataset category. - domain : Optional[DataDomain] - The semi-optional (depending on keyword args) domain for the dataset. + action : str + The action selection of interest. kwargs - Other optional keyword args. - - Keyword Args - ---------- - data_format : DataFormat - An optional data format, used if no ``domain`` is provided - continuous_restrictions : List[ContinuousRestrictions] - An optional list of continuous domain restrictions, used if no ``domain`` is provided - discrete_restrictions : List[DiscreteRestrictions] - An optional list of discrete domain restrictions, used if no ``domain`` is provided Returns ------- - bool - Whether creation was successful. + ResultIndicator + An indication of whether the requested action was performed successfully. """ - # If a domain wasn't passed, generate one from the kwargs, or raise and exception if we can't - if domain is None: - data_format = kwargs.pop('data_format', None) - if data_format is None: - msg = "Client can't create dataset with `None` for {}, nor generate a default {} without a provided {}" - raise ValueError(msg.format(DataDomain.__name__, DataDomain.__name__, DataFormat.__name__)) - print_msg = "INFO: no {} provided; dataset will be created with a basic default domain using format {}" - print(print_msg.format(DataDomain.__name__, data_format.name)) - # If neither provided, bootstrap a basic restriction on the first index variable in the data format - if not ('discrete_restrictions' in kwargs or 'continuous_restrictions' in kwargs): - c_restricts = None - d_restricts = [DiscreteRestriction(variable=data_format.indices[0], values=[])] - # If at least one is provided, use whatever was passed, and fallback to None for the other if needed + try: + if action == 'create': + # Do a little extra here to get the domain + if 'domain' in kwargs: + domain = kwargs.pop('domain') + elif 'domain_file' in kwargs: + with kwargs['domain_file'].open() as domain_file: + domain_json = json.load(domain_file) + domain = DataDomain.factory_init_from_deserialized_json(domain_json) + else: + domain = DataDomain(**kwargs) + return await self.data_service_client.create_dataset(domain=domain, **kwargs) + elif action == 'delete': + return await self.data_service_client.delete_dataset(**kwargs) + elif action == 'upload': + return await self.data_service_client.upload_to_dataset(**kwargs) + elif action == 'download': + return await self.data_service_client.retrieve_from_dataset(**kwargs) + elif action == 'list': + return await self.data_service_client.get_dataset_names(**kwargs) + elif action == 'items': + return await self.data_service_client.get_dataset_item_names(**kwargs) else: - c_restricts = list(kwargs.pop('continuous_restrictions')) if 'continuous_restrictions' in kwargs else [] - d_restricts = list(kwargs.pop('discrete_restrictions')) if 'discrete_restrictions' in kwargs else [] - domain = DataDomain(data_format=data_format, continuous_restrictions=c_restricts, - discrete_restrictions=d_restricts) - # Finally, ask the client to create the dataset, passing the details - return await self.dataset_client.create_dataset(dataset_name, category, domain, **kwargs) + raise ValueError(f"Unsupported data service action to {self.__class__.__name__}: {action}") + except NotImplementedError: + raise NotImplementedError(f"Impl of supported data action {action} not yet in {self.__class__.__name__}") @property - def dataset_client(self) -> DatasetClient: - if self._dataset_client is None: - if self._bypass_request_service: - if self.client_config.dataservice_endpoint_uri is None: - raise RuntimeError("Cannot bypass request service without data service config details") - self._dataset_client = DatasetInternalClient(self.client_config.dataservice_endpoint_uri, - self.client_config.dataservice_ssl_dir) + def data_service_client(self) -> DataServiceClient: + if self._data_service_client is None: + if self.client_config.data_service is not None and self.client_config.data_service.active: + t_client_type = determine_transport_client_type(self.client_config.data_service.endpoint_protocol) + t_client = t_client_type(**self.client_config.data_service.dict()) + self._data_service_client = DataServiceClient(t_client, self._auth_client) else: - self._dataset_client = DatasetExternalClient(self.requests_endpoint_uri, self.requests_ssl_dir) - return self._dataset_client - - @property - def ngen_request_client(self) -> NgenRequestClient: - if self._ngen_client is None: - self._ngen_client = NgenRequestClient(self.requests_endpoint_uri, self.requests_ssl_dir) - return self._ngen_client + self._data_service_client = DataServiceClient(self._get_transport_client(), self._auth_client) + return self._data_service_client @property - def ngen_cal_request_client(self) -> NgenCalRequestClient: - if self._ngen_cal_client is None: - self._ngen_cal_client = NgenCalRequestClient(self.requests_endpoint_uri, self.requests_ssl_dir) - return self._ngen_cal_client - - async def delete_dataset(self, dataset_name: str, **kwargs): - return await self.dataset_client.delete_dataset(dataset_name, **kwargs) - - async def download_dataset(self, dataset_name: str, dest_dir: Path) -> bool: - return await self.dataset_client.download_dataset(dataset_name=dataset_name, dest_dir=dest_dir) - - async def download_from_dataset(self, dataset_name: str, item_name: str, dest: Path) -> bool: - return await self.dataset_client.download_from_dataset(dataset_name=dataset_name, item_name=item_name, - dest=dest) - - async def list_datasets(self, category: Optional[DataCategory] = None): - return await self.dataset_client.list_datasets(category) + def job_client(self) -> JobClient: + if self._job_client is None: + self._job_client = JobClient(transport_client=self._get_transport_client(), auth_client=self._auth_client) + return self._job_client - async def request_job_info(self, job_id: str, *args, **kwargs) -> dict: + async def execute_job(self, workflow: str, **kwargs) -> ResultIndicator: """ - Request the full state of the provided job, formatted as a JSON dictionary. + Submit a requested job defined by the provided ``kwargs``. - Parameters - ---------- - job_id : str - The id of the job in question. - args - (Unused) variable positional args. - kwargs - (Unused) variable keyword args. - - Returns - ------- - dict - The full state of the provided job, formatted as a JSON dictionary. - """ - # TODO: implement - raise NotImplementedError('{} function "request_job_info" not implemented yet'.format(self.__class__.__name__)) + Currently supported job workflows are: + - ``ngen`` : submit a job request to execute a ngen model exec job + - ``ngen_cal`` : submit a job request to execute a ngen-cal model calibration job + - ``from_json`` : submit a provided job request, given in serialized JSON form + - ``from_file`` : submit a provided job request, serialized to JSON form and saved in the given file - async def request_job_release(self, job_id: str, *args, **kwargs) -> bool: - """ - Request the allocated resources for the provided job be released. + For most supported workflows, ``kwargs`` should contain necessary params for initializing a request object of + the correct type. However, for ``workflow`` values ``from_json`` or ``from_file``, ``kwargs`` should instead + contain params for deserializing the right type of request, either directly or from a provided file. Parameters ---------- - job_id : str - The id of the job in question. - args - (Unused) variable positional args. + workflow: str + The type of workflow, as a string, which should correspond to parsed CLI options. kwargs - (Unused) variable keyword args. + Dynamic keyword args used to produce a request object to initiate a job, which vary by workflow. Returns ------- - bool - Whether there had been allocated resources for the job, all of which are now released. + The result of the request to run the job. """ - # TODO: implement - raise NotImplementedError('{} function "request_job_release" not implemented yet'.format(self.__class__.__name__)) - - async def request_job_status(self, job_id: str, *args, **kwargs) -> str: + if workflow == 'from_json': + return await self.job_client.submit_request_from_json(**kwargs) + if workflow == 'from_file': + return await self.job_client.submit_request_from_file(**kwargs) + if workflow == 'ngen': + return await self.job_client.submit_ngen_request(**kwargs) + elif workflow == "ngen_cal": + return await self.job_client.submit_ngen_cal_request(**kwargs) + else: + raise ValueError(f"Unsupported job execution workflow {workflow}") + + async def job_command(self, command: str, **kwargs) -> ResultIndicator: """ - Request the status of the provided job, represented in string form. + Submit a request that performs a particular job command. - Parameters - ---------- - job_id : str - The id of the job in question. - args - (Unused) variable positional args. - kwargs - (Unused) variable keyword args. - - Returns - ------- - str - The status of the provided job, represented in string form. - """ - # TODO: implement - raise NotImplementedError('{} function "request_job_status" not implemented yet'.format(self.__class__.__name__)) - - async def request_job_stop(self, job_id: str, *args, **kwargs) -> bool: - """ - Request the provided job be stopped; i.e., transitioned to the ``STOPPED`` exec step. + Supported commands are: + - ``list`` : get a list of ids of existing jobs (supports optional ``jobs_list_active_only`` in ``kwargs``) + - ``info`` : get information on a particular job (requires ``job_id`` in ``kwargs``) + - ``release`` : request allocated resources for a job be released (requires ``job_id`` in ``kwargs``) + - ``status`` : get the status of a particular job (requires ``job_id`` in ``kwargs``) + - ``stop`` : request the provided job be stopped (requires ``job_id`` in ``kwargs``) Parameters ---------- - job_id : str - The id of the job in question. - args - (Unused) variable positional args. + command : str + A string indicating the particular job command to run. kwargs - (Unused) variable keyword args. + Other required/optional parameters as needed/desired for the particular job command to be run. Returns ------- - bool - Whether the job was stopped as requested. - """ - # TODO: implement - raise NotImplementedError('{} function "request_job_stop" not implemented yet'.format(self.__class__.__name__)) - - async def request_jobs_list(self, jobs_list_active_only: bool, *args, **kwargs) -> List[str]: + ResultIndicator + An indicator of the results of attempting to run the command. """ - Request a list of ids of existing jobs. - - Parameters - ---------- - jobs_list_active_only : bool - Whether to exclusively include jobs with "active" status values. - args - (Unused) variable positional args. - kwargs - (Unused) variable keyword args. - - Returns - ------- - List[str] - A list of ids of existing jobs. - """ - # TODO: implement - raise NotImplementedError('{} function "request_jobs_list" not implemented yet'.format(self.__class__.__name__)) - - @property - def requests_endpoint_uri(self) -> str: - return self.client_config.requests_endpoint_uri - - @property - def requests_ssl_dir(self) -> Path: - return self.client_config.requests_ssl_dir - - async def submit_ngen_request(self, start: datetime, end: datetime, hydrofabric_data_id: str, hydrofabric_uid: str, - cpu_count: int, realization_cfg_data_id: str, bmi_cfg_data_id: str, - partition_cfg_data_id: Optional[str] = None, cat_ids: Optional[List[str]] = None, - allocation_paradigm: Optional[AllocationParadigm] = None, *args, **kwargs): - return await self.ngen_request_client.request_exec(start, end, hydrofabric_data_id, hydrofabric_uid, - cpu_count, realization_cfg_data_id, bmi_cfg_data_id, - partition_cfg_data_id, cat_ids, allocation_paradigm) - - async def submit_ngen_cal_request(self, start: datetime, end: datetime, hydrofabric_data_id: str, hydrofabric_uid: str, - cpu_count: int, realization_cfg_data_id: str, bmi_cfg_data_id: str, ngen_cal_cfg_data_id: str, - partition_cfg_data_id: Optional[str] = None, cat_ids: Optional[List[str]] = None, - allocation_paradigm: Optional[AllocationParadigm] = None, *args, **kwargs): - return await self.ngen_cal_request_client.request_exec(start, end, hydrofabric_data_id, hydrofabric_uid, - cpu_count, realization_cfg_data_id, bmi_cfg_data_id, ngen_cal_cfg_data_id, - partition_cfg_data_id, cat_ids, allocation_paradigm) + try: + if command == 'info': + return await self.job_client.request_job_info(**kwargs) + elif command == 'list': + return await self.job_client.request_jobs_list(**kwargs) + elif command == 'release': + return await self.job_client.request_job_release(**kwargs) + elif command == 'status': + return await self.job_client.request_job_status(**kwargs) + elif command == 'stop': + return await self.job_client.request_job_stop(**kwargs) + else: + raise ValueError(f"Unsupported job command to {self.__class__.__name__}: {command}") + except NotImplementedError: + raise NotImplementedError(f"Supported command {command} not yet implemented by {self.__class__.__name__}") def print_config(self): - print(self.client_config.print_config()) - - async def upload_to_dataset(self, dataset_name: str, paths: List[Path]) -> bool: - """ - Upload data a dataset. - - Parameters - ---------- - dataset_name : str - The name of the dataset. - paths : List[Path] - List of one or more paths of files to upload or directories containing files to upload. - - Returns - ------- - bool - Whether uploading was successful - """ - return await self.dataset_client.upload_to_dataset(dataset_name, paths) + print(self.client_config.json(by_alias=True, exclude_none=True, indent=2)) def validate_config(self): # TODO: diff --git a/python/lib/client/dmod/client/request_clients.py b/python/lib/client/dmod/client/request_clients.py index 543cfc4d5..773301cbc 100644 --- a/python/lib/client/dmod/client/request_clients.py +++ b/python/lib/client/dmod/client/request_clients.py @@ -1,600 +1,852 @@ from abc import ABC, abstractmethod -from dmod.communication import DataServiceClient, ExternalRequestClient, ManagementAction, ModelExecRequestClient, \ - NGENRequest, NGENRequestResponse, \ - NgenCalibrationRequest, NgenCalibrationResponse -from dmod.communication.client import R +from dmod.communication import (AuthClient, InvalidMessageResponse, ManagementAction, NGENRequest, NGENRequestResponse, + NgenCalibrationRequest, NgenCalibrationResponse, TransportLayerClient) +from dmod.communication.client import ConnectionContextClient from dmod.communication.dataset_management_message import DatasetManagementMessage, DatasetManagementResponse, \ MaaSDatasetManagementMessage, MaaSDatasetManagementResponse, QueryType, DatasetQuery from dmod.communication.data_transmit_message import DataTransmitMessage, DataTransmitResponse -from dmod.communication.session import Session +from dmod.core.exception import DmodRuntimeError from dmod.core.meta_data import DataCategory, DataDomain +from dmod.core.serializable import BasicResultIndicator, ResultIndicator from pathlib import Path -from typing import List, Optional, Tuple, Type, Union -from typing_extensions import Self +from typing import Dict, List, Optional, Sequence, Tuple, Type, Union import json -import websockets #import logging #logger = logging.getLogger("gui_log") -class NgenRequestClient(ModelExecRequestClient[NGENRequest, NGENRequestResponse]): +class JobClient: - # In particular needs - endpoint_uri: str, ssl_directory: Path - def __init__(self, cached_session_file: Optional[Path] = None, *args, **kwargs): + def __init__(self, transport_client: TransportLayerClient, auth_client: AuthClient, *args, **kwargs): super().__init__(*args, **kwargs) - if cached_session_file is None: - self._cached_session_file = Path.home().joinpath('.dmod_client_session') + self._transport_client: TransportLayerClient = transport_client + self._auth_client: AuthClient = auth_client + + async def _submit_job_request(self, request) -> str: + if await self._auth_client.apply_auth(request): + # Some clients may be async context managers + if isinstance(self._transport_client, ConnectionContextClient): + async with self._transport_client as t_client: + await t_client.async_send(data=str(request)) + return await t_client.async_recv() + else: + await self._transport_client.async_send(data=str(request)) + return await self._transport_client.async_recv() else: - self._cached_session_file = cached_session_file + msg = f"{self.__class__.__name__} could not use {self._auth_client.__class__.__name__} to authenticate " \ + f"{request.__class__.__name__}" + raise RuntimeError(msg) - # TODO: need some integration tests for this and CLI main and arg parsing - async def request_exec(self, *args, **kwargs) -> NGENRequestResponse: - await self._async_acquire_session_info() - request = NGENRequest(session_secret=self.session_secret, *args, **kwargs) - return await self.async_make_request(request) + async def get_jobs_list(self, active_only: bool) -> List[str]: + """ + Get a list of ids of existing jobs. + A convenience wrapper around ::method:`request_jobs_list` that returns just the list of job ids rather than the + full ::class:`ResultsIndicator` object. -class NgenCalRequestClient(ModelExecRequestClient[NgenCalibrationRequest, NgenCalibrationResponse]): + Parameters + ---------- + active_only : bool + Whether only the ids of active jobs should be included. - # In particular needs - endpoint_uri: str, ssl_directory: Path - def __init__(self, cached_session_file: Optional[Path] = None, *args, **kwargs): - super().__init__(*args, **kwargs) - if cached_session_file is None: - self._cached_session_file = Path.home().joinpath('.dmod_client_session') - else: - self._cached_session_file = cached_session_file + Returns + ------- + List[str] + The list of ids of existing jobs. - # TODO: need some integration tests for this and CLI main and arg parsing - async def request_exec(self, *args, **kwargs) -> NgenCalibrationResponse: - await self._async_acquire_session_info() - request = NgenCalibrationRequest(session_secret=self.session_secret, *args, **kwargs) - return await self.async_make_request(request) + Raises + ------- + RuntimeError + If the indicator returned by ::method:`request_jobs_list` has ``success`` value of ``False``. + See Also + ------- + request_jobs_list + """ + indicator = await self.request_jobs_list(jobs_list_active_only=active_only) + if not indicator.success: + raise RuntimeError(f"{self.__class__.__name__} received failure indicator getting list of jobs.") + else: + return indicator.data -class DatasetClient(ABC): + # TODO: this is going to need some adjustments to the type hinting + async def request_job_info(self, job_id: str, *args, **kwargs) -> ResultIndicator: + """ + Request the full state of the provided job, formatted as a JSON dictionary. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.last_response = None + Parameters + ---------- + job_id : str + The id of the job in question. + args + (Unused) variable positional args. + kwargs + (Unused) variable keyword args. - def _parse_list_of_dataset_names_from_response(self, response: DatasetManagementResponse) -> List[str]: + Returns + ------- + ResultIndicator + An indicator of success of the request that, when successful, contains he full state of the provided job, + formatted as a JSON dictionary, in the ``data`` attribute. """ - Parse an included list of dataset names from a received management response. + # TODO: implement + raise NotImplementedError('{} function "request_job_info" not implemented yet'.format(self.__class__.__name__)) - Note that an unsuccessful response, or a response (of the correct type) that does not explicitly include the - expected data attribute with dataset names, will result in an empty list being returned. However, an unexpected - type for the parameter will cause a ::class:`RuntimeError`. + async def request_job_release(self, job_id: str, *args, **kwargs) -> ResultIndicator: + """ + Request the allocated resources for the provided job be released. Parameters ---------- - response : DatasetManagementResponse - The response message from which to parse dataset names. + job_id : str + The id of the job in question. + args + (Unused) variable positional args. + kwargs + (Unused) variable keyword args. Returns ------- - List[str] - The list of parsed dataset names. + ResultIndicator + An indicator of whether there had been allocated resources for the job, all of which are now released. + """ + # TODO: implement + raise NotImplementedError('{} function "request_job_release" not implemented yet'.format(self.__class__.__name__)) - Raises - ------- - RuntimeError - Raised if the parameter is not a ::class:`DatasetManagementResponse` (or subtype) object. + async def request_job_status(self, job_id: str, *args, **kwargs) -> BasicResultIndicator: """ - if not isinstance(response, DatasetManagementResponse): - msg = "Can't parse list of datasets from non-{} (received a {} object)" - raise RuntimeError(msg.format(DatasetManagementResponse.__name__, response.__class__.__name__)) - # Consider these as valid cases, and treat them as just not listing any datasets - elif not response.success or response.data is None or response.data.datasets is None: - return [] - else: - return response.data.datasets + Request the status of the provided job. - @abstractmethod - async def create_dataset(self, name: str, category: DataCategory, domain: DataDomain, **kwargs) -> bool: - pass + The status value will be serialized to a string and included as the ``data`` attribute of the returned + ::class:`ResultIndicator`. - @abstractmethod - async def delete_dataset(self, name: str, **kwargs) -> bool: - pass + Parameters + ---------- + job_id : str + The id of the job in question. + args + (Unused) variable positional args. + kwargs + (Unused) variable keyword args. - @abstractmethod - async def download_dataset(self, dataset_name: str, dest_dir: Path) -> bool: + Returns + ------- + BasicResultIndicator + An indicator that, when successful, includes as ``data`` the serialized status string of the provided job. """ - Download an entire dataset to a local directory. + # TODO: implement + raise NotImplementedError('{} function "request_job_status" not implemented yet'.format(self.__class__.__name__)) + + async def request_job_stop(self, job_id: str, *args, **kwargs) -> ResultIndicator: + """ + Request the provided job be stopped; i.e., transitioned to the ``STOPPED`` exec step. Parameters ---------- - dataset_name : str - The dataset of interest. - dest_dir : Path - Path to the local directory to which to save the dataset's data. + job_id : str + The id of the job in question. + args + (Unused) variable positional args. + kwargs + (Unused) variable keyword args. Returns ------- - bool - Whether the download was successful. + ResultIndicator + An indicator of whether the job was stopped as requested. """ - pass + # TODO: implement + raise NotImplementedError('{} function "request_job_stop" not implemented yet'.format(self.__class__.__name__)) - @abstractmethod - async def download_from_dataset(self, dataset_name: str, item_name: str, dest: Path) -> bool: + async def request_jobs_list(self, jobs_list_active_only: bool, *args, **kwargs) -> BasicResultIndicator: """ - Download a specific item within a dataset to a local path. + Request a list of ids of existing jobs. - Exactly what an "item" is is implementation specific, and should be documented. + The list of ids will be included as the ``data`` attribute of the returned ::class:`ResultIndicator`. Parameters ---------- - dataset_name : str - The dataset of interest. - item_name : str - The name of the item within a dataset to download. - dest : Path - A local path at which to save the downloaded item. + jobs_list_active_only : bool + Whether to exclusively include jobs with "active" status values. + args + (Unused) variable positional args. + kwargs + (Unused) variable keyword args. Returns ------- + BasicResultIndicator + An indicator that, when successful, includes as ``data`` the list of ids of existing jobs. + + See Also + ------- + get_jobs_list + """ + # TODO: implement + raise NotImplementedError('{} function "request_jobs_list" not implemented yet'.format(self.__class__.__name__)) + + async def submit_ngen_request(self, **kwargs) -> NGENRequestResponse: + return NGENRequestResponse.factory_init_from_deserialized_json( + json.loads(await self._submit_job_request(request=NGENRequest(request_body=kwargs, **kwargs)))) + + async def submit_ngen_cal_request(self, **kwargs) -> NgenCalibrationResponse: + return NgenCalibrationResponse.factory_init_from_deserialized_json( + json.loads(await self._submit_job_request(request=NgenCalibrationRequest(request_body=kwargs, **kwargs)))) + async def submit_request_from_file(self, job_type: str, request_file: Path, *kwargs) -> ResultIndicator: + """ + Submit a serialized job request stored in the given file. + + Parameters + ---------- + job_type : str + String representation of the type of request: either ``ngen`` or ``ngen_cal``. + request_file : Path + The supplied file containing a JSON string, which should be a serialized, applicable request object. """ + return await self.submit_request_from_json(job_type=job_type, request_json=request_file.read_text()) + + async def submit_request_from_json(self, job_type: str, request_json: Union[dict, str], **kwargs) -> ResultIndicator: + """ + Submit a supplied, serialized job request. + + Parameters + ---------- + job_type : str + String representation of the type of request: either ``ngen`` or ``ngen_cal``. + request_json : Union[dict, str] + The serialized representation of a request as JSON, either as a ``str`` or already inflated to a JSON + ``dict`` object. + """ + if isinstance(request_json, str): + request_json = json.loads(request_json) + + if job_type == 'ngen': + return NGENRequestResponse.factory_init_from_deserialized_json( + json.loads(await self._submit_job_request(NGENRequest.factory_init_from_deserialized_json(request_json)))) + elif job_type == 'ngen_cal': + return NgenCalibrationResponse.factory_init_from_deserialized_json( + json.loads(await self._submit_job_request(NgenCalibrationRequest.factory_init_from_deserialized_json(request_json)))) + else: + raise RuntimeError(f"Invalid job type indicator for serialized job request: {job_type}") + + +class DataTransferAgent(ABC): + + @abstractmethod + async def download_dataset_item(self, dataset_name: str, item_name: str, dest: Path): pass @abstractmethod - async def list_datasets(self, category: Optional[DataCategory] = None) -> List[str]: + async def upload_dataset_item(self, dataset_name: str, item_name: str, source: Path) -> DatasetManagementResponse: pass + @property @abstractmethod - async def upload_to_dataset(self, dataset_name: str, paths: List[Path]) -> bool: + def uses_auth(self) -> bool: """ - Upload data a dataset. + Whether this particular agent instance uses authentication when interacting with the other party. - Parameters - ---------- - dataset_name : str - The name of the dataset. - paths : List[Path] - List of one or more paths of files to upload or directories containing files to upload. + Clients that use auth Returns ------- bool - Whether uploading was successful + Whether this particular agent uses authentication in interactions. """ pass -class DatasetInternalClient(DatasetClient, DataServiceClient): +class SimpleDataTransferAgent(DataTransferAgent): - @classmethod - def get_response_subtype(cls) -> Type[R]: - return DatasetManagementResponse - - def __init__(self, *args, **kwargs): + def __init__(self, transport_client: TransportLayerClient, auth_client: Optional[AuthClient] = None, *args, **kwargs): super().__init__(*args, **kwargs) + self._transport_client: TransportLayerClient = transport_client + self._auth_client: Optional[AuthClient] = auth_client - async def create_dataset(self, name: str, category: DataCategory, domain: DataDomain, **kwargs) -> bool: - # TODO: (later) consider also adding param for data to be added - request = DatasetManagementMessage(action=ManagementAction.CREATE, domain=domain, dataset_name=name, - category=category) - self.last_response = await self.async_make_request(request) - return self.last_response is not None and self.last_response.success + async def _transfer_receiver(self): + """ + Receive a series of data transmit messages, with the transfer already initiated. - async def delete_dataset(self, name: str, **kwargs) -> bool: - request = DatasetManagementMessage(action=ManagementAction.DELETE, dataset_name=name) - self.last_response = await self.async_make_request(request) - return self.last_response is not None and self.last_response.success + Note that this generator expects that the first message received be the first ::class:`DataTransmitMessage`. + Additionally, after the last ``yield``, which will be the final necessary ::class:`DataTransmitMessage` that has + ``is_last`` value of ``True``, the transport client should expect to immediately receive the final + ::class:`DatasetManagementResponse` message closing the request. - async def download_dataset(self, dataset_name: str, dest_dir: Path) -> bool: + Yields + ------- + DataTransmitMessage + The next data transmit message in the transfer series. + """ + incoming_obj: DataTransmitMessage = None + + while incoming_obj is None or not incoming_obj.is_last: + # TODO: may need to make messages at the transport level session aware to make this work with shared connections/client + incoming_data = await self._transport_client.async_recv() + incoming_obj = DataTransmitMessage.factory_init_from_deserialized_json(json.loads(incoming_data)) + if not isinstance(incoming_obj, DataTransmitMessage): + await self._transport_client.async_send(str(InvalidMessageResponse())) + raise DmodRuntimeError(f"{self.__class__.__name__} could not deserialize DataTransmitMessage in data " + f"transfer receipt attempt") + # TODO: confirm that this works as expected/needed after the last data message + reply_obj = DataTransmitResponse.create_for_received(received_msg=incoming_obj) + await self._transport_client.async_send(str(reply_obj)) + yield incoming_obj + + async def _request_prep(self, dataset_name: str, item_name: str, action: ManagementAction) -> Tuple[DatasetManagementMessage, Type[DatasetManagementResponse]]: """ - Download an entire dataset to a local directory. + Prep a download or upload initial request. Parameters ---------- - dataset_name : str - The dataset of interest. - dest_dir : Path - Path to the local directory to which to save the dataset's data. + dataset_name + item_name + action Returns ------- - bool - Whether the download was successful. + Tuple[DatasetManagementMessage, Type[DatasetManagementResponse]] + A tuple of two items: + - the prepared initial request + - the appropriate type for response objects, depending on whether authentication is being used """ - try: - dest_dir.mkdir(parents=True, exist_ok=True) - except: - return False - success = True - query = DatasetQuery(query_type=QueryType.LIST_FILES) - request = DatasetManagementMessage(action=ManagementAction.QUERY, dataset_name=dataset_name, query=query) - self.last_response: DatasetManagementResponse = await self.async_make_request(request) - # TODO: (later) need to formalize this a little better than just here (and whereever it is serialized) - results = self.last_response.query_results - for item, dest in [(f, dest_dir.joinpath(f)) for f in (results['files'] if 'files' in results else [])]: - dest.parent.mkdir(exist_ok=True) - success = success and await self.download_from_dataset(dataset_name=dataset_name, item_name=item, dest=dest) - return success - - async def download_from_dataset(self, dataset_name: str, item_name: str, dest: Path) -> bool: + req_params = {'action': action, 'dataset_name': dataset_name, 'data_location': item_name} + + if self.uses_auth: + # This will be replaced as soon as we call apply_auth, but some string is required for __init__ + req_params['session_secret'] = '' + request = MaaSDatasetManagementMessage(**req_params) + # TODO: (later) implement and use DmodAuthenticationFailure, though possibly down the apply_auth call stack + if not await self._auth_client.apply_auth(request): + msg = f'{self.__class__.__name__} could not apply auth to {request.__class__.__name__}' + raise DmodRuntimeError(msg) + else: + return request, MaaSDatasetManagementResponse + else: + return DatasetManagementMessage(**req_params), DatasetManagementResponse + + async def download_dataset_item(self, dataset_name: str, item_name: str, dest: Path) -> DatasetManagementResponse: if dest.exists(): - return False + reason = f'Destination File Exists' + msg = f'{self.__class__.__name__} could not download dataset item to existing path {str(dest)}' + return DatasetManagementResponse(success=False, reason=reason, message=msg) + try: dest.parent.mkdir(parents=True, exist_ok=True) except: - return False - request = DatasetManagementMessage(action=ManagementAction.REQUEST_DATA, dataset_name=dataset_name, - data_location=item_name) - self.last_response: DatasetManagementResponse = await self.async_make_request(request) + reason = f'Unable to Create Parent Directory' + msg = f'{self.__class__.__name__} could not create parent directory for downloading item to {str(dest)}' + return DatasetManagementResponse(success=False, reason=reason, message=msg) + + try: + request, final_response_type = self._request_prep(dataset_name=dataset_name, item_name=item_name, + action=ManagementAction.REQUEST_DATA) + # TODO: (later) implement and use DmodAuthenticationFailure + except DmodRuntimeError as e: + reason = f'{self.__class__.__name__} Download Auth Failure' + return MaaSDatasetManagementResponse(success=False, reason=reason, message=str(e)) + with dest.open('w') as file: - for page in range(1, (self.last_response.total_pages + 1)): - request = DatasetManagementMessage(action=ManagementAction.DOWNLOAD_DATA, dataset_name=dataset_name, - data_location=item_name, page=page) - self.last_response: DatasetManagementResponse = await self.async_make_request(request) - file.write(self.last_response.file_data) + # Do initial request outside of generator + await self._transport_client.async_send(data=str(request)) + async for received_data_msg in self._transfer_receiver(): + data = received_data_msg.data + while data: + bytes_written = file.write(data) + data = data[bytes_written:] + final_data = await self._transport_client.async_recv() - async def list_datasets(self, category: Optional[DataCategory] = None) -> List[str]: - action = ManagementAction.LIST_ALL if category is None else ManagementAction.SEARCH - request = DatasetManagementMessage(action=action, category=category) - self.last_response = await self.async_make_request(request) - return self._parse_list_of_dataset_names_from_response(self.last_response) + try: + final_response_json = json.loads(final_data) + except Exception as e: + msg = f"{self.__class__.__name__} failed with {e.__class__.__name__} parsing `{final_data}` to JSON)" + return final_response_type(success=False, reason=f"JSON Parse Failure On Final Response", message=msg) + + final_response = final_response_type.factory_init_from_deserialized_json(final_response_json) + if final_response is None: + return final_response_type(success=False, reason="Failed to Deserialize Final Response") + else: + return final_response - async def upload_to_dataset(self, dataset_name: str, paths: List[Path]) -> bool: + async def upload_dataset_item(self, dataset_name: str, item_name: str, source: Path) -> DatasetManagementResponse: + if not source.is_file(): + return DatasetManagementResponse(success=False, reason="Dataset Upload File Not Found", + message=f"File {source!s} does not exist") + try: + message, final_response_type = self._request_prep(dataset_name=dataset_name, item_name=item_name, + action=ManagementAction.ADD_DATA) + # TODO: (later) implement and use DmodAuthenticationFailure + except DmodRuntimeError as e: + reason = f'{self.__class__.__name__} Upload Request Auth Failure' + return MaaSDatasetManagementResponse(success=False, reason=reason, message=str(e)) + + chunk_size = 1024 + + with source.open('r') as file: + last_send = False + raw_chunk = file.read(chunk_size) + + while True: + await self._transport_client.async_send(data=str(message)) + + response_json = json.loads(await self._transport_client.async_recv()) + response = final_response_type.factory_init_from_deserialized_json(response_json) + if response is not None: + return response + elif last_send: + msg = f"{self.__class__.__name__} should have received {final_response_type.__name__} here" + raise DmodRuntimeError(msg) + + response = DataTransmitResponse.factory_init_from_deserialized_json(response_json) + if response is None: + msg = f"{self.__class__.__name__} couldn't parse response to request to upload {source!s}" + return final_response_type(success=False, reason="Unparseable Upload Init", message=msg) + elif not response.success: + msg = f"{self.__class__.__name__} received {response.__class__.__name__} indicating failure" + return final_response_type(success=False, reason="Failed Upload Transfer", message=msg) + + # Look ahead to see if this is the last transmission ... + next_chunk = file.read(chunk_size) + # ... keep track if it is last ... + last_send = not bool(next_chunk) + # ... and also note in message if it is last ... + message = DataTransmitMessage(data=raw_chunk, series_uuid=response.series_uuid, is_last=last_send) + + # Then once that chunk is sent, bump the look-ahead to the current + raw_chunk = next_chunk + + @property + def uses_auth(self) -> bool: """ - Upload data a dataset. + Whether this particular agent instance uses authentication when interacting with the other party. - Parameters - ---------- - dataset_name : str - The name of the dataset. - paths : List[Path] - List of one or more paths of files to upload or directories containing files to upload. + Clients that use auth Returns ------- bool - Whether uploading was successful + Whether this particular agent uses authentication in interactions. """ - # TODO: ********************************************* - raise NotImplementedError('Function upload_to_dataset not implemented') - + return self._auth_client is not None -class DatasetExternalClient(DatasetClient, - ExternalRequestClient[MaaSDatasetManagementMessage, MaaSDatasetManagementResponse]): - """ - Client for authenticated communication sessions via ::class:`MaaSDatasetManagementMessage` instances. - """ - # In particular needs - endpoint_uri: str, ssl_directory: Path - def __init__(self, *args, cache_session: bool = True, **kwargs): - super().__init__(*args, **kwargs) - self._cached_session_file: Optional[Path] = ( - Path.home().joinpath(".dmod_client_session") if cache_session else None - ) +class DataServiceClient: @classmethod - def from_session(cls, *, endpoint_uri: str, ssl_directory: Path, session: Session, cache_session: bool = True, **kwargs) -> Self: + def extract_dataset_names(cls, response: DatasetManagementResponse) -> List[str]: """ - Create a `DatasetExternalClient` from an existing `Session` instance. + Parse response object for an included list of dataset names. + + Parse a received ::class:`DatasetManagementResponse` for a list of dataset names. Generally, this should be the + response to a request of either the ``LIST_ALL`` or ``SEARCH`` ::class:`ManagementAction` value. + + An unsuccessful response or a response that does not container the dataset names data will return an empty list. + + Parameters + ---------- + response : DatasetManagementResponse + The response message from which to parse dataset names. + + Returns + ------- + List[str] + The list of parsed dataset names. - Note, the passed `Session` object will not be written to disk even if the `cache_session` - flag is present. However, if the `Session` instance expires and a new session is acquired, - it will be cached if `cache_session` is set. + Raises + ------- + DmodRuntimeError + Raised if the parameter is not a ::class:`DatasetManagementResponse` (or subtype) object. """ - client = cls(endpoint_uri=endpoint_uri, ssl_directory=ssl_directory, cache_session=cache_session, **kwargs) - client._session_id = session.session_id - client._session_secret = session.session_secret - client._session_created = session.created - client._is_new_session = False - return client + if not isinstance(response, DatasetManagementResponse): + raise DmodRuntimeError(f"{cls.__name__} can't parse list of datasets from {response.__class__.__name__}") + # Consider these as valid cases, and treat them as just not listing any datasets + elif not response.success or response.data is None or response.data.datasets is None: + return [] + else: + return response.data.datasets + + def __init__(self, transport_client: TransportLayerClient, auth_client: Optional[AuthClient] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + self._transport_client: TransportLayerClient = transport_client + self._auth_client: Optional[AuthClient] = auth_client - def _acquire_session_info(self, use_current_values: bool = True, force_new: bool = False): + async def _process_request(self, request: DatasetManagementMessage) -> DatasetManagementResponse: """ - Attempt to set the session information properties needed to submit a maas request. + Reusable, general helper function to process a custom-assembled request for the data service. + + Function recreates request as auth-supporting type and applies auth when necessary, then processes the send and + parsing of the response. Parameters ---------- - use_current_values : bool - Whether to use currently held attribute values for session details, if already not None (disregarded if - ``force_new`` is ``True``). - force_new : bool - Whether to force acquiring a new session, regardless of data available is available on an existing session. + request : DatasetManagementMessage + The assembled request message, of a type that does not yet (and cannot) have any auth applied. Returns ------- - bool - whether session details were acquired and set successfully + DatasetManagementResponse + A response to the request from the service to delete the given dataset, which may actually be a + ::class:`MaaSDatasetManagementMessage` depending on whether this type uses auth. + + Raises + ------- + DmodRuntimeError + If the response from the service cannot be deserialized successfully to the expected response type. + + See Also + ------- + uses_auth """ - #logger.info("{}._acquire_session_info: getting session info".format(self.__class__.__name__) - if not force_new and use_current_values and self._session_id and self._session_secret and self._session_created: - #logger.info('Using previously acquired session details (new session not forced)') - return True + if self.uses_auth: + request = MaaSDatasetManagementMessage.factory_create(mgmt_msg=request, session_secret='') + all_required_auth_is_applied = await self._auth_client.apply_auth(request) + response_type = MaaSDatasetManagementResponse else: - #logger.info("Session from JobRequestClient: force_new={}".format(force_new)) - tmp = self._acquire_new_session() - #logger.info("Session Info Return: {}".format(tmp)) - return tmp - - async def _async_acquire_session_info(self, use_current_values: bool = True, force_new: bool = False): - if ( - use_current_values - and not force_new - and self._cached_session_file is not None - and self._cached_session_file.exists() - ): - try: - session_id, secret, created = self.parse_session_auth_text( - self._cached_session_file.read_text() - ) - self._session_id = session_id - self._session_secret = secret - self._session_create = created - except Exception as e: - # TODO: consider logging; for now, just don't bail and move on to logic for new session - pass - - if ( - not force_new - and use_current_values - and self._session_id - and self._session_secret - and self._session_created - ): - # logger.info('Using previously acquired session details (new session not forced)') - return True + # If no auth was required, treat as though all required auth as applied + all_required_auth_is_applied = True + response_type = DatasetManagementResponse + + if not all_required_auth_is_applied: + reason = f'{self.__class__.__name__} Request Auth Failure' + msg = f'{self.__class__.__name__} create_dataset could not apply auth to {request.__class__.__name__}' + return response_type(success=False, reason=reason, message=msg) + + # Some clients may be async context managers + if isinstance(self._transport_client, ConnectionContextClient): + async with self._transport_client as t_client: + await t_client.async_send(data=str(request)) + response_data = await t_client.async_recv() else: - # TODO: look at if there needs to be an addition to connection count, active connections, or something here - tmp = await self._async_acquire_new_session( - cached_session_file=self._cached_session_file - ) - # logger.info("Session Info Return: {}".format(tmp)) - return tmp + await self._transport_client.async_send(data=str(request)) + response_data = await self._transport_client.async_recv() + + response_obj = response_type.factory_init_from_deserialized_json(json.loads(response_data)) + if not isinstance(response_obj, response_type): + msg = f"{self.__class__.__name__} could not deserialize {response_type.__name__} from raw response data" \ + f" '{response_data}'" + raise DmodRuntimeError(msg) + else: + return response_obj - def _process_data_download_iteration(self, raw_received_data: str) -> Tuple[bool, Union[DataTransmitMessage, MaaSDatasetManagementResponse]]: + # TODO: better integrate uploading initial data into the create request itself + async def create_dataset(self, name: str, category: DataCategory, domain: DataDomain, + upload_paths: Optional[List[Path]] = None, data_root: Optional[Path] = None, + **kwargs) -> DatasetManagementResponse: """ - Helper function for processing a single iteration of the process of downloading data. + Create a dataset from the given parameters. - Function process the received param, assumed to be received from the data service via a websocket connection, - by loading it to JSON and attempting to deserialize it, first to a ::class:`MaaSDatasetManagementResponse`, then - to a ::class:`DataTransmitMessage`. If both fail, a ::class:`MaaSDatasetManagementResponse` indicating failure - is created. + Parameters + ---------- + name : str + The name for the dataset. + category : DataCategory + The category for the dataset. + domain : DataDomain + The defined domain for the dataset. + upload_paths : Union[Path, List[Path]] + List of paths of files to upload. + data_root : Optional[Path] + A relative data root directory, used to adjust the names for uploaded items. + kwargs - To minimize later processing, a tuple is instead returned, containing not only the obtained message, but also - whether it contains transmitted data. Note that the obtained message is the second tuple item. + Returns + ------- + DatasetManagementResponse + A response to the request for the service to create a new dataset, which may actually be a + ::class:`MaaSDatasetManagementMessage` depending on whether this instance uses auth. + + Raises + ------- + DmodRuntimeError + If the response from the service cannot be deserialized successfully to the expected response type. + + See Also + ------- + upload_to_dataset + uses_auth + """ + request = DatasetManagementMessage(action=ManagementAction.CREATE, domain=domain, dataset_name=name, + category=category) + try: + create_response = await self._process_request(request=request) + except DmodRuntimeError as e: + raise DmodRuntimeError(f"DMOD error when creating dataset: {str(e)}") + + if not create_response.success or not upload_paths: + return create_response + + upload_response = await self.upload_to_dataset(dataset_name=name, paths=upload_paths, data_root=data_root) + if upload_response.success: + return create_response + else: + create_response.success = False + create_response.reason = "Initial Uploads Failed" + create_response.message = f"Dataset {name} was created, but upload failures occurred: `{upload_response!s}`" + return create_response + + async def does_dataset_exist(self, dataset_name: str, **kwargs) -> bool: + """ + Helper function to test whether a dataset of the given name/id exists. Parameters ---------- - raw_received_data : str - The raw message text data, received over a websocket connection to the data service, expected to be either a - serialized ::class:`DataTransmitMessage` or ::class:`MaaSDatasetManagementResponse`. + dataset_name : str + The hypothetical dataset name of interest. Returns ------- - Tuple[bool, Union[DataTransmitMessage, MaaSDatasetManagementResponse]] - A tuple of whether the returned message for data transmission (i.e., contains data) and a returned message - that either contains download data or is a management response indicating the download process is finished. + bool + Whether a dataset of the given name exists with the data service. """ + # FIXME: optimize this more effectively later. + return dataset_name in await self.list_datasets() + + async def delete_dataset(self, name: str, **kwargs) -> DatasetManagementResponse: + """ + Delete a dataset. + + Parameters + ---------- + name : str + The unique name of the dataset to delete. + kwargs + + Returns + ------- + DatasetManagementResponse + A response to the request from the service to delete the given dataset, which may actually be a + ::class:`MaaSDatasetManagementMessage` depending on whether this type uses auth. + + Raises + ------- + DmodRuntimeError + If the response from the service cannot be deserialized successfully to the expected response type. + + See Also + ------- + uses_auth + """ + request = DatasetManagementMessage(action=ManagementAction.DELETE, dataset_name=name) try: - received_as_json = json.loads(raw_received_data) - except: - received_as_json = '' - - # Try to deserialize to this type 1st; if message is something else (e.g., more data), we'll get None, - # but if message deserializes to this kind of object, then this will be the last (and only) message - received_message = MaaSDatasetManagementResponse.factory_init_from_deserialized_json(received_as_json) - if received_message is not None: - return False, received_message - # If this wasn't deserialized to a response before, and wasn't to a data transmit just now, then bail - received_message = DataTransmitMessage.factory_init_from_deserialized_json(received_as_json) - if received_message is None: - message_obj = MaaSDatasetManagementResponse(success=False, action=ManagementAction.REQUEST_DATA, - reason='Unparseable Message') - return False, message_obj - else: - return True, received_message + return await self._process_request(request=request) + except DmodRuntimeError as e: + raise DmodRuntimeError(f"DMOD error when deleting dataset: {str(e)}") - def _update_after_valid_response(self, response: MaaSDatasetManagementResponse): + # TODO: this needs a storage client instead of to figure out where/how to "put" the data + async def get_dataset_names(self, category: Optional[DataCategory] = None, **kwargs) -> DatasetManagementResponse: """ - Perform any required internal updates immediately after a request gets back a successful, valid response. + Get a list of the names of datasets, optionally filtering to a specific category. + + Parameters + ---------- + category : DataCategory + Optional exclusively ::class:`DataCategory` to consider. - This provides a way of extending the behavior of this type specifically regarding the ::method:make_maas_request - function. Any updates specific to the type, which should be performed after a request receives back a valid, - successful response object, can be implemented here. + Returns + ------- + DatasetManagementResponse + The returned response object itself that, when successful, contains a list of dataset names. + + See Also + ------- + list_datasets + """ + action = ManagementAction.LIST_ALL if category is None else ManagementAction.SEARCH + request = DatasetManagementMessage(action=action, category=category) + try: + return await self._process_request(request=request) + except DmodRuntimeError as e: + raise DmodRuntimeError(f"DMOD error when getting dataset names: {str(e)}") + + async def get_dataset_item_names(self, dataset_name: str, **kwargs) -> DatasetManagementResponse: + """ + Request the name/id of all items in the given dataset. Parameters ---------- - response : MaaSDatasetManagementResponse - The response triggering the update. + dataset_name : str + The name/id of the dataset of interest. + + Returns + ------- + DatasetManagementResponse + A response containing item names if successful, or indicating failure. + """ + request = DatasetManagementMessage(action=ManagementAction.QUERY, dataset_name=dataset_name, + query=DatasetQuery(query_type=QueryType.LIST_FILES)) + try: + return await self._process_request(request=request) + except DmodRuntimeError as e: + raise DmodRuntimeError(f"DMOD error when getting dataset item: {str(e)}") + + async def list_datasets(self, category: Optional[DataCategory] = None, **kwargs) -> List[str]: + """ + Convenience method to list datasets, optionally filtering to a specific category. + + Function simply makes a nested call to ::method:`get_dataset_names` and parses the names of the datasets using + ::method:`extract_dataset_names`. + + Parameters + ---------- + category : DataCategory + Optional exclusively ::class:`DataCategory` to consider. + + Returns + ------- + List[str] + The list of dataset names, or an empty list if the request was not successful. See Also ------- - ::method:make_maas_request + extract_dataset_names + get_dataset_item_names """ - # TODO: think about if anything is needed for this - pass + return self.extract_dataset_names(response=await self.get_dataset_names(category=category, **kwargs)) - async def _upload_file(self, dataset_name: str, path: Path, item_name: str) -> bool: + async def list_dataset_items(self, dataset_name: str, **kwargs) -> List[str]: """ - Upload a single file to the dataset + Convenience method to get the list of items within a dataset. Parameters ---------- dataset_name : str - The name of the destination dataset. - path : Path - The path of the local file to upload. - item_name : str - The name of the destination dataset item in which to place the data. + The name/id of the dataset of interest. + Returns ------- - bool - Whether the data upload was successful. + List[str] + The list of item names, or an empty list if the request was not successful. + + See Also + ------- + get_dataset_item_names """ - await self._async_acquire_session_info() - #raw_data = path.read_bytes() - chunk_size = 1024 - message = MaaSDatasetManagementMessage(action=ManagementAction.ADD_DATA, dataset_name=dataset_name, - session_secret=self.session_secret, data_location=item_name) - async with websockets.connect(self.endpoint_uri, ssl=self.client_ssl_context) as websocket: - with path.open() as file: - raw_chunk = file.read(chunk_size) - while True: - await websocket.send(str(message)) - response_json = json.loads(await websocket.recv()) - response = MaaSDatasetManagementResponse.factory_init_from_deserialized_json(response_json) - if response is not None: - self.last_response = response - return response.success - response = DataTransmitResponse.factory_init_from_deserialized_json(response_json) - if response is None: - return False - if not response.success: - self.last_response = response - return response.success - # If here, we must have gotten a transmit response indicating we can send more data, so prime the next - # sending message for the start of the loop - next_chunk = file.read(chunk_size) - message = DataTransmitMessage(data=raw_chunk, series_uuid=response.series_uuid, - is_last=not bool(next_chunk)) - raw_chunk = next_chunk - - async def _upload_dir(self, dataset_name: str, dir_path: Path, item_name_prefix: str = '') -> bool: - """ - Upload the contents of a local directory to a dataset. + response = await self.get_dataset_item_names(dataset_name=dataset_name, **kwargs) + return response.query_results.get('files', []) if response.success else [] + + async def retrieve_from_dataset(self, dataset_name: str, dest_dir: Path, + item_names: Optional[Union[str, Sequence[str]]] = None, **kwargs) -> ResultIndicator: + """ + Download data from either all or specific item(s) within a dataset to a local path. + + Items are saved with the same name, relative to the ``dest`` directory. If item names have '/' characters, it + is assumed they were to emulate file system paths in the dataset storage location. As such, these are separated + and treated as nested directories within ``dest`` and created as needed. Parameters ---------- dataset_name : str - The name of the dataset. - dir_path : Path - The path of the local directory containing data files to upload. - item_name_prefix : str - A prefix to append to the name of destination items (otherwise equal to the local data files basename), used - to make recursive calls to this function on subdirectories and emulate the local directory structure. + The dataset of interest. + dest_dir : Path + A local directory path under which to save the downloaded item(s). + item_names : Optional[Union[str, Sequence[str]]] = None + The name(s) of specific item(s) within a dataset to download; if ``None`` (the default), download all items. Returns ------- - bool - Whether data upload was successful. - """ - success = True - for child in dir_path.iterdir(): - if child.is_dir(): - new_prefix = '{}{}/'.format(item_name_prefix, child.name) - success = success and await self._upload_dir(dataset_name=dataset_name, dir_path=child, - item_name_prefix=new_prefix) - else: - success = success and await self._upload_file(dataset_name=dataset_name, path=child, - item_name='{}{}'.format(item_name_prefix, child.name)) - return success - - async def create_dataset(self, name: str, category: DataCategory, domain: DataDomain, **kwargs) -> bool: - await self._async_acquire_session_info() - # TODO: (later) consider also adding param for data to be added - request = MaaSDatasetManagementMessage(session_secret=self.session_secret, action=ManagementAction.CREATE, - domain=domain, dataset_name=name, category=category) - self.last_response = await self.async_make_request(request) - return self.last_response is not None and self.last_response.success - - async def delete_dataset(self, name: str, **kwargs) -> bool: - await self._async_acquire_session_info() - request = MaaSDatasetManagementMessage(session_secret=self.session_secret, action=ManagementAction.DELETE, - dataset_name=name) - self.last_response = await self.async_make_request(request) - return self.last_response is not None and self.last_response.success - - async def download_dataset(self, dataset_name: str, dest_dir: Path) -> bool: - await self._async_acquire_session_info() - try: - dest_dir.mkdir(parents=True, exist_ok=True) - except: - return False - success = True - query = DatasetQuery(query_type=QueryType.LIST_FILES) - request = MaaSDatasetManagementMessage(action=ManagementAction.QUERY, dataset_name=dataset_name, query=query, - session_secret=self.session_secret) - self.last_response: MaaSDatasetManagementResponse = await self.async_make_request(request) - for item, dest in [(filename, dest_dir.joinpath(filename)) for filename in self.last_response.query_results]: - dest.parent.mkdir(exist_ok=True) - success = success and await self.download_from_dataset(dataset_name=dataset_name, item_name=item, dest=dest) - return success - - async def download_from_dataset(self, dataset_name: str, item_name: str, dest: Path) -> bool: - await self._async_acquire_session_info() - if dest.exists(): - return False - try: - dest.parent.mkdir(parents=True, exist_ok=True) - except: - return False - - request = MaaSDatasetManagementMessage(action=ManagementAction.REQUEST_DATA, dataset_name=dataset_name, - session_secret=self.session_secret, data_location=item_name) - async with websockets.connect(self.endpoint_uri, ssl=self.client_ssl_context) as websocket: - # Do this once outside loop, so we don't open a file for writing to which nothing is written - await websocket.send(str(request)) - has_data, message_object = self._process_data_download_iteration(await websocket.recv()) - if not has_data: - return message_object - - # Here, we will have our first piece of data to write, so open file and start our loop - with dest.open('w') as file: - while True: - file.write(message_object.data) - # Do basically same as above, except here send message to acknowledge data just written was received - await websocket.send(str(DataTransmitResponse(success=True, reason='Data Received', - series_uuid=message_object.series_uuid))) - has_data, message_object = self._process_data_download_iteration(await websocket.recv()) - if not has_data: - return message_object - - async def list_datasets(self, category: Optional[DataCategory] = None) -> List[str]: - await self._async_acquire_session_info() - action = ManagementAction.LIST_ALL if category is None else ManagementAction.SEARCH - request = MaaSDatasetManagementMessage(session_secret=self.session_secret, action=action, category=category) - self.last_response = await self.async_make_request(request) - return self._parse_list_of_dataset_names_from_response(self.last_response) + ResultIndicator + A result indicator indicating whether downloading was successful. + """ + if not dest_dir.exists(): + return BasicResultIndicator(success=False, reason="No Dest Dir", message=f"'{dest_dir!s}' doesn't exist") + elif not dest_dir.is_dir(): + return BasicResultIndicator(success=False, reason="Bad Dest", message=f"Non-dir '{dest_dir!s}' exists") + elif not await self.does_dataset_exist(dataset_name=dataset_name): + return BasicResultIndicator(success=False, reason="Dataset Does Not Exist", + message=f"No existing dataset '{dataset_name}' was found") + + if not item_names: + item_names = await self.list_dataset_items(dataset_name) + else: + unrecognized = [i for i in item_names if i not in set(await self.list_dataset_items(dataset_name))] + if unrecognized: + return BasicResultIndicator(success=False, reason="Can't Get Unrecognized Items", data=unrecognized) + + failed_items: Dict[str, DatasetManagementResponse] = dict() + # TODO: see if we can perhaps have multiple agents and thread pool if multiplexing is available + tx_agent = SimpleDataTransferAgent(transport_client=self._transport_client, auth_client=self._auth_client) + + for i in item_names: + result = await tx_agent.download_dataset_item(dataset_name=dataset_name, item_name=i, + dest=dest_dir.joinpath(i)) + if not result.success: + failed_items[i] = result + + if len(failed_items) == 0: + return BasicResultIndicator(success=True, reason="Retrieval Complete") + else: + return BasicResultIndicator(success=False, reason=f"{len(failed_items)!s} Failures", + data=failed_items) - async def upload_to_dataset(self, dataset_name: str, paths: List[Path]) -> bool: + async def upload_to_dataset(self, dataset_name: str, paths: Union[Path, List[Path]], + data_root: Optional[Path] = None, **kwargs) -> ResultIndicator: """ Upload data a dataset. + A ``data_root`` param can optionally be supplied to adjust uploaded item names. E.g., if ``paths`` contains + the file ``/home/username/data_dir_1/file_1``, then by default its contents will be uploaded to the dataset item + named "/home/username/data_dir_1/file_1". However, if ``data_root`` is set to, e.g., + ``/home/username/data_dir_1``, then the uploaded item will instead be named simply "file_1". + Parameters ---------- dataset_name : str The name of the dataset. - paths : List[Path] - List of one or more paths of files to upload or directories containing files to upload. + paths : Union[Path, List[Path]] + Path or list of paths of files to upload. + data_root : Optional[Path] + A relative data root directory, used to adjust the names for uploaded items. Returns ------- - bool - Whether uploading was successful + ResultIndicator + An indicator of whether uploading was successful """ - # Don't do anything if any paths are bad - if len([p for p in paths if not p.exists()]) > 0: - raise RuntimeError('Upload failed due to invalid non-existing paths being received') + # TODO: see if we can perhaps have multiple agents and thread pool if multiplexing is available + tx_agent = SimpleDataTransferAgent(transport_client=self._transport_client, auth_client=self._auth_client) + if isinstance(paths, Path): + paths = [paths] + + not_exist = list() + not_file = list() - success = True - # For all individual files for p in paths: - if p.is_file(): - success = success and await self._upload_file(dataset_name=dataset_name, path=p, item_name=p.name) - else: - success = success and await self._upload_dir(dataset_name=dataset_name, dir_path=p) - return success + if not p.exists(): + not_exist.append(p) + elif not p.is_file(): + not_file.append(p) - @property - def errors(self): - # TODO: think about this more - return self._errors + if len(not_exist) > 0: + return BasicResultIndicator(success=False, reason="Non-Existing Upload Paths", data=not_exist) + elif len(not_file) > 0: + return BasicResultIndicator(success=False, reason="Non-File Upload Paths", data=not_file) - @property - def info(self): - # TODO: think about this more - return self._info + items = {str(p): p for p in paths} if data_root is None else {str(p.relative_to(data_root)): p for p in paths} + + failed_items = dict() + + for name, file in items.items(): + response = await tx_agent.upload_dataset_item(dataset_name=dataset_name, item_name=name, source=file) + if not response.success: + failed_items[name] = response + + if len(failed_items) == 0: + return BasicResultIndicator(success=True, reason="Upload Complete", message=f"{len(paths)!s} items") + else: + return BasicResultIndicator(success=True, reason=f"{len(failed_items)!s} Failed Uploads", data=failed_items) @property - def warnings(self): - # TODO: think about this more - return self._warnings \ No newline at end of file + def uses_auth(self) -> bool: + """ + Whether this particular client instance uses auth when interacting with the service. + + Clients that use auth + + Returns + ------- + bool + Whether this particular client instance uses auth when interacting with the service. + """ + return self._auth_client is not None diff --git a/python/lib/client/dmod/test/test_client_config.py b/python/lib/client/dmod/test/test_client_config.py new file mode 100644 index 000000000..372881038 --- /dev/null +++ b/python/lib/client/dmod/test/test_client_config.py @@ -0,0 +1,34 @@ +import unittest +from ..client.client_config import ClientConfig, ConnectionConfig +from pathlib import Path +from typing import Dict + + +class TestClientConfig(unittest.TestCase): + + def setUp(self) -> None: + self._test_config_files: Dict[int, Path] = dict() + self._test_configs: Dict[int, ClientConfig] = dict() + + # Example 0 + ex_idx = 0 + self._test_config_files[ex_idx] = Path(__file__).parent.joinpath("testing_config.json") + self._test_configs[ex_idx] = ClientConfig.parse_file(self._test_config_files[ex_idx]) + + def tearDown(self) -> None: + pass + + def test_request_service_0_a(self): + ex_idx = 0 + cfg_obj = self._test_configs[ex_idx] + + self.assertIsInstance(cfg_obj.request_service, ConnectionConfig) + + def test_request_service_0_b(self): + ex_idx = 0 + cfg_obj = self._test_configs[ex_idx] + + self.assertEqual(cfg_obj.request_service.endpoint_protocol, "wss") + + + diff --git a/python/lib/client/dmod/test/test_dataset_client.py b/python/lib/client/dmod/test/test_dataset_client.py index b266658c7..38d4e47ec 100644 --- a/python/lib/client/dmod/test/test_dataset_client.py +++ b/python/lib/client/dmod/test/test_dataset_client.py @@ -1,46 +1,48 @@ import unittest -from ..client.request_clients import DataCategory, DatasetClient, DatasetManagementResponse, MaaSDatasetManagementResponse +from ..client.request_clients import (DataCategory, DataDomain, DataServiceClient, DatasetManagementResponse, + MaaSDatasetManagementResponse, ManagementAction, ResultIndicator) from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Sequence, Union -class SimpleMockDatasetClient(DatasetClient): +class SimpleMockDataServiceClient(DataServiceClient): """ - Mock subtype, primarily for testing base implementation of ::method:`_parse_list_of_dataset_names_from_response`. + Mock subtype, primarily for testing base implementation of ::method:`extract_dataset_names`. """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + #super().__init__(*args, **kwargs) + pass - async def create_dataset(self, name: str, category: DataCategory) -> bool: - """ Mock implementation, always returning ``False``. """ - return False - - async def delete_dataset(self, name: str, **kwargs) -> bool: - """ Mock implementation, always returning ``False``. """ - return False + async def create_dataset(self, name: str, category: DataCategory, domain: DataDomain, + upload_paths: Optional[List[Path]] = None, data_root: Optional[Path] = None, + **kwargs) -> DatasetManagementResponse: + """ Mock implementation, always returning an unsuccessful result. """ + return DatasetManagementResponse(success=False, action=ManagementAction.CREATE, reason="Mock") - async def download_dataset(self, dataset_name: str, dest_dir: Path) -> bool: + async def delete_dataset(self, name: str, **kwargs) -> DatasetManagementResponse: """ Mock implementation, always returning ``False``. """ - return False + return DatasetManagementResponse(success=False, action=ManagementAction.DELETE, reason="Mock") - async def download_from_dataset(self, dataset_name: str, item_name: str, dest: Path) -> bool: + async def retrieve_from_dataset(self, dataset_name: str, dest_dir: Path, + item_names: Optional[Union[str, Sequence[str]]] = None, **kwargs) -> ResultIndicator: """ Mock implementation, always returning ``False``. """ - return False + return DatasetManagementResponse(success=False, action=ManagementAction.REQUEST_DATA, reason="Mock") - async def list_datasets(self, category: Optional[DataCategory] = None) -> List[str]: + async def list_datasets(self, category: Optional[DataCategory] = None, **kwargs) -> List[str]: """ Mock implementation, always returning an empty list. """ return [] - async def upload_to_dataset(self, dataset_name: str, paths: List[Path]) -> bool: + async def upload_to_dataset(self, dataset_name: str, paths: Union[Path, List[Path]], + data_root: Optional[Path] = None, **kwargs) -> ResultIndicator: """ Mock implementation, always returning ``False``. """ - return False + return DatasetManagementResponse(success=False, action=ManagementAction.ADD_DATA, reason="Mock") class TestDatasetClient(unittest.TestCase): def setUp(self) -> None: - self.client = SimpleMockDatasetClient() + self.client = SimpleMockDataServiceClient() self.example_responses = [] self.example_dataset_names_lists = [] @@ -69,42 +71,42 @@ def setUp(self) -> None: {'success': True, 'reason': 'List Assembled', 'message': '', 'data': {'datasets': list(dataset_names_list), 'action': 'LIST_ALL', 'is_awaiting': False}})) - def test__parse_list_of_dataset_names_from_response_0_a(self): + def test_extract_dataset_names_0_a(self): """ Test example 0 with base ::class:`DatasetManagementResponse` and empty list is parsed correctly """ ex_indx = 0 expected_names = self.example_dataset_names_lists[ex_indx] response = self.example_responses[ex_indx] - dataset_names = self.client._parse_list_of_dataset_names_from_response(response) + dataset_names = self.client.extract_dataset_names(response) self.assertEqual(expected_names, dataset_names) - def test__parse_list_of_dataset_names_from_response_1_a(self): + def test_extract_dataset_names_1_a(self): """ Test example 1 with base ::class:`DatasetManagementResponse` and non-empty list is parsed correctly """ ex_indx = 1 expected_names = self.example_dataset_names_lists[ex_indx] response = self.example_responses[ex_indx] - dataset_names = self.client._parse_list_of_dataset_names_from_response(response) + dataset_names = self.client.extract_dataset_names(response) self.assertEqual(expected_names, dataset_names) - def test__parse_list_of_dataset_names_from_response_2_a(self): + def test_extract_dataset_names_2_a(self): """ Test example 2 with subtype ::class:`MaaSDatasetManagementResponse` and empty list is parsed correctly """ ex_indx = 2 expected_names = self.example_dataset_names_lists[ex_indx] response = self.example_responses[ex_indx] - dataset_names = self.client._parse_list_of_dataset_names_from_response(response) + dataset_names = self.client.extract_dataset_names(response) self.assertEqual(expected_names, dataset_names) - def test__parse_list_of_dataset_names_from_response_3_a(self): + def test_extract_dataset_names_3_a(self): """ Test example 3 w/ subtype ::class:`MaaSDatasetManagementResponse` and non-empty list is parsed correctly """ ex_indx = 3 expected_names = self.example_dataset_names_lists[ex_indx] response = self.example_responses[ex_indx] - dataset_names = self.client._parse_list_of_dataset_names_from_response(response) + dataset_names = self.client.extract_dataset_names(response) self.assertEqual(expected_names, dataset_names) diff --git a/python/lib/client/dmod/test/test_dmod_client.py b/python/lib/client/dmod/test/test_dmod_client.py new file mode 100644 index 000000000..e6c16a154 --- /dev/null +++ b/python/lib/client/dmod/test/test_dmod_client.py @@ -0,0 +1,30 @@ +import unittest +from ..client.client_config import ClientConfig +from ..client.dmod_client import DmodClient +from ..client.request_clients import JobClient +from pathlib import Path +from typing import Dict + + +class TestDmodClient(unittest.TestCase): + + def setUp(self) -> None: + self._test_config_files: Dict[int, Path] = dict() + self._test_configs: Dict[int, ClientConfig] = dict() + self._test_clients: Dict[int, DmodClient] = dict() + + # Example 0 + ex_idx = 0 + self._test_config_files[ex_idx] = Path(__file__).parent.joinpath("testing_config.json") + self._test_configs[ex_idx] = ClientConfig.parse_file(self._test_config_files[ex_idx]) + self._test_clients[ex_idx] = DmodClient(client_config=self._test_configs[ex_idx]) + + def tearDown(self) -> None: + pass + + def test_job_client_0_a(self): + """ Make sure a valid job client is initialized for config example 0. """ + ex_idx = 0 + client = self._test_clients[ex_idx] + + self.assertIsInstance(client.job_client, JobClient) diff --git a/python/lib/client/dmod/test/testing_config.json b/python/lib/client/dmod/test/testing_config.json new file mode 100644 index 000000000..b2a09e5e0 --- /dev/null +++ b/python/lib/client/dmod/test/testing_config.json @@ -0,0 +1,14 @@ +{ + "request-service": { + "protocol": "wss", + "hostname": "127.0.0.1", + "port": 18012, + "ssl-dir": "/home/user/dmod/ssl/requestservice" + }, + "data-service": { + "protocol": "https", + "hostname": "127.0.0.1", + "port": 18015, + "ssl-dir": "/home/user/dmod/ssl/dataservice" + } +} \ No newline at end of file diff --git a/python/lib/client/setup.py b/python/lib/client/setup.py index 3052639b1..90baf658b 100644 --- a/python/lib/client/setup.py +++ b/python/lib/client/setup.py @@ -22,6 +22,7 @@ license='', include_package_data=True, #install_requires=['websockets', 'jsonschema'],vi - install_requires=['dmod-core>=0.1.0', 'websockets>=8.1', 'pyyaml', 'dmod-communication>=0.11.0', 'dmod-externalrequests>=0.3.0'], + install_requires=['dmod-core>=0.11.0', 'websockets>=8.1', 'pydantic>=1.10.8,~=1.10', 'dmod-communication>=0.16.0', + 'dmod-externalrequests>=0.3.0'], packages=find_namespace_packages(include=['dmod.*'], exclude=['dmod.test']) ) diff --git a/python/lib/communication/dmod/communication/_version.py b/python/lib/communication/dmod/communication/_version.py index 00d1ab54f..8911e95ca 100644 --- a/python/lib/communication/dmod/communication/_version.py +++ b/python/lib/communication/dmod/communication/_version.py @@ -1 +1 @@ -__version__ = '0.15.2' +__version__ = '0.16.0' diff --git a/python/lib/communication/dmod/communication/client.py b/python/lib/communication/dmod/communication/client.py index dd05b1cec..f99e50120 100644 --- a/python/lib/communication/dmod/communication/client.py +++ b/python/lib/communication/dmod/communication/client.py @@ -7,10 +7,12 @@ from asyncio import AbstractEventLoop from deprecated import deprecated from pathlib import Path -from typing import Generic, Optional, Type, TypeVar, Union +from typing import Generic, List, Optional, Type, TypeVar, Union import websockets +from dmod.core.exception import DmodRuntimeError + from .maas_request import ExternalRequest, ExternalRequestResponse from .message import AbstractInitRequest, Response from .partition_request import PartitionResponse @@ -26,6 +28,9 @@ CONN = TypeVar("CONN") +ResponseTypes = Union[Type[Response], List[Type[Response]]] + + def get_or_create_eventloop() -> AbstractEventLoop: """ Retrieves an async event loop @@ -536,17 +541,17 @@ def __init__(self, *, self._transport_client = transport_client self._default_response_type: Optional[Type[Response]] = default_response_type - def _process_request_response(self, response_str: str, response_type: Optional[Type[Response]] = None) -> Response: + def _process_request_response(self, response_str: str, response_type: Optional[ResponseTypes] = None) -> Response: """ - Process the serial form of a response returned by ::method:`async_send` into a response object. + Process the serial form of a response into a response object. Parameters ---------- response_str : str The string returned by a request made via ::method:`async_send`. - response_type: Optional[Type[Response]] - An optional class type for the response that, if ``None`` (the default) is replaced with the default - provided at initialization. + response_type: Optional[ResponseTypes] + One or more optional class types for the response that, if ``None`` (the default) is replaced with the + default provided at initialization. Returns ------- @@ -558,35 +563,33 @@ def _process_request_response(self, response_str: str, response_type: Optional[T async_send """ if response_type is None: - response_type = self._default_response_type + response_type = [self._default_response_type] + elif not isinstance(response_type, list): + response_type = [response_type] - response_json = {} try: # Consume the response confirmation by deserializing first to JSON, then from this to a response object response_json = json.loads(response_str) - try: - response_object = response_type.factory_init_from_deserialized_json(response_json) - if response_object is None: - msg = f'********** {self.__class__.__name__} could not deserialize {response_type.__name__} ' \ - f'from raw websocket response: `{str(response_str)}`' - reason = f'{self.__class__.__name__} Could Not Deserialize To {response_type.__name__}' - response_object = response_type(success=False, reason=reason, message=msg, data=response_json) - except Exception as e2: - msg = f'********** While deserializing {response_type.__name__}, {self.__class__.__name__} ' \ - f'encountered {e2.__class__.__name__}: {str(e2)}' - reason = f'{self.__class__.__name__} {e2.__class__.__name__} Deserialize {response_type.__name__}' - response_object = response_type(success=False, reason=reason, message=msg, data=response_json) except Exception as e: - reason = 'Invalid JSON Response' - msg = f'Encountered {e.__class__.__name__} loading response to JSON: {str(e)}' - response_object = response_type(success=False, reason=reason, message=msg, data=response_json) - - if not response_object.success: - logging.error(response_object.message) - logging.debug(f'{self.__class__.__name__} returning {str(response_type)} {response_str}') + raise DmodRuntimeError(f"{self.__class__.__name__} could not parse JSON due to {e.__class__.__name__} " + f"({e!s}); raw response was: `{response_str}`") + response_object = None + try: + for t in response_type: + response_object = t.factory_init_from_deserialized_json(response_json) + if response_object is not None: + break + except Exception as e2: + raise DmodRuntimeError( + f'{e2.__class__.__name__} for {self.__class__.__name__} deserializing {t.__name__}: {str(e2)}') + + if response_object is None: + raise DmodRuntimeError(f"{self.__class__.__name__} could not deserialize to any of " + f"{','.join([r.__name__ for r in response_type])} from raw websocket response: " + f"`{response_str}`") return response_object - async def async_make_request(self, message: AbstractInitRequest, response_type: Optional[Type[Response]] = None) -> Response: + async def async_make_request(self, message: AbstractInitRequest, response_type: Optional[ResponseTypes] = None) -> Response: """ Async send a request message object and return the received response. @@ -597,9 +600,9 @@ async def async_make_request(self, message: AbstractInitRequest, response_type: ---------- message : AbstractInitRequest The request message object. - response_type: Optional[Type[Response]] - An optional class type for the response that, if ``None`` (the default) is replaced with the default - provided at initialization. + response_type: Optional[ResponseTypes] + One or more optional class types for the response that, if ``None`` (the default) is replaced with the + default provided at initialization. Returns ------- @@ -613,20 +616,13 @@ async def async_make_request(self, message: AbstractInitRequest, response_type: else: response_type = self._default_response_type - response_json = {} - try: - # Send the request and get the service response - serialized_response = await self._transport_client.async_send(data=str(message), await_response=True) - if serialized_response is None: - raise ValueError(f'Serialized response from {self.__class__.__name__} async message was `None`') - except Exception as e: - reason = f'{self.__class__.__name__} Send {message.__class__.__name__} Failure ({e.__class__.__name__})' - msg = f'Sending {message.__class__.__name__} raised {e.__class__.__name__}: {str(e)}' - logger.error(msg) - return response_type(success=False, reason=reason, message=msg, data=response_json) + # Send the request and get the service response + serialized_response = await self._transport_client.async_send(data=str(message), await_response=True) + if serialized_response is None: + raise ValueError(f'Serialized response from {self.__class__.__name__} async message was `None`') assert isinstance(serialized_response, str) - return self._process_request_response(serialized_response) + return self._process_request_response(serialized_response, response_type) class ConnectionContextClient(Generic[CONN], TransportLayerClient, ABC): @@ -775,7 +771,7 @@ async def async_recv(self) -> Union[str, bytes]: Union[str, bytes] The data received over the connection. """ - with self as connection_owner: + async with self as connection_owner: return await connection_owner._connection_recv() @property diff --git a/python/lib/communication/dmod/communication/data_transmit_message.py b/python/lib/communication/dmod/communication/data_transmit_message.py index ec861498e..c923f7040 100644 --- a/python/lib/communication/dmod/communication/data_transmit_message.py +++ b/python/lib/communication/dmod/communication/data_transmit_message.py @@ -2,7 +2,7 @@ from .message import AbstractInitRequest, MessageEventType, Response from pydantic import Field from typing import ClassVar, Type, Union -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias from uuid import UUID @@ -47,6 +47,30 @@ class DataTransmitResponse(Response): series of which it is a part. """ + @classmethod + def create_for_received(cls, received_msg: DataTransmitMessage, success: bool = True, + reason: str = 'Data Received', message: str = "") -> Self: + """ + Create an appropriate response object that corresponds to the received incoming message. + + Parameters + ---------- + received_msg : DataTransmitMessage + The received transmit message for which a response needs to be generated. + success : bool + The ``success`` value for the created response (``True`` by default). + reason : str + The ``reason`` value for the created response ("Data Received" by default). + message : str + The ``message`` value for the created response ("" by default). + + Returns + ------- + Self + The generated response object. + """ + return cls(series_uuid=received_msg.series_uuid, success=success, reason=reason, message=message) + response_to_type: ClassVar[Type[AbstractInitRequest]] = DataTransmitMessage data: DataTransmitResponseBody diff --git a/python/lib/communication/dmod/communication/dataset_management_message.py b/python/lib/communication/dmod/communication/dataset_management_message.py index 476454c13..bdbcc052b 100644 --- a/python/lib/communication/dmod/communication/dataset_management_message.py +++ b/python/lib/communication/dmod/communication/dataset_management_message.py @@ -413,10 +413,7 @@ class Config: @classmethod def factory_create(cls, mgmt_msg: DatasetManagementMessage, session_secret: str) -> 'MaaSDatasetManagementMessage': - return cls(session_secret=session_secret, action=mgmt_msg.management_action, dataset_name=mgmt_msg.dataset_name, - is_read_only_dataset=mgmt_msg.is_read_only_dataset, category=mgmt_msg.data_category, - domain=mgmt_msg.data_domain, data_location=mgmt_msg.data_location, - is_pending_data=mgmt_msg.is_pending_data) + return cls(session_secret=session_secret, **mgmt_msg.to_dict()) @classmethod def factory_init_correct_response_subtype(cls, json_obj: dict) -> 'MaaSDatasetManagementResponse': diff --git a/python/lib/communication/setup.py b/python/lib/communication/setup.py index 56ac781c5..cea733fec 100644 --- a/python/lib/communication/setup.py +++ b/python/lib/communication/setup.py @@ -21,7 +21,7 @@ url='', license='', include_package_data=True, - install_requires=['dmod-core>=0.10.0', 'websockets>=8.1', 'jsonschema', 'redis', 'pydantic>=1.10.8,~=1.10', + install_requires=['dmod-core>=0.11.0', 'websockets>=8.1', 'jsonschema', 'redis', 'pydantic>=1.10.8,~=1.10', 'Deprecated', 'ngen-config@git+https://github.com/noaa-owp/ngen-cal@master#egg=ngen-config&subdirectory=python/ngen_conf'], packages=find_namespace_packages(include=['dmod.*'], exclude=['dmod.test']) diff --git a/python/lib/core/dmod/core/_version.py b/python/lib/core/dmod/core/_version.py index 85b551d30..121d6890a 100644 --- a/python/lib/core/dmod/core/_version.py +++ b/python/lib/core/dmod/core/_version.py @@ -1 +1 @@ -__version__ = '0.10.2' +__version__ = '0.11.0' \ No newline at end of file diff --git a/python/lib/core/dmod/core/serializable.py b/python/lib/core/dmod/core/serializable.py index 7ed8e486b..02da62148 100644 --- a/python/lib/core/dmod/core/serializable.py +++ b/python/lib/core/dmod/core/serializable.py @@ -1,7 +1,7 @@ from abc import ABC from numbers import Number from enum import Enum -from typing import Any, Callable, ClassVar, Dict, Type, TypeVar, TYPE_CHECKING, Union, Optional +from typing import Any, Callable, ClassVar, Dict, List, Type, TypeVar, TYPE_CHECKING, Union, Optional from typing_extensions import Self, TypeAlias from pydantic import BaseModel, Field from functools import lru_cache @@ -25,6 +25,8 @@ SelfFieldSerializer: TypeAlias = Callable[[M, T], R] FieldSerializer = Union[SelfFieldSerializer[M, Any], FnSerializer[Any]] +SimpleData = Union[int, float, bool, str] + class Serializable(BaseModel, ABC): """ @@ -367,9 +369,11 @@ class ResultIndicator(Serializable, ABC): class BasicResultIndicator(ResultIndicator): """ - Bare-bones, concrete implementation of ::class:`ResultIndicator`. + Bare-bones, concrete implementation of ::class:`ResultIndicator` that also supports carrying simple data. """ + data: Optional[Union[SimpleData, Dict[str, SimpleData], List[SimpleData]]] + # NOTE: function below are intentionally not methods on `Serializable` to avoid subclasses # overriding their behavior. diff --git a/python/lib/core/dmod/test/test_basic_result_indicator.py b/python/lib/core/dmod/test/test_basic_result_indicator.py new file mode 100644 index 000000000..a5207cd7c --- /dev/null +++ b/python/lib/core/dmod/test/test_basic_result_indicator.py @@ -0,0 +1,91 @@ +import unittest +from ..core.serializable import BasicResultIndicator + + +class TestBasicResultIndicator(unittest.TestCase): + + def setUp(self) -> None: + self.ex_objs = [] + + # Example 0: Successful, with no message and no data + self.ex_objs.append(BasicResultIndicator(success=True, reason="Successful")) + + # Example 1: Successful, with message but no data + self.ex_objs.append(BasicResultIndicator(success=True, reason="Successful", message="This worked")) + + # Example 2: Successful, with message and list of ints in data + data_item = list(range(5)) + self.ex_objs.append(BasicResultIndicator(success=True, reason="Successful", message="This worked", + data=data_item)) + + # Example 3: Failed, with message and dict of int values in data (keys are value as string, prefixed by "i-") + data_item = {f"i-{i!s}": i for i in range(5)} + self.ex_objs.append(BasicResultIndicator(success=True, reason="Successful", message="This worked", + data=data_item)) + + # Example 4: Successful, with message and list of floats in data + data_item = [0.0, 1.0] + self.ex_objs.append(BasicResultIndicator(success=True, reason="Successful", message="This worked", + data=data_item)) + + # Example 5: Successful, with message and list of strings in data + data_item = ["one", "two"] + self.ex_objs.append(BasicResultIndicator(success=True, reason="Successful", message="This worked", + data=data_item)) + + # Example 6: Successful, with message and list of bools in data + data_item = [True, False] + self.ex_objs.append(BasicResultIndicator(success=True, reason="Successful", message="This worked", + data=data_item)) + + def tearDown(self) -> None: + pass + + def test_data_0_a(self): + """ Test that an object that is not initialized with a ``data`` param gets a ``None`` value for it. """ + ex_idx = 0 + obj = self.ex_objs[ex_idx] + + self.assertIsNone(obj.data) + + def test_data_2_a(self): + """ Test that an object with a ``data`` param has it. """ + ex_idx = 2 + obj = self.ex_objs[ex_idx] + + self.assertIsInstance(obj.data, list) + + def test_data_2_b(self): + """ Test that an object with a ``data`` param has expected ``int`` values. """ + ex_idx = 2 + obj = self.ex_objs[ex_idx] + + self.assertEqual(obj.data, [0, 1, 2, 3, 4]) + + def test_data_3_b(self): + """ Test that an object with a ``data`` diction param has expected values. """ + ex_idx = 3 + obj = self.ex_objs[ex_idx] + + self.assertEqual(obj.data, {"i-0": 0, "i-1": 1, "i-2": 2, "i-3": 3, "i-4": 4}) + + def test_data_4_b(self): + """ Test that an object with a ``data`` param has expected float values. """ + ex_idx = 4 + obj = self.ex_objs[ex_idx] + + self.assertEqual(obj.data, [0.0, 1.0]) + + def test_data_5_b(self): + """ Test that an object with a ``data`` param has expected string values. """ + ex_idx = 5 + obj = self.ex_objs[ex_idx] + + self.assertEqual(obj.data, ["one", "two"]) + + def test_data_6_b(self): + """ Test that an object with a ``data`` param has expected bool values. """ + ex_idx = 6 + obj = self.ex_objs[ex_idx] + + self.assertEqual(obj.data, [True, False]) diff --git a/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py b/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py index 913d9a45b..c2c1c957b 100644 --- a/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py +++ b/python/lib/externalrequests/dmod/externalrequests/maas_request_handlers.py @@ -11,6 +11,7 @@ from dmod.communication.dataset_management_message import MaaSDatasetManagementMessage, MaaSDatasetManagementResponse, \ ManagementAction from dmod.communication.data_transmit_message import DataTransmitMessage, DataTransmitResponse +from dmod.core.exception import DmodRuntimeError from pathlib import Path from typing import Optional, Tuple @@ -269,48 +270,51 @@ async def determine_required_access_types(self, request: MaaSDatasetManagementMe # FIXME: for now, just use the default type (which happens to be "everything") return self._default_required_access_type, - async def _handle_data_download(self, client_websocket, service_websocket) -> MaaSDatasetManagementResponse: + async def _handle_data_download(self, download_request: MaaSDatasetManagementMessage, client_websocket) -> MaaSDatasetManagementResponse: series_uuid = None + # This might be data transmission, or it might be a management response message + possible_responses = [MaaSDatasetManagementResponse, DataTransmitMessage] + service_response = self.service_client.async_make_request(download_request, possible_responses) while True: - # This might be data transmission, or it might be a management response message - raw_service_response = await service_websocket.recv() - service_response_json = json.loads(raw_service_response) - mgmt_response = MaaSDatasetManagementResponse.factory_init_from_deserialized_json(service_response_json) - if mgmt_response is not None: - return mgmt_response - data_transmit_msg = DataTransmitMessage.factory_init_from_deserialized_json(service_response_json) + if isinstance(service_response, MaaSDatasetManagementResponse): + return service_response + + assert isinstance(service_response, DataTransmitMessage) + if series_uuid is None: - series_uuid = data_transmit_msg.series_uuid - elif data_transmit_msg.series_uuid != series_uuid: - raise RuntimeError("Data series UUID for data transmit does not match expected.") - await client_websocket.send(raw_service_response) + series_uuid = service_response.series_uuid + elif service_response.series_uuid != series_uuid: + raise DmodRuntimeError("Data series UUID for data transmit does not match expected.") + + await client_websocket.send(str(service_response)) raw_client_response = await client_websocket.recv() data_response = DataTransmitResponse.factory_init_from_deserialized_json(json.loads(raw_client_response)) if data_response.series_uuid != series_uuid: raise RuntimeError("Data series UUID for data receipt does not match expected.") - await service_websocket.send(raw_client_response) + service_response = self.service_client.async_make_request(data_response, possible_responses) - async def _handle_data_upload(self, client_websocket, service_websocket) -> MaaSDatasetManagementResponse: + async def _handle_data_upload(self, upload_request: MaaSDatasetManagementMessage, client_websocket, service_websocket) -> MaaSDatasetManagementResponse: series_uuid = None + # This might be DataTransmitResponse, or it might be a management response message + possible_responses = [MaaSDatasetManagementResponse, DataTransmitResponse] + service_response = self.service_client.async_make_request(upload_request, possible_responses) while True: - # Await a DataTransmitResponse with success indicating ready to receive - # TODO: update Data service to do this - raw_service_response = await service_websocket.recv() - service_response_json = json.loads(raw_service_response) - mgmt_response = MaaSDatasetManagementResponse.factory_init_from_deserialized_json(service_response_json) - if mgmt_response is not None: - return mgmt_response - data_transmit_response = DataTransmitResponse.factory_init_from_deserialized_json(service_response_json) + if isinstance(service_response, MaaSDatasetManagementResponse): + return service_response + + assert isinstance(service_response, DataTransmitResponse) + if series_uuid is None: - series_uuid = data_transmit_response.series_uuid - elif data_transmit_response.series_uuid != series_uuid: - raise RuntimeError("Data series UUID for data upload response does not match expected.") - await client_websocket.send(raw_service_response) + series_uuid = service_response.series_uuid + elif service_response.series_uuid != series_uuid: + raise DmodRuntimeError("Data series UUID for data upload response does not match expected.") + + await client_websocket.send(str(service_response)) raw_client_response = await client_websocket.recv() data_transmit_msg = DataTransmitMessage.factory_init_from_deserialized_json(json.loads(raw_client_response)) if data_transmit_msg.series_uuid != series_uuid: - raise RuntimeError("Data series UUID for data upload transmit does not match expected.") - await service_websocket.send(raw_client_response) + raise RuntimeError("Data series UUID for data upload does not match expected.") + service_response = self.service_client.async_make_request(data_transmit_msg, possible_responses) async def handle_request(self, request: MaaSDatasetManagementMessage, **kwargs) -> MaaSDatasetManagementResponse: # Need receiver websocket (i.e. DMOD client side) as kwarg @@ -318,19 +322,16 @@ async def handle_request(self, request: MaaSDatasetManagementMessage, **kwargs) if not is_authorized: return MaaSDatasetManagementResponse(success=False, reason=reason.name, message=msg) # In this case, we actually can pass the request as-is straight through (i.e., after confirming authorization) - async with self.service_client as client: - # Have to handle these two slightly differently, since multiple message will be going over the websocket - if request.management_action == ManagementAction.REQUEST_DATA: - await client.connection.send(str(request)) - mgmt_response = await self._handle_data_download(client_websocket=kwargs['upstream_websocket'], - service_websocket=client.connection) - elif request.management_action == ManagementAction.ADD_DATA: - await client.connection.send(str(request)) - mgmt_response = await self._handle_data_upload(client_websocket=kwargs['upstream_websocket'], - service_websocket=client.connection) - else: - mgmt_response = await client.async_make_request(request) - logging.debug("************* {} received response:\n{}".format(self.__class__.__name__, str(mgmt_response))) + # Have to handle these two slightly differently, since multiple message will be going over the websocket + if request.management_action == ManagementAction.REQUEST_DATA: + mgmt_response = await self._handle_data_download(download_request=request, + client_websocket=kwargs['upstream_websocket']) + elif request.management_action == ManagementAction.ADD_DATA: + mgmt_response = await self._handle_data_upload(upload_request=request, + client_websocket=kwargs['upstream_websocket']) + else: + mgmt_response = await self.service_client.async_make_request(request) + logging.debug("************* {} received response:\n{}".format(self.__class__.__name__, str(mgmt_response))) # Likewise, can just send back the response from the internal service client return MaaSDatasetManagementResponse.factory_create(mgmt_response) diff --git a/python/services/requestservice/dmod/requestservice/service.py b/python/services/requestservice/dmod/requestservice/service.py index daa14e338..ebbb38fe7 100755 --- a/python/services/requestservice/dmod/requestservice/service.py +++ b/python/services/requestservice/dmod/requestservice/service.py @@ -111,10 +111,11 @@ def __init__(self, listen_host='', # FIXME: implement real authorizer self.authorizer = self.authenticator - self._scheduler_client = SchedulerClient(transport_client=WebSocketClient(endpoint_host=self.scheduler_host, - endpoint_port=self.scheduler_port, - capath=self.scheduler_client_ssl_dir)) - """SchedulerClient: Client for interacting with scheduler, which also is a context manager for connections.""" + # TODO: make sure this isn't still needed (or shouldn't be re-added) + #self._scheduler_client = SchedulerClient(transport_client=WebSocketClient(endpoint_host=self.scheduler_host, + # endpoint_port=self.scheduler_port, + # capath=self.scheduler_client_ssl_dir)) + #"""SchedulerClient: Client for interacting with scheduler, which also is a context manager for connections.""" self._auth_handler: AuthHandler = AuthHandler(session_manager=self._session_manager, authenticator=self.authenticator,