Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

RUN_DEFAULTS_KEY = "strax_defaults"
TEMP_DATA_TYPE_PREFIX = "_temp_"
NOT_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin)

# use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env)
tqdm = strax.utils.tqdm
Expand Down Expand Up @@ -986,18 +987,45 @@ def __add_lineage_to_plugin(
# Set chunk_number in the lineage
if chunk_number is not None:
for d_depends in plugin.depends_on:
if d_depends in chunk_number:
if len(plugin.depends_on) > 1:
raise ValueError(
"Can not assign chunk_number for multi-dependencies plugins "
"because it is not clear which input should be assigned."
)
not_allowed_plugins = (strax.LoopPlugin, strax.OverlapWindowPlugin)
if issubclass(plugin.__class__, not_allowed_plugins):
dependencies = self.get_dependencies(d_depends) | {d_depends}
for d in chunk_number.keys():
if d not in dependencies:
continue
if issubclass(plugin.__class__, NOT_ALLOWED_PLUGINS):
raise ValueError(
f"Can not assign chunk_number for {plugin.__class__} "
f"because it is subclass of one of {not_allowed_plugins}!"
f"Can not load per-chunk storage from {d} for {plugin.__class__} "
f"because it is subclass of one of {NOT_ALLOWED_PLUGINS}!"
)
if d_depends in chunk_number:
if len(plugin.depends_on) > 1:
for d in plugin.depends_on:
dependencies = self.get_dependencies(d) | {d}
msg = (
f"Can not assign chunk_number for {plugin.__class__} "
"because it has multiple dependencies and one of the "
f"dependencies {d} does not (eventually) depend on {d_depends}."
)
mask = d_depends in dependencies
if not mask:
raise ValueError(msg)
# Make sure other dependencies depend on the same per-chunk data_type
for shortest in [False, True]:
levels = {
_d: self.tree_levels[shortest][_d]["level"]
for _d in dependencies
}
mask &= (
len(
[
k
for k, v in levels.items()
if v == levels.get(d_depends, -1)
]
)
== 1
)
if not mask:
raise ValueError(msg)
configs.setdefault("chunk_number", {})
if d_depends in configs["chunk_number"]:
raise ValueError(
Expand Down Expand Up @@ -2848,32 +2876,50 @@ def tree_levels(self):
if self._fixed_level_cache is not None and context_hash in self._fixed_level_cache:
return self._fixed_level_cache[context_hash]

def _get_levels(data_type=None, results=None):
def _get_levels(data_type=None, results=None, shortest=False):
"""Get the level data_type in the context."""
if results is None:
results = dict()
for k in [data_type] if data_type else self._plugin_class_registry.keys():
if k in results:
continue
results[k] = dict()
_v = self._plugin_class_registry[k]()
if _v.depends_on:
results[k]["level"] = (
max(_get_levels(d, results)[d]["level"] for d in _v.depends_on) + 1
)
if shortest:
results[k]["level"] = (
min(
_get_levels(d, results, shortest=shortest)[d]["level"]
for d in _v.depends_on
)
+ 1
)
else:
results[k]["level"] = (
max(
_get_levels(d, results, shortest=shortest)[d]["level"]
for d in _v.depends_on
)
+ 1
)
else:
results[k]["level"] = 0
results[k]["class"] = self._plugin_class_registry[k].__name__
results[k]["index"] = _v.provides.index(k)
return results

# Sort the results by level, class, and index in provides
_results = sorted(
_get_levels().items(), key=lambda x: (x[1]["level"], x[1]["class"], x[1]["index"])
)
results = dict()
for shortest in [False, True]:
results[shortest] = sorted(
_get_levels(shortest=shortest).items(),
key=lambda x: (x[1]["level"], x[1]["class"], x[1]["index"]),
)

# Assign order to the results
for order, (key, value) in enumerate(_results):
value["order"] = order
results = dict(_results)
# Assign order to the results
for order, (key, value) in enumerate(results[shortest]):
value["order"] = order
results[shortest] = dict(results[shortest])

if self._fixed_level_cache is None:
self._fixed_level_cache = {context_hash: results}
Expand Down