Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/queens/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
from pathlib import Path

from queens.schedulers._dask import SHUTDOWN_CLIENTS
from queens.schedulers._scheduler import CLEANUP_SCHEDULERS
from queens.utils.ascii_art import print_banner_and_description
from queens.utils.logger_settings import reset_logging, setup_basic_logging
from queens.utils.path import PATH_TO_ROOT, create_folder_if_not_existent
Expand Down Expand Up @@ -168,8 +168,8 @@ def __exit__(self, exception_type, exception_value, traceback):
exception_value: indicates exception instance
traceback: traceback object
"""
for shutdown_client in SHUTDOWN_CLIENTS.copy():
SHUTDOWN_CLIENTS.remove(shutdown_client)
shutdown_client()
for cleanup_scheduler in CLEANUP_SCHEDULERS.copy():
CLEANUP_SCHEDULERS.remove(cleanup_scheduler)
cleanup_scheduler()

reset_logging()
31 changes: 12 additions & 19 deletions src/queens/schedulers/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@

_logger = logging.getLogger(__name__)

SHUTDOWN_CLIENTS = []


class Dask(Scheduler):
"""Abstract base class for schedulers in QUEENS.
Expand Down Expand Up @@ -83,21 +81,7 @@ def _start_cluster_and_connect_client(self):
def start_cluster_and_connect_client(self):
"""Start a Dask cluster and a client that connects to it."""
if self.client is None or self.client.status == "closed":
client = self._start_cluster_and_connect_client()
self.register_shutdown(client)
self.client = client

def register_shutdown(self, client):
"""Register shutdown callback.

The Dask client and cluster will be shut down when leaving the GlobalSettings context.

Args:
client (Client): Dask client that is connected to and submits computations to a Dask
cluster.
"""
global SHUTDOWN_CLIENTS # pylint: disable=global-variable-not-assigned
SHUTDOWN_CLIENTS.append(client.shutdown)
self.client = self._start_cluster_and_connect_client()

def evaluate(
self, samples: Iterable, function: SchedulerCallableSignature, job_ids: Iterable = None
Expand Down Expand Up @@ -171,6 +155,15 @@ def run_function(*args, **kwargs):
def restart_worker(self, worker):
"""Restart a worker."""

async def shutdown_client(self):
def shutdown_client(self):
"""Shutdown the DASK client."""
await self.client.shutdown()
if self.client is not None:
try:
self.client.shutdown()
except AttributeError as e:
_logger.warning("AttributeError while shutting down Dask client: %s", e)

def cleanup(self):
"""Cleanup after QUEENS run."""
self.shutdown_client()
self.delete_experiment_dir_if_empty(self.experiment_dir)
20 changes: 20 additions & 0 deletions src/queens/schedulers/_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

_logger = logging.getLogger(__name__)

CLEANUP_SCHEDULERS = []


class SchedulerCallableSignature(Protocol):
"""Signature for callables which can be used with QUEENS schedulers."""
Expand Down Expand Up @@ -81,6 +83,8 @@ def __init__(self, experiment_name, experiment_dir, num_jobs, verbose=True):
self.next_job_id = 0
self.verbose = verbose

CLEANUP_SCHEDULERS.append(self.cleanup)

@abc.abstractmethod
def evaluate(
self, samples: Iterable, function: SchedulerCallableSignature, job_ids: Iterable = None
Expand Down Expand Up @@ -179,3 +183,19 @@ def get_job_ids(self, num_samples):
job_ids = self.next_job_id + np.arange(num_samples)
self.next_job_id += num_samples
return job_ids

def cleanup(self):
"""Cleanup after QUEENS run."""
self.delete_experiment_dir_if_empty(self.experiment_dir)

@staticmethod
def delete_experiment_dir_if_empty(experiment_dir):
"""Delete the experiment directory if it is empty.

Args:
experiment_dir (Path): Path to the experiment directory.
"""
if experiment_dir.exists() and experiment_dir.is_dir():
if not any(experiment_dir.iterdir()):
experiment_dir.rmdir()
_logger.debug("Deleted empty experiment directory '%s'.", experiment_dir)
4 changes: 4 additions & 0 deletions src/queens/schedulers/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,7 @@ def copy_files_from_experiment_dir(
self.remote_connection.copy_from_remote(
self.experiment_dir, destination, verbose, exclude, filters
)

@staticmethod
def delete_experiment_dir_if_empty(_):
"""The remote experiment directory will never be empty, so pass."""
44 changes: 40 additions & 4 deletions tests/integration_tests/cluster/test_dask_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,9 @@ def patch_experiments_directory(experiment_name, experiment_base_directory=None)
return patch_experiments_directory

@pytest.fixture(name="experiment_dir")
def fixture_experiment_dir(self, global_settings, remote_connection, mock_experiment_dir):
def fixture_experiment_dir(self, test_name, remote_connection, mock_experiment_dir):
"""Fixture providing the remote experiment directory."""
experiment_dir, _ = remote_connection.run_function(
mock_experiment_dir, global_settings.experiment_name, None
)
experiment_dir, _ = remote_connection.run_function(mock_experiment_dir, test_name, None)
return experiment_dir

@pytest.fixture(name="_create_experiment_dir")
Expand Down Expand Up @@ -209,6 +207,44 @@ def test_y_prompt_input_for_existing_experiment_dir(
mocker.patch("sys.stdin.readline", return_value=user_input)
Cluster(**cluster_kwargs, overwrite_existing_experiment=False)

def test_deletion_of_experiment_dir_with_files(
self, global_settings, cluster_kwargs, remote_connection, experiment_dir
):
"""Test the deletion of an experiment directory containing files.

The experiment directory should NOT be deleted when exiting the
global settings context.
"""

def experiment_dir_exists_and_contents(experiment_dir):
"""Assert that experiment directory and test file exist."""
experiment_dir_exists = experiment_dir.exists()
if not experiment_dir_exists:
return experiment_dir_exists, []

experiment_dir_contents = list(experiment_dir.iterdir())
return experiment_dir_exists, experiment_dir_contents

with global_settings:
Cluster(**cluster_kwargs)

# Check that remote experiment directory is not empty
experiment_dir_exists, experiment_dir_contents_before = remote_connection.run_function(
experiment_dir_exists_and_contents, experiment_dir
)
assert experiment_dir_exists
assert any(experiment_dir_contents_before)

# Check that remote experiment directory has not been changed
experiment_dir_exists, experiment_dir_contents_after = remote_connection.run_function(
experiment_dir_exists_and_contents, experiment_dir
)
assert experiment_dir_exists
for file_before, file_after in zip(
experiment_dir_contents_before, experiment_dir_contents_after, strict=True
):
assert file_before == file_after

def test_fourc_mc_cluster(
self,
third_party_inputs,
Expand Down
42 changes: 42 additions & 0 deletions tests/unit_tests/schedulers/test_experiment_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,45 @@ def fixture_create_experiment_dir(experiment_dir):
"""Create the experiment directory."""
os.mkdir(experiment_dir)
assert experiment_dir.exists()


@pytest.mark.parametrize("scheduler_class", [Local, Pool])
def test_empty_experiment_dir_is_deleted(
global_settings, tmp_path, test_name, experiment_dir, scheduler_class
):
"""Test that an empty experiment directory is deleted.

This should happen when exiting the global settings context.
"""
with global_settings:
scheduler_class(
experiment_name=test_name,
experiment_base_dir=tmp_path,
)
assert experiment_dir.exists()
assert not any(experiment_dir.iterdir())

assert not experiment_dir.exists()


@pytest.mark.parametrize("scheduler_class", [Local, Pool])
def test_experiment_dir_with_files_is_not_deleted(
global_settings, tmp_path, test_name, experiment_dir, scheduler_class
):
"""Test that an experiment directory containing files is not deleted.

Such an experiment directory should NOT be deleted when exiting the
global settings context.
"""
with global_settings:
scheduler_class(
experiment_name=test_name,
experiment_base_dir=tmp_path,
)
assert experiment_dir.exists()
test_file = experiment_dir / "test_file.txt"
test_file.write_text("test content")
assert test_file.exists()

assert experiment_dir.exists()
assert test_file.exists()
Loading