Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions scisample/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Module defining an ensemble, which contains multiple samplers joined together.

Follows the same interface as the samplers.
"""
import logging

from scisample.base_sampler import BaseSampler
from scisample.samplers import new_sampler
from scisample.utils import log_and_raise_exception

LOG = logging.getLogger(__name__)


class Ensemble(BaseSampler):
"""
Class for combining multiple samplers.
"""

def __init__(self, *sampler_data):
"""
Initialize the sampler.

:param sampler_data: Sampler data to use to initialize the samplers.
"""
super(Ensemble, self).__init__(sampler_data)

self._samplers = []
for data in sampler_data:
if not isinstance(data, list):
data = [data]
self.add_samplers(*data)

self.check_validity()

def add_samplers(self, *sampler_data):
"""
Add additional samplers to the Ensemble.
A validity check will be performed after adding samples.
Additionally, the samples will be un-cached from the Ensemble.
(sampler caches will not be impacted).

:param sampler_data: Sampler data to use to add samplers.
"""
for data in sampler_data:
self._samplers.append(new_sampler(data))
self.check_validity()
self._samples = None
self._parameter_block = None

def check_validity(self):
"""
Check the validity of the underlying samplers.
"""
if not self._samplers:
log_and_raise_exception(
"No samplers requested for ensemble"
)
for sampler in self._samplers:
sampler.check_validity()
if sorted(self.parameters) != sorted(sampler.parameters):
log_and_raise_exception(
"All samplers in an ensemble must have the same "
f"parameters. Parameters from {sampler.parameters} "
f"did not match {self.parameters}."
)

@property
def parameters(self):
"""
Return a of list of the parameters being generated by the
sampler.
"""
return self._samplers[0].parameters

def get_samples(self):
"""
Get samples from the samplers.

This returns samples as a list of dictionaries, with the
sample variables as the keys:

.. code:: python

[{'b': 0.89856, 'a': 1}, {'b': 0.923223, 'a': 1}, ... ]
"""
LOG.info("Entering Ensemble.get_samples()")
if self._samples is not None:
return self._samples

self._samples = []
for sampler in self._samplers:
self._samples.extend(sampler.get_samples())

return self._samples
Loading