diff --git a/aepsych/database/db.py b/aepsych/database/db.py index c9c9cc65d..421b859b2 100644 --- a/aepsych/database/db.py +++ b/aepsych/database/db.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import datetime +import io import json import logging import os @@ -440,12 +441,14 @@ def record_outcome( self._session.add(outcome_entry) self._session.commit() - def record_strat(self, master_table: tables.DBMasterTable, strat: Strategy) -> None: + def record_strat( + self, master_table: tables.DBMasterTable, strat: io.BytesIO + ) -> None: """Record a strategy in the database. Args: master_table (tables.DBMasterTable): The master table. - strat (Strategy): The strategy. + strat (BytesIO): The strategy in buffer form. """ strat_entry = tables.DbStratTable() strat_entry.strat = strat diff --git a/aepsych/server/server.py b/aepsych/server/server.py index d0a16ba92..30ca22335 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -1,28 +1,29 @@ #!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. and its affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - import argparse +import asyncio +import concurrent import io +import json import logging import os -import sys -import threading import traceback import warnings -from typing import Dict, Union +from typing import Any, Dict, List, Optional, Union -import aepsych.database.db as db -import aepsych.utils_logging as utils_logging import dill import numpy as np +import pandas as pd import torch -from aepsych import version +from aepsych import utils_logging, version +from aepsych.config import Config +from aepsych.database import db +from aepsych.database.tables import DBMasterTable from aepsych.server.message_handlers import MESSAGE_MAP -from aepsych.server.message_handlers.handle_ask import ask from aepsych.server.message_handlers.handle_setup import configure from aepsych.server.replay import ( get_dataframe_from_replay, @@ -30,12 +31,9 @@ get_strats_from_replay, replay, ) -from aepsych.server.sockets import BAD_REQUEST, DummySocket, PySocket -from aepsych.utils import promote_0d +from aepsych.strategy import SequentialStrategy, Strategy logger = utils_logging.getLogger(logging.INFO) -DEFAULT_DESC = "default description" -DEFAULT_NAME = "default name" def get_next_filename(folder, fname, ext): @@ -44,191 +42,84 @@ def get_next_filename(folder, fname, ext): return f"{folder}/{fname}_{n + 1}.{ext}" -class AEPsychServer(object): - def __init__(self, socket=None, database_path=None): - """Server for doing black box optimization using gaussian processes. - Keyword Arguments: - socket -- socket object that implements `send` and `receive` for json - messages (default: DummySocket()). - TODO actually make an abstract interface to subclass from here - """ - if socket is None: - self.socket = DummySocket() - else: - self.socket = socket - self.db = None +class AEPsychServer: + def __init__( + self, + host: str = "0.0.0.0", + port: int = 5555, + database_path: str = "./databases/default.db", + max_workers: Optional[int] = None, + ): + self.host = host + self.port = port + self.max_workers = max_workers + self.db: db.Database = db.Database(database_path) self.is_performing_replay = False self.exit_server_loop = False self._db_raw_record = None - self.db: db.Database = db.Database(database_path) self.skip_computations = False self.strat_names = None self.extensions = None + self._strats: List[SequentialStrategy] = [] + self._parnames: List[List[str]] = [] + self._configs: List[Config] = [] + self._master_records: List[DBMasterTable] = [] + self.strat_id = -1 + self.outcome_names: List[str] = [] if self.db.is_update_required(): self.db.perform_updates() - self._strats = [] - self._parnames = [] - self._configs = [] - self._master_records = [] - self.strat_id = -1 - self._pregen_asks = [] - self.enable_pregen = False - self.outcome_names = [] - - self.debug = False - self.receive_thread = threading.Thread( - target=self._receive_send, args=(self.exit_server_loop,), daemon=True - ) - - self.queue = [] - - def cleanup(self): - """Close the socket and terminate connection to the server. - - Returns: - None - """ - self.socket.close() - - def _receive_send(self, is_exiting: bool) -> None: - """Receive messages from the client. - - Args: - is_exiting (bool): True to terminate reception of new messages from the client, False otherwise. - - Returns: - None - """ - while True: - request = self.socket.receive(is_exiting) - if request != BAD_REQUEST: - self.queue.append(request) - if self.exit_server_loop: - break - logger.info("Terminated input thread") - - def _handle_queue(self) -> None: - """Handles the queue of messages received by the server. - - Returns: - None - """ - if self.queue: - request = self.queue.pop(0) - try: - result = self.handle_request(request) - except Exception as e: - error_message = f"Request '{request}' raised error '{e}'!" - result = f"server_error, {error_message}" - logger.error(f"{error_message}! Full traceback follows:") - logger.error(traceback.format_exc()) - self.socket.send(result) - else: - if self.can_pregen_ask and (len(self._pregen_asks) == 0): - self._pregen_asks.append(ask(self)) - - def serve(self) -> None: - """Run the server. Note that all configuration outside of socket type and port - happens via messages from the client. The server simply forwards messages from - the client to its `setup`, `ask` and `tell` methods, and responds with either - acknowledgment or other response as needed. To understand the server API, see - the docs on the methods in this class. - - Returns: - None - - Raises: - RuntimeError: if a request from a client has no request type - RuntimeError: if a request from a client has no known request type - TODO make things a little more robust to bad messages from client; this - requires resetting the req/rep queue status. - - """ - logger.info("Server up, waiting for connections!") - logger.info("Ctrl-C to quit!") - # yeah we're not sanitizing input at all - - # Start the method to accept a client connection - self.socket.accept_client() - self.receive_thread.start() - while True: - self._handle_queue() - if self.exit_server_loop: - break - # Close the socket and terminate with code 0 - self.cleanup() - sys.exit(0) - - def _unpack_strat_buffer(self, strat_buffer): - if isinstance(strat_buffer, io.BytesIO): - strat = torch.load(strat_buffer, pickle_module=dill) - strat_buffer.seek(0) - elif isinstance(strat_buffer, bytes): - warnings.warn( - "Strat buffer is not in bytes format!" - + " This is a deprecated format, loading using dill.loads.", - DeprecationWarning, - ) - strat = dill.loads(strat_buffer) - else: - raise RuntimeError("Trying to load strat in unknown format!") - return strat - - ### Properties that are set on a per-strat basis + #### Properties #### @property - def strat(self): + def strat(self) -> Optional[SequentialStrategy]: if self.strat_id == -1: return None else: return self._strats[self.strat_id] @strat.setter - def strat(self, s): + def strat(self, s: SequentialStrategy): self._strats.append(s) @property - def config(self): + def config(self) -> Optional[Config]: if self.strat_id == -1: return None else: return self._configs[self.strat_id] @config.setter - def config(self, s): + def config(self, s: Config): self._configs.append(s) @property - def parnames(self): + def parnames(self) -> List[str]: if self.strat_id == -1: return [] else: return self._parnames[self.strat_id] @parnames.setter - def parnames(self, s): + def parnames(self, s: List[str]): self._parnames.append(s) @property - def _db_master_record(self): + def _db_master_record(self) -> Optional[DBMasterTable]: if self.strat_id == -1: return None else: return self._master_records[self.strat_id] @_db_master_record.setter - def _db_master_record(self, s): + def _db_master_record(self, s: DBMasterTable): self._master_records.append(s) @property - def n_strats(self): + def n_strats(self) -> int: return len(self._strats) - @property - def can_pregen_ask(self): - return self.strat is not None and self.enable_pregen - + #### Methods to handle parameter configs #### def _tensor_to_config(self, next_x): stim_per_trial = self.strat.stimuli_per_trial dim = self.strat.dim @@ -280,8 +171,11 @@ def _config_to_tensor(self, config): return x - def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]): + def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]) -> Dict[int, Any]: # Given a dictionary of fixed parameters, turn the parameters names into indices + if self.strat is None: + raise ValueError("No strategy is set, cannot convert fixed parameters.") + dummy = np.zeros(len(self.parnames)).astype("O") for key, value in fixed.items(): idx = self.parnames.index(key) @@ -297,14 +191,211 @@ def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]): return fixed_features - def __getstate__(self): - # nuke the socket since it's not pickleble - state = self.__dict__.copy() - del state["socket"] - del state["db"] - return state + #### Methods to handle replay #### + def replay(self, uuid_to_replay: int, skip_computations: bool = False) -> None: + """Replay an experiment with a specific unique ID. This will leave the + server state at the end of the replay. + + Args: + uuid_to_replay (int): Unique ID of the experiment to replay. This is + the primary key of the experiment's master table. + skip_computations (bool): If True, skip computations during the replay. + Defaults to False. + """ + return replay(self, uuid_to_replay, skip_computations) + + def get_strats_from_replay( + self, uuid_of_replay: Optional[int] = None, force_replay: bool = False + ) -> List[Strategy]: + """Replay an experiment then return the strategies from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + force_replay (bool): If True, force a replay. Defaults to False. + + Returns: + List[Union[SequentialStrategy, Strategy]]: List of strategies from + the replay. + """ + return get_strats_from_replay(self, uuid_of_replay, force_replay) + + def get_strat_from_replay( + self, uuid_of_replay: Optional[int] = None, strat_id: int = -1 + ) -> Strategy: + """Replay an experiment then return a strategy from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + strat_id (int): ID of the strategy to return. Defaults to -1, which + returns the last strategy. + + Returns: + Strategy: The strategy from the replay. + """ + return get_strat_from_replay(self, uuid_of_replay, strat_id) + + def get_dataframe_from_replay( + self, uuid_of_replay: Optional[int] = None, force_replay: bool = False + ) -> pd.DataFrame: + """Replay an experiment then return the dataframe from the replay. + + Args: + uuid_to_replay (int, optional): Unique ID of the experiment to + replay. If not set, the last experiment in the database will be + used. + force_replay (bool): If True, force a replay. Defaults to False. + + Returns: + pd.DataFrame: Dataframe from the replay. + """ + return get_dataframe_from_replay(self, uuid_of_replay, force_replay) + + def _unpack_strat_buffer(self, strat_buffer): + # Unpacks a strategy buffer from the database. + if isinstance(strat_buffer, io.BytesIO): + strat = torch.load(strat_buffer, pickle_module=dill) + strat_buffer.seek(0) + elif isinstance(strat_buffer, bytes): + warnings.warn( + "Strat buffer is not in bytes format!" + + " This is a deprecated format, loading using dill.loads.", + DeprecationWarning, + ) + strat = dill.loads(strat_buffer) + else: + raise RuntimeError("Trying to load strat in unknown format!") + return strat - def write_strats(self, termination_type): + #### Method to handle async server #### + def start_blocking(self) -> None: + """Starts the server in a blocking state in the main thread. Used by the + command line interface to start the server for a client in another + process or machine.""" + asyncio.run(self.serve()) + + def start_background(self): + """Starts the server in a background thread. Used for scripts where the + client and server are in the same process.""" + raise NotImplementedError + + async def serve(self) -> None: + """Serves the server on the set IP and port. This creates a coroutine + for asyncio to handle requests asyncronously. + """ + self.server = await asyncio.start_server( + self.handle_client, self.host, self.port + ) + self.loop = asyncio.get_running_loop() + pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) + self.loop.set_default_executor(pool) + + async with self.server: + logging.info(f"Serving on {self.host}:{self.port}") + try: + await self.server.serve_forever() + except asyncio.CancelledError: + raise + except KeyboardInterrupt: + exception_type = "CTRL+C" + dump_type = "dump" + self.write_strats(exception_type) + self.generate_debug_info(exception_type, dump_type) + except RuntimeError as e: + exception_type = "RuntimeError" + dump_type = "crashdump" + self.write_strats(exception_type) + self.generate_debug_info(exception_type, dump_type) + raise RuntimeError(e) + + async def handle_client(self, reader, writer): + """Coroutine for handling a client connection. This will read messages + from the connected client and dispatch a task to handle the request on + another thread such that its blocking state does not block the server. + This coroutine will end if the client closes the connection. + + Args: + reader: asyncio.StreamReader: The stream reader for the client. + writer: asyncio.StreamWriter: The stream writer for the client. + """ + addr = writer.get_extra_info("peername") + logger.info(f"Connected to {addr}") + + try: + while True: + if self.exit_server_loop: + self.server.close() + break + rcv = await reader.read(1024 * 512) + try: + message = json.loads(rcv) + except UnicodeDecodeError as e: + logger.error(f"Malformed message: {rcv}") + logger.error(traceback.format_exc()) + result = {"error": str(e)} + return_msg = json.dumps(self._simplify_arrays(result)).encode() + writer.write(return_msg) + continue + + future = self.loop.run_in_executor(None, self.handle_request, message) + try: + result = await future + except Exception as e: + logger.error(f"Error handling message: {message}") + logger.error(traceback.format_exc()) + # Some exceptions turned into string are meaningless, so we use repr + result = {"error": e.__repr__()} + if isinstance(result, dict): + return_msg = json.dumps(self._simplify_arrays(result)).encode() + writer.write(return_msg) + else: + writer.write(str(result).encode()) + + await writer.drain() + except asyncio.CancelledError: + pass + finally: + logger.info(f"Connection closed for {addr}") + writer.close() + await writer.wait_closed() + + def handle_request(self, message: Dict[str, Any]) -> Union[Dict[str, Any], str]: + """Given a message, dispatch the correct handler and return the result. + + Args: + message (Dict[str, Any]): The message to handle. + + Returns: + Union[Dict[str, Any], str]: The result of handling the message. + """ + type_ = message["type"] + result = MESSAGE_MAP[type_](self, message) + return result + + def _simplify_arrays(self, message): + # Simplify arrays for encoding and sending a message to the client + return { + k: ( + v.tolist() + if type(v) == np.ndarray + else self._simplify_arrays(v) + if type(v) is dict + else v + ) + for k, v in message.items() + } + + #### Methods to handle exiting #### + def write_strats(self, termination_type: str) -> None: + """Pickle the stats and records them into the database. + + Args: + termination_type (str): The type of termination. This only affects + the log message. + """ if self._db_master_record is not None and self.strat is not None: logger.info(f"Dumping strats to DB due to {termination_type}.") for strat in self._strats: @@ -313,77 +404,28 @@ def write_strats(self, termination_type): buffer.seek(0) self.db.record_strat(master_table=self._db_master_record, strat=buffer) - def generate_debug_info(self, exception_type, dumptype): + def generate_debug_info(self, exception_type: str, dumptype: str) -> None: + """Generate a debug info file for the server. This will pickle the server + and save it to a file. + + Args: + exception_type (str): The type of exception that caused the server + to terminate. This only affects the log message. + dump_type (str): The type of dump. This only affects the log file. + """ fname = get_next_filename(".", dumptype, "pkl") logger.exception(f"Got {exception_type}, exiting! Server dump in {fname}") dill.dump(self, open(fname, "wb")) - def handle_request(self, request): - if "type" not in request.keys(): - raise RuntimeError(f"Request {request} contains no request type!") - else: - type = request["type"] - if type in MESSAGE_MAP.keys(): - logger.info(f"Received msg [{type}]") - ret_val = MESSAGE_MAP[type](self, request) - return ret_val - - else: - exception_message = ( - f"unknown type: {type}. Allowed types [{MESSAGE_MAP.keys()}]" - ) - - raise RuntimeError(exception_message) - - def replay(self, uuid_to_replay, skip_computations=False): - return replay(self, uuid_to_replay, skip_computations) - - def get_strats_from_replay(self, uuid_of_replay=None, force_replay=False): - return get_strats_from_replay(self, uuid_of_replay, force_replay) - - def get_strat_from_replay(self, uuid_of_replay=None, strat_id=-1): - return get_strat_from_replay(self, uuid_of_replay, strat_id) - - def get_dataframe_from_replay(self, uuid_of_replay=None, force_replay=False): - return get_dataframe_from_replay(self, uuid_of_replay, force_replay) - - -#! THIS IS WHAT START THE SERVER -def startServerAndRun( - server_class, socket=None, database_path=None, config_path=None, id_of_replay=None -): - server = server_class(socket=socket, database_path=database_path) - try: - if config_path is not None: - with open(config_path) as f: - config_str = f.read() - configure(server, config_str=config_str) - - if socket is not None: - if id_of_replay is not None: - server.replay(id_of_replay, skip_computations=True) - server.serve() - else: - if config_path is not None: - logger.info( - "You have passed in a config path but this is a replay. If there's a config in the database it will be used instead of the passed in config path." - ) - server.replay(id_of_replay) - except KeyboardInterrupt: - exception_type = "CTRL+C" - dump_type = "dump" - server.write_strats(exception_type) - server.generate_debug_info(exception_type, dump_type) - except RuntimeError as e: - exception_type = "RuntimeError" - dump_type = "crashdump" - server.write_strats(exception_type) - server.generate_debug_info(exception_type, dump_type) - raise RuntimeError(e) + def __getstate__(self): + # Called when the server is pickled, we can't pickle the DB. + state = self.__dict__.copy() + del state["db"] + return state def parse_argument(): - parser = argparse.ArgumentParser(description="AEPsych Server!") + parser = argparse.ArgumentParser(description="AEPsych Server") parser.add_argument( "--port", metavar="N", type=int, default=5555, help="port to serve on" ) @@ -415,72 +457,54 @@ def parse_argument(): "--db", type=str, help="The database to use if not the default (./databases/default.db).", - default=None, - ) - - parser.add_argument( - "-r", "--replay", type=str, help="Unique id of the experiment to replay." + default="./databases/default.db", ) parser.add_argument( - "-m", "--resume", action="store_true", help="Resume server after replay." + "-r", + "--resume", + type=str, + help="Unique id of the experiment to replay and resume the server from.", ) args = parser.parse_args() return args -def start_server(server_class, args): - logger.info("Starting the AEPsychServer") +def main(): + logger = utils_logging.getLogger() + logger.info("Starting AEPsychServer") logger.info(f"AEPsych Version: {version.__version__}") - try: - if "db" in args and args.db is not None: - database_path = args.db - if "replay" in args and args.replay is not None: - logger.info(f"Attempting to replay {args.replay}") - if args.resume is True: - sock = PySocket(port=args.port) - logger.info(f"Will resume {args.replay}") - else: - sock = None - startServerAndRun( - server_class, - socket=sock, - database_path=database_path, - uuid_of_replay=args.replay, - config_path=args.stratconfig, - ) - else: - logger.info(f"Setting the database path {database_path}") - sock = PySocket(port=args.port) - startServerAndRun( - server_class, - database_path=database_path, - socket=sock, - config_path=args.stratconfig, - ) - else: - sock = PySocket(port=args.port) - startServerAndRun(server_class, socket=sock, config_path=args.stratconfig) - except (KeyboardInterrupt, SystemExit): - logger.exception("Got Ctrl+C, exiting!") - sys.exit() - except RuntimeError as e: - fname = get_next_filename(".", "dump", "pkl") - logger.exception(f"CRASHING!! dump in {fname}") - raise RuntimeError(e) - - -def main(server_class=AEPsychServer): args = parse_argument() if args.logs: # overide logger path log_path = args.logs - logger = utils_logging.getLogger(logging.DEBUG, log_path) - logger.info(f"Saving logs to path: {log_path}") - start_server(server_class, args) + logger = utils_logging.getLogger(log_path) + logger.info(f"Saving logs to path: {log_path}") + + server = AEPsychServer( + host=args.ip, + port=args.port, + database_path=args.db, + ) + + if args.stratconfig is not None and args.resume is not None: + raise ValueError( + "Cannot configure the server with a config file and a resume from a replay at the same time." + ) + + elif args.stratconfig is not None: + configure(server, config_str=args.stratconfig) + + elif args.resume is not None: + if args.db is None: + raise ValueError("Cannot resume from a replay if no database is given.") + server.replay(args.resume, skip_computations=True) + + # Starts the server in a blocking state + server.start_blocking() if __name__ == "__main__": - main(AEPsychServer) + main() diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 494e5d72d..96425b481 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -34,8 +34,6 @@ def setUp(self): ) def tearDown(self): - self.s.cleanup() - # cleanup the db if self.s.db is not None: self.s.db.delete_db() diff --git a/tests/models/test_pairwise_probit.py b/tests/models/test_pairwise_probit.py index 9b2b6c138..359c040f3 100644 --- a/tests/models/test_pairwise_probit.py +++ b/tests/models/test_pairwise_probit.py @@ -501,8 +501,6 @@ def setUp(self): self.s = server.AEPsychServer(database_path=database_path) def tearDown(self): - self.s.cleanup() - # cleanup the db if self.s.db is not None: self.s.db.delete_db() diff --git a/tests/server/message_handlers/test_ask_handlers.py b/tests/server/message_handlers/test_ask_handlers.py index 9d3fa5c6b..30d773f90 100644 --- a/tests/server/message_handlers/test_ask_handlers.py +++ b/tests/server/message_handlers/test_ask_handlers.py @@ -8,7 +8,7 @@ import unittest -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase dummy_config = """ [common] @@ -69,7 +69,7 @@ """ -class AskHandlerTestCase(BaseServerTestCase): +class AskHandlerTestCase(AsyncServerTestBase): def test_handle_ask(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_can_model.py b/tests/server/message_handlers/test_can_model.py index 01b3c4b8a..6f5b090e0 100644 --- a/tests/server/message_handlers/test_can_model.py +++ b/tests/server/message_handlers/test_can_model.py @@ -7,10 +7,10 @@ import unittest -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class StratCanModelTestCase(BaseServerTestCase): +class StratCanModelTestCase(AsyncServerTestBase): def test_strat_can_model(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_handle_exit.py b/tests/server/message_handlers/test_handle_exit.py index 5b548bc25..bd8ff89b7 100644 --- a/tests/server/message_handlers/test_handle_exit.py +++ b/tests/server/message_handlers/test_handle_exit.py @@ -5,24 +5,30 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio import unittest -from unittest.mock import MagicMock -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase, dummy_config -class HandleExitTestCase(BaseServerTestCase): - def test_handle_exit(self): +class HandleExitTestCase(AsyncServerTestBase): + async def test_handle_exit(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + + await self.mock_client(setup_request) + request = {} request["type"] = "exit" - self.s.socket.accept_client = MagicMock() - self.s.socket.receive = MagicMock(return_value=request) - self.s.dump = MagicMock() + await self.mock_client(request) - with self.assertRaises(SystemExit) as cm: - self.s.serve() + with self.assertRaises(ConnectionRefusedError): + await asyncio.open_connection(self.s.host, self.s.port) - self.assertEqual(cm.exception.code, 0) + self.assertTrue(self.s.exit_server_loop) if __name__ == "__main__": diff --git a/tests/server/message_handlers/test_handle_finish_strategy.py b/tests/server/message_handlers/test_handle_finish_strategy.py index 9efffdb20..729421bdb 100644 --- a/tests/server/message_handlers/test_handle_finish_strategy.py +++ b/tests/server/message_handlers/test_handle_finish_strategy.py @@ -7,10 +7,10 @@ import unittest -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class ResumeTestCase(BaseServerTestCase): +class ResumeTestCase(AsyncServerTestBase): def test_handle_finish_strategy(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_handle_get_config.py b/tests/server/message_handlers/test_handle_get_config.py index d79c0697f..b173d22e5 100644 --- a/tests/server/message_handlers/test_handle_get_config.py +++ b/tests/server/message_handlers/test_handle_get_config.py @@ -9,10 +9,10 @@ from aepsych.config import Config -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class HandleExitTestCase(BaseServerTestCase): +class HandleExitTestCase(AsyncServerTestBase): def test_get_config(self): setup_request = { "type": "setup", diff --git a/tests/server/message_handlers/test_query_handlers.py b/tests/server/message_handlers/test_query_handlers.py index 2f8eaff2a..3ff9f618e 100644 --- a/tests/server/message_handlers/test_query_handlers.py +++ b/tests/server/message_handlers/test_query_handlers.py @@ -7,12 +7,12 @@ import unittest -from ..test_server import BaseServerTestCase +from ..test_server import AsyncServerTestBase # Smoke test to make sure nothing breaks. This should really be combined with # the individual query tests -class QueryHandlerTestCase(BaseServerTestCase): +class QueryHandlerTestCase(AsyncServerTestBase): def test_strat_query(self): # Annoying and complex model and output shapes config_str = """ diff --git a/tests/server/message_handlers/test_tell_handlers.py b/tests/server/message_handlers/test_tell_handlers.py index 4128b4ed6..7f68e84f5 100644 --- a/tests/server/message_handlers/test_tell_handlers.py +++ b/tests/server/message_handlers/test_tell_handlers.py @@ -9,10 +9,10 @@ import unittest from unittest.mock import MagicMock -from ..test_server import BaseServerTestCase, dummy_config +from ..test_server import AsyncServerTestBase, dummy_config -class MessageHandlerTellTests(BaseServerTestCase): +class MessageHandlerTellTests(AsyncServerTestBase): def test_tell(self): setup_request = { "type": "setup", diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 57d59da50..6c5617896 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -5,14 +5,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import asyncio import json import logging -import select import time import unittest import uuid from pathlib import Path -from unittest.mock import MagicMock +from typing import Any, Dict import aepsych.server as server import aepsych.utils_logging as utils_logging @@ -77,28 +77,52 @@ """ -class BaseServerTestCase(unittest.TestCase): - # so that this can be overridden for tests that require specific databases. +class AsyncServerTestBase(unittest.IsolatedAsyncioTestCase): @property def database_path(self): return "./{}_test_server.db".format(str(uuid.uuid4().hex)) - def setUp(self): + async def asyncSetUp(self): + self.ip = "127.0.0.1" + self.port = 5555 + # setup logger server.logger = utils_logging.getLogger(logging.DEBUG, "logs") - # random port - socket = server.sockets.PySocket(port=0) + # random datebase path name without dashes database_path = self.database_path - self.s = server.AEPsychServer(socket=socket, database_path=database_path) + self.s = server.AEPsychServer( + database_path=database_path, host=self.ip, port=self.port + ) self.db_name = database_path.split("/")[1] self.db_path = database_path - def tearDown(self): - self.s.cleanup() + try: + self.server_task = asyncio.create_task(self.s.serve()) + except OSError: + # Try 0.0.0.0 after waiting + time.sleep(5) + self.ip = "0.0.0.0" + self.s = server.AEPsychServer( + database_path=database_path, host=self.ip, port=self.port + ) + self.server_task = asyncio.create_task(self.s.serve()) + await asyncio.sleep(0.1) + + self.reader, self.writer = await asyncio.open_connection(self.ip, self.port) + + async def asyncTearDown(self): + # Stops the client + self.writer.close() - # sleep to ensure db is closed - time.sleep(0.2) + # Stops the server + self.server_task.cancel() + try: + await self.server_task + except asyncio.CancelledError: + pass + + await asyncio.sleep(0.2) # cleanup the db if self.s.db is not None: @@ -107,46 +131,18 @@ def tearDown(self): except PermissionError as e: print("Failed to deleted database: ", e) - def dummy_create_setup(self, server, request=None): - request = request or {"test": "test request"} - server._db_master_record = server.db.record_setup( - description="default description", name="default name", request=request - ) + async def mock_client(self, request: Dict[str, Any]) -> Any: + self.writer.write(json.dumps(request).encode()) + await self.writer.drain() + response = await self.reader.read(1024 * 512) + return response.decode() -class ServerTestCase(BaseServerTestCase): - def test_final_strat_serialization(self): - setup_request = { - "type": "setup", - "version": "0.01", - "message": {"config_str": dummy_config}, - } - ask_request = {"type": "ask", "message": ""} - tell_request = { - "type": "tell", - "message": {"config": {"x": [0.5]}, "outcome": 1}, - } - self.s.handle_request(setup_request) - while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) - unique_id = self.s.db.get_master_records()[-1].unique_id - stored_strat = self.s.get_strat_from_replay(unique_id) - # just some spot checks that the strat's the same - # same data. We do this twice to make sure buffers are - # in a good state and we can load twice without crashing - for _ in range(2): - stored_strat = self.s.get_strat_from_replay(unique_id) - self.assertTrue((stored_strat.x == self.s.strat.x).all()) - self.assertTrue((stored_strat.y == self.s.strat.y).all()) - # same lengthscale and outputscale - self.assertEqual( - stored_strat.model.covar_module.lengthscale, - self.s.strat.model.covar_module.lengthscale, - ) +class AsyncServerTestCase(AsyncServerTestBase): + """Server functions are all async""" - def test_pandadf_dump_single(self): + async def test_pandadf_dump_single(self): setup_request = { "type": "setup", "version": "0.01", @@ -158,20 +154,22 @@ def test_pandadf_dump_single(self): "message": {"config": {"x": [0.5]}, "outcome": 1}, "extra_info": {}, } - self.s.handle_request(setup_request) + + await self.mock_client(setup_request) + expected_x = [0, 1, 2, 3] expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = [expected_x[i]] tell_request["message"]["config"]["z"] = [expected_z[i]] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -183,7 +181,38 @@ def test_pandadf_dump_single(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_pandadf_dump_multistrat(self): + async def test_final_strat_serialization(self): + setup_request = { + "type": "setup", + "version": "0.01", + "message": {"config_str": dummy_config}, + } + ask_request = {"type": "ask", "message": ""} + tell_request = { + "type": "tell", + "message": {"config": {"x": [0.5]}, "outcome": 1}, + } + await self.mock_client(setup_request) + while not self.s.strat.finished: + await self.mock_client(ask_request) + await self.mock_client(tell_request) + + unique_id = self.s.db.get_master_records()[-1].unique_id + stored_strat = self.s.get_strat_from_replay(unique_id) + # just some spot checks that the strat's the same + # same data. We do this twice to make sure buffers are + # in a good state and we can load twice without crashing + for _ in range(2): + stored_strat = self.s.get_strat_from_replay(unique_id) + self.assertTrue((stored_strat.x == self.s.strat.x).all()) + self.assertTrue((stored_strat.y == self.s.strat.y).all()) + # same lengthscale and outputscale + self.assertEqual( + stored_strat.model.covar_module.lengthscale, + self.s.strat.model.covar_module.lengthscale, + ) + + async def test_pandadf_dump_multistrat(self): setup_request = { "type": "setup", "version": "0.01", @@ -199,16 +228,16 @@ def test_pandadf_dump_multistrat(self): expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = [expected_x[i]] tell_request["message"]["config"]["z"] = [expected_z[i]] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -221,7 +250,7 @@ def test_pandadf_dump_multistrat(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_pandadf_dump_flat(self): + async def test_pandadf_dump_flat(self): """ This test handles the case where the config values are flat scalars and not lists @@ -237,20 +266,20 @@ def test_pandadf_dump_flat(self): "message": {"config": {"x": [0.5]}, "outcome": 1}, "extra_info": {}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) expected_x = [0, 1, 2, 3] expected_z = list(reversed(expected_x)) expected_y = [x % 2 for x in expected_x] i = 0 while not self.s.strat.finished: - self.s.handle_request(ask_request) + await self.mock_client(ask_request) tell_request["message"]["config"]["x"] = expected_x[i] tell_request["message"]["config"]["z"] = expected_z[i] tell_request["message"]["outcome"] = expected_y[i] tell_request["extra_info"]["e1"] = 1 tell_request["extra_info"]["e2"] = 2 i = i + 1 - self.s.handle_request(tell_request) + await self.mock_client(tell_request) unique_id = self.s.db.get_master_records()[-1].unique_id out_df = self.s.get_dataframe_from_replay(unique_id) @@ -262,52 +291,7 @@ def test_pandadf_dump_flat(self): self.assertTrue("post_mean" in out_df.columns) self.assertTrue("post_var" in out_df.columns) - def test_receive(self): - """test_receive - verifies the receive is working when server receives unexpected messages""" - - message1 = b"\x16\x03\x01\x00\xaf\x01\x00\x00\xab\x03\x03\xa9\x80\xcc" # invalid message - message2 = b"\xec\xec\x14M\xfb\xbd\xac\xe7jF\xbe\xf9\x9bM\x92\x15b\xb5" # invalid message - message3 = {"message": {"target": "test request"}} # valid message - message_list = [message1, message2, json.dumps(message3)] - - self.s.socket.conn = MagicMock() - - for i, message in enumerate(message_list): - select.select = MagicMock(return_value=[[self.s.socket.conn], [], []]) - self.s.socket.conn.recv = MagicMock(return_value=message) - if i != 2: - self.assertEqual(self.s.socket.receive(False), BAD_REQUEST) - else: - self.assertEqual(self.s.socket.receive(False), message3) - - def test_error_handling(self): - # double brace escapes, single brace to substitute, so we end up with 3 braces - request = f"{{{BAD_REQUEST}}}" - - expected_error = f"server_error, Request '{request}' raised error ''str' object has no attribute 'keys''!" - - self.s.socket.accept_client = MagicMock() - - self.s.socket.receive = MagicMock(return_value=request) - self.s.socket.send = MagicMock() - self.s.exit_server_loop = True - with self.assertRaises(SystemExit): - self.s.serve() - self.s.socket.send.assert_called_once_with(expected_error) - - def test_queue(self): - """Test to see that the queue is being handled correctly""" - - self.s.socket.accept_client = MagicMock() - ask_request = {"type": "ask", "message": ""} - self.s.socket.receive = MagicMock(return_value=ask_request) - self.s.socket.send = MagicMock() - self.s.exit_server_loop = True - with self.assertRaises(SystemExit): - self.s.serve() - assert len(self.s.queue) == 0 - - def test_replay(self): + async def test_replay(self): exp_config = """ [common] lb = [0] @@ -341,15 +325,14 @@ def test_replay(self): } exit_request = {"message": "", "type": "exit"} - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) + await self.mock_client(ask_request) + await self.mock_client(tell_request) - self.s.handle_request(exit_request) + await self.mock_client(exit_request) - socket = server.sockets.PySocket(port=0) - serv = server.AEPsychServer(socket=socket, database_path=self.db_path) + serv = server.AEPsychServer(database_path=self.db_path) exp_ids = [rec.unique_id for rec in serv.db.get_master_records()] serv.replay(exp_ids[-1], skip_computations=True) @@ -359,7 +342,7 @@ def test_replay(self): self.assertTrue(strat.finished) self.assertTrue(strat.x.shape[0] == 4) - def test_string_parameter(self): + async def test_string_parameter(self): string_config = """ [common] parnames = [x, y, z] @@ -405,16 +388,17 @@ def test_string_parameter(self): "type": "tell", "message": {"config": {"x": [0.5], "y": ["blue"], "z": [50]}, "outcome": 1}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - response = self.s.handle_request(ask_request) + response = await self.mock_client(ask_request) + response = json.loads(response) self.assertTrue(response["config"]["y"][0] == "blue") - self.s.handle_request(tell_request) + await self.mock_client(tell_request) self.assertTrue(len(self.s.strat.lb) == 2) self.assertTrue(len(self.s.strat.ub) == 2) - def test_metadata(self): + async def test_metadata(self): setup_request = { "type": "setup", "version": "0.01", @@ -425,10 +409,10 @@ def test_metadata(self): "type": "tell", "message": {"config": {"x": [0.5]}, "outcome": 1}, } - self.s.handle_request(setup_request) + await self.mock_client(setup_request) while not self.s.strat.finished: - self.s.handle_request(ask_request) - self.s.handle_request(tell_request) + await self.mock_client(ask_request) + await self.mock_client(tell_request) master_record = self.s.db.get_master_records()[-1] extra_metadata = json.loads(master_record.extra_metadata) @@ -443,7 +427,7 @@ def test_metadata(self): self.assertTrue(extra_metadata["extra"] == "data that is arbitrary") self.assertTrue("experiment_id" not in extra_metadata) - def test_extension_server(self): + async def test_extension_server(self): extension_path = Path(__file__).parent.parent.parent extension_path = extension_path / "extensions_example" / "new_objects.py" @@ -470,8 +454,8 @@ def test_extension_server(self): "message": {"config_str": config_str}, } - with self.assertLogs(level=logging.INFO) as logs: - self.s.handle_request(setup_request) + with self.assertLogs() as logs: + await self.mock_client(setup_request) outputs = ";".join(logs.output) self.assertTrue(str(extension_path) in outputs) @@ -481,6 +465,30 @@ def test_extension_server(self): self.assertTrue(one == 1) self.assertTrue(strat.generator._base_obj.__class__.__name__ == "OnesGenerator") + async def test_receive(self): + """test_receive - verifies the receive is working when server receives unexpected messages""" + + message1 = b"\x16\x03\x01\x00\xaf\x01\x00\x00\xab\x03\x03\xa9\x80\xcc" # invalid message + message2 = b"\xec\xec\x14M\xfb\xbd\xac\xe7jF\xbe\xf9\x9bM\x92\x15b\xb5" # invalid message + message3 = {"message": {"target": "test request"}} # valid message + message_list = [message1, message2, message3] + + for i, message in enumerate(message_list): + if isinstance(message, dict): + send = json.dumps(message).encode() + else: + send = message + self.writer.write(send) + await self.writer.drain() + + response = await self.reader.read(1024 * 512) + response = response.decode() + response = json.loads(response) + if i != 2: + self.assertTrue("error" in response) # Very generic error for malformed + else: + self.assertTrue("KeyError" in response["error"]) # Specific error + if __name__ == "__main__": unittest.main() diff --git a/tests/test_datafetcher.py b/tests/test_datafetcher.py index ffe3f1980..62761ffa6 100644 --- a/tests/test_datafetcher.py +++ b/tests/test_datafetcher.py @@ -98,9 +98,6 @@ def setUp(self): # setup logger server.logger = utils_logging.getLogger(logging.DEBUG, "logs") - # random port - socket = server.sockets.PySocket(port=0) - database_path = Path(__file__).parent / "test_databases" / "1000_outcome.db" dst_db_path = Path("./{}.db".format(str(uuid.uuid4().hex))) @@ -109,7 +106,7 @@ def setUp(self): time.sleep(0.1) self.assertTrue(dst_db_path.is_file()) - self.s = server.AEPsychServer(socket=socket, database_path=dst_db_path) + self.s = server.AEPsychServer(database_path=dst_db_path) setup_message = { "type": "setup", @@ -125,8 +122,6 @@ def setUp(self): def tearDown(self): time.sleep(0.1) - - self.s.cleanup() self.s.db.delete_db() def test_create_from_config(self): diff --git a/tests/test_db.py b/tests/test_db.py index ec44dca58..0dbb98d1e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -31,11 +31,6 @@ def tearDown(self): time.sleep(0.1) self._database.delete_db() - def test_db_create(self): - engine = self._database.get_engine() - self.assertIsNotNone(engine) - self.assertIsNotNone(self._database._engine) - def test_record_setup_basic(self): master_table = self._database.record_setup( description="test description",