diff --git a/.all-contributorsrc b/.all-contributorsrc index d8c51a20d..78194e308 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -71,7 +71,8 @@ "profile": "https://github.com/Bogdan-Wiederspan", "contributions": [ "code", - "test" + "test", + "review" ] }, { @@ -153,9 +154,11 @@ "avatar_url": "https://avatars.githubusercontent.com/u/99343616?v=4", "profile": "https://github.com/aalvesan", "contributions": [ - "code" + "code", + "review" ] - }, { + }, + { "login": "philippgadow", "name": "philippgadow", "avatar_url": "https://avatars.githubusercontent.com/u/6804366?v=4", @@ -163,6 +166,15 @@ "contributions": [ "code" ] + }, + { + "login": "LuSchaller", + "name": "Lukas Schaller", + "avatar_url": "https://avatars.githubusercontent.com/u/30951523?v=4", + "profile": "https://github.com/LuSchaller", + "contributions": [ + "code" + ] } ], "commitType": "docs" diff --git a/README.md b/README.md index 410d7e1b7..43abbe91b 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ For a better overview of the tasks that are triggered by the commands below, che Daniel Savoiu
Daniel Savoiu

💻 👀 pkausw
pkausw

💻 👀 nprouvost
nprouvost

💻 ⚠️ - Bogdan-Wiederspan
Bogdan-Wiederspan

💻 ⚠️ + Bogdan-Wiederspan
Bogdan-Wiederspan

💻 ⚠️ 👀 Tobias Kramer
Tobias Kramer

💻 👀 @@ -151,8 +151,9 @@ For a better overview of the tasks that are triggered by the commands below, che JulesVandenbroeck
JulesVandenbroeck

💻 - Ana Andrade
Ana Andrade

💻 + Ana Andrade
Ana Andrade

💻 👀 philippgadow
philippgadow

💻 + Lukas Schaller
Lukas Schaller

💻 diff --git a/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py b/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py index da7e34817..f160e5e6f 100644 --- a/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py +++ b/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py @@ -14,14 +14,16 @@ apply_variable_settings, apply_process_settings, ) +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") od = maybe_import("order") +# import hist, matplotlib... for type checking only like this! import them then also locallu. +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") + def my_plot1d_func( hists: OrderedDict[od.Process, hist.Hist], @@ -45,6 +47,9 @@ def my_plot1d_func( --plot-function __cf_module_name__.plotting.example.my_plot1d_func \ --general-settings example_param=some_text """ + import mplhep + import matplotlib.pyplot as plt + # we can add arbitrary parameters via the `general_settings` parameter to access them in the # plotting function. They are automatically parsed either to a bool, float, or string print(f"the example_param has been set to '{example_param}' (type: {type(example_param)})") diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index 8721d9d1d..35a233f00 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -27,10 +27,10 @@ default_analysis: __cf_module_name__.config.analysis___cf_short_name_lc__.analys default_config: run2_2017_nano_v9 default_dataset: st_tchannel_t_4f_powheg -calibration_modules: columnflow.calibration.cms.{jets,met,tau}, __cf_module_name__.calibration.example +calibration_modules: columnflow.calibration.cms.{jets,met,tau,egamma,muon}, __cf_module_name__.calibration.example selection_modules: columnflow.selection.empty, columnflow.selection.cms.{json_filter,met_filters}, __cf_module_name__.selection.example reduction_modules: columnflow.reduction.default, __cf_module_name__.reduction.example -production_modules: columnflow.production.{categories,matching,normalization,processes}, columnflow.production.cms.{btag,electron,jet,matching,mc_weight,muon,pdf,pileup,scale,parton_shower,seeds}, __cf_module_name__.production.example +production_modules: columnflow.production.{categories,matching,normalization,processes}, columnflow.production.cms.{btag,electron,jet,matching,mc_weight,muon,pdf,pileup,scale,parton_shower,seeds,gen_particles}, __cf_module_name__.production.example categorization_modules: __cf_module_name__.categorization.example hist_production_modules: columnflow.histogramming.default, __cf_module_name__.histogramming.example ml_modules: columnflow.ml, __cf_module_name__.ml.example @@ -56,12 +56,16 @@ default_create_selection_hists: False # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False +# the name of a sandbox to use for tasks in remote jobs initially (invoked with claw when set) +default_remote_claw_sandbox: None + # some remote workflow parameter defaults # (resources like memory and disk can also be set in [resources] with more granularity) htcondor_flavor: $CF_HTCONDOR_FLAVOR htcondor_share_software: False htcondor_memory: -1 htcondor_disk: -1 +htcondor_runtime: 3h slurm_flavor: $CF_SLURM_FLAVOR slurm_partition: $CF_SLURM_PARTITION @@ -70,6 +74,9 @@ chunked_io_chunk_size: 100000 chunked_io_pool_size: 2 chunked_io_debug: False +# settings for merging parquet files in several locations +merging_row_group_size: 50000 + # csv list of task families that inherit from ChunkedReaderMixin and whose output arrays should be # checked (raising an exception) for non-finite values before saving them to disk check_finite_output: cf.CalibrateEvents, cf.SelectEvents, cf.ReduceEvents, cf.ProduceColumns @@ -98,8 +105,8 @@ lfn_sources: wlcg_fs_t2b_redirector, wlcg_fs_infn_redirector, wlcg_fs_global_red # output locations per task family # the key can consist of multple underscore-separated parts, that can each be patterns or regexes # these parts are used for the lookup from within tasks and can contain (e.g.) the analysis name, -# the config name, the task family, the dataset name, or the shift name -# (see AnalysisTask.get_config_lookup_keys() - and subclasses - for the exact order) +# the config name, the task family, the dataset name, or the shift name, for more info, see +# https://columnflow.readthedocs.io/en/latest/user_guide/best_practices.html#selecting-output-locations # values can have the following format: # for local targets : "local[, LOCAL_FS_NAME or STORE_PATH][, store_parts_modifier]" # for remote targets : "wlcg[, WLCG_FS_NAME][, store_parts_modifier]" @@ -108,8 +115,8 @@ lfn_sources: wlcg_fs_t2b_redirector, wlcg_fs_infn_redirector, wlcg_fs_global_red # the "store_parts_modifiers" can be the name of a function in the "store_parts_modifiers" aux dict # of the analysis instance, which is called with an output's store parts of an output to modify them # example: -; run3_2023__cf.CalibrateEvents__nomin*: local -; cf.CalibrateEvents: wlcg +; cfg_run3_2023__task_cf.CalibrateEvents__shift_nomin*: local +; task_cf.CalibrateEvents: wlcg [versions] @@ -117,13 +124,13 @@ lfn_sources: wlcg_fs_t2b_redirector, wlcg_fs_infn_redirector, wlcg_fs_global_red # default versions of specific tasks to pin # the key can consist of multple underscore-separated parts, that can each be patterns or regexes # these parts are used for the lookup from within tasks and can contain (e.g.) the analysis name, -# the config name, the task family, the dataset name, or the shift name -# (see AnalysisTask.get_config_lookup_keys() - and subclasses - for the exact order) +# the config name, the task family, the dataset name, or the shift name, for more info, see +# https://columnflow.readthedocs.io/en/latest/user_guide/best_practices.html#pinned-versions-in-the-analysis-config-or-law-cfg-file # note: # this lookup is skipped if the lookup based on the config instance's auxiliary data succeeded # example: -; run3_2023__cf.CalibrateEvents__nomin*: prod1 -; cf.CalibrateEvents: prod2 +; cfg_run3_2023__task_cf.CalibrateEvents__shift_nomin*: prod1 +; task_cf.CalibrateEvents: prod2 [resources] @@ -135,8 +142,8 @@ lfn_sources: wlcg_fs_t2b_redirector, wlcg_fs_infn_redirector, wlcg_fs_global_red # by the respective parameter instance at runtime # same as for [versions], the order of options is important as it defines the resolution order # example: -; run3_2023__cf.CalibrateEvents__nomin*: htcondor_memory=5GB -; run3_2023__cf.CalibrateEvents: htcondor_memory=2GB +; cfg_run3_2023__task_cf.CalibrateEvents__shift_nomin*: htcondor_memory=5GB +; cfg_run3_2023__task_cf.CalibrateEvents: htcondor_memory=2GB [job] @@ -159,6 +166,12 @@ remote_lcg_setup_el9: /cvmfs/grid.cern.ch/alma9-ui-test/etc/profile.d/setup-alma remote_lcg_setup_force: False +[target] + +# when removing target collections, use multi-threading +collection_remove_threads: 2 + + [local_fs] base: / diff --git a/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py b/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py index 943d3ce33..2166f1f22 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py +++ b/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py @@ -14,14 +14,16 @@ apply_variable_settings, apply_process_settings, ) +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") od = maybe_import("order") +# import hist, matplotlib... for type checking only like this! import them then also locallu. +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") + def my_plot1d_func( hists: OrderedDict[od.Process, hist.Hist], @@ -45,6 +47,9 @@ def my_plot1d_func( --plot-function __cf_module_name__.plotting.example.my_plot1d_func \ --general-settings example_param=some_text """ + import mplhep + import matplotlib.pyplot as plt + # we can add arbitrary parameters via the `general_settings` parameter to access them in the # plotting function. They are automatically parsed either to a bool, float, or string print(f"The example_param has been set to '{example_param}' (type: {type(example_param)})") diff --git a/analysis_templates/ghent_template/__cf_module_name__/production/default.py b/analysis_templates/ghent_template/__cf_module_name__/production/default.py index a4617ef4b..fb546987c 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/production/default.py +++ b/analysis_templates/ghent_template/__cf_module_name__/production/default.py @@ -16,8 +16,10 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents.methods.nanoaod") + +# do not import coffea globally! Do this inside the function +# coffea = maybe_import("coffea") +# maybe_import("coffea.nanoevents.methods.nanoaod") @producer( diff --git a/analysis_templates/ghent_template/__cf_module_name__/selection/default.py b/analysis_templates/ghent_template/__cf_module_name__/selection/default.py index 1e237ff5b..e2ea6669b 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/selection/default.py +++ b/analysis_templates/ghent_template/__cf_module_name__/selection/default.py @@ -28,15 +28,21 @@ from __cf_short_name_lc__.selection.stats import __cf_short_name_lc___increment_stats from __cf_short_name_lc__.selection.trigger import trigger_selection +# only numpy and awkward are okay to import globally np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents.methods.nanoaod") + +# do not import coffea globally! Do this inside the function +# coffea = maybe_import("coffea") +# maybe_import("coffea.nanoevents.methods.nanoaod") logger = law.logger.get_logger(__name__) def TetraVec(arr: ak.Array) -> ak.Array: + import coffea + import coffea.nanoevents.methods.nanoaod + TetraVec = ak.zip({"pt": arr.pt, "eta": arr.eta, "phi": arr.phi, "mass": arr.mass}, with_name="PtEtaPhiMLorentzVector", behavior=coffea.nanoevents.methods.vector.behavior) diff --git a/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py b/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py index 568350b58..86a3952b5 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py +++ b/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py @@ -10,7 +10,7 @@ from columnflow.util import maybe_import, four_vec from columnflow.columnar_util import set_ak_column from columnflow.selection import Selector, SelectionResult, selector -from columnflow.reduction.util import masked_sorted_indices +from columnflow.columnar_util import sorted_indices_from_mask ak = maybe_import("awkward") @@ -53,7 +53,7 @@ def muon_object( steps={}, objects={ "Muon": { - "Muon": masked_sorted_indices(mu_mask, muon.pt) + "Muon": sorted_indices_from_mask(mu_mask, muon.pt) } }, ) @@ -108,7 +108,7 @@ def electron_object( steps={}, objects={ "Electron": { - "Electron": masked_sorted_indices(e_mask, electron.pt) + "Electron": sorted_indices_from_mask(e_mask, electron.pt) } }, ) @@ -142,7 +142,7 @@ def jet_object( (dR_mask) ) - jet_indices = masked_sorted_indices(jet_mask, events.Jet.pt) + jet_indices = sorted_indices_from_mask(jet_mask, events.Jet.pt) n_jets = ak.sum(jet_mask, axis=-1) return events, SelectionResult( diff --git a/bin/cf_inspect b/bin/cf_inspect index 3c0a01f08..2b35c5449 100755 --- a/bin/cf_inspect +++ b/bin/cf_inspect @@ -1,15 +1,25 @@ #!/bin/sh action () { + # local variables local shell_is_zsh="$( [ -z "${ZSH_VERSION}" ] && echo "false" || echo "true" )" local this_file="$( ${shell_is_zsh} && echo "${(%):-%x}" || echo "${BASH_SOURCE[0]}" )" local this_dir="$( cd "$( dirname "${this_file}" )" && pwd )" + # check arguments # [ "$#" -eq 0 ] && { # echo "ERROR: at least one file must be provided" # return 1 # } - cf_sandbox venv_columnar_dev python "${this_dir}/cf_inspect.py" "$@" + # determine the sandbox to use + local cf_inspect_sandbox="${CF_INSPECT_SANDBOX:-venv_columnar_dev}" + + # run the inspection script, potentially switching to the sandbox if not already in it + if [ "${CF_VENV_NAME}" = "${cf_inspect_sandbox}" ]; then + python "${this_dir}/cf_inspect.py" "$@" + else + cf_sandbox "${cf_inspect_sandbox}" python "${this_dir}/cf_inspect.py" "$@" + fi } action "$@" diff --git a/bin/cf_inspect.py b/bin/cf_inspect.py index aeb5606d2..8e5465508 100644 --- a/bin/cf_inspect.py +++ b/bin/cf_inspect.py @@ -13,60 +13,74 @@ import pickle import awkward as ak -import coffea.nanoevents -import uproot import numpy as np # noqa from columnflow.util import ipython_shell from columnflow.types import Any -def _load_json(fname: str) -> Any: +def _load_json(fname: str, **kwargs) -> Any: with open(fname, "r") as fobj: return json.load(fobj) -def _load_pickle(fname: str) -> Any: +def _load_pickle(fname: str, **kwargs) -> Any: with open(fname, "rb") as fobj: return pickle.load(fobj) -def _load_parquet(fname: str) -> ak.Array: +def _load_parquet(fname: str, **kwargs) -> ak.Array: return ak.from_parquet(fname) -def _load_nano_root(fname: str) -> ak.Array: +def _load_nano_root(fname: str, treepath: str | None = None, **kwargs) -> ak.Array: + import uproot + import coffea.nanoevents + source = uproot.open(fname) + + # get the default treepath + if treepath is None: + for treepath in ["events", "Events"] + list(source.keys()): + treepath = treepath.split(";", 1)[0] + if treepath in source and isinstance(source[treepath], uproot.TTree): + print(f"using treepath '{treepath}' in root file {fname}") + break + else: + raise ValueError(f"no default treepath determined in {fname}") try: return coffea.nanoevents.NanoEventsFactory.from_root( source, + treepath=treepath, + delayed=False, runtime_cache=None, persistent_cache=None, ).events() except: return uproot.open(fname) - -def _load_h5(fname: str): - import h5py - return h5py.File(fname, "r") + return coffea.nanoevents.NanoEventsFactory.from_root( + source, + treepath=treepath, + mode="eager", + runtime_cache=None, + persistent_cache=None, + ).events() -def load(fname: str) -> Any: +def load(fname: str, **kwargs) -> Any: """ Load file contents based on file extension. """ basename, ext = os.path.splitext(fname) if ext == ".pickle": - return _load_pickle(fname) + return _load_pickle(fname, **kwargs) if ext == ".parquet": - return _load_parquet(fname) + return _load_parquet(fname, **kwargs) if ext == ".root": - return _load_nano_root(fname) + return _load_nano_root(fname, **kwargs) if ext == ".json": - return _load_json(fname) - if ext in [".h5", ".hdf5"]: - return _load_h5(fname) + return _load_json(fname, **kwargs) raise NotImplementedError(f"no loader implemented for extension '{ext}'") @@ -101,18 +115,22 @@ def list_content(data: Any) -> None: ap.add_argument("files", metavar="FILE", nargs="*", help="one or more supported files") ap.add_argument("--events", "-e", action="store_true", help="assume files to contain event info") ap.add_argument("--hists", "-h", action="store_true", help="assume files to contain histograms") + ap.add_argument("--treepath", "-t", type=str, help="name of the tree in ROOT files") ap.add_argument("--list", "-l", action="store_true", help="list contents of the loaded file") ap.add_argument("--help", action="help", help="show this help message and exit") args = ap.parse_args() - objects = [load(fname) for fname in args.files] + load_kwargs = { + "treepath": args.treepath, + } + objects = [load(fname, **load_kwargs) for fname in args.files] if len(objects) == 1: objects = objects[0] print("file content loaded into variable 'objects'") # interpret data - intepreted = objects + interpreted = objects if args.events: # preload common packages import awkward as ak # noqa diff --git a/bin/cf_remove_tmp b/bin/cf_remove_tmp index 0d9ba39f3..d5deb67d3 100755 --- a/bin/cf_remove_tmp +++ b/bin/cf_remove_tmp @@ -24,6 +24,8 @@ cf_remove_tmp() { fi # get the directory + local prompt + local confirm local tmp_dir="$( law config target.tmp_dir )" local ret="$?" if [ "${ret}" != "0" ]; then @@ -35,12 +37,37 @@ cf_remove_tmp() { elif [ ! -d "${tmp_dir}" ]; then >&2 echo "'law config target.tmp_dir' is not a directory" return "3" + elif [ -z "${LAW_TARGET_TMP_DIR}" ] && [ "$( cd "${tmp_dir}" && pwd )" = "${PWD}" ]; then + prompt="'law config target.tmp_dir' reports that the tmp directory is set to the current working directory '${PWD}'. Continue? (y/n) " + read -rp "${prompt}" confirm + case "${confirm}" in + [Yy]) + ;; + *) + >&2 echo "canceled" + return "4" + ;; + esac fi - # remove all files and directories in tmp_dir owned by the user + # define the search pattern local pattern="luigi-tmp-*" [ "${mode}" = "all" ] && pattern="*" - find "${tmp_dir}" -maxdepth 1 -name "${pattern}" -user "$( id -u )" -exec rm -r "{}" \; + + # ask for confirmation + prompt="Are you sure you want to delete all files in path \"${tmp_dir}\" matching \"${pattern}\"? (y/n) " + read -rp "${prompt}" confirm + case "${confirm}" in + [Yy]) + # remove all files and directories in tmp_dir owned by the user + echo "deleting files..." + find "${tmp_dir}" -maxdepth 1 -name "${pattern}" -user "$( id -u )" -print -exec rm -r "{}" \; + ;; + *) + >&2 echo "canceled" + return "4" + ;; + esac } cf_remove_tmp "$@" diff --git a/columnflow/__init__.py b/columnflow/__init__.py index dda0895c7..19177c644 100644 --- a/columnflow/__init__.py +++ b/columnflow/__init__.py @@ -6,6 +6,7 @@ import os import re +import time import logging import law @@ -17,7 +18,7 @@ ) -logger = logging.getLogger(__name__) +logger = logging.getLogger(f"{__name__}_module_loader") # version info m = re.match(r"^(\d+)\.(\d+)\.(\d+)(-.+)?$", __version__) @@ -79,62 +80,59 @@ for fs in law.config.get_expanded("outputs", "wlcg_file_systems", [], split_csv=True) ] - # initialize producers, calibrators, selectors, categorizers, ml models and stat models + # initialize producers, calibrators, selectors, reducers, categorizers, ml models, hist producers and stat models from columnflow.util import maybe_import + def load(module, group): + t0 = time.perf_counter() + maybe_import(module) + duration = law.util.human_duration(seconds=time.perf_counter() - t0) + logger.debug(f"loaded {group} module '{module}', took {duration}") + import columnflow.calibration # noqa if law.config.has_option("analysis", "calibration_modules"): for m in law.config.get_expanded("analysis", "calibration_modules", [], split_csv=True): - logger.debug(f"loading calibration module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "calibration") import columnflow.selection # noqa if law.config.has_option("analysis", "selection_modules"): for m in law.config.get_expanded("analysis", "selection_modules", [], split_csv=True): - logger.debug(f"loading selection module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "selection") import columnflow.reduction # noqa if law.config.has_option("analysis", "reduction_modules"): for m in law.config.get_expanded("analysis", "reduction_modules", [], split_csv=True): - logger.debug(f"loading reduction module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "reduction") import columnflow.production # noqa if law.config.has_option("analysis", "production_modules"): for m in law.config.get_expanded("analysis", "production_modules", [], split_csv=True): - logger.debug(f"loading production module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "production") import columnflow.histogramming # noqa if law.config.has_option("analysis", "hist_production_modules"): for m in law.config.get_expanded("analysis", "hist_production_modules", [], split_csv=True): - logger.debug(f"loading hist production module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "hist production") import columnflow.categorization # noqa if law.config.has_option("analysis", "categorization_modules"): for m in law.config.get_expanded("analysis", "categorization_modules", [], split_csv=True): - logger.debug(f"loading categorization module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "categorization") import columnflow.ml # noqa if law.config.has_option("analysis", "ml_modules"): for m in law.config.get_expanded("analysis", "ml_modules", [], split_csv=True): - logger.debug(f"loading ml module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "ml") import columnflow.inference # noqa if law.config.has_option("analysis", "inference_modules"): for m in law.config.get_expanded("analysis", "inference_modules", [], split_csv=True): - logger.debug(f"loading inference module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "inference") # preload all task modules so that task parameters are globally known and accepted if law.config.has_section("modules"): for m in law.config.options("modules"): - logger.debug(f"loading task module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "task") # cleanup del m diff --git a/columnflow/__version__.py b/columnflow/__version__.py index d3a0772fb..bf1b3f34b 100644 --- a/columnflow/__version__.py +++ b/columnflow/__version__.py @@ -20,8 +20,9 @@ "Tobias Kramer", "Matthias Schroeder", "Johannes Lange", + "Ana Andrade", ] __contact__ = "https://github.com/columnflow/columnflow" __license__ = "BSD-3-Clause" __status__ = "Development" -__version__ = "0.2.4" +__version__ = "0.3.0" diff --git a/columnflow/calibration/__init__.py b/columnflow/calibration/__init__.py index 276e22c6d..f0ed046bf 100644 --- a/columnflow/calibration/__init__.py +++ b/columnflow/calibration/__init__.py @@ -8,18 +8,55 @@ import inspect -from columnflow.types import Callable +import law + from columnflow.util import DerivableMeta from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, Sequence, Any + + +class TaskArrayFunctionWithCalibratorRequirements(TaskArrayFunction): + + require_calibrators: Sequence[str] | set[str] | None = None + + def _req_calibrator(self, task: law.Task, calibrator: str) -> Any: + # hook to customize how required calibrators are requested + from columnflow.tasks.calibration import CalibrateEvents + return CalibrateEvents.req_other_calibrator(task, calibrator=calibrator) + def requires_func(self, task: law.Task, reqs: dict, **kwargs) -> None: + # no requirements for workflows in pilot mode + if callable(getattr(task, "is_workflow", None)) and task.is_workflow() and getattr(task, "pilot", False): + return -class Calibrator(TaskArrayFunction): + # add required calibrators when set + if (calibs := self.require_calibrators): + reqs["required_calibrators"] = {calib: self._req_calibrator(task, calib) for calib in calibs} + + def setup_func( + self, + task: law.Task, + reqs: dict, + inputs: dict, + reader_targets: law.util.InsertableDict, + **kwargs, + ) -> None: + if "required_calibrators" in inputs: + for calib, inp in inputs["required_calibrators"].items(): + reader_targets[f"required_calibrator_{calib}"] = inp["columns"] + + +class Calibrator(TaskArrayFunctionWithCalibratorRequirements): """ Base class for all calibrators. """ exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def calibrator( cls, @@ -27,25 +64,26 @@ def calibrator( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_calibrators: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ - Decorator for creating a new :py:class:`~.Calibrator` subclass with additional, optional - *bases* and attaching the decorated function to it as ``call_func``. + Decorator for creating a new :py:class:`~.Calibrator` subclass with additional, optional *bases* and attaching + the decorated function to it as ``call_func``. - When *mc_only* (*data_only*) is *True*, the calibrator is skipped and not considered by - other calibrators, selectors and producers in case they are evalauted on a - :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` - (``is_data``) attribute is *False*. + When *mc_only* (*data_only*) is *True*, the calibrator is skipped and not considered by other calibrators, + selectors and producers in case they are evalauted on a :py:class:`order.Dataset` (using the + :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Calibrator` class. :param bases: Additional bases for the new :py:class:`Calibrator`. - :param mc_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on - Monte Carlo simulation and skipped for real data. - :param data_only: Boolean flag indicating that this :py:class:`Calibrator` should only run - on real data and skipped for Monte Carlo simulation. + :param mc_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on Monte Carlo + simulation and skipped for real data. + :param data_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on real data and + skipped for Monte Carlo simulation. + :param require_calibrators: Sequence of names of other calibrators to add to the requirements. :return: New :py:class:`Calibrator` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -55,6 +93,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_calibrators": require_calibrators, } # get the module name diff --git a/columnflow/calibration/cms/egamma.py b/columnflow/calibration/cms/egamma.py index fc31a289e..54993bf01 100644 --- a/columnflow/calibration/cms/egamma.py +++ b/columnflow/calibration/cms/egamma.py @@ -1,649 +1,242 @@ # coding: utf-8 """ -Egamma energy correction methods. -Source: https://twiki.cern.ch/twiki/bin/view/CMS/EgammSFandSSRun3#Scale_And_Smearings_Correctionli +CMS-specific calibrators applying electron and photon energy scale and smearing. + +1. Scale corrections are applied to data. +2. Resolution smearing is applied to simulation. +3. Both scale and resolution uncertainties are applied to simulation. + +Resources: + - https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3#Scale_And_Smearings_Correctionli + - https://egammapog.docs.cern.ch/Run3/SaS + - https://cms-analysis-corrections.docs.cern.ch/corrections_era/Run3-22CDSep23-Summer22-NanoAODv12/EGM/2025-10-22 """ from __future__ import annotations -import abc import functools +import dataclasses + import law -from dataclasses import dataclass, field from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, ak_copy, optional_column +from columnflow.columnar_util import TAFConfig, set_ak_column, full_like from columnflow.types import Any ak = maybe_import("awkward") np = maybe_import("numpy") +logger = law.logger.get_logger(__name__) + # helper set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) -@dataclass -class EGammaCorrectionConfig: - correction_set: str - value_type: str - uncertainty_type: str - compound: bool = False - corrector_kwargs: dict[str, Any] = field(default_factory=dict) - - -class egamma_scale_corrector(Calibrator): - - with_uncertainties = True - """Switch to control whether uncertainties are calculated.""" - - @property - @abc.abstractmethod - def source_field(self) -> str: - """Fields required for the current calibrator.""" - ... - - @abc.abstractmethod - def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFileTarget: - """Function to retrieve the correction file from the external files. - - :param external_files: File target containing the files as requested - in the current config instance under ``config_inst.x.external_files`` - """ - ... - - @abc.abstractmethod - def get_scale_config(self) -> EGammaCorrectionConfig: - """Function to retrieve the configuration for the photon energy correction.""" - ... - - def call_func(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Apply energy corrections to EGamma objects in the events array. There are two types of implementations: standard - and Et dependent. - For Run2 the standard implementation is used, while for Run3 the Et dependent is recommended by the EGammaPog: - https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3?rev=41 - The Et dependendent recipe follows the example given in: - https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/66f581d0549e8d67fc55420d8bba15c9369fff7c/examples/egmScaleAndSmearingExample.py - - Requires an external file in the config under ``electron_ss``. Example: - - .. code-block:: python - - cfg.x.external_files = DotDict.wrap({ - "electron_ss": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-120c4271/POG/EGM/2022_Summer22//electronSS_EtDependent.json.gz", # noqa - }) - - The pairs of correction set, value and uncertainty type names, and if a compound method is used should be configured using the :py:class:`EGammaCorrectionConfig` as an - auxiliary entry in the config: - - .. code-block:: python - - cfg.x.eec = EGammaCorrectionConfig( - correction_set="EGMScale_Compound_Ele_2022preEE", - value_type="scale", - uncertainty_type="escale", - compound=True, - ) - - Derivatives of this base class require additional member variables and functions: - - - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` - or `Photon`). - - *get_correction_file*: Function to retrieve the correction file, e.g.from - the list, of external files in the current `config_inst`. - - *get_scale_config*: Function to retrieve the configuration for the energy correction. - This config must be an instance of - :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. - - If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. The - correction tool only supports flat arrays, so inputs are converted to a flat numpy view - first. Corrections are always applied to the raw pt, which is important if more than one - correction is applied in a row. The final corrections must be applied to the current pt. - - If :py:attr:`with_uncertainties` is set to `True`, the scale uncertainties are calculated. - The scale uncertainties are only available for simulated data. - - :param events: The events array containing EGamma objects. - :return: The events array with applied scale corrections. - - :notes: - - Varied corrections are only applied to Monte Carlo (MC) data. - - EGamma energy correction is only applied to real data. - - Changes are applied to the views and directly propagate to the original awkward - arrays. - """ - # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt - if "rawPt" not in events[self.source_field].fields: - events = set_ak_column_f32(events, f"{self.source_field}.rawPt", events[self.source_field].pt) - - # the correction tool only supports flat arrays, so convert inputs to flat np view first - # corrections are always applied to the raw pt - this is important if more than - # one correction is applied in a row - pt_eval = flat_np_view(events[self.source_field].rawPt, axis=1) - - # the final corrections must be applied to the current pt though - pt_application = flat_np_view(events[self.source_field].pt, axis=1) - - broadcasted_run = ak.broadcast_arrays(events[self.source_field].pt, events.run) - run = flat_np_view(broadcasted_run[1], axis=1) - gain = flat_np_view(events[self.source_field].seedGain, axis=1) - sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) - r9 = flat_np_view(events[self.source_field].r9, axis=1) - - # prepare arguments - # (energy is part of the LorentzVector behavior) - variable_map = { - "et": pt_eval, - "eta": sceta, - "gain": gain, - "r9": r9, - "run": run, - "seedGain": gain, - "pt": pt_eval, - "AbsScEta": np.abs(sceta), - "ScEta": sceta, - **self.scale_config.corrector_kwargs, - } - args = tuple( - variable_map[inp.name] for inp in self.scale_corrector.inputs - if inp.name in variable_map - ) - - # varied corrections are only applied to MC - if self.with_uncertainties and self.dataset_inst.is_mc: - scale_uncertainties = self.scale_corrector.evaluate(self.scale_config.uncertainty_type, *args) - scales_up = (1 + scale_uncertainties) - scales_down = (1 - scale_uncertainties) - - for (direction, scales) in [("up", scales_up), ("down", scales_down)]: - # copy pt and mass - pt_varied = ak_copy(events[self.source_field].pt) - pt_view = flat_np_view(pt_varied, axis=1) - - # apply the scale variation - pt_view *= scales - - # save columns - postfix = f"scale_{direction}" - events = set_ak_column_f32(events, f"{self.source_field}.pt_{postfix}", pt_varied) - - # apply the nominal correction - # note: changes are applied to the views and directly propagate to the original ak arrays - # and do not need to be inserted into the events chunk again - # EGamma energy correction is ONLY applied to DATA - if self.dataset_inst.is_data: - scales_nom = self.scale_corrector.evaluate(self.scale_config.value_type, *args) - pt_application *= scales_nom - - return events - - def init_func(self, **kwargs) -> None: - """Function to initialize the calibrator. - - Sets the required and produced columns for the calibrator. - """ - self.uses |= { - # nano columns - f"{self.source_field}.{{seedGain,pt,eta,phi,superclusterEta,r9}}", - "run", - optional_column(f"{self.source_field}.rawPt"), - } - self.produces |= { - f"{self.source_field}.pt", - optional_column(f"{self.source_field}.rawPt"), - } - - # if we do not calculate uncertainties, this module - # should only run on observed DATA - self.data_only = not self.with_uncertainties - - # add columns with unceratinties if requested - # photon scale _uncertainties_ are only available for MC - if self.with_uncertainties and self.dataset_inst.is_mc: - self.produces |= {f"{self.source_field}.pt_scale_{{up,down}}"} - - def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: - """Function to add necessary requirements. - - This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` - task to the requirements. - - :param reqs: Dictionary of requirements. - """ - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(task) - - def setup_func( - self, - task: law.Task, - reqs: dict[str, DotDict[str, Any]], - inputs: dict[str, Any], - reader_targets: law.util.InsertableDict, - **kwargs, - ) -> None: - """Setup function before event chunk loop. - - This function loads the correction file and sets up the correction tool. - Additionally, the *scale_config* is retrieved. - - :param reqs: Dictionary with resolved requirements. - :param inputs: Dictionary with inputs (not used). - :param reader_targets: Dictionary for optional additional columns to load - """ - self.scale_config = self.get_scale_config() - # create the egamma corrector - corr_file = self.get_correction_file(reqs["external_files"].files) - # init and extend the correction set - corr_set = load_correction_set(corr_file) - if self.scale_config.compound: - corr_set = corr_set.compound - self.scale_corrector = corr_set[self.scale_config.correction_set] - - -class egamma_resolution_corrector(Calibrator): - - with_uncertainties = True - """Switch to control whether uncertainties are calculated.""" - - # smearing of the energy resolution is only applied to MC - mc_only = True - """This calibrator is only applied to simulated data.""" - - deterministic_seed_index = -1 - """ use deterministic seeds for random smearing and - take the "index"-th random number per seed when not -1 +@dataclasses.dataclass +class EGammaCorrectionConfig(TAFConfig): """ + Container class to describe energy scaling and smearing configurations. Example: - @property - @abc.abstractmethod - def source_field(self) -> str: - """Fields required for the current calibrator.""" - ... - - @abc.abstractmethod - def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFile: - """Function to retrieve the correction file from the external files. - - :param external_files: File target containing the files as requested - in the current config instance under ``config_inst.x.external_files`` - """ - ... - - @abc.abstractmethod - def get_resolution_config(self) -> EGammaCorrectionConfig: - """Function to retrieve the configuration for the photon energy correction.""" - ... - - def call_func(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Apply energy resolution corrections to EGamma objects in the events array. - - There are two types of implementations: standard and Et dependent. For Run2 the standard - implementation is used, while for Run3 the Et dependent is recommended by the EGammaPog: - https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3?rev=41 The Et dependendent - recipe follows the example given in: - https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/66f581d0549e8d67fc55420d8bba15c9369fff7c/examples/egmScaleAndSmearingExample.py - - Requires an external file in the config under ``electron_ss``. Example: - - .. code-block:: python - - cfg.x.external_files = DotDict.wrap({ - "electron_ss": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-120c4271/POG/EGM/2022_Summer22/electronSS_EtDependent.json.gz", # noqa - }) - - The pairs of correction set, value and uncertainty type names, and if a compound method is used should be configured using the :py:class:`EGammaCorrectionConfig` as an - auxiliary entry in the config: - - .. code-block:: python - - cfg.x.eec = EGammaCorrectionConfig( - correction_set="EGMSmearAndSyst_ElePTsplit_2022preEE", - value_type="smear", - uncertainty_type="esmear", - ) - - Derivatives of this base class require additional member variables and functions: - - - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` or `Photon`). - - *get_correction_file*: Function to retrieve the correction file, e.g. - from the list of external files in the current `config_inst`. - - *get_resolution_config*: Function to retrieve the configuration for the energy resolution correction. - This config must be an instance of :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. - - If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. - The correction tool only supports flat arrays, so inputs are converted to a flat numpy view first. - Corrections are always applied to the raw pt, which is important if more than one correction is applied in a - row. The final corrections must be applied to the current pt. - - If :py:attr:`with_uncertainties` is set to `True`, the resolution uncertainties are calculated. - - If :py:attr:`deterministic_seed_index` is set to a value greater than or equal to 0, deterministic seeds - are used for random smearing. The "index"-th random number per seed is taken for the nominal resolution - correction. The "index+1"-th random number per seed is taken for the up variation and the "index+2"-th random - number per seed is taken for the down variation. - - :param events: The events array containing EGamma objects. - :return: The events array with applied resolution corrections. - - :notes: - - Energy resolution correction are only to be applied to simulation. - - Changes are applied to the views and directly propagate to the original awkward arrays. - """ - - # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt - if "rawPt" not in events[self.source_field].fields: - events = set_ak_column_f32(events, f"{self.source_field}.rawPt", ak_copy(events[self.source_field].pt)) - - # the correction tool only supports flat arrays, so convert inputs to flat np view first - sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) - r9 = flat_np_view(events[self.source_field].r9, axis=1) - flat_seeds = flat_np_view(events[self.source_field].deterministic_seed, axis=1) - pt = flat_np_view(events[self.source_field].rawPt, axis=1) - - # prepare arguments - variable_map = { - "AbsScEta": np.abs(sceta), - "ScEta": sceta, # 2024 version - "eta": sceta, - "r9": r9, - "pt": pt, - **self.resolution_cfg.corrector_kwargs, - } - - args = tuple( - variable_map[inp.name] - for inp in self.resolution_corrector.inputs - if inp.name in variable_map - ) - - # calculate the smearing scale - # as mentioned in the example above, allows us to apply them directly to the MC simulation. - rho = self.resolution_corrector.evaluate(self.resolution_cfg.value_type, *args) - - # varied corrections - if self.with_uncertainties and self.dataset_inst.is_mc: - rho_unc = self.resolution_corrector.evaluate(self.resolution_cfg.uncertainty_type, *args) - random_normal_number = functools.partial(ak_random, 0, 1) - smearing_func = lambda rng_array, variation: rng_array * variation + 1 - - smearing_up = ( - smearing_func( - random_normal_number(flat_seeds, rand_func=self.deterministic_normal_up), - rho + rho_unc, - ) - if self.deterministic_seed_index >= 0 - else smearing_func( - random_normal_number(rand_func=np.random.Generator(np.random.SFC64(events.event.to_list())).normal), - rho + rho_unc, - ) - ) - - smearing_down = ( - smearing_func( - random_normal_number(flat_seeds, rand_func=self.deterministic_normal_down), - rho - rho_unc, - ) - if self.deterministic_seed_index >= 0 - else smearing_func( - random_normal_number(rand_func=np.random.Generator(np.random.SFC64(events.event.to_list())).normal), - rho - rho_unc, - ) - ) - - for (direction, smear) in [("up", smearing_up), ("down", smearing_down)]: - # copy pt and mass - pt_varied = ak_copy(events[self.source_field].pt) - pt_view = flat_np_view(pt_varied, axis=1) - - # apply the scale variation - # cast ak to numpy array for convenient usage of *= - pt_view *= smear.to_numpy() - - # save columns - postfix = f"res_{direction}" - events = set_ak_column_f32(events, f"{self.source_field}.pt_{postfix}", pt_varied) - - # apply the nominal correction - # note: changes are applied to the views and directly propagate to the original ak arrays - # and do not need to be inserted into the events chunk again - # EGamma energy resolution correction is ONLY applied to MC - if self.dataset_inst.is_mc: - smearing = ( - ak_random(1, rho, flat_seeds, rand_func=self.deterministic_normal) - if self.deterministic_seed_index >= 0 - else ak_random(1, rho, rand_func=np.random.Generator( - np.random.SFC64(events.event.to_list())).normal, - ) - ) - # the final corrections must be applied to the current pt though - pt = flat_np_view(events[self.source_field].pt, axis=1) - pt *= smearing.to_numpy() - - return events - - def init_func(self, **kwargs) -> None: - """Function to initialize the calibrator. - - Sets the required and produced columns for the calibrator. - """ - self.uses |= { - # nano columns - f"{self.source_field}.{{pt,eta,phi,superclusterEta,r9}}", - optional_column(f"{self.source_field}.rawPt"), - } - self.produces |= { - f"{self.source_field}.pt", - optional_column(f"{self.source_field}.rawPt"), - } - - # add columns with unceratinties if requested - if self.with_uncertainties and self.dataset_inst.is_mc: - self.produces |= {f"{self.source_field}.pt_res_{{up,down}}"} - - def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: - """Function to add necessary requirements. - - This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` - task to the requirements. - - :param reqs: Dictionary of requirements. - """ - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(task) - - def setup_func( - self, - task: law.Task, - reqs: dict[str, DotDict[str, Any]], - inputs: dict[str, Any], - reader_targets: law.util.InsertableDict, - **kwargs, - ) -> None: - """Setup function before event chunk loop. - - This function loads the correction file and sets up the correction tool. - Additionally, the *resolution_config* is retrieved. - If :py:attr:`deterministic_seed_index` is set to a value greater than or equal to 0, - random generator based on object-specific random seeds are setup. - - :param reqs: Dictionary with resolved requirements. - :param inputs: Dictionary with inputs (not used). - :param reader_targets: Dictionary for optional additional columns to load - (not used). - """ - self.resolution_cfg = self.get_resolution_config() - # create the egamma corrector - corr_file = self.get_correction_file(reqs["external_files"].files) - corr_set = load_correction_set(corr_file) - if self.resolution_cfg.compound: - corr_set = corr_set.compound - self.resolution_corrector = corr_set[self.resolution_cfg.correction_set] - - # use deterministic seeds for random smearing if requested - if self.deterministic_seed_index >= 0: - idx = self.deterministic_seed_index - bit_generator = np.random.SFC64 - - def deterministic_normal(loc, scale, seed, idx_offset=0): - return np.asarray([ - np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1 + idx_offset)[-1] - for _loc, _scale, _seed in zip(loc, scale, seed) - ]) - self.deterministic_normal = functools.partial(deterministic_normal, idx_offset=0) - self.deterministic_normal_up = functools.partial(deterministic_normal, idx_offset=1) - self.deterministic_normal_down = functools.partial(deterministic_normal, idx_offset=2) - - -pec = egamma_scale_corrector.derive( - "pec", cls_dict={ - "source_field": "Photon", - "with_uncertainties": True, - "get_correction_file": (lambda self, external_files: external_files.photon_ss), - "get_scale_config": (lambda self: self.config_inst.x.pec), - }, -) + .. code-block:: python -per = egamma_resolution_corrector.derive( - "per", cls_dict={ - "source_field": "Photon", - "with_uncertainties": True, - # function to determine the correction file - "get_correction_file": (lambda self, external_files: external_files.photon_ss), - # function to determine the tec config - "get_resolution_config": (lambda self: self.config_inst.x.per), - }, -) + cfg.x.ess = EGammaCorrectionConfig( + scale_correction_set="Scale", + scale_compound=True, + smear_syst_correction_set="SmearAndSyst", + systs=["scale_down", "scale_up", "smear_down", "smear_up"], + ) + """ + scale_correction_set: str + smear_syst_correction_set: str + scale_compound: bool = False + smear_syst_compound: bool = False + systs: list[str] = dataclasses.field(default_factory=lambda: ["scale_down", "scale_up", "smear_down", "smear_up"]) + corrector_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) @calibrator( - uses={per, pec}, - produces={per, pec}, + exposed=False, + # used and produced columns are defined dynamically in init function with_uncertainties=True, - get_correction_file=None, - get_scale_config=None, - get_resolution_config=None, - deterministic_seed_index=-1, + collection_name=None, # to be set in derived classes to "Electron" or "Photon" + get_scale_smear_config=None, # to be set in derived classes + get_correction_file=None, # to be set in derived classes + deterministic_seed_index=-1, # use deterministic seeds for random smearing when >=0 + store_original=False, # if original columns (pt, energyErr) should be stored as "*_uncorrected" ) -def photons(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Calibrator for photons. This calibrator runs the energy scale and resolution calibrators - for photons. - - Careful! Always apply resolution before scale corrections for MC. - """ +def _egamma_scale_smear(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: + # gather inputs + coll = events[self.collection_name] + variable_map = { + "run": events.run if ak.sum(ak.num(coll, axis=1), axis=0) else [], + "pt": coll.pt, + "ScEta": coll.superclusterEta, + "AbsScEta": abs(coll.superclusterEta), + "r9": coll.r9, + "seedGain": coll.seedGain, + **self.cfg.corrector_kwargs, + } + def get_inputs(corrector, **additional_variables): + _variable_map = variable_map | additional_variables + return (_variable_map[inp.name] for inp in corrector.inputs if inp.name in _variable_map) + + # apply scale correction to data + if self.dataset_inst.is_data: + # store uncorrected values before correcting + if self.store_original: + events = set_ak_column(events, f"{self.collection_name}.pt_scale_uncorrected", coll.pt) + events = set_ak_column(events, f"{self.collection_name}.energyErr_scale_uncorrected", coll.energyErr) + + # get scaled pt + scale = self.scale_corrector.evaluate("scale", *get_inputs(self.scale_corrector)) + pt_scaled = coll.pt * scale + + # get scaled energy error + smear = self.smear_syst_corrector.evaluate("smear", *get_inputs(self.smear_syst_corrector, pt=pt_scaled)) + energy_err_scaled = (((coll.energyErr)**2 + (coll.energy * smear)**2) * scale)**0.5 + + # store columns + events = set_ak_column_f32(events, f"{self.collection_name}.pt", pt_scaled) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr", energy_err_scaled) + + # apply smearing to MC if self.dataset_inst.is_mc: - events = self[per](events, **kwargs) - - if self.with_uncertainties or self.dataset_inst.is_data: - events = self[pec](events, **kwargs) + # store uncorrected values before correcting + if self.store_original: + events = set_ak_column(events, f"{self.collection_name}.pt_smear_uncorrected", coll.pt) + events = set_ak_column(events, f"{self.collection_name}.energyErr_smear_uncorrected", coll.energyErr) + + # compute random variables in the shape of the collection once + rnd_args = (full_like(coll.pt, 0.0), full_like(coll.pt, 1.0)) + if self.use_deterministic_seeds: + rnd_args += (coll.deterministic_seed,) + rand_func = self.deterministic_normal + else: + # TODO: bit generator could be configurable + rand_func = np.random.Generator(np.random.SFC64((events.event).to_list())).normal + rnd = ak_random(*rnd_args, rand_func=rand_func) + + # helper to compute smeared pt and energy error values given a syst + def apply_smearing(syst): + # get smeared pt + smear = self.smear_syst_corrector.evaluate(syst, *get_inputs(self.smear_syst_corrector)) + smear_factor = 1.0 + smear * rnd + pt_smeared = coll.pt * smear_factor + # get smeared energy error + energy_err_smeared = (((coll.energyErr)**2 + (coll.energy * smear)**2) * smear_factor)**0.5 + # return both + return pt_smeared, energy_err_smeared + + # compute and store columns + pt_smeared, energy_err_smeared = apply_smearing("smear") + events = set_ak_column_f32(events, f"{self.collection_name}.pt", pt_smeared) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr", energy_err_smeared) + + # apply scale and smearing uncertainties to MC + if self.with_uncertainties and self.cfg.systs: + for syst in self.cfg.systs: + # exact behavior depends on syst itself + if syst in {"scale_up", "scale_down"}: + # compute scale with smeared pt and apply muliplicatively to smeared values + scale = self.smear_syst_corrector.evaluate(syst, *get_inputs(self.smear_syst_corrector, pt=pt_smeared)) # noqa: E501 + events = set_ak_column_f32(events, f"{self.collection_name}.pt_{syst}", pt_smeared * scale) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr_{syst}", energy_err_smeared * scale) # noqa: E501 + + elif syst in {"smear_up", "smear_down"}: + # compute smearing variations on original variables with same method as above + pt_smeared_syst, energy_err_smeared_syst = apply_smearing(syst) + events = set_ak_column_f32(events, f"{self.collection_name}.pt_{syst}", pt_smeared_syst) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr_{syst}", energy_err_smeared_syst) # noqa: E501 + + else: + logger.error(f"{self.cls_name} calibrator received unknown systematic '{syst}', skipping") return events -@photons.pre_init -def photons_pre_init(self, **kwargs) -> None: - # forward argument to the producers - if pec not in self.deps_kwargs: - self.deps_kwargs[pec] = dict() - if per not in self.deps_kwargs: - self.deps_kwargs[per] = dict() - self.deps_kwargs[pec]["with_uncertainties"] = self.with_uncertainties - self.deps_kwargs[per]["with_uncertainties"] = self.with_uncertainties - - self.deps_kwargs[per]["deterministic_seed_index"] = self.deterministic_seed_index - if self.get_correction_file is not None: - self.deps_kwargs[pec]["get_correction_file"] = self.get_correction_file - self.deps_kwargs[per]["get_correction_file"] = self.get_correction_file - - if self.get_resolution_config is not None: - self.deps_kwargs[per]["get_resolution_config"] = self.get_resolution_config - if self.get_scale_config is not None: - self.deps_kwargs[pec]["get_scale_config"] = self.get_scale_config - - -photons_nominal = photons.derive("photons_nominal", cls_dict={"with_uncertainties": False}) - - -eer = egamma_resolution_corrector.derive( - "eer", cls_dict={ - "source_field": "Electron", - # calculation of superclusterEta for electrons requires the deltaEtaSC - "uses": {"Electron.deltaEtaSC"}, - "with_uncertainties": True, - # function to determine the correction file - "get_correction_file": (lambda self, external_files: external_files.electron_ss), - # function to determine the tec config - "get_resolution_config": (lambda self: self.config_inst.x.eer), +@_egamma_scale_smear.init +def _egamma_scale_smear_init(self: Calibrator, **kwargs) -> None: + # store the config + self.cfg = self.get_scale_smear_config() + + # update used columns + self.uses |= {"run", f"{self.collection_name}.{{pt,eta,phi,mass,energyErr,superclusterEta,r9,seedGain}}"} + + # update produced columns + if self.dataset_inst.is_data: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}"} + if self.store_original: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}_scale_uncorrected"} + else: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}"} + if self.store_original: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}_smear_uncorrected"} + if self.with_uncertainties: + for syst in self.cfg.systs: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}_{syst}"} + + +@_egamma_scale_smear.requires +def _egamma_scale_smear_requires(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +@_egamma_scale_smear.setup +def _egamma_scale_smear_setup( + self, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: + # get and load the correction file + corr_file = self.get_correction_file(reqs["external_files"].files) + corr_set = load_correction_set(corr_file) + + # setup the correctors + get_set = lambda set_name, compound: (corr_set.compound if compound else corr_set)[set_name] + self.scale_corrector = get_set(self.cfg.scale_correction_set, self.cfg.scale_compound) + self.smear_syst_corrector = get_set(self.cfg.smear_syst_correction_set, self.cfg.smear_syst_compound) + + # use deterministic seeds for random smearing if requested + self.use_deterministic_seeds = self.deterministic_seed_index >= 0 + if self.use_deterministic_seeds: + idx = self.deterministic_seed_index + bit_generator = np.random.SFC64 + + def _deterministic_normal(loc, scale, seed, idx_offset=0): + return np.asarray([ + np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1 + idx_offset)[-1] + for _loc, _scale, _seed in zip(loc, scale, seed) + ]) + + # each systematic is to be evaluated with the same random number so use a fixed offset + self.deterministic_normal = functools.partial(_deterministic_normal, idx_offset=0) + + +electron_scale_smear = _egamma_scale_smear.derive( + "electron_scale_smear", + cls_dict={ + "collection_name": "Electron", + "get_scale_smear_config": lambda self: self.config_inst.x.ess, + "get_correction_file": lambda self, external_files: external_files.electron_ss, }, ) -eec = egamma_scale_corrector.derive( - "eec", cls_dict={ - "source_field": "Electron", - # calculation of superclusterEta for electrons requires the deltaEtaSC - "uses": {"Electron.deltaEtaSC"}, - "with_uncertainties": True, - "get_correction_file": (lambda self, external_files: external_files.electron_ss), - "get_scale_config": (lambda self: self.config_inst.x.eec), +photon_scale_smear = _egamma_scale_smear.derive( + "photon_scale_smear", + cls_dict={ + "collection_name": "Photon", + "get_scale_smear_config": lambda self: self.config_inst.x.gss, + "get_correction_file": lambda self, external_files: external_files.photon_ss, }, ) - - -@calibrator( - uses={eer, eec}, - produces={eer, eec}, - with_uncertainties=True, - get_correction_file=None, - get_scale_config=None, - get_resolution_config=None, - deterministic_seed_index=-1, -) -def electrons(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Calibrator for electrons. This calibrator runs the energy scale and resolution calibrators - for electrons. - - Careful! Always apply resolution before scale corrections for MC. - """ - if self.dataset_inst.is_mc: - events = self[eer](events, **kwargs) - - if self.with_uncertainties or self.dataset_inst.is_data: - events = self[eec](events, **kwargs) - - return events - - -@electrons.pre_init -def electrons_pre_init(self, **kwargs) -> None: - # forward argument to the producers - if eec not in self.deps_kwargs: - self.deps_kwargs[eec] = dict() - if eer not in self.deps_kwargs: - self.deps_kwargs[eer] = dict() - self.deps_kwargs[eec]["with_uncertainties"] = self.with_uncertainties - self.deps_kwargs[eer]["with_uncertainties"] = self.with_uncertainties - - self.deps_kwargs[eer]["deterministic_seed_index"] = self.deterministic_seed_index - if self.get_correction_file is not None: - self.deps_kwargs[eec]["get_correction_file"] = self.get_correction_file - self.deps_kwargs[eer]["get_correction_file"] = self.get_correction_file - - if self.get_resolution_config is not None: - self.deps_kwargs[eer]["get_resolution_config"] = self.get_resolution_config - if self.get_scale_config is not None: - self.deps_kwargs[eec]["get_scale_config"] = self.get_scale_config - - -electrons_nominal = photons.derive("electrons_nominal", cls_dict={"with_uncertainties": False}) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 709f253bc..09f381770 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -4,20 +4,24 @@ Jet energy corrections and jet resolution smearing. """ +from __future__ import annotations + import functools import law -from columnflow.types import Any from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random, propagate_met, sum_transverse from columnflow.production.util import attach_coffea_behavior -from columnflow.util import maybe_import, DotDict, load_correction_set +from columnflow.util import UNSET, maybe_import, DotDict, load_correction_set from columnflow.columnar_util import set_ak_column, layout_ak_array, optional_column as optional +from columnflow.types import TYPE_CHECKING, Any np = maybe_import("numpy") ak = maybe_import("awkward") -correctionlib = maybe_import("correctionlib") +if TYPE_CHECKING: + correctionlib = maybe_import("correctionlib") + logger = law.logger.get_logger(__name__) @@ -228,6 +232,7 @@ def get_jec_config_default(self: Calibrator) -> DotDict: @calibrator( uses={ + "run", optional("fixedGridRhoFastjetAll"), optional("Rho.fixedGridRhoFastjetAll"), attach_coffea_behavior, @@ -584,6 +589,8 @@ def jec_setup( "CorrelationGroupFlavor", "CorrelationGroupUncorrelated", ], + # whether the JECs for data should be era-specific + "data_per_era": True, }, }) @@ -600,25 +607,36 @@ def jec_setup( jec_cfg = self.get_jec_config() def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): - if is_data: + if is_data and jec.get("data_per_era", True): + if "data_per_era" not in jec: + logger.warning_once( + f"{id(self)}_depr_jec_config_data_per_era", + "config aux 'jec' does not contain key 'data_per_era'. " + "This may be due to an outdated config. Continuing under the assumption that " + "JEC keys for data are era-specific. " + "This assumption will be removed in future versions of " + "columnflow, so please adapt the config according to the " + "documentation to remove this warning and ensure future " + "compatibility of the code.", + ) jec_era = self.dataset_inst.get_aux("jec_era", None) # if no special JEC era is specified, infer based on 'era' if jec_era is None: - jec_era = "Run" + self.dataset_inst.get_aux("era") - elif jec_era == "": - return [ - f"{jec.campaign}_{jec.version}_DATA_{name}_{jec.jet_type}" - if is_data else - f"{jec.campaign}_{jec.version}_MC_{name}_{jec.jet_type}" - for name in names - ] - - return [ - f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{name}_{jec.jet_type}" - if is_data else - f"{jec.campaign}_{jec.version}_MC_{name}_{jec.jet_type}" - for name in names - ] + era = self.dataset_inst.get_aux("era", None) + if era is None: + raise ValueError( + "JEC data key is requested to be era dependent, but neither jec_era or era " + f"auxiliary is set for dataset {self.dataset_inst.name}.", + ) + jec_era = "Run" + era + + jme_key = f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{{name}}_{jec.jet_type}" + elif is_data: + jme_key = f"{jec.campaign}_{jec.version}_DATA_{{name}}_{jec.jet_type}" + else: # MC + jme_key = f"{jec.campaign}_{jec.version}_MC_{{name}}_{jec.jet_type}" + + return [jme_key.format(name=name) for name in names] jec_keys = make_jme_keys(jec_cfg.levels) jec_keys_subset_type1_met = make_jme_keys(jec_cfg.levels_for_type1_met) @@ -720,10 +738,12 @@ def get_jer_config_default(self: Calibrator) -> DotDict: get_jec_config=get_jec_config_default, # jec uncertainty sources to propagate jer to, defaults to config when empty jec_uncertainty_sources=None, + # MET uncertainty sources to propagate jer to, defaults to None when empty + met_uncertainty_sources=None, # whether gen jet matching should be performed relative to the nominal jet pt, or the jec varied values gen_jet_matching_nominal=False, # regions where stochastic smearing is applied - stochastic_smearing_mask=lambda self, jets: ak.ones_like(jets.pt, dtype=np.bool), + stochastic_smearing_mask=lambda self, jets: ak.ones_like(jets.pt, dtype=bool), ) def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ @@ -820,17 +840,17 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # extract nominal pt resolution inputs = [variable_map[inp.name] for inp in self.evaluators["jer"].inputs] - jerpt = {jer_nom: ak_evaluate(self.evaluators["jer"], *inputs)} + jer = {jer_nom: ak_evaluate(self.evaluators["jer"], *inputs)} # for simplifications below, use the same values for jer variations - jerpt[jer_up] = jerpt[jer_nom] - jerpt[jer_down] = jerpt[jer_nom] + jer[jer_up] = jer[jer_nom] + jer[jer_down] = jer[jer_nom] # extract pt resolutions evaluted for jec uncertainties for jec_var in self.jec_variations: _variable_map = variable_map | {"JetPt": events[jet_name][f"pt_{jec_var}"]} inputs = [_variable_map[inp.name] for inp in self.evaluators["jer"].inputs] - jerpt[jec_var] = ak_evaluate(self.evaluators["jer"], *inputs) + jer[jec_var] = ak_evaluate(self.evaluators["jer"], *inputs) # extract scale factors jersf = {} @@ -847,8 +867,8 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # array with all JER scale factor variations as an additional axis # (note: axis needs to be regular for broadcasting to work correctly) - jerpt = ak.concatenate( - [jerpt[v][..., None] for v in self.jer_variations + self.jec_variations], + jer = ak.concatenate( + [jer[v][..., None] for v in self.jer_variations + self.jec_variations], axis=-1, ) jersf = ak.concatenate( @@ -865,7 +885,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # compute smearing factors (stochastic method) smear_factors_stochastic = ak.where( self.stochastic_smearing_mask(events[jet_name]), - 1.0 + random_normal * jerpt * add_smear, + 1.0 + random_normal * jer * add_smear, 1.0, ) @@ -897,7 +917,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # test if matched gen jets are within 3 * resolution # (no check for Delta-R matching criterion; we assume this was done during nanoAOD production to get the genJetIdx) - is_matched_pt = np.abs(pt_relative_diff) < 3 * jerpt + is_matched_pt = np.abs(pt_relative_diff) < 3 * jer is_matched_pt = ak.fill_none(is_matched_pt, False) # masked values = no gen match # compute smearing factors (scaling method) @@ -928,7 +948,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: if self.propagate_met: jetsum_pt_before = {} jetsum_phi_before = {} - for postfix in self.postfixes: + for postfix in self.jet_postfixes: jetsum_pt_before[postfix], jetsum_phi_before[postfix] = sum_transverse( events[jet_name][f"pt{postfix}"], events[jet_name].phi, @@ -936,7 +956,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # apply the smearing # (note: this requires that postfixes and smear_factors have the same order, but this should be the case) - for i, postfix in enumerate(self.postfixes): + for i, postfix in enumerate(self.jet_postfixes): pt_name = f"pt{postfix}" m_name = f"mass{postfix}" events = set_ak_column_f32(events, f"{jet_name}.{pt_name}", events[jet_name][pt_name] * smear_factors[..., i]) @@ -952,22 +972,27 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: events = set_ak_column_f32(events, f"{met_name}.phi_unsmeared", events[met_name].phi) # propagate per variation - for postfix in self.postfixes: + for postfix in self.met_postfixes: # get pt and phi of all jets after correcting + + jet_postfix = postfix if hasattr(events[jet_name], f"pt{postfix}") else "" + + # jet variation exists, use it jetsum_pt_after, jetsum_phi_after = sum_transverse( - events[jet_name][f"pt{postfix}"], + events[jet_name][f"pt{jet_postfix}"], events[jet_name].phi, ) # propagate changes to MET met_pt, met_phi = propagate_met( - jetsum_pt_before[postfix], - jetsum_phi_before[postfix], + jetsum_pt_before[jet_postfix], + jetsum_phi_before[jet_postfix], jetsum_pt_after, jetsum_phi_after, events[met_name][f"pt{postfix}"], events[met_name][f"phi{postfix}"], ) + events = set_ak_column_f32(events, f"{met_name}.pt{postfix}", met_pt) events = set_ak_column_f32(events, f"{met_name}.phi{postfix}", met_phi) @@ -1001,7 +1026,7 @@ def jer_init(self: Calibrator, **kwargs) -> None: # prepare jer variations and postfixes self.jer_variations = ["nom", "up", "down"] - self.postfixes = ["", "_jer_up", "_jer_down"] + [f"_{jec_var}" for jec_var in self.jec_variations] + self.jet_postfixes = ["", "_jer_up", "_jer_down"] + [f"_{jec_var}" for jec_var in self.jec_variations] # register used jet columns self.uses.add(f"{self.jet_name}.{{pt,eta,phi,mass,{self.gen_jet_idx_column}}}") @@ -1021,11 +1046,23 @@ def jer_init(self: Calibrator, **kwargs) -> None: if jec_sources: self.uses |= met_jec_columns + met_sources = self.met_uncertainty_sources or [] + self.met_variations = sum(([f"{unc}_up", f"{unc}_down"] for unc in met_sources), []) + self.met_postfixes = ["", "_jer_up", "_jer_down"] + \ + [f"_{jec_var}" for jec_var in self.jec_variations] + \ + [f"_{met_source}" for met_source in self.met_variations] + + if met_sources: + self.uses |= {f"{self.met_name}.{{pt,phi}}_{met_source}" for met_source in self.met_variations} + # register produced MET columns self.produces.add(f"{self.met_name}.{{pt,phi}}{{,_jer_up,_jer_down,_unsmeared}}") if jec_sources: self.produces |= met_jec_columns + if met_sources: + self.produces |= {f"{self.met_name}.{{pt,phi}}_{met_source}" for met_source in self.met_variations} + @jer.requires def jer_requires( @@ -1122,8 +1159,6 @@ def deterministic_normal(loc, scale, seed): # @calibrator( - uses={jec, jer}, - produces={jec, jer}, # name of the jet collection to smear jet_name="Jet", # name of the associated gen jet collection (for JER smearing) @@ -1144,32 +1179,35 @@ def jets(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: :param events: awkward array containing events to process """ # apply jet energy corrections - events = self[jec](events, **kwargs) + events = self[self.jec_cls](events, **kwargs) # apply jer smearing on MC only if self.dataset_inst.is_mc: - events = self[jer](events, **kwargs) + events = self[self.jer_cls](events, **kwargs) return events -@jets.pre_init -def jets_pre_init(self: Calibrator, **kwargs) -> None: - # forward argument to the producers - self.deps_kwargs[jec]["jet_name"] = self.jet_name - self.deps_kwargs[jer]["jet_name"] = self.jet_name - self.deps_kwargs[jer]["gen_jet_name"] = self.gen_jet_name - if self.propagate_met is not None: - self.deps_kwargs[jec]["propagate_met"] = self.propagate_met - self.deps_kwargs[jer]["propagate_met"] = self.propagate_met - if self.get_jec_file is not None: - self.deps_kwargs[jec]["get_jec_file"] = self.get_jec_file - if self.get_jec_config is not None: - self.deps_kwargs[jec]["get_jec_config"] = self.get_jec_config - if self.get_jer_file is not None: - self.deps_kwargs[jer]["get_jer_file"] = self.get_jer_file - if self.get_jer_config is not None: - self.deps_kwargs[jer]["get_jer_config"] = self.get_jer_config +@jets.init +def jets_init(self: Calibrator, **kwargs) -> None: + # create custom jec and jer calibrators, using the jet name as the identifying value + def get_attrs(attrs): + cls_dict = {} + for attr in attrs: + if (value := getattr(self, attr, UNSET)) is not UNSET: + cls_dict[attr] = value + return cls_dict + + jec_attrs = ["jet_name", "gen_jet_name", "propagate_met", "get_jec_file", "get_jec_config"] + self.jec_cls = jec.derive(f"jec_{self.jet_name}", cls_dict=get_attrs(jec_attrs)) + self.uses.add(self.jec_cls) + self.produces.add(self.jec_cls) + + if self.dataset_inst.is_mc: + jer_attrs = ["jet_name", "gen_jet_name", "propagate_met", "get_jer_file", "get_jer_config"] + self.jer_cls = jer.derive(f"jer_{self.jet_name}", cls_dict=get_attrs(jer_attrs)) + self.uses.add(self.jer_cls) + self.produces.add(self.jer_cls) # explicit calibrators for standard jet collections diff --git a/columnflow/calibration/cms/met.py b/columnflow/calibration/cms/met.py index 229b4c9cb..9774c8006 100644 --- a/columnflow/calibration/cms/met.py +++ b/columnflow/calibration/cms/met.py @@ -4,9 +4,14 @@ MET corrections. """ +from __future__ import annotations + +import functools +from dataclasses import dataclass, field + import law -from columnflow.calibration import Calibrator, calibrator +from columnflow.calibration import Calibrator from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column from columnflow.types import Any @@ -15,46 +20,84 @@ ak = maybe_import("awkward") -@calibrator( - uses={"run", "PV.npvs"}, - # name of the MET collection to calibrate - met_name="MET", +# helpers +set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) + + +class _met_phi_base(Calibrator): + """" + Common base class for MET phi calibrators. + """ + + exposed = False + # function to determine the correction file - get_met_file=(lambda self, external_files: external_files.met_phi_corr), + get_met_file = lambda self, external_files: external_files.met_phi_corr + # function to determine met correction config - get_met_config=(lambda self: self.config_inst.x.met_phi_correction_set), -) -def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: + get_met_config = lambda self: self.config_inst.x.met_phi_correction + + def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +# +# Run 2 implementation +# + +@dataclass +class METPhiConfigRun2: + correction_set_template = r"{variable}_metphicorr_pfmet_{data_source}" + met_name: str = "MET" + keep_uncorrected: bool = False + + +@_met_phi_base.calibrator(exposed=True) +def met_phi_run2(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ - Performs the MET phi (type II) correction using the - :external+correctionlib:doc:`index` for events there the - uncorrected MET pt is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``). - Requires an external file in the config under ``met_phi_corr``: + Performs the MET phi (type II) correction using :external+correctionlib:doc:`index`. Events whose uncorrected MET pt + is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``) are skipped. Requires an external file + in the config under ``met_phi_corr``: .. code-block:: python cfg.x.external_files = DotDict.wrap({ - "met_phi_corr": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/met.json.gz", # noqa + "met_phi_corr": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-406118ec/POG/JME/2022_Summer22EE/met_xyCorrections_2022_2022EE.json.gz", # noqa }) - *get_met_file* can be adapted in a subclass in case it is stored differently in the external - files. + *get_met_file* can be adapted in a subclass in case it is stored differently in the external files. - The name of the correction set should be present as an auxiliary entry in the config: + The calibrator should be configured with an :py:class:`METPhiConfigRun2` as an auxiliary entry in the config named + ``met_phi_correction``. *get_met_config* can be adapted in a subclass in case it is stored differently in the + config. Exemplary config entry: .. code-block:: python - cfg.x.met_phi_correction_set = "{variable}_metphicorr_pfmet_{data_source}" + from columnflow.calibration.cms.met import METPhiConfigRun2 + cfg.x.met_phi_correction = METPhiConfigRun2( + met_name="MET", + correction_set_template="{variable}_metphicorr_pfmet_{data_source}", + keep_uncorrected=False, + ) - where "variable" and "data_source" are placeholders that are inserted in the - calibrator setup :py:meth:`~.met_phi.setup_func`. - *get_met_correction_set* can be adapted in a subclass in case it is stored - differently in the config. + "variable" and "data_source" are placeholders that will be replace with "pt" or "phi", and the data source of the + current dataset, respectively. - :param events: awkward array containing events to process + Resources: + - https://twiki.cern.ch/twiki/bin/view/CMS/MissingETRun2Corrections?rev=79#xy_Shift_Correction_MET_phi_modu """ - # get Met columns - met = events[self.met_name] + # get met columns + met_name = self.met_config.met_name + met = events[met_name] + + # store uncorrected values if requested + if self.met_config.keep_uncorrected: + events = set_ak_column_f32(events, f"{met_name}.pt_metphi_uncorrected", met.pt) + events = set_ak_column_f32(events, f"{met_name}.phi_metphi_uncorrected", met.phi) # copy the intial pt and phi values corr_pt = np.array(met.pt, dtype=np.float32) @@ -76,37 +119,27 @@ def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: corr_phi[mask] = self.met_phi_corrector.evaluate(*args) # save the corrected values - events = set_ak_column(events, f"{self.met_name}.pt", corr_pt, value_type=np.float32) - events = set_ak_column(events, f"{self.met_name}.phi", corr_phi, value_type=np.float32) + events = set_ak_column_f32(events, f"{met_name}.pt", corr_pt) + events = set_ak_column_f32(events, f"{met_name}.phi", corr_phi) return events -@met_phi.init -def met_phi_init(self: Calibrator, **kwargs) -> None: - """ - Initialize the :py:attr:`met_pt_corrector` and :py:attr:`met_phi_corrector` attributes. - """ - self.uses.add(f"{self.met_name}.{{pt,phi}}") - self.produces.add(f"{self.met_name}.{{pt,phi}}") +@met_phi_run2.init +def met_phi_run2_init(self: Calibrator, **kwargs) -> None: + self.met_config = self.get_met_config() + # set used columns + self.uses.update({"run", "PV.npvs", f"{self.met_config.met_name}.{{pt,phi}}"}) -@met_phi.requires -def met_phi_requires( - self: Calibrator, - task: law.Task, - reqs: dict[str, DotDict[str, Any]], - **kwargs, -) -> None: - if "external_files" in reqs: - return + # set produced columns + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}") + if self.met_config.keep_uncorrected: + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}_metphi_uncorrected") - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(task) - -@met_phi.setup -def met_phi_setup( +@met_phi_run2.setup +def met_phi_run2_setup( self: Calibrator, task: law.Task, reqs: dict[str, DotDict[str, Any]], @@ -114,21 +147,11 @@ def met_phi_setup( reader_targets: law.util.InsertableDict, **kwargs, ) -> None: - """ - Load the correct met files using the :py:func:`from_string` method of the - :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` - function and apply the corrections as needed. - - :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` - instance - :param inputs: Additional inputs, currently not used. - :param reader_targets: Additional targets, currently not used. - """ # create the pt and phi correctors met_file = self.get_met_file(reqs["external_files"].files) correction_set = load_correction_set(met_file) - name_tmpl = self.get_met_config() + name_tmpl = self.met_config.correction_set_template self.met_pt_corrector = correction_set[name_tmpl.format( variable="pt", data_source=self.dataset_inst.data_source, @@ -138,8 +161,150 @@ def met_phi_setup( data_source=self.dataset_inst.data_source, )] - # check versions - if self.met_pt_corrector.version not in (1,): - raise Exception(f"unsuppprted met pt corrector version {self.met_pt_corrector.version}") - if self.met_phi_corrector.version not in (1,): - raise Exception(f"unsuppprted met phi corrector version {self.met_phi_corrector.version}") + +# +# Run 3 implementation +# + +@dataclass +class METPhiConfig: + correction_set: str = "met_xy_corrections" + met_name: str = "PuppiMET" + met_type: str = "PuppiMET" + keep_uncorrected: bool = False + # variations (intrinsic method uncertainties) for pt and phi + pt_phi_variations: dict[str, str] | None = field(default_factory=lambda: { + "stat_xdn": "metphi_statx_down", + "stat_xup": "metphi_statx_up", + "stat_ydn": "metphi_staty_down", + "stat_yup": "metphi_staty_up", + }) + # other variations (external uncertainties) + variations: dict[str, str] | None = field(default_factory=lambda: { + "pu_dn": "minbias_xs_down", + "pu_up": "minbias_xs_up", + }) + + +@_met_phi_base.calibrator(exposed=True) +def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: + """ + Performs the MET phi (type II) correction using :external+correctionlib:doc:`index`. Events whose uncorrected MET pt + is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``) are skipped. Requires an external file + in the config under ``met_phi_corr``: + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "met_phi_corr": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-406118ec/POG/JME/2022_Summer22EE/met_xyCorrections_2022_2022EE.json.gz", # noqa + }) + + *get_met_file* can be adapted in a subclass in case it is stored differently in the external files. + + The calibrator should be configured with an :py:class:`METPhiConfig` as an auxiliary entry in the config named + ``met_phi_correction``. *get_met_config* can be adapted in a subclass in case it is stored differently in the + config. Exemplary config entry: + + .. code-block:: python + + from columnflow.calibration.cms.met import METPhiConfig + cfg.x.met_phi_correction = METPhiConfig( + correction_set="met_xy_corrections", + met_name="PuppiMET", + met_type="PuppiMET", + keep_uncorrected=False, + # mappings of method variation to column (pt/phi) postfixes + pt_phi_variations={ + "stat_xdn": "metphi_statx_down", + "stat_xup": "metphi_statx_up", + "stat_ydn": "metphi_staty_down", + "stat_yup": "metphi_staty_up", + }, + variations={ + "pu_dn": "minbias_xs_down", + "pu_up": "minbias_xs_up", + }, + ) + """ + # get met + met_name = self.met_config.met_name + met = events[met_name] + + # store uncorrected values if requested + if self.met_config.keep_uncorrected: + events = set_ak_column_f32(events, f"{met_name}.pt_metphi_uncorrected", met.pt) + events = set_ak_column_f32(events, f"{met_name}.phi_metphi_uncorrected", met.phi) + + # correct only events where MET pt is below the expected beam energy + mask = met.pt < (0.5 * self.config_inst.campaign.ecm * 1000) # convert TeV to GeV + + # gather variables + variable_map = { + "met_type": self.met_config.met_type, + "epoch": f"{self.config_inst.campaign.x.year}{self.config_inst.campaign.x.postfix}", + "dtmc": "DATA" if self.dataset_inst.is_data else "MC", + "met_pt": ak.values_astype(met.pt[mask], np.float32), + "met_phi": ak.values_astype(met.phi[mask], np.float32), + "npvGood": ak.values_astype(events.PV.npvsGood[mask], np.float32), + } + + # evaluate pt and phi separately + for var in ["pt", "phi"]: + # remember initial values + vals_orig = np.array(met[var], dtype=np.float32) + # loop over general variations, then pt/phi variations + # (needed since the JME correction file is inconsistent in how intrinsic and external variations are treated) + general_vars = {"nom": ""} + if self.dataset_inst.is_mc: + general_vars.update(self.met_config.variations or {}) + for variation, postfix in general_vars.items(): + pt_phi_vars = {"": ""} + if variation == "nom" and self.dataset_inst.is_mc: + pt_phi_vars.update(self.met_config.pt_phi_variations or {}) + for pt_phi_variation, pt_phi_postfix in pt_phi_vars.items(): + _postfix = postfix or pt_phi_postfix + out_var = f"{var}{_postfix and '_' + _postfix}" + # prepare evaluator inputs + _variable_map = { + **variable_map, + "pt_phi": f"{var}{pt_phi_variation and '_' + pt_phi_variation}", + "variation": variation, + } + inputs = [_variable_map[inp.name] for inp in self.met_corrector.inputs] + # evaluate and create new column + corr_vals = np.array(vals_orig) + corr_vals[mask] = self.met_corrector(*inputs) + events = set_ak_column_f32(events, f"{met_name}.{out_var}", corr_vals) + + return events + + +@met_phi.init +def met_phi_init(self: Calibrator, **kwargs) -> None: + self.met_config = self.get_met_config() + + # set used columns + self.uses.update({"PV.npvsGood", f"{self.met_config.met_name}.{{pt,phi}}"}) + + # set produced columns + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}") + if self.dataset_inst.is_mc: + for postfix in {**(self.met_config.pt_phi_variations or {}), **(self.met_config.variations or {})}.values(): + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}_{postfix}") + if self.met_config.keep_uncorrected: + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}_metphi_uncorrected") + + +@met_phi.setup +def met_phi_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: + # load the corrector + met_file = self.get_met_file(reqs["external_files"].files) + correction_set = load_correction_set(met_file) + self.met_corrector = correction_set[self.met_config.correction_set] diff --git a/columnflow/calibration/cms/muon.py b/columnflow/calibration/cms/muon.py new file mode 100644 index 000000000..d096c94f1 --- /dev/null +++ b/columnflow/calibration/cms/muon.py @@ -0,0 +1,222 @@ +# coding: utf-8 + +""" +Muon calibration methods. +""" + +from __future__ import annotations + +import functools +import dataclasses +import inspect + +import law + +from columnflow.calibration import Calibrator, calibrator +from columnflow.columnar_util import TAFConfig, set_ak_column, IF_MC +from columnflow.util import maybe_import, load_correction_set, import_file, DotDict +from columnflow.types import Any + +ak = maybe_import("awkward") +np = maybe_import("numpy") + + +logger = law.logger.get_logger(__name__) + +# helper +set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) + + +@dataclasses.dataclass +class MuonSRConfig(TAFConfig): + """ + Container class to configure muon momentum scale and resolution corrections. Example: + + .. code-block:: python + + cfg.x.muon_sr = MuonSRConfig( + systs=["scale_up", "scale_down", "res_up", "res_down"], + ) + """ + systs: list[str] = dataclasses.field(default_factory=lambda: ["scale_up", "scale_down", "res_up", "res_down"]) + + +@calibrator( + uses={ + "Muon.{pt,eta,phi,mass,charge}", + IF_MC("event", "luminosityBlock", "Muon.nTrackerLayers"), + }, + # uncertainty variations added in init + produces={"Muon.pt"}, + # whether to produce also uncertainties + with_uncertainties=True, + # functions to determine the correction and tool files + get_muon_sr_file=(lambda self, external_files: external_files.muon_sr), + get_muon_sr_tool_file=(lambda self, external_files: external_files.muon_sr_tools), + # function to determine the muon config + get_muon_sr_config=(lambda self: self.config_inst.x.muon_sr), + # if the original pt columns should be stored as "pt_sr_uncorrected" + store_original=False, +) +def muon_sr( + self: Calibrator, + events: ak.Array, + **kwargs, +) -> ak.Array: + """ + Calibrator for muon scale and resolution smearing. Requires two external file in the config under the ``muon_sr`` + and ``muon_sr_tools`` keys, pointing to the json correction file and the "MuonScaRe" tools script, respectively, + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "muon_sr": "/cvmfs/cms-griddata.cern.ch/cat/metadata/MUO/Run3-22CDSep23-Summer22-NanoAODv12/2025-08-14/muon_scalesmearing.json.gz", # noqa + "muon_sr_tools": "/path/to/MuonScaRe.py", + }) + + and a :py:class:`MuonSRConfig` configuration object in the auxiliary field ``muon_sr``, + + .. code-block:: python + + from columnflow.calibration.cms.muon import MuonSRConfig + cfg.x.muon_sr = MuonSRConfig( + systs=["scale_up", "scale_down", "res_up", "res_down"], + ) + + *get_muon_sr_file*, *get_muon_sr_tool_file* and *get_muon_sr_config* can be adapted in a subclass in case they are + stored differently in the config. + + Resources: + + - https://gitlab.cern.ch/cms-muonPOG/muonscarekit + - https://cms-analysis-corrections.docs.cern.ch/corrections_era/Run3-22CDSep23-Summer22-NanoAODv12/MUO/latest/#muon_scalesmearingjsongz # noqa + """ + # store the original pt column if requested + if self.store_original: + events = set_ak_column(events, "Muon.pt_sr_uncorrected", events.Muon.pt) + + # apply scale correction to data + if self.dataset_inst.is_data: + pt_scale_corr = self.muon_sr_tools.pt_scale( + 1, + events.Muon.pt, + events.Muon.eta, + events.Muon.phi, + events.Muon.charge, + self.muon_correction_set, + nested=True, + ) + events = set_ak_column_f32(events, "Muon.pt", pt_scale_corr) + + # apply scale and resolution correction to mc + if self.dataset_inst.is_mc: + pt_scale_corr = self.muon_sr_tools.pt_scale( + 0, + events.Muon.pt, + events.Muon.eta, + events.Muon.phi, + events.Muon.charge, + self.muon_correction_set, + nested=True, + ) + pt_scale_res_corr = self.muon_sr_tools.pt_resol( + pt_scale_corr, + events.Muon.eta, + events.Muon.phi, + events.Muon.nTrackerLayers, + events.event, + events.luminosityBlock, + self.muon_correction_set, + rnd_gen="np", + nested=True, + ) + events = set_ak_column_f32(events, "Muon.pt", pt_scale_res_corr) + + # apply scale and resolution uncertainties to mc + if self.with_uncertainties and self.muon_cfg.systs: + for syst in self.muon_cfg.systs: + # the sr tools use up/dn naming + sr_direction = {"up": "up", "down": "dn"}[syst.rsplit("_", 1)[-1]] + + # exact behavior depends on syst itself + if syst in {"scale_up", "scale_down"}: + pt_syst = self.muon_sr_tools.pt_scale_var( + pt_scale_res_corr, + events.Muon.eta, + events.Muon.phi, + events.Muon.charge, + sr_direction, + self.muon_correction_set, + nested=True, + ) + events = set_ak_column_f32(events, f"Muon.pt_{syst}", pt_syst) + + elif syst in {"res_up", "res_down"}: + pt_syst = self.muon_sr_tools.pt_resol_var( + pt_scale_corr, + pt_scale_res_corr, + events.Muon.eta, + sr_direction, + self.muon_correction_set, + nested=True, + ) + events = set_ak_column_f32(events, f"Muon.pt_{syst}", pt_syst) + + else: + logger.error(f"{self.cls_name} calibrator received unknown systematic '{syst}', skipping") + + return events + + +@muon_sr.init +def muon_sr_init(self: Calibrator, **kwargs) -> None: + self.muon_cfg = self.get_muon_sr_config() + + # add produced columns with unceratinties if requested + if self.dataset_inst.is_mc and self.with_uncertainties and self.muon_cfg.systs: + for syst in self.muon_cfg.systs: + self.produces.add(f"Muon.pt_{syst}") + + # original column + if self.store_original: + self.produces.add("Muon.pt_sr_uncorrected") + + +@muon_sr.requires +def muon_sr_requires( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +@muon_sr.setup +def muon_sr_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: + # load the correction set + muon_sr_file = self.get_muon_sr_file(reqs["external_files"].files) + self.muon_correction_set = load_correction_set(muon_sr_file) + + # also load the tools as an external package + muon_sr_tool_file = self.get_muon_sr_tool_file(reqs["external_files"].files) + self.muon_sr_tools = import_file(muon_sr_tool_file.abspath) + + # silence printing of the filter_boundaries function + spec = inspect.getfullargspec(self.muon_sr_tools.filter_boundaries) + if "silent" in spec.args or "silent" in spec.kwonlyargs: + self.muon_sr_tools.filter_boundaries = functools.partial(self.muon_sr_tools.filter_boundaries, silent=True) + + +muon_sr_nominal = muon_sr.derive("muon_sr_nominal", cls_dict={"with_uncertainties": False}) diff --git a/columnflow/calibration/cms/tau.py b/columnflow/calibration/cms/tau.py index b9ada41ef..897ebea4f 100644 --- a/columnflow/calibration/cms/tau.py +++ b/columnflow/calibration/cms/tau.py @@ -8,14 +8,14 @@ import functools import itertools -from dataclasses import dataclass, field +import dataclasses import law from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import propagate_met from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, ak_copy +from columnflow.columnar_util import TAFConfig, set_ak_column, flat_np_view, ak_copy from columnflow.types import Any ak = maybe_import("awkward") @@ -26,11 +26,11 @@ set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) -@dataclass -class TECConfig: +@dataclasses.dataclass +class TECConfig(TAFConfig): tagger: str correction_set: str = "tau_energy_scale" - corrector_kwargs: dict[str, Any] = field(default_factory=dict) + corrector_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) @classmethod def new(cls, obj: TECConfig | tuple[str] | dict[str, str]) -> TECConfig: @@ -44,14 +44,8 @@ def new(cls, obj: TECConfig | tuple[str] | dict[str, str]) -> TECConfig: @calibrator( - uses={ - # nano columns - "nTau", "Tau.pt", "Tau.eta", "Tau.phi", "Tau.mass", "Tau.charge", "Tau.genPartFlav", - "Tau.decayMode", - }, - produces={ - "Tau.pt", "Tau.mass", - }, + uses={"Tau.{pt,eta,phi,mass,charge,genPartFlav,decayMode}"}, + produces={"Tau.{pt,mass}"}, # whether to produce also uncertainties with_uncertainties=True, # toggle for propagation to MET @@ -94,6 +88,11 @@ def tec( *get_tau_file* and *get_tec_config* can be adapted in a subclass in case they are stored differently in the config. + .. note:: + + In case you also perform the propagation from jet energy calibrations to MET, please check if the propagation of + tau energy calibrations to MET is required in your analysis! + Resources: https://twiki.cern.ch/twiki/bin/view/CMS/TauIDRecommendationForRun2?rev=113 https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/849c6a6efef907f4033715d52290d1a661b7e8f9/POG/TAU @@ -137,17 +136,12 @@ def tec( scales_down = np.ones_like(dm_mask, dtype=np.float32) scales_down[dm_mask] = self.tec_corrector(*args, "down") - # custom adjustment 1: reset where the matching value is unhandled - # custom adjustment 2: reset electrons faking taus where the pt is too small - mask1 = (match < 1) | (match > 5) - mask2 = ((match == 1) | (match == 3)) & (pt <= 20.0) - - # apply reset masks - mask = mask1 | mask2 - scales_nom[mask] = 1.0 + # custom adjustment: reset where the matching value is unhandled + reset_mask = (match < 1) | (match > 5) + scales_nom[reset_mask] = 1.0 if self.with_uncertainties: - scales_up[mask] = 1.0 - scales_down[mask] = 1.0 + scales_up[reset_mask] = 1.0 + scales_down[reset_mask] = 1.0 # create varied collections per decay mode if self.with_uncertainties: @@ -258,7 +252,7 @@ def tec_setup( self.tec_corrector = load_correction_set(tau_file)[self.tec_cfg.correction_set] # check versions - assert self.tec_corrector.version in [0, 1] + assert self.tec_corrector.version in {0, 1, 2} tec_nominal = tec.derive("tec_nominal", cls_dict={"with_uncertainties": False}) diff --git a/columnflow/calibration/cmsGhent/lepton_mva.py b/columnflow/calibration/cmsGhent/lepton_mva.py index 95c4199ad..739f22d81 100644 --- a/columnflow/calibration/cmsGhent/lepton_mva.py +++ b/columnflow/calibration/cmsGhent/lepton_mva.py @@ -13,8 +13,6 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents.methods.nanoaod") @producer( diff --git a/columnflow/cms_util.py b/columnflow/cms_util.py new file mode 100644 index 000000000..2e283009f --- /dev/null +++ b/columnflow/cms_util.py @@ -0,0 +1,201 @@ +# coding: utf-8 + +""" +Collection of CMS specific helpers and utilities. +""" + +from __future__ import annotations + +__all__ = [] + +import os +import re +import copy +import pathlib +import dataclasses + +from columnflow.types import ClassVar, Generator + + +#: Default root path to CAT metadata. +cat_metadata_root = "/cvmfs/cms-griddata.cern.ch/cat/metadata" + + +@dataclasses.dataclass +class CATSnapshot: + """ + Dataclass to wrap YYYY-MM-DD stype timestamps of CAT metadata per POG stored in + "/cvmfs/cms-griddata.cern.ch/cat/metadata". No format parsing or validation is done, leaving responsibility to the + user. + """ + btv: str = "" + dc: str = "" + egm: str = "" + jme: str = "" + lum: str = "" + muo: str = "" + tau: str = "" + + def items(self) -> Generator[tuple[str, str], None, None]: + return ((k, getattr(self, k)) for k in self.__dataclass_fields__.keys()) + + +@dataclasses.dataclass +class CATInfo: + """ + Dataclass to describe and wrap information about a specific CAT-defined metadata era. + + .. code-block:: python + + CATInfo( + run=3, + era="22CDSep23-Summer22", + vnano=12, + snapshot=CATSnapshot( + btv="2025-08-20", + dc="2025-07-25", + egm="2025-04-15", + jme="2025-09-23", + lum="2024-01-31", + muo="2025-08-14", + tau="2025-10-01", + ), + # pog-specific settings + pog_directories={"dc": "Collisions22"}, + ) + """ + run: int + era: str + vnano: int + snapshot: CATSnapshot + # optional POG-specific overrides + pog_eras: dict[str, str] = dataclasses.field(default_factory=dict) + pog_directories: dict[str, str] = dataclasses.field(default_factory=dict) + + metadata_root: ClassVar[str] = cat_metadata_root + + def get_era_directory(self, pog: str = "") -> str: + """ + Returns the era directory name for a given *pog*. + + :param pog: The POG to get the era for. Leave empty if the common POG-unspecific directory name should be used. + """ + pog = pog.lower() + + # use specific directory if defined + if pog in self.pog_directories: + return self.pog_directories[pog] + + # build common directory name from run, era, and vnano + era = self.pog_eras.get(pog.lower(), self.era) if pog else self.era + return f"Run{self.run}-{era}-NanoAODv{self.vnano}" + + def get_file(self, pog: str, *paths: str | pathlib.Path) -> str: + """ + Returns the full path to a specific file or directory defined by *paths* in the CAT metadata structure for a + given *pog*. + """ + return os.path.join( + self.metadata_root, + pog.upper(), + self.get_era_directory(pog), + getattr(self.snapshot, pog.lower()), + *(str(p).strip("/") for p in paths), + ) + + +@dataclasses.dataclass +class CMSDatasetInfo: + """ + Container to wrap a CMS dataset given by its *key* with access to its components. The key should be in the format + ``//--/AOD``. + + .. code-block:: python + + d = CMSDatasetInfo.from_key("/TTtoLNu2Q_TuneCP5_13p6TeV_powheg-pythia8/RunIII2024Summer24MiniAODv6-150X_mcRun3_2024_realistic_v2-v2/MINIAODSIM") # noqa + print(d.name) # TTtoLNu2Q_TuneCP5_13p6TeV_powheg-pythia8 + print(d.campaign) # RunIII2024Summer24MiniAODv6 + print(d.campaign_version) # 150X_mcRun3_2024_realistic_v2 + print(d.dataset_version) # v2 + print(d.tier) # mini (lower case) + print(d.mc) # True + print(d.data) # False + print(d.kind) # mc + """ + name: str + campaign: str + campaign_version: str + dataset_version: str # this is usually the GT for MC + tier: str + mc: bool + + @classmethod + def from_key(cls, key: str) -> CMSDatasetInfo: + """ + Takes a dataset *key*, splits it into its components, and returns a new :py:class:`CMSDatasetInfo` instance. + + :param key: The dataset key: + :return: A new instance of :py:class:`CMSDatasetInfo`. + """ + # split + if not (m := re.match(r"^/([^/]+)/([^/-]+)-([^/-]+)-([^/-]+)/([^/-]+)AOD(SIM)?$", key)): + raise ValueError(f"invalid dataset key '{key}'") + + # create instance + return cls( + name=m.group(1), + campaign=m.group(2), + campaign_version=m.group(3), + dataset_version=m.group(4), + tier=m.group(5).lower(), + mc=m.group(6) == "SIM", + ) + + @property + def key(self) -> str: + # transform back to key format + return ( + f"/{self.name}" + f"/{self.campaign}-{self.campaign_version}-{self.dataset_version}" + f"/{self.tier.upper()}AOD{'SIM' if self.mc else ''}" + ) + + @property + def data(self) -> bool: + return not bool(self.mc) + + @data.setter + def data(self, value: bool) -> None: + self.mc = not bool(value) + + @property + def kind(self) -> str: + return "mc" if self.mc else "data" + + @kind.setter + def kind(self, value: str) -> None: + if (_value := str(value).lower()) not in {"mc", "data"}: + raise ValueError(f"invalid kind '{value}', expected 'mc' or 'data'") + self.mc = _value == "mc" + + @property + def store_path(self) -> str: + return ( + "/store" + f"/{self.kind}" + f"/{self.campaign}" + f"/{self.name}" + f"/{self.tier.upper()}AOD{'SIM' if self.mc else ''}" + f"/{self.campaign_version}-{self.dataset_version}" + ) + + def copy(self, **kwargs) -> CMSDatasetInfo: + """ + Creates a copy of this instance, allowing to override specific attributes via *kwargs*. + + :param kwargs: Attributes to override in the copy. + :return: A new instance of :py:class:`CMSDatasetInfo`. + """ + attrs = copy.deepcopy(self.__dict__) + attrs.update(kwargs) + return self.__class__(**attrs) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 884622cf1..00878e579 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -16,6 +16,7 @@ import enum import inspect import threading +import dataclasses import multiprocessing import multiprocessing.pool from functools import partial @@ -24,21 +25,15 @@ import law import order as od -from columnflow.types import Sequence, Callable, Any, T, Generator +from columnflow.types import Sequence, Callable, Any, T, Generator, Hashable from columnflow.util import ( - UNSET, maybe_import, classproperty, DotDict, DerivableMeta, Derivable, pattern_matcher, + UNSET, maybe_import, classproperty, DotDict, DerivableMeta, CachedDerivableMeta, Derivable, pattern_matcher, get_source_code, real_path, freeze, get_docs_url, ) np = maybe_import("numpy") ak = maybe_import("awkward") -dak = maybe_import("dask_awkward") uproot = maybe_import("uproot") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents") -maybe_import("coffea.nanoevents.methods.base") -maybe_import("coffea.nanoevents.methods.nanoaod") -pq = maybe_import("pyarrow.parquet") # loggers @@ -1237,6 +1232,9 @@ def attach_behavior( (*skip_fields*) can contain names or name patterns of fields that are kept (filtered). *keep_fields* has priority, i.e., when it is set, *skip_fields* is not considered. """ + import coffea.nanoevents + import coffea.nanoevents.methods.nanoaod + if behavior is None: behavior = getattr(ak_array, "behavior", None) or coffea.nanoevents.methods.nanoaod.behavior if behavior is None: @@ -1635,11 +1633,6 @@ def my_other_func_init(self): """ # class-level attributes as defaults - call_func = None - pre_init_func = None - init_func = None - skip_func = None - uses = set() produces = set() check_used_columns = True @@ -1731,15 +1724,20 @@ def PRODUCES(cls) -> IOFlagged: return cls.IOFlagged(cls, cls.IOFlag.PRODUCES) @classmethod - def call(cls, func: Callable[[Any, ...], Any]) -> None: + def pre_init(cls, func: Callable[[], None]) -> None: """ - Decorator to wrap a function *func* that should be registered as :py:meth:`call_func` - which defines the main callable for processing chunks of data. The function should accept - arbitrary arguments and can return arbitrary objects. + Decorator to wrap a function *func* that should be registered as :py:meth:`pre_init_func` + which is invoked prior to any dependency creation. The function should not accept arguments. The decorator does not return the wrapped function. """ - cls.call_func = func + cls.pre_init_func = func + + def pre_init_func(self) -> None: + """ + Default pre-init function. + """ + return @classmethod def init(cls, func: Callable[[], None]) -> None: @@ -1752,16 +1750,11 @@ def init(cls, func: Callable[[], None]) -> None: """ cls.init_func = func - @classmethod - def pre_init(cls, func: Callable[[], None]) -> None: + def init_func(self) -> None: """ - Decorator to wrap a function *func* that should be registered as :py:meth:`pre_init_func` - which is invoked prior to any dependency creation. The function should not accept positional - arguments. - - The decorator does not return the wrapped function. + Default init function. """ - cls.pre_init_func = func + return @classmethod def skip(cls, func: Callable[[], bool]) -> None: @@ -1774,12 +1767,35 @@ def skip(cls, func: Callable[[], bool]) -> None: """ cls.skip_func = func + def skip_func(self) -> None: + """ + Default skip function. + """ + return + + @classmethod + def call(cls, func: Callable[[Any, ...], Any]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`call_func` + which defines the main callable for processing chunks of data. The function should accept + arbitrary arguments and can return arbitrary objects. + + The decorator does not return the wrapped function. + """ + cls.call_func = func + + def call_func(self, *args, **kwargs) -> Any: + """ + Default call function. + """ + return + def __init__( self, - call_func: Callable | None = law.no_value, - pre_init_func: Callable | None = law.no_value, - init_func: Callable | None = law.no_value, - skip_func: Callable | None = law.no_value, + pre_init_func: Callable | law.NoValue | None = law.no_value, + init_func: Callable | law.NoValue | None = law.no_value, + skip_func: Callable | law.NoValue | None = law.no_value, + call_func: Callable | law.NoValue | None = law.no_value, check_used_columns: bool | None = None, check_produced_columns: bool | None = None, instance_cache: dict | None = None, @@ -1790,14 +1806,14 @@ def __init__( super().__init__() # add class-level attributes as defaults for unset arguments (no_value) - if call_func == law.no_value: - call_func = self.__class__.call_func if pre_init_func == law.no_value: pre_init_func = self.__class__.pre_init_func if init_func == law.no_value: init_func = self.__class__.init_func if skip_func == law.no_value: skip_func = self.__class__.skip_func + if call_func == law.no_value: + call_func = self.__class__.call_func if check_used_columns is not None: self.check_used_columns = check_used_columns if check_produced_columns is not None: @@ -1806,14 +1822,14 @@ def __init__( self.log_runtime = log_runtime # when a custom funcs are passed, bind them to this instance - if call_func: - self.call_func = call_func.__get__(self, self.__class__) if pre_init_func: self.pre_init_func = pre_init_func.__get__(self, self.__class__) if init_func: self.init_func = init_func.__get__(self, self.__class__) if skip_func: self.skip_func = skip_func.__get__(self, self.__class__) + if call_func: + self.call_func = call_func.__get__(self, self.__class__) # create instance-level sets of dependent ArrayFunction classes, # optionally with priority to sets passed in keyword arguments @@ -2202,6 +2218,30 @@ def __call__(self, *args, **kwargs) -> Any: deferred_column = ArrayFunction.DeferredColumn.deferred_column +@deferred_column +def IF_DATA(self: ArrayFunction.DeferredColumn, func: ArrayFunction) -> Any | set[Any]: + return self.get() if func.dataset_inst.is_data else None + + +@deferred_column +def IF_MC(self: ArrayFunction.DeferredColumn, func: ArrayFunction) -> Any | set[Any]: + return self.get() if func.dataset_inst.is_mc else None + + +def IF_DATASET_HAS_TAG(*args, negate: bool = False, **kwargs) -> ArrayFunction.DeferredColumn: + @deferred_column + def deferred( + self: ArrayFunction.DeferredColumn, + func: ArrayFunction, + ) -> Any | set[Any]: + return self.get() if func.dataset_inst.has_tag(*args, **kwargs) is not negate else None + + return deferred + + +IF_DATASET_NOT_HAS_TAG = partial(IF_DATASET_HAS_TAG, negate=True) + + def tagged_column( tag: str | Sequence[str] | set[str], *routes: Route | Any | set[Route | Any], @@ -2263,26 +2303,10 @@ def skip_column( return tagged_column("skip", *routes) -class TaskArrayFunctionMeta(DerivableMeta): - - def __new__(metacls, cls_name: str, bases: tuple, cls_dict: dict) -> TaskArrayFunctionMeta: - # add an instance cache if not disabled - cls_dict.setdefault("cache_instances", True) - cls_dict["_instances"] = {} if cls_dict["cache_instances"] else None - - return super().__new__(metacls, cls_name, bases, cls_dict) - - def __call__(cls, *args, **kwargs) -> TaskArrayFunction: - # when not caching instances, return right away - if not cls.cache_instances: - return super().__call__(*args, **kwargs) +class TaskArrayFunctionMeta(CachedDerivableMeta): - # build the cache key from the inst_dict in kwargs - key = freeze((cls, kwargs.get("inst_dict", {}))) - if key not in cls._instances: - cls._instances[key] = super().__call__(*args, **kwargs) - - return cls._instances[key] + def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: + return freeze((cls, kwargs.get("inst_dict", {}))) class TaskArrayFunction(ArrayFunction, metaclass=TaskArrayFunctionMeta): @@ -2421,10 +2445,6 @@ class the normal way, or use a decorator to wrap the main callable first and by """ # class-level attributes as defaults - post_init_func = None - requires_func = None - setup_func = None - teardown_func = None sandbox = None call_force = None max_chunk_size = None @@ -2478,6 +2498,12 @@ def post_init(cls, func: Callable[[dict], None]) -> None: """ cls.post_init_func = func + def post_init_func(self, task: law.Task) -> None: + """ + Default post-init function. + """ + return + @classmethod def requires(cls, func: Callable[[dict], None]) -> None: """ @@ -2500,6 +2526,12 @@ def requires(cls, func: Callable[[dict], None]) -> None: """ cls.requires_func = func + def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]]) -> None: + """ + Default requires function. + """ + return + @classmethod def setup(cls, func: Callable[[dict], None]) -> None: """ @@ -2518,6 +2550,18 @@ def setup(cls, func: Callable[[dict], None]) -> None: """ cls.setup_func = func + def setup_func( + self, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + ) -> None: + """ + Default setup function. + """ + return + @classmethod def teardown(cls, func: Callable[[dict], None]) -> None: """ @@ -2531,6 +2575,12 @@ def teardown(cls, func: Callable[[dict], None]) -> None: """ cls.teardown_func = func + def teardown_func(self, task: law.Task) -> None: + """ + Default teardown function. + """ + return + def __init__( self, *args, @@ -2669,7 +2719,7 @@ def _get_all_shifts(self, _cache: set | None = None) -> set[str]: if isinstance(shift, od.Shift): shifts.add(shift.name) elif isinstance(shift, str): - shifts.add(shift) + shifts.update(law.util.brace_expand(shift)) _cache.add(self) # add shifts of all dependent objects @@ -2901,6 +2951,19 @@ def get_min_chunk_size(self) -> int | None: return min((s for s in sizes if isinstance(s, int)), default=None) +@dataclasses.dataclass +class TAFConfig: + + def copy(self, **kwargs) -> TAFConfig: + """ + Returns a copy of this TAFConfig instance, updated by any given *kwargs*. + + :param kwargs: Attributes to update in the copied instance. + :return: The copied and updated TAFConfig instance. + """ + return self.__class__(self.__dict__ | kwargs) + + class NoThreadPool(object): """ Dummy implementation that mimics parts of the usual thread pool interface but instead of @@ -3035,7 +3098,11 @@ def __init__( # case nested nodes separated by "*.list.element.*" (rather than "*.list.item.*") are found # (to be removed in the future) if open_options.get("split_row_groups"): - nodes = ak.ak_from_parquet.metadata(path)[0] + try: + nodes = ak.ak_from_parquet.metadata(path)[0] + except: + logger.error(f"unable to read {path}") + raise cre = re.compile(r"^.+\.list\.element(|\..+)$") if any(map(cre.match, nodes)): logger.warning( @@ -3047,6 +3114,7 @@ def __init__( open_options["split_row_groups"] = False # open the file + import dask_awkward as dak self.dak_array = dak.from_parquet(path, **open_options) self.path = path @@ -3198,6 +3266,145 @@ def _materialize_via_partitions( return arr +class ChunkedParquetReader(object): + """ + Class that wraps a parquet file containing an awkward array and handles chunked reading via splitting and merging of + row groups. To allow memory efficient caching in case of overlaps between groups on disk and chunks to be read + (possibly with different sizes) this process is implemented as a one-time-only read operation. Hence, in situations + where particular chunks need to be read more than once, another instance of this class should be used. + """ + + def __init__(self, path: str, open_options: dict | None = None) -> None: + super().__init__() + if not open_options: + open_options = {} + + # store attributes + self.path = path + self.open_options = open_options.copy() + + # open and store meta data with updated open options + # (when closing the reader, this attribute is set to None) + meta_options = open_options.copy() + meta_options.pop("row_groups", None) + meta_options.pop("ignore_metadata", None) + meta_options.pop("columns", None) + try: + self.metadata = ak.metadata_from_parquet(path, **meta_options) + except: + logger.error(f"unable to read {path}") + raise + + # extract row group sizes for chunked reading + if "col_counts" not in self.metadata: + raise Exception( + f"{self.__class__.__name__}: entry 'col_counts' is missing in meta data of file '{path}', but it is " + "strictly required for chunked reading; please debug", + ) + self.group_sizes = list(self.metadata["col_counts"]) + + # compute cumulative division boundaries + divs = [0] + for s in self.group_sizes: + divs.append(divs[-1] + s) + self.group_divisions = tuple(divs) + + # fixed mapping of chunk indices to group indices, created in materialize + self.chunk_to_groups = {} + + # mapping of group indices to cache information (chunks still to be handled and a cached array) that changes + # during the read process in materialize + self.group_cache = {g: DotDict(chunks=set(), array=None) for g in range(len(self.group_sizes))} + + # locks to protect against RCs during read operations by different threads + self.chunk_to_groups_lock = threading.Lock() + self.group_locks = {g: threading.Lock() for g in self.group_cache} + + def __del__(self) -> None: + self.close() + + def __len__(self) -> int: + return self.group_divisions[-1] + + @property + def closed(self) -> bool: + return self.metadata is None + + def close(self) -> None: + self.metadata = None + if getattr(self, "group_cache", None): + for g in self.group_cache: + self.group_cache[g] = None + + def materialize( + self, + *, + chunk_index: int, + entry_start: int, + entry_stop: int, + max_chunk_size: int, + ) -> ak.Array: + # strategy: read from disk with granularity given by row group sizes + # - use chunk info to determine which groups need to be read + # - guard each read operation of a group by locks + # - add materialized groups that might overlap with another chunk in a temporary cache + # - remove cached groups eagerly once it becomes clear that no chunk will need it + + # fill the chunk -> groups mapping once + with self.chunk_to_groups_lock: + if not self.chunk_to_groups: + # note: a hare-and-tortoise algorithm could be possible to get the mapping with less + # than n^2 complexity, but for our case with ~30 chunks this should be ok (for now) + n_chunks = int(math.ceil(len(self) / max_chunk_size)) + # in case there are no entries, ensure that at least one empty chunk is created + for _chunk_index in range(max(n_chunks, 1)): + _entry_start = _chunk_index * max_chunk_size + _entry_stop = min(_entry_start + max_chunk_size, len(self)) + groups = [] + for g, (g_start, g_stop) in enumerate(zip(self.group_divisions[:-1], self.group_divisions[1:])): + # note: check strict increase of chunk size to accommodate zero-length size + if g_stop <= _entry_start < _entry_stop: + continue + if g_start >= _entry_stop > _entry_start: + break + groups.append(g) + self.group_cache[g].chunks.add(_chunk_index) + self.chunk_to_groups[_chunk_index] = groups + + # read groups one at a time and store parts that make up the chunk for concatenation + parts = [] + for g in self.chunk_to_groups[chunk_index]: + # obtain the array + with self.group_locks[g]: + # remove this chunk from the list of chunks to be handled + self.group_cache[g].chunks.remove(chunk_index) + + if self.group_cache[g].array is None: + arr = ak.from_parquet(self.path, row_groups=[g], **self.open_options) + # add to cache when there is a chunk left that will need it + if self.group_cache[g].chunks: + self.group_cache[g].array = arr + else: + arr = self.group_cache[g].array + # remove from cache when there is no chunk left that would need it + if not self.group_cache[g].chunks: + self.group_cache[g].array = None + + # add part for concatenation using entry info + div_start, div_stop = self.group_divisions[g:g + 2] + part_start = max(entry_start - div_start, 0) + part_stop = min(entry_stop - div_start, div_stop - div_start) + parts.append(arr[part_start:part_stop]) + + # construct the full array + arr = parts[0] if len(parts) == 1 else ak.concatenate(parts, axis=0) + + # cleanup + del parts + + return arr + + class ChunkedIOHandler(object): """ Allows reading one or multiple files and iterating through chunks of their content with @@ -3268,7 +3475,7 @@ class ChunkedIOHandler(object): # chunk position container ChunkPosition = namedtuple( "ChunkPosition", - ["index", "entry_start", "entry_stop", "max_chunk_size"], + ["index", "entry_start", "entry_stop", "max_chunk_size", "n_chunks"], ) # read result container @@ -3374,11 +3581,13 @@ def create_chunk_position( if n_entries == 0: entry_start = 0 entry_stop = 0 + n_chunks = 0 else: entry_start = chunk_index * chunk_size entry_stop = min((chunk_index + 1) * chunk_size, n_entries) + n_chunks = int(math.ceil(n_entries / chunk_size)) - return cls.ChunkPosition(chunk_index, entry_start, entry_stop, chunk_size) + return cls.ChunkPosition(chunk_index, entry_start, entry_stop, chunk_size, n_chunks) @classmethod def get_source_handler( @@ -3400,6 +3609,7 @@ def get_source_handler( - "coffea_root" - "coffea_parquet" - "awkward_parquet" + - "dask_awkward_parquet" """ if source_type is None: if isinstance(source, uproot.ReadOnlyDirectory): @@ -3411,7 +3621,7 @@ def get_source_handler( # priotize coffea nano events source_type = "coffea_root" elif source.endswith(".parquet"): - # priotize awkward nano events + # prioritize non-dask awkward reader source_type = "awkward_parquet" if not source_type: @@ -3445,6 +3655,13 @@ def get_source_handler( cls.close_awkward_parquet, cls.read_awkward_parquet, ) + if source_type == "dask_awkward_parquet": + return cls.SourceHandler( + source_type, + cls.open_dask_awkward_parquet, + cls.close_dask_awkward_parquet, + cls.read_dask_awkward_parquet, + ) raise NotImplementedError(f"unknown source_type '{source_type}'") @@ -3585,7 +3802,7 @@ def read_coffea_root( chunk_pos: ChunkPosition, read_options: dict | None = None, read_columns: set[str | Route] | None = None, - ) -> coffea.nanoevents.methods.base.NanoEventsArray: + ) -> ak.Array: """ Given a file location or opened uproot file, and a tree name in a 2-tuple *source_object*, returns an awkward array chunk referred to by *chunk_pos*, assuming nanoAOD structure. @@ -3593,9 +3810,11 @@ def read_coffea_root( *read_columns* are converted to strings and, if not already present, added as nested fields ``iteritems_options.filter_name`` to *read_options*. """ + import coffea.nanoevents + # default read options read_options = read_options or {} - read_options["delayed"] = False + read_options["mode"] = "eager" read_options["runtime_cache"] = None read_options["persistent_cache"] = None @@ -3640,6 +3859,8 @@ def open_coffea_parquet( Given a parquet file located at *source*, returns a 2-tuple *(source, entries)*. Passing *open_options* or *read_columns* has no effect. """ + import pyarrow.parquet as pq + return (source, pq.ParquetFile(source).metadata.num_rows) @classmethod @@ -3659,7 +3880,7 @@ def read_coffea_parquet( chunk_pos: ChunkPosition, read_options: dict | None = None, read_columns: set[str | Route] | None = None, - ) -> coffea.nanoevents.methods.base.NanoEventsArray: + ) -> ak.Array: """ Given a the location of a parquet file *source_object*, returns an awkward array chunk referred to by *chunk_pos*, assuming nanoAOD structure. *read_options* are passed to @@ -3667,8 +3888,11 @@ def read_coffea_parquet( strings and, if not already present, added as nested field ``parquet_options.read_dictionary`` to *read_options*. """ + import coffea.nanoevents + # default read options read_options = read_options or {} + read_options["mode"] = "eager" read_options["runtime_cache"] = None read_options["persistent_cache"] = None @@ -3705,20 +3929,19 @@ def open_awkward_parquet( source: str, open_options: dict | None = None, read_columns: set[str | Route] | None = None, - ) -> tuple[ak.Array, int]: + ) -> tuple[ChunkedParquetReader, int]: """ - Opens a parquet file saved at *source*, loads the content as an dask awkward array, - wrapped by a :py:class:`DaskArrayReader`, and returns a 2-tuple *(array, length)*. - *open_options* and *chunk_size* are forwarded to :py:class:`DaskArrayReader`. *read_columns* - are converted to strings and, if not already present, added as field ``columns`` to - *open_options*. + Opens a parquet file saved at *source*, loads the content as chunks of an awkward array wrapped by a + :py:class:`ChunkedParquetReader`, and returns a 2-tuple *(reader, length)*. + + *open_options* and *chunk_size* are forwarded accordingly. *read_columns* are converted to strings and, if not + already present, added as field ``columns`` to *open_options*. """ if not isinstance(source, str): raise Exception(f"'{source}' cannot be opened as awkward_parquet") # default open options open_options = open_options or {} - open_options.setdefault("split_row_groups", True) # preserve input file partitions # inject read_columns if read_columns and "columns" not in open_options: @@ -3726,12 +3949,72 @@ def open_awkward_parquet( open_options["columns"] = filter_name # load the array wrapper - arr = DaskArrayReader(source, open_options) + reader = ChunkedParquetReader(source, open_options) - return (arr, len(arr)) + return (reader, len(reader)) @classmethod def close_awkward_parquet( + cls, + source_object: ChunkedParquetReader, + ) -> None: + """ + Closes the chunked parquet reader referred to by *source_object*. + """ + source_object.close() + + @classmethod + def read_awkward_parquet( + cls, + source_object: ChunkedParquetReader, + chunk_pos: ChunkedIOHandler.ChunkPosition, + read_options: dict | None = None, + read_columns: set[str | Route] | None = None, + ) -> ak.Array: + """ + Given a :py:class:`ChunkedParquetReader` *source_object*, returns the chunk referred to by *chunk_pos* as a + full copy loaded into memory. Passing neither *read_options* nor *read_columns* has an effect. + """ + # get the materialized ak array for that chunk + return source_object.materialize( + chunk_index=chunk_pos.index, + entry_start=chunk_pos.entry_start, + entry_stop=chunk_pos.entry_stop, + max_chunk_size=chunk_pos.max_chunk_size, + ) + + @classmethod + def open_dask_awkward_parquet( + cls, + source: str, + open_options: dict | None = None, + read_columns: set[str | Route] | None = None, + ) -> tuple[DaskArrayReader, int]: + """ + Opens a parquet file saved at *source*, loads the content as an dask awkward array, wrapped by a + :py:class:`DaskArrayReader`, and returns a 2-tuple *(reader, length)*. + + *open_options* and *chunk_size* are forwarded to :py:class:`DaskArrayReader`. *read_columns* are converted to + strings and, if not already present, added as field ``columns`` to *open_options*. + """ + if not isinstance(source, str): + raise Exception(f"'{source}' cannot be opened as awkward_parquet") + + # default open options + open_options = open_options or {} + + # inject read_columns + if read_columns and "columns" not in open_options: + filter_name = [Route(s).string_column for s in read_columns] + open_options["columns"] = filter_name + + # load the array wrapper + reader = DaskArrayReader(source, open_options) + + return (reader, len(reader)) + + @classmethod + def close_dask_awkward_parquet( cls, source_object: DaskArrayReader, ) -> None: @@ -3741,7 +4024,7 @@ def close_awkward_parquet( source_object.close() @classmethod - def read_awkward_parquet( + def read_dask_awkward_parquet( cls, source_object: DaskArrayReader, chunk_pos: ChunkedIOHandler.ChunkPosition, diff --git a/columnflow/columnar_util_Ghent.py b/columnflow/columnar_util_Ghent.py index 0e985b106..6cd73cffb 100644 --- a/columnflow/columnar_util_Ghent.py +++ b/columnflow/columnar_util_Ghent.py @@ -16,13 +16,14 @@ from columnflow.columnar_util import remove_ak_column, has_ak_column ak = maybe_import("awkward") -coffea = maybe_import("coffea") def TetraVec(arr: ak.Array, keep: Sequence | str | Literal[-1] = -1) -> ak.Array: """ create a Lorentz for fector from an awkward array with pt, eta, phi, and mass fields """ + import coffea + mandatory_fields = ("pt", "eta", "phi", "mass") exclude_fields = ("x", "y", "z", "t") for field in mandatory_fields: diff --git a/columnflow/config_util.py b/columnflow/config_util.py index 3a3da34f3..d728f5800 100644 --- a/columnflow/config_util.py +++ b/columnflow/config_util.py @@ -18,7 +18,7 @@ from columnflow.util import maybe_import, get_docs_url from columnflow.columnar_util import flat_np_view, layout_ak_array -from columnflow.types import Callable, Any, Sequence +from columnflow.types import Callable, Any, Sequence, Literal ak = maybe_import("awkward") np = maybe_import("numpy") @@ -333,16 +333,27 @@ def get_shift_from_configs(configs: list[od.Config], shift: str | od.Shift, sile def get_shifts_from_sources(config: od.Config, *shift_sources: Sequence[str]) -> list[od.Shift]: """ - Takes a *config* object and returns a list of shift instances for both directions given a - sequence *shift_sources*. + Takes a *config* object and returns a list of shift instances for both directions given a sequence of + *shift_sources*. Each source should be the name of a shift source (no direction suffix) or a pattern. + + :param config: :py:class:`order.Config` object from which to retrieve the shifts. + :param shift_sources: Sequence of shift source names or patterns. + :return: List of :py:class:`order.Shift` instances obtained from the given sources. """ - return sum( - ( - [config.get_shift(f"{s}_{od.Shift.UP}"), config.get_shift(f"{s}_{od.Shift.DOWN}")] - for s in shift_sources - ), - [], - ) + # since each passed source can be a pattern, all existing sources need to be checked + # however, the order should be preserved, so loop through each pattern and check for matching sources + existing_sources = {shift.source for shift in config.shifts} + found_sources = set() + shifts = [] + for pattern in shift_sources: + for source in existing_sources: + if source not in found_sources and law.util.multi_match(source, pattern): + found_sources.add(source) + shifts += [ + config.get_shift(f"{source}_{od.Shift.UP}"), + config.get_shift(f"{source}_{od.Shift.DOWN}"), + ] + return shifts def group_shifts( @@ -467,6 +478,10 @@ class CategoryGroup: Container to store information about a group of categories, mostly used for creating combinations in :py:func:`create_category_combinations`. + .. note:: + + A group is considered a full partition of the phase space if it is both complete and non-overlapping. + :param categories: List of :py:class:`order.Category` objects or names that refer to the desired category. :param is_complete: Should be *True* if the union of category selections covers the full phase space (no gaps). :param has_overlap: Should be *False* if all categories are pairwise disjoint (no overlap). @@ -490,6 +505,7 @@ def create_category_combinations( config: od.Config, categories: dict[str, CategoryGroup | list[od.Category]], name_fn: Callable[[Any], str], + parent_mode: Literal["all", "none", "safe"] = "safe", kwargs_fn: Callable[[Any], dict] | None = None, skip_existing: bool = True, skip_fn: Callable[[dict[str, od.Category], str], bool] | None = None, @@ -500,9 +516,9 @@ def create_category_combinations( returns the number of newly created categories. *categories* should be a dictionary that maps string names to :py:class:`CategoryGroup` objects which are thin - wrappers around sequences of categories (objects or names). Group names (dictionary keys) are used as keyword - arguments in a callable *name_fn* that is supposed to return the name of newly created categories (see example - below). + wrappers around sequences of categories (objects or names) and provide additional information about the group as a + whole. Group names (dictionary keys) are used as keyword arguments in a callable *name_fn* that is supposed to + return the name of newly created categories (see example below). .. note:: @@ -510,6 +526,26 @@ def create_category_combinations( columnflow to determine whether the summation over specific categories is valid or may result in under- or over-counting when combining leaf categories. These checks may be performed by other functions and tools based on information derived from groups and stored in auxiliary fields of the newly created categories. + Given a *config* object and sequences of *categories* in a dict, creates all combinations of possible leaf + categories at different depths, connects them with parent - child relations (see :py:class:`order.Category`) and + returns the number of newly created categories. + + *categories* should be a dictionary that maps string names to :py:class:`CategoryGroup` objects which are thin + wrappers around sequences of categories (objects or names). Group names (dictionary keys) are used as keyword + arguments in a callable *name_fn* that is supposed to return the name of newly created categories (see example + below). + + All intermediate layers of categories can be built and connected automatically to one another by parent - child + category relations. The exact behavior is controlled by *parent_mode*: + + - ``"all"``: All intermediate parent category layers are created and connected. Please note that this choice + omits information about group completeness and overlaps (see :py:attr:`CategoryGroup.is_partition`) of child + categories which - in cases such as child category summation - can lead to unintended results. + - ``"none"``: No intermediate parent category layers but only leaf categories are created and connected to their + root categories. + - ``"safe"``: Intermediate parent category layers are created and connected only if the group of child + categories is both complete and non-overlapping (see :py:attr:`CategoryGroup.is_partition`). This is the + recommended choice (and the default) as it avoids unintended results as mentioned in ``"all"``. Each newly created category is instantiated with this name as well as arbitrary keyword arguments as returned by *kwargs_fn*. This function is called with the categories (in a dictionary, mapped to the sequence names as given in @@ -547,6 +583,8 @@ def kwargs_fn(categories): :param categories: Dictionary that maps group names to :py:class:`CategoryGroup` containers. :param name_fn: Callable that receives a dictionary mapping group names to categories and returns the name of the newly created category. + :param parent_mode: Controls how intermediate parent categories are created and connected. Either of ``"all"``, + ``"none"``, or ``"safe"``. :param kwargs_fn: Callable that receives a dictionary mapping group names to categories and returns a dictionary of keyword arguments that are forwarded to the category constructor. :param skip_existing: If *True*, skip the creation of a category when it already exists in *config*. @@ -557,6 +595,12 @@ def kwargs_fn(categories): :raises ValueError: If a non-unique category id is detected. :return: Number of newly created categories. """ + # check parent mode + parent_mode = parent_mode.lower() + known_parent_modes = ["all", "none", "safe"] + if parent_mode not in known_parent_modes: + raise ValueError(f"unknown parent_mode {parent_mode}, known values are {', '.join(known_parent_modes)}") + # cast categories for name, _categories in categories.items(): # ensure CategoryGroup is used @@ -567,6 +611,7 @@ def kwargs_fn(categories): f"using a list to define a sequence of categories for create_category_combinations() is depcreated " f"and will be removed in a future version, please use a CategoryGroup instance instead: {docs_url}", ) + # create a group assuming (!) it describes a full, valid phasespace partition _categories = CategoryGroup( categories=law.util.make_list(_categories), is_complete=True, @@ -582,6 +627,8 @@ def kwargs_fn(categories): unique_ids_cache = {cat.id for cat, _, _ in config.walk_categories()} n_groups = len(categories) group_names = list(categories.keys()) + safe_groups = {name for name, group in categories.items() if group.is_partition} + unsafe_groups = set(group_names) - safe_groups # nothing to do when there are less than 2 groups if n_groups < 2: @@ -593,11 +640,19 @@ def kwargs_fn(categories): if kwargs_fn and not callable(kwargs_fn): raise TypeError(f"when set, kwargs_fn must be a function, but got {kwargs_fn}") - # start combining, considering one additional groups for combinatorics at a time - for _n_groups in range(2, n_groups + 1): + # lookup table with created categories for faster access when connecting parents + created_categories: dict[str, od.Category] = {} + # start combining, considering one additional group for combinatorics at a time + # if skipping parents entirely, only consider the iteration that contains all groups + for _n_groups in ([n_groups] if parent_mode == "none" else range(2, n_groups + 1)): # build all group combinations for _group_names in itertools.combinations(group_names, _n_groups): + # when creating parents in "safe" mode, skip combinations that miss unsafe groups + # (i.e. they must be part of _group_names to be used later) + if parent_mode == "safe": + if (set(group_names) - set(_group_names)) & unsafe_groups: + continue # build the product of all categories for the given groups _categories = [categories[group_name].categories for group_name in _group_names] @@ -623,7 +678,10 @@ def kwargs_fn(categories): # create the new category cat = od.Category(name=cat_name, **kwargs) - n_created_categories += 1 + created_categories[cat_name] = cat + + # add a tag to denote this category was auto-created + cat.add_tag("auto_created_by_combinations") # ID uniqueness check: raise an error when a non-unique id is detected for a new category if isinstance(kwargs["id"], int): @@ -636,19 +694,116 @@ def kwargs_fn(categories): ) unique_ids_cache.add(kwargs["id"]) - # find direct parents and connect them - for _parent_group_names in itertools.combinations(_group_names, _n_groups - 1): + # find combinations of parents and connect them, depending on parent_mode + if parent_mode == "all": + # all direct parents, obtained by combinations with one missing group + parent_gen = itertools.combinations(_group_names, _n_groups - 1) + elif parent_mode == "none": + # only connect to root categories + parent_gen = ((name,) for name in _group_names) + else: # safe + # same as "all", but unsafe groups must be part of the combinations + def _parent_gen(): + seen = set() + # choose 1 group to sum over from _n_groups available + for names in itertools.combinations(_group_names, _n_groups - 1): + # as above, if there is at least one unsafe group missing, the parent was not created + if (set(_group_names) - set(names)) & unsafe_groups: + continue + if names and names not in seen: + seen.add(names) + yield names + # in case no parent combination was yielded, yield all root categories separately + if not seen: + yield from ((name,) for name in _group_names) + parent_gen = _parent_gen() + + # actual connections + for _parent_group_names in parent_gen: + # find the parent if len(_parent_group_names) == 1: - parent_cat_name = root_cats[_parent_group_names[0]].name + parent_cat = root_cats[_parent_group_names[0]] else: parent_cat_name = name_fn({ group_name: root_cats[group_name] for group_name in _parent_group_names }) - parent_cat = config.get_category(parent_cat_name, deep=True) + if parent_cat_name in created_categories: + parent_cat = created_categories[parent_cat_name] + else: + parent_cat = config.get_category(parent_cat_name, deep=True) + # connect parent_cat.add_category(cat) - return n_created_categories + return len(created_categories) + + +def track_category_changes( + config: od.Config, + summary_path: str | None = None, + skip_auto_created: bool = False, +) -> None: + """ + Scans the categories in *config* and saves a summary in a file located at *summary_path*. If the file exists, + the summary from a previous run is loaded first and compare to the current categories. If changes are found, a + warning is shown with details about these changes. + + Categories automatically created via :py:func:`create_category_combinations` can be skipped via *skip_auto_created*. + + :param config: :py:class:`~order.config.Config` instance to scan for categories. + :param summary_path: Path to the summary file. Defaults to "$LAW_HOME/category_summary_{config.name}.json". + :param skip_auto_created: If *True*, categories with the tag "auto_created_by_combinations" are skipped. + """ + # build summary file as law target + if not summary_path: + summary_path = law.config.law_home_path(f"category_summary_{config.name}.json") + summary_file = law.LocalFileTarget(summary_path) + + # gather category info + cat_pairs = sorted( + (cat.name, cat.id) + for cat, *_ in config.walk_categories(include_self=True) + if not skip_auto_created or not cat.has_tag("auto_created_by_combinations") + ) + cat_summary = { + "hash": law.util.create_hash(cat_pairs), + "categories": dict(cat_pairs), + } + + save_summary = True + if summary_file.exists(): + previous_summary = summary_file.load(formatter="json") + if previous_summary["hash"] == cat_summary["hash"]: + save_summary = False + else: + msgs = [ + f"the category definitions in config '{config.name}' seem to have changed based on a hash comparison, " + "ignore this message in case you knowingly adjusted categories fully aware of the changes:", + f"old hash: {previous_summary['hash']}, new hash: {cat_summary['hash']}", + ] + curr = cat_summary["categories"] + prev = previous_summary["categories"] + # track added and removed names + curr_names = set(curr) + prev_names = set(prev) + if (added_names := curr_names - prev_names): + msgs.append(f"added categories : {', '.join(sorted(added_names))}") + if (removed_names := prev_names - curr_names): + msgs.append(f"removed categories : {', '.join(sorted(removed_names))}") + # track id changes for names present in both + changed_ids = { + name: (prev[name], curr[name]) + for name in curr_names & prev_names + if prev[name] != curr[name] + } + if changed_ids: + pair_repr = lambda pair: f"{pair[0]}: {pair[1][0]} -> {pair[1][1]}" + msgs.append("changed category ids:\n - " + "\n - ".join(map(pair_repr, changed_ids.items()))) + + logger.warning_once(f"categories_changed_{config.name}", "\n".join(msgs)) + + if save_summary: + summary_file.dump(cat_summary, formatter="json", indent=4) def verify_config_processes(config: od.Config, warn: bool = False) -> None: diff --git a/columnflow/hist_util.py b/columnflow/hist_util.py index 7f16da17a..4efd3c73d 100644 --- a/columnflow/hist_util.py +++ b/columnflow/hist_util.py @@ -14,11 +14,12 @@ from columnflow.columnar_util import flat_np_view from columnflow.util import maybe_import -from columnflow.types import Any +from columnflow.types import TYPE_CHECKING, Any, Sequence -hist = maybe_import("hist") np = maybe_import("numpy") ak = maybe_import("awkward") +if TYPE_CHECKING: + hist = maybe_import("hist") logger = law.logger.get_logger(__name__) @@ -38,6 +39,8 @@ def fill_hist( determined automatically and depends on the variable axis type. In this case, shifting is applied to all continuous, non-circular axes. """ + import hist + if fill_kwargs is None: fill_kwargs = {} @@ -163,6 +166,8 @@ def get_axis_kwargs(axis: hist.axis.AxesMixin) -> dict[str, Any]: :param axis: The axis instance to extract information from. :return: The extracted information in a dict. """ + import hist + axis_attrs = ["name", "label"] traits_attrs = [] kwargs = {} @@ -213,6 +218,8 @@ def create_hist_from_variables( weight: bool = True, storage: str | None = None, ) -> hist.Hist: + import hist + histogram = hist.Hist.new # additional category axes @@ -259,6 +266,8 @@ def translate_hist_intcat_to_strcat( axis_name: str, id_map: dict[int, str], ) -> hist.Hist: + import hist + out_axes = [ ax if ax.name != axis_name else hist.axis.StrCategory( [id_map[v] for v in list(ax)], @@ -280,6 +289,8 @@ def add_missing_shifts( """ Adds missing shift bins to a histogram *h*. """ + import hist + # get the set of bins that are missing in the histogram shift_bins = set(h.axes[str_axis]) missing_shifts = set(expected_shifts_bins) - shift_bins @@ -295,3 +306,63 @@ def add_missing_shifts( h.fill(*dummy_fill, weight=0) # TODO: this might skip overflow and underflow bins h[{str_axis: hist.loc(missing_shift)}] = nominal.view() + + +def update_ax_labels(hists: list[hist.Hist], config_inst: od.Config, variable_name: str) -> None: + """ + Helper function to update the axis labels of histograms based on variable instances from + the *config_inst*. + + :param hists: List of histograms to update. + :param config_inst: Configuration instance containing variable definitions. + :param variable_name: Name of the variable to update labels for, formatted as a string + with variable names separated by hyphens (e.g., "var1-var2"). + :raises ValueError: If a variable name is not found in the histogram axes. + """ + labels = {} + for var_name in variable_name.split("-"): + var_inst = config_inst.get_variable(var_name, None) + if var_inst: + labels[var_name] = var_inst.x_title + + for h in hists: + for var_name, label in labels.items(): + ax_names = [ax.name for ax in h.axes] + if var_name in ax_names: + h.axes[var_name].label = label + else: + raise ValueError(f"variable '{var_name}' not found in histogram axes: {h.axes}") + + +def sum_hists(hists: Sequence[hist.Hist]) -> hist.Hist: + """ + Sums a sequence of histograms into a new histogram. In case axis labels differ, which typically leads to errors + ("axes not mergable"), the labels of the first histogram are used. + + :param hists: The histograms to sum. + :return: The summed histogram. + """ + hists = list(hists) + if not hists: + raise ValueError("no histograms given for summation") + + # copy the first histogram + h_sum = hists[0].copy() + if len(hists) == 1: + return h_sum + + # store labels of first histogram + axis_labels = {ax.name: ax.label for ax in h_sum.axes} + + for h in hists[1:]: + # align axis labels if needed, only copy if necessary + h_aligned_labels = None + for ax in h.axes: + if ax.name not in axis_labels or ax.label == axis_labels[ax.name]: + continue + if h_aligned_labels is None: + h_aligned_labels = h.copy() + h_aligned_labels.axes[ax.name].label = axis_labels[ax.name] + h_sum = h_sum + (h if h_aligned_labels is None else h_aligned_labels) + + return h_sum diff --git a/columnflow/histogramming/__init__.py b/columnflow/histogramming/__init__.py index 2282f94fb..f8b76ff20 100644 --- a/columnflow/histogramming/__init__.py +++ b/columnflow/histogramming/__init__.py @@ -11,16 +11,15 @@ import law import order as od -from columnflow.types import Callable +from columnflow.production import TaskArrayFunctionWithProducerRequirements from columnflow.util import DerivableMeta, maybe_import -from columnflow.columnar_util import TaskArrayFunction -from columnflow.types import Any +from columnflow.types import TYPE_CHECKING, Any, Callable, Sequence +if TYPE_CHECKING: + hist = maybe_import("hist") -hist = maybe_import("hist") - -class HistProducer(TaskArrayFunction): +class HistProducer(TaskArrayFunctionWithProducerRequirements): """ Base class for all histogram producers, i.e., functions that control the creation of histograms, event weights, and optional post-processing. @@ -58,6 +57,10 @@ class HistProducer(TaskArrayFunction): skip_compatibility_check = False exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def hist_producer( cls, @@ -65,6 +68,7 @@ def hist_producer( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_producers: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -83,6 +87,7 @@ def hist_producer( skipped for real data. :param data_only: Boolean flag indicating that this hist producer should only run on real data and skipped for Monte Carlo simulation. + :param require_producers: Sequence of names of other producers to add to the requirements. :return: New hist producer subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -92,6 +97,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_producers": require_producers, } # get the module name @@ -247,7 +253,7 @@ def run_post_process_hist(self, h: Any, task: law.Task) -> Any: return h return self.post_process_hist_func(h, task=task) - def run_post_process_merged_hist(self, h: Any, task: law.Task) -> hist.Histogram: + def run_post_process_merged_hist(self, h: Any, task: law.Task) -> hist.Hist: """ Invokes the :py:meth:`post_process_merged_hist_func` of this instance and returns its result, forwarding all arguments. diff --git a/columnflow/histogramming/default.py b/columnflow/histogramming/default.py index 8171031ef..c702b0d87 100644 --- a/columnflow/histogramming/default.py +++ b/columnflow/histogramming/default.py @@ -10,14 +10,15 @@ import order as od from columnflow.histogramming import HistProducer, hist_producer -from columnflow.util import maybe_import -from columnflow.hist_util import create_hist_from_variables, fill_hist, translate_hist_intcat_to_strcat from columnflow.columnar_util import has_ak_column, Route -from columnflow.types import Any +from columnflow.hist_util import create_hist_from_variables, fill_hist, translate_hist_intcat_to_strcat +from columnflow.util import maybe_import +from columnflow.types import TYPE_CHECKING, Any np = maybe_import("numpy") ak = maybe_import("awkward") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") @hist_producer() @@ -39,7 +40,7 @@ def cf_default_create_hist( variables: list[od.Variable], task: law.Task, **kwargs, -) -> hist.Histogram: +) -> hist.Hist: """ Define the histogram structure for the default histogram producer. """ @@ -55,7 +56,7 @@ def cf_default_create_hist( @cf_default.fill_hist -def cf_default_fill_hist(self: HistProducer, h: hist.Histogram, data: dict[str, Any], task: law.Task) -> None: +def cf_default_fill_hist(self: HistProducer, h: hist.Hist, data: dict[str, Any], task: law.Task) -> None: """ Fill the histogram with the data. """ @@ -63,7 +64,7 @@ def cf_default_fill_hist(self: HistProducer, h: hist.Histogram, data: dict[str, @cf_default.post_process_hist -def cf_default_post_process_hist(self: HistProducer, h: hist.Histogram, task: law.Task) -> hist.Histogram: +def cf_default_post_process_hist(self: HistProducer, h: hist.Hist, task: law.Task) -> hist.Hist: """ Post-process the histogram, converting integer to string axis for consistent lookup across configs where ids might be different. diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index d5c3ab01e..0023c1352 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -13,13 +13,14 @@ import order as od import yaml -from columnflow.types import Generator, Callable, TextIO, Sequence, Any -from columnflow.util import DerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url +from columnflow.types import Generator, Callable, TextIO, Sequence, Any, Hashable, Type, T +from columnflow.util import ( + CachedDerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url, freeze, +) logger = law.logger.get_logger(__name__) - default_dataset = law.config.get_expanded("analysis", "default_dataset") @@ -38,16 +39,14 @@ class ParameterType(enum.Enum): rate_unconstrained = "rate_unconstrained" shape = "shape" - def __str__(self: ParameterType) -> str: - """ - Returns the string representation of the parameter type. + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.value}>" - :returns: The string representation of the parameter type. - """ + def __str__(self) -> str: return self.value @property - def is_rate(self: ParameterType) -> bool: + def is_rate(self) -> bool: """ Checks if the parameter type is a rate type. @@ -60,7 +59,7 @@ def is_rate(self: ParameterType) -> bool: } @property - def is_shape(self: ParameterType) -> bool: + def is_shape(self) -> bool: """ Checks if the parameter type is a shape type. @@ -75,35 +74,63 @@ class ParameterTransformation(enum.Enum): """ Flags denoting transformations to be applied on parameters. + Implementation details depend on the routines that apply these transformations, usually as part for a serialization + processes (such as so-called "datacards" in the CMS context). As such, the exact implementation may also differ + depending on the type of the parameter that a transformation is applied to (e.g. shape vs rate). + + The general purpose of each transformation is described below. + :cvar none: No transformation. - :cvar centralize: Centralize the parameter. - :cvar symmetrize: Symmetrize the parameter. - :cvar asymmetrize: Asymmetrize the parameter. - :cvar asymmetrize_if_large: Asymmetrize the parameter if it is large. - :cvar normalize: Normalize the parameter. - :cvar effect_from_shape: Derive effect from shape. - :cvar effect_from_rate: Derive effect from rate. + :cvar effect_from_rate: Creates shape variations for a shape-type parameter using the single- or two-valued effect + usually attributed to rate-type parameters. Only applies to shape-type parameters. + :cvar effect_from_shape: Derive the effect of a rate-type parameter using the overall, integral effect of shape + variations. Only applies to rate-type parameters. + :cvar effect_from_shape_if_flat: Same as :py:attr:`effect_from_shape`, but applies only if both shape variations are + reasonably flat. The definition of "reasonably flat" can be subject to the serialization routine. Only applies + to rate-type parameters. + :cvar symmetrize: The overall (integral) effect of up and down variations is measured and centralized, updating the + variations such that they are equidistant to the nominal one. Can apply to both rate- and shape-type parameters. + :cvar asymmetrize: The symmetric effect on a rate-type parameter (usually given as a single value) is converted into + an asymmetric representation (using two values). Only applies to rate-type parameters. + :cvar asymmetrize_if_large: Same as :py:attr:`asymmetrize`, but depending on a threshold on the size of the + symmetric effect which can be subject to the serialization routine. Only applies to rate-type parameters. + :cvar normalize: Variations of shape-type parameters are changed such that their integral effect identical to the + nominal one. Should only apply to shape-type parameters. + :cvar envelope: Builds an evelope of the up and down variations of a shape-type parameter, potentially on a + bin-by-bin basis. Only applies to shape-type parameters. + :cvar envelope_if_one_sided: Same as :py:attr:`envelope`, but only if the shape variations are one-sided following + a definition that can be subject to the serialization routine. Only applies to shape-type parameters. + :cvar envelope_enforce_two_sided: Same as :py:attr:`envelope`, but it enforces that the up (down) variation of the + constructed envelope is always above (below) the nominal one. Only applies to shape-type parameters. + :cvar flip_smaller_if_one_sided: For asymmetric rate effects (usually given by two values) that are found to be + one-sided (e.g. after applying :py:attr:`effect_from_shape`), flips the smaller effect to the other side of the + nominal value. Only applies to rate-type parameters. + :cvar flip_larger_if_one_sided: Same as :py:attr:`flip_smaller_if_one_sided`, but flips the larger effect. Only + applies to rate-type parameters. """ none = "none" - centralize = "centralize" + effect_from_rate = "effect_from_rate" + effect_from_shape = "effect_from_shape" + effect_from_shape_if_flat = "effect_from_shape_if_flat" symmetrize = "symmetrize" asymmetrize = "asymmetrize" asymmetrize_if_large = "asymmetrize_if_large" normalize = "normalize" - effect_from_shape = "effect_from_shape" - effect_from_rate = "effect_from_rate" + envelope = "envelope" + envelope_if_one_sided = "envelope_if_one_sided" + envelope_enforce_two_sided = "envelope_enforce_two_sided" + flip_smaller_if_one_sided = "flip_smaller_if_one_sided" + flip_larger_if_one_sided = "flip_larger_if_one_sided" - def __str__(self: ParameterTransformation) -> str: - """ - Returns the string representation of the parameter transformation. + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.value}>" - :returns: The string representation of the parameter transformation. - """ + def __str__(self) -> str: return self.value @property - def from_shape(self: ParameterTransformation) -> bool: + def from_shape(self) -> bool: """ Checks if the transformation is derived from shape. @@ -111,10 +138,11 @@ def from_shape(self: ParameterTransformation) -> bool: """ return self in { self.effect_from_shape, + self.effect_from_shape_if_flat, } @property - def from_rate(self: ParameterTransformation) -> bool: + def from_rate(self) -> bool: """ Checks if the transformation is derived from rate. @@ -189,7 +217,15 @@ def __str__(self) -> str: return self.value -class InferenceModel(Derivable): +class InferenceModelMeta(CachedDerivableMeta): + + def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: + config_insts = args[0] if args else kwargs.get("config_insts", []) + config_names = tuple(sorted(config_inst.name for config_inst in config_insts)) + return freeze((cls, config_names, kwargs.get("inst_dict", {}))) + + +class InferenceModel(Derivable, metaclass=InferenceModelMeta): """ Interface to statistical inference models with connections to config objects (such as py:class:`order.Config` or :py:class:`order.Dataset`). @@ -322,11 +358,11 @@ def ignore_aliases(self, *args, **kwargs) -> bool: @classmethod def inference_model( - cls, + cls: T, func: Callable | None = None, bases: tuple[type] = (), **kwargs, - ) -> DerivableMeta | Callable: + ) -> Type[T] | Callable: """ Decorator for creating a new :py:class:`InferenceModel` subclass with additional, optional *bases* and attaching the decorated function to it as ``init_func``. All additional *kwargs* @@ -336,7 +372,7 @@ def inference_model( :param bases: Optional tuple of base classes for the new subclass. :returns: The new subclass or a decorator function. """ - def decorator(func: Callable) -> DerivableMeta: + def decorator(func: Callable) -> Type[T]: # create the class dict cls_dict = { **kwargs, @@ -364,7 +400,7 @@ def model_spec(cls) -> DotDict: Returns a dictionary representing the top-level structure of the model. - *categories*: List of :py:meth:`category_spec` objects. - - *parameter_groups*: List of :py:meth:`paramter_group_spec` objects. + - *parameter_groups*: List of :py:meth:`parameter_group_spec` objects. """ return DotDict([ ("categories", []), @@ -564,40 +600,11 @@ def parameter_config_spec( ("shift_source", str(shift_source) if shift_source else None), ]) - @classmethod - def require_shapes_for_parameter(self, param_obj: dict) -> bool: - """ - Function to check if for a certain parameter object *param_obj* varied - shapes are needed. - - :param param_obj: The parameter object to check. - :returns: *True* if varied shapes are needed, *False* otherwise. - """ - if param_obj.type.is_shape: - # the shape might be build from a rate, in which case input shapes are not required - if param_obj.transformations.any_from_rate: - return False - # in any other case, shapes are required - return True - - if param_obj.type.is_rate: - # when the rate effect is extracted from shapes, they are required - if param_obj.transformations.any_from_shape: - return True - # in any other case, shapes are not required - return False - - # other cases are not supported - raise Exception( - f"shape requirement cannot be evaluated for parameter '{param_obj.name}' with type " + - f"'{param_obj.type}' and transformations {param_obj.transformations}", - ) - - def __init__(self, config_insts: list[od.Config]) -> None: + def __init__(self, config_insts: list[od.Config] | None = None) -> None: super().__init__() # store attributes - self.config_insts = config_insts + self.config_insts = config_insts or [] # temporary attributes for as long as we issue deprecation warnings self.__config_inst = None @@ -1327,8 +1334,8 @@ def add_parameter( for process in _processes: process.parameters.append(_copy.deepcopy(parameter)) - # add to groups - if group: + # add to groups if it was added to at least one process + if group and processes and any(_processes for _processes in processes.values()): self.add_parameter_to_group(parameter.name, group) return parameter @@ -1794,8 +1801,7 @@ def remove_dangling_parameters_from_groups( match_mode: Callable = any, ) -> None: """ - Removes names of parameters from parameter groups that are not assigned to any process in - any category. + Removes names of parameters from parameter groups that are not assigned to any process in any category. :param keep_parameters: A string, pattern, or sequence of them to specify parameters to keep. :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see @@ -1805,9 +1811,7 @@ def remove_dangling_parameters_from_groups( parameter_names = self.get_parameters("*", flat=True) # get set of parameters to keep - _keep_parameters = set() - if keep_parameters: - _keep_parameters = set(self.get_parameters(keep_parameters, match_mode=match_mode, flat=True)) + _keep_parameters = law.util.make_set(keep_parameters) if keep_parameters else set() # go through groups and remove dangling parameters for group in self.parameter_groups: @@ -1816,7 +1820,7 @@ def remove_dangling_parameters_from_groups( for parameter_name in group.parameter_names if ( parameter_name in parameter_names or - (_keep_parameters and parameter_name in _keep_parameters) + law.util.multi_match(parameter_name, _keep_parameters, mode=any) ) ] diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 032a988ce..373f0875a 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -13,19 +13,21 @@ from columnflow import __version__ as cf_version from columnflow.inference import InferenceModel, ParameterType, ParameterTransformation, FlowStrategy +from columnflow.hist_util import sum_hists from columnflow.util import DotDict, maybe_import, real_path, ensure_dir, safe_div, maybe_int -from columnflow.types import Sequence, Any, Union, Hashable +from columnflow.types import TYPE_CHECKING, Sequence, Any, Union, Hashable -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") + # type aliases for nested histogram structs + ShiftHists = dict[Union[str, tuple[str, str]], hist.Hist] # "nominal" or (param_name, "up|down") -> hists + ConfigHists = dict[str, ShiftHists] # config name -> hists + ProcHists = dict[str, ConfigHists] # process name -> hists + DatacardHists = dict[str, ProcHists] # category name -> hists -logger = law.logger.get_logger(__name__) -# type aliases for nested histogram structs -ShiftHists = dict[Union[str, tuple[str, str]], hist.Hist] # "nominal" or (param_name, "up|down") -> hists -ConfigHists = dict[str, ShiftHists] # config name -> hists -ProcHists = dict[str, ConfigHists] # process name -> hists -DatacardHists = dict[str, ProcHists] # category name -> hists +logger = law.logger.get_logger(__name__) class DatacardWriter(object): @@ -41,17 +43,123 @@ class DatacardWriter(object): At the moment, all shapes are written into the same root file and a shape line with wildcards for both bin and process resolution is created. + + As per the definition in :py:class:`ParameterTransformation`, the following parameter effect transormations are + implemented with the following details. + + - :py:attr:`ParameterTransformation.effect_from_rate`: Creates shape variations from a rate-style effect. + Shape-type parameters only. + - :py:attr:`ParameterTransformation.effect_from_shape`: Converts the integral effect of shape variations to an + asymmetric rate-style effect. Rate-type parameters only. + - :py:attr:`ParameterTransformation.effect_from_shape_if_flat`: Same as above but only applies to cases where + both shape variations are reasonably flat. The flatness per varied shape is determined by two criteria that + both must be met: 1. the maximum relative outlier of bin contents with respect to their mean (defaults to + 20%, configurable via *effect_from_shape_if_flat_max_outlier*), 2. the deviation / dispersion of bin + contents, i.e., the square root of the variance of bin contents, relative to their mean (defaults to 10%, + configurable via *effect_from_shape_if_flat_max_deviation*). The parameter should initially be of rate-type, + but in case the criteria are not met, the effect is interpreted as shape-type. + - :py:attr:`ParameterTransformation.symmetrize`: Changes up and down variations of either rate effects and + shapes to symmetrize them around the nominal value. For rate-type parameters, this has no effect if the + effect strength was provided by a single value. There is no conversion into a single value and consequently, + the result is always a two-valued effect. + - :py:attr:`ParameterTransformation.asymmetrize`: Converts single-valued to two-valued effects for rate-style + parameters. + - :py:attr:`ParameterTransformation.asymmetrize_if_large`: Same as above, with a default threshold of 20%. + Configurable via *asymmetrize_if_large_threshold*. + - :py:attr:`ParameterTransformation.normalize`: Normalizes shape variations such that their integrals match that + of the nominal shape. + - :py:attr:`ParameterTransformation.envelope`: Takes the bin-wise maximum in each direction of the up and down + variations of shape-type parameters and constructs new shapes. + - :py:attr:`ParameterTransformation.envelope_if_one_sided`: Same as above, but only in bins where up and down + contributions are one-sided. + - :py:attr:`ParameterTransformation.envelope_enforce_two_sided`: Same as :py:attr:`envelope`, but it enforces + that the up (down) variation of the constructed envelope is always above (below) the nominal one. + - :py:attr:`ParameterTransformation.flip_smaller_if_one_sided`: For asymmetric (two-valued) rate effects that + are found to be one-sided (e.g. after :py:attr:`ParameterTransformation.effect_from_shape`), flips the + smaller effect to the other side. Rate-type parameters only. + - :py:attr:`ParameterTransformation.flip_larger_if_one_sided`: Same as + :py:attr:`ParameterTransformation.flip_smaller_if_one_sided`, but flips the larger effect. Rate-type + parameters only. + + .. note:: + + If used, the transformations :py:attr:`ParameterTransformation.effect_from_rate`, + :py:attr:`ParameterTransformation.effect_from_shape`, and + :py:attr:`ParameterTransformation.effect_from_shape_if_flat` must be the first element in the sequence of + transformations to be applied. The remaining transformations are applied in order based on the outcome of the + effect conversion. """ # minimum separator between columns col_sep = " " + # specific sets of transformations + first_index_trafos = { + ParameterTransformation.effect_from_rate, + ParameterTransformation.effect_from_shape, + ParameterTransformation.effect_from_shape_if_flat, + } + shape_only_trafos = { + ParameterTransformation.effect_from_rate, + ParameterTransformation.normalize, + ParameterTransformation.envelope, + ParameterTransformation.envelope_if_one_sided, + ParameterTransformation.envelope_enforce_two_sided, + } + rate_only_trafos = { + ParameterTransformation.effect_from_shape, + ParameterTransformation.effect_from_shape_if_flat, + ParameterTransformation.asymmetrize, + ParameterTransformation.asymmetrize_if_large, + ParameterTransformation.flip_smaller_if_one_sided, + ParameterTransformation.flip_larger_if_one_sided, + } + + @classmethod + def validate_model(cls, inference_model_inst: InferenceModel, silent: bool = False) -> bool: + # perform parameter checks one after another, collect errors along the way + errors: list[str] = [] + for cat_name, proc_name, param_obj in inference_model_inst.iter_parameters(): + # check the transformations + _errors: list[str] = [] + for i, trafo in enumerate(param_obj.transformations): + if i != 0 and trafo in cls.first_index_trafos: + _errors.append( + f"parameter transformation '{trafo}' must be the first one to apply, but found at index {i}", + ) + if not param_obj.type.is_shape and trafo in cls.shape_only_trafos: + _errors.append( + f"parameter transformation '{trafo}' only applies to shape-type parameters, but found type " + f"'{param_obj.type}'", + ) + if not param_obj.type.is_rate and trafo in cls.rate_only_trafos: + _errors.append( + f"parameter transformation '{trafo}' only applies to rate-type parameters, but found type " + f"'{param_obj.type}'", + ) + errors.extend( + f"for parameter '{param_obj}' in process '{proc_name}' in category '{cat_name}': {err}" + for err in _errors + ) + + # handle errors + if errors: + if silent: + return False + errors_repr = "\n - ".join(errors) + raise ValueError(f"inference model invalid, reasons:\n - {errors_repr}") + + return True + def __init__( self, inference_model_inst: InferenceModel, histograms: DatacardHists, rate_precision: int = 4, effect_precision: int = 4, + effect_from_shape_if_flat_max_outlier: float = 0.2, + effect_from_shape_if_flat_max_deviation: float = 0.1, + asymmetrize_if_large_threshold: float = 0.2, ) -> None: super().__init__() @@ -60,6 +168,13 @@ def __init__( self.histograms = histograms self.rate_precision = rate_precision self.effect_precision = effect_precision + self.effect_precision = effect_precision + self.effect_from_shape_if_flat_max_outlier = effect_from_shape_if_flat_max_outlier + self.effect_from_shape_if_flat_max_deviation = effect_from_shape_if_flat_max_deviation + self.asymmetrize_if_large_threshold = asymmetrize_if_large_threshold + + # validate the inference model + self.validate_model(self.inference_model_inst) def write( self, @@ -152,10 +267,10 @@ def write( # tabular-style parameters blocks.tabular_parameters = [] for param_name in self.inference_model_inst.get_parameters(flat=True): - param_obj = None + types = set() effects = [] for cat_name, proc_name in flat_rates: - _param_obj = self.inference_model_inst.get_parameter( + param_obj = self.inference_model_inst.get_parameter( param_name, category=cat_name, process=proc_name, @@ -163,83 +278,108 @@ def write( ) # skip line-style parameters as they are handled separately below - if _param_obj and _param_obj.type == ParameterType.rate_unconstrained: + if param_obj and param_obj.type == ParameterType.rate_unconstrained: continue # empty effect - if _param_obj is None: + if param_obj is None: effects.append("-") continue - # compare with previous param_obj - if param_obj is None: - param_obj = _param_obj - elif _param_obj.type != param_obj.type: + # compare with previously seen types as combine cannot mix arbitrary parameter types acting differently + # on different processes + types.add(param_obj.type) + if len(types) > 1 and types != {ParameterType.rate_gauss, ParameterType.shape}: raise ValueError( - f"misconfigured parameter '{param_name}' with type '{_param_obj.type}' that was previously " - f"seen with incompatible type '{param_obj.type}'", + f"misconfigured parameter '{param_name}' with type '{param_obj.type}' that was previously " + f"seen with incompatible type(s) '{types - {param_obj.type}}'", ) # get the effect - effect = _param_obj.effect + effect = param_obj.effect # rounding helper depending on the effect precision effect_precision = ( self.effect_precision - if _param_obj.effect_precision <= 0 - else _param_obj.effect_precision + if param_obj.effect_precision <= 0 + else param_obj.effect_precision ) rnd = lambda f: round(f, effect_precision) # update and transform effects - if _param_obj.type.is_rate: - # obtain from shape effects when requested - if _param_obj.transformations.any_from_shape: - effect = shape_effects[cat_name][proc_name][param_name] - + if param_obj.type.is_rate: # apply transformations one by one - for trafo in _param_obj.transformations: - if trafo == ParameterTransformation.centralize: - # skip symmetric effects - if not isinstance(effect, tuple) and len(effect) != 2: - continue - # skip one sided effects - if not (min(effect) <= 1 <= max(effect)): - continue - d, u = effect - diff = 0.5 * (d + u) - 1.0 - effect = (effect[0] - diff, effect[1] - diff) + for trafo in param_obj.transformations: + if trafo.from_shape: + # take effect from shape variations + effect = shape_effects[cat_name][proc_name][param_name] elif trafo == ParameterTransformation.symmetrize: # skip symmetric effects - if not isinstance(effect, tuple) and len(effect) != 2: + if not isinstance(effect, tuple) or len(effect) != 2: continue # skip one sided effects if not (min(effect) <= 1 <= max(effect)): continue d, u = effect - effect = 0.5 * (u - d) + 1.0 + diff = 0.5 * (d + u) - 1.0 + effect = (effect[0] - diff, effect[1] - diff) - elif trafo == ParameterTransformation.asymmetrize or ( - trafo == ParameterTransformation.asymmetrize_if_large and - isinstance(effect, float) and - abs(effect - 1.0) >= 0.2 + elif ( + trafo == ParameterTransformation.asymmetrize or + ( + trafo == ParameterTransformation.asymmetrize_if_large and + isinstance(effect, float) and + abs(effect - 1.0) >= self.asymmetrize_if_large_threshold + ) ): # skip asymmetric effects if not isinstance(effect, float): continue effect = (2.0 - effect, effect) - elif _param_obj.type.is_shape: - # when the shape was constructed from a rate, reset the effect to 1 - if _param_obj.transformations.any_from_rate: - effect = 1.0 + elif trafo in { + ParameterTransformation.flip_smaller_if_one_sided, + ParameterTransformation.flip_larger_if_one_sided, + }: + # skip symmetric effects + if not isinstance(effect, tuple) or len(effect) != 2: + continue + flip_larger = trafo == ParameterTransformation.flip_larger_if_one_sided + flip_smaller = trafo == ParameterTransformation.flip_smaller_if_one_sided + # check sidedness and determine which of the two effect values to flip, identified by index + if max(effect) < 1.0: + # both below nominal + flip_index = int( + (effect[1] > effect[0] and flip_larger) or + (effect[1] < effect[0] and flip_smaller), + ) + elif min(effect) > 1.0: + # both above nominal + flip_index = int( + (effect[1] > effect[0] and flip_smaller) or + (effect[1] < effect[0] and flip_larger), + ) + else: + # skip one-sided effects + continue + effect = tuple(((2.0 - e) if i == flip_index else e) for i, e in enumerate(effect)) + + elif param_obj.type.is_shape: + # apply transformations one by one + for trafo in param_obj.transformations: + if trafo.from_rate: + # when the shape was constructed from a rate, reset the effect to 1 + effect = 1.0 + + # custom hook to modify the effect + effect = self.modify_parameter_effect(cat_name, proc_name, param_obj, effect) # encode the effect if isinstance(effect, (int, float)): if effect == 0.0: effects.append("-") - elif effect == 1.0 and _param_obj.type.is_shape: + elif effect == 1.0 and param_obj.type.is_shape: effects.append("1") else: effects.append(str(rnd(effect))) @@ -252,14 +392,28 @@ def write( ) # add the tabular line - if param_obj and effects: - type_str = "shape" - if param_obj.type == ParameterType.rate_gauss: - type_str = "lnN" - elif param_obj.type == ParameterType.rate_uniform: - type_str = "lnU" + if types and effects: + type_str = None + if len(types) == 1: + _type = list(types)[0] + if _type == ParameterType.rate_gauss: + type_str = "lnN" + elif _type == ParameterType.rate_uniform: + type_str = "lnU" + elif _type == ParameterType.shape: + type_str = "shape" + elif types == {ParameterType.rate_gauss, ParameterType.shape}: + # when mixing lnN and shape effects, combine expects the "shape?" type and makes the actual decision + # dependend on the presence of shape variations in the accompaying shape files, see + # https://cms-analysis.github.io/HiggsAnalysis-CombinedLimit/v10.2.X/part2/settinguptheanalysis/?h=shape%3F#template-shape-uncertainties # noqa + type_str = "shape?" + if not type_str: + raise ValueError(f"misconfigured parameter '{param_name}' with incompatible type(s) '{types}'") blocks.tabular_parameters.append([param_name, type_str, effects]) + # alphabetical, case-insensitive order by name + blocks.tabular_parameters.sort(key=lambda line: line[0].lower()) + if blocks.tabular_parameters: empty_lines.add("tabular_parameters") @@ -402,12 +556,14 @@ def handle_flow(cat_obj, h, name): # warn in case of flow content if cat_obj.flow_strategy == FlowStrategy.warn: if underflow[0]: - logger.warning( + logger.warning_once( + f"underflow_warn_{self.inference_model_inst.cls_name}_{cat_obj.name}_{name}", f"underflow content detected in category '{cat_obj.name}' for histogram " f"'{name}' ({underflow[0] / view.value.sum() * 100:.1f}% of integral)", ) if overflow[0]: - logger.warning( + logger.warning_once( + f"overflow_warn_{self.inference_model_inst.cls_name}_{cat_obj.name}_{name}", f"overflow content detected in category '{cat_obj.name}' for histogram " f"'{name}' ({overflow[0] / view.value.sum() * 100:.1f}% of integral)", ) @@ -457,13 +613,13 @@ def fill_empty(cat_obj, h): # flat list of hists for configs that contribute to this category hists: list[dict[Hashable, hist.Hist]] = [ hd for config_name, hd in config_hists.items() - if config_name in cat_obj.config_data + if not cat_obj.config_data or config_name in cat_obj.config_data ] if not hists: continue # helper to sum over them for a given shift key and an optional fallback - def sum_hists(key: Hashable, fallback_key: Hashable | None = None) -> hist.Hist: + def get_hist_sum(key: Hashable, fallback_key: Hashable | None = None) -> hist.Hist: def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: if key in hd: return hd[key] @@ -472,29 +628,54 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: raise Exception( f"'{key}' shape for process '{proc_name}' in category '{cat_name}' misconfigured: {hd}", ) - return sum(map(get, hists[1:]), get(hists[0]).copy()) + return sum_hists(map(get, hists)) + + # helper to extract sum of hists, apply scale, handle flow and fill empty bins + def load( + hist_name: str, + hist_key: Hashable, + fallback_key: Hashable | None = None, + scale: float = 1.0, + ) -> hist.Hist: + h = get_hist_sum(hist_key, fallback_key) * scale + handle_flow(cat_obj, h, hist_name) + fill_empty(cat_obj, h) + return h # get the process scale (usually 1) proc_obj = self.inference_model_inst.get_process(proc_name, category=cat_name) scale = proc_obj.scale # nominal shape - h_nom = sum_hists("nominal") * scale nom_name = nom_pattern.format(category=cat_name, process=proc_name) - fill_empty(cat_obj, h_nom) - handle_flow(cat_obj, h_nom, nom_name) + h_nom = load(nom_name, "nominal", scale=scale) out_file[nom_name] = h_nom _rates[proc_name] = h_nom.sum().value + integral = lambda h: h.sum().value # prepare effects __effects = _effects[proc_name] = OrderedDict() - # go through all parameters and check if varied shapes need to be processed + # go through all parameters and potentially handle varied shapes for _, _, param_obj in self.inference_model_inst.iter_parameters(category=cat_name, process=proc_name): + down_name = syst_pattern.format( + category=cat_name, + process=proc_name, + parameter=param_obj.name, + direction="Down", + ) + up_name = syst_pattern.format( + category=cat_name, + process=proc_name, + parameter=param_obj.name, + direction="Up", + ) + # read or create the varied histograms, or skip the parameter if param_obj.type.is_shape: # the source of the shape depends on the transformation if param_obj.transformations.any_from_rate: + # create the shape from the nominal one and an integral rate effect if isinstance(param_obj.effect, float): f_down, f_up = 2.0 - param_obj.effect, param_obj.effect elif isinstance(param_obj.effect, tuple) and len(param_obj.effect) == 2: @@ -510,26 +691,54 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: # just extract the shapes h_down = sum_hists((param_obj.name, "down"), "nominal") * scale h_up = sum_hists((param_obj.name, "up"), "nominal") * scale + # just extract the shapes from the inputs + h_down = load(down_name, (param_obj.name, "down"), "nominal", scale=scale) + h_up = load(up_name, (param_obj.name, "up"), "nominal", scale=scale) elif param_obj.type.is_rate: if param_obj.transformations.any_from_shape: # just extract the shapes - h_down = sum_hists((param_obj.name, "down"), "nominal") * scale - h_up = sum_hists((param_obj.name, "up"), "nominal") * scale + h_down = load(down_name, (param_obj.name, "down"), "nominal", scale=scale) + h_up = load(up_name, (param_obj.name, "up"), "nominal", scale=scale) + + # in case the transformation is effect_from_shape_if_flat, and any of the two variations + # do not qualify as "flat", convert the parameter to shape-type and drop all transformations + # that do not apply to shapes + if param_obj.transformations[0] == ParameterTransformation.effect_from_shape_if_flat: + # check if flatness criteria are met + for h in [h_down, h_up]: + values = h.view().value + mean, std = values.mean(), values.std() + rel_deviation = safe_div(std, mean) + max_rel_outlier = safe_div(max(abs(values - mean)), mean) + is_flat = ( + rel_deviation <= self.effect_from_shape_if_flat_max_deviation and + max_rel_outlier <= self.effect_from_shape_if_flat_max_outlier + ) + if not is_flat: + param_obj.type = ParameterType.shape + param_obj.transformations = type(param_obj.transformations)( + trafo for trafo in param_obj.transformations[1:] + if trafo not in self.rate_only_trafos + ) + break else: - # skip the parameter continue - # apply optional transformations - integral = lambda h: h.sum().value + else: + # other effect type that is not handled yet + logger.warning(f"datacard parameter '{param_obj.name}' has unsupported type '{param_obj.type}'") + continue + + # apply optional transformations one by one for trafo in param_obj.transformations: - if trafo == ParameterTransformation.centralize: + if trafo == ParameterTransformation.symmetrize: # get the absolute spread based on integrals n, d, u = integral(h_nom), integral(h_down), integral(h_up) + # skip one sided effects if not (min(d, n) <= n <= max(d, n)): - # skip one sided effects logger.info( - f"skipping shape centralization of parameter '{param_obj.name}' for process " + f"skipping shape symmetrization of parameter '{param_obj.name}' for process " f"'{proc_name}' in category '{cat_name}' as effect is one-sided", ) continue @@ -540,42 +749,74 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: elif trafo == ParameterTransformation.normalize: # normale varied hists to the nominal integral - h_down *= safe_div(integral(h_nom), integral(h_down)) - h_up *= safe_div(integral(h_nom), integral(h_up)) - - else: - # no other transormation is applied at this point - continue + n, d, u = integral(h_nom), integral(h_down), integral(h_up) + h_down *= safe_div(n, d) + h_up *= safe_div(n, u) + + elif trafo in {ParameterTransformation.envelope, ParameterTransformation.envelope_if_one_sided}: + d, u = integral(h_down), integral(h_up) + v_nom = h_nom.view() + v_down = h_down.view() + v_up = h_up.view() + # compute masks denoting at which locations a variation is abs larger than the other + diffs_up = v_up.value - v_nom.value + diffs_down = v_down.value - v_nom.value + up_mask = abs(diffs_up) > abs(diffs_down) + down_mask = abs(diffs_down) > abs(diffs_up) + # when only checking one-sided, remove True's from the masks where variations are two-sided + if trafo == ParameterTransformation.envelope_if_one_sided: + one_sided = (diffs_up * diffs_down) > 0 + up_mask &= one_sided + down_mask &= one_sided + # fill values from the larger variation + v_up.value[down_mask] = v_nom.value[down_mask] - diffs_down[down_mask] + v_up.variance[down_mask] = v_down.variance[down_mask] + v_down.value[up_mask] = v_nom.value[up_mask] - diffs_up[up_mask] + v_down.variance[up_mask] = v_up.variance[up_mask] + + elif trafo == ParameterTransformation.envelope_enforce_two_sided: + # envelope creation with enforced two-sidedness + v_nom = h_nom.view() + v_down = h_down.view() + v_up = h_up.view() + # compute masks denoting at which locations a variation is abs larger than the other + abs_diffs_up = abs(v_up.value - v_nom.value) + abs_diffs_down = abs(v_down.value - v_nom.value) + up_mask = abs_diffs_up >= abs_diffs_down + down_mask = ~up_mask + # fill values from the absolute larger variation + v_up.value[up_mask] = v_nom.value[up_mask] + abs_diffs_up[up_mask] + v_up.value[down_mask] = v_nom.value[down_mask] + abs_diffs_down[down_mask] + v_up.variance[down_mask] = v_down.variance[down_mask] + v_down.value[down_mask] = v_nom.value[down_mask] - abs_diffs_down[down_mask] + v_down.value[up_mask] = v_nom.value[up_mask] - abs_diffs_up[up_mask] + v_down.variance[up_mask] = v_up.variance[up_mask] + + # custom hook to modify the shapes + h_nom, h_down, h_up = self.modify_parameter_shape( + cat_name, + proc_name, + param_obj, + h_nom, + h_down, + h_up, + ) - # empty bins are always filled + # fill empty bins again after all transformations fill_empty(cat_obj, h_down) fill_empty(cat_obj, h_up) - # save them when they represent real shapes - if param_obj.type.is_shape: - down_name = syst_pattern.format( - category=cat_name, - process=proc_name, - parameter=param_obj.name, - direction="Down", - ) - up_name = syst_pattern.format( - category=cat_name, - process=proc_name, - parameter=param_obj.name, - direction="Up", - ) - handle_flow(cat_obj, h_down, down_name) - handle_flow(cat_obj, h_up, up_name) - out_file[down_name] = h_down - out_file[up_name] = h_up - # save the effect __effects[param_obj.name] = ( safe_div(integral(h_down), integral(h_nom)), safe_div(integral(h_up), integral(h_nom)), ) + # save them to file if they have shape-type + if param_obj.type.is_shape: + out_file[down_name] = h_down + out_file[up_name] = h_up + # data handling, first checking if data should be faked, then if real data exists if cat_obj.data_from_processes: # fake data from processes @@ -588,31 +829,36 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: if not h_data: proc_str = ",".join(map(str, cat_obj.data_from_processes)) raise Exception(f"none of requested processes '{proc_str}' found to create fake data") - h_data = sum(h_data[1:], h_data[0].copy()) data_name = data_pattern.format(category=cat_name) - fill_empty(cat_obj, h_data) + h_data = sum_hists(h_data) handle_flow(cat_obj, h_data, data_name) + h_data.view().variance = h_data.view().value out_file[data_name] = h_data _rates["data"] = float(h_data.sum().value) - elif any(cd.data_datasets for cd in cat_obj.config_data.values()): + elif proc_hists.get("data"): + # real data h_data = [] - for config_name, config_data in cat_obj.config_data.items(): - if "data" not in proc_hists or config_name not in proc_hists["data"]: + for config_name, config_hists in proc_hists["data"].items(): + if cat_obj.config_data and config_name not in cat_obj.config_data: raise Exception( - f"the inference model '{self.inference_model_inst.cls_name}' is configured to use real " - f"data for config '{config_name}' in category '{cat_name}' but no histogram received at " - f"entry ['data']['{config_name}']: {proc_hists}", + f"received real data in datacard category '{cat_name}' for config '{config_name}', but the " + f"inference model '{self.inference_model_inst.cls_name}' is not configured to use it in " + f"the config_data for that config; configured config_names are " + f"'{','.join(cat_obj.config_data.keys())}'", ) - h_data.append(proc_hists["data"][config_name]["nominal"]) + h_data.append(config_hists["nominal"]) # simply save the data histogram that was already built from the requested datasets - h_data = sum(h_data[1:], h_data[0].copy()) + h_data = sum_hists(h_data) data_name = data_pattern.format(category=cat_name) handle_flow(cat_obj, h_data, data_name) out_file[data_name] = h_data _rates["data"] = h_data.sum().value + else: + logger.warning(f"neither real data found nor fake data created in category '{cat_name}'") + return (rates, effects, nom_pattern_comb, syst_pattern_comb) @classmethod @@ -675,3 +921,45 @@ def align_rates_and_parameters( lines = cls.align_lines(rates + parameters) return lines[:n_rate_lines], lines[n_rate_lines:] + + def modify_parameter_effect( + self, + category: str, + process: str, + param_obj: DotDict, + effect: float | tuple[float, float], + ) -> float | tuple[float, float]: + """ + Custom hook to modify the effect of a parameter on a given category and process before it is encoded into the + datacard. By default, this does nothing and simply returns the given effect. + + :param category: The category name. + :param process: The process name. + :param param_obj: The parameter object, following :py:meth:`columnflow.inference.InferenceModel.parameter_spec`. + :param effect: The effect value(s) to be modified. + :returns: The modified effect value(s). + """ + return effect + + def modify_parameter_shape( + self, + category: str, + process: str, + param_obj: DotDict, + h_nom: hist.Hist, + h_down: hist.Hist, + h_up: hist.Hist, + ) -> tuple[hist.Hist, hist.Hist, hist.Hist]: + """ + Custom hook to modify the nominal and varied (down, up) shapes of a parameter on a given category and process + before they are saved to the shapes file. By default, this does nothing and simply returns the given histograms. + + :param category: The category name. + :param process: The process name. + :param param_obj: The parameter object, following :py:meth:`columnflow.inference.InferenceModel.parameter_spec`. + :param h_nom: The nominal histogram. + :param h_down: The down-varied histogram. + :param h_up: The up-varied histogram. + :returns: The modified nominal and varied (down, up) histograms. + """ + return h_nom, h_down, h_up diff --git a/columnflow/ml/__init__.py b/columnflow/ml/__init__.py index 43419ea86..e50b22bf9 100644 --- a/columnflow/ml/__init__.py +++ b/columnflow/ml/__init__.py @@ -12,11 +12,12 @@ import law import order as od -from columnflow.types import Any, Sequence -from columnflow.util import maybe_import, Derivable, DotDict, KeyValueMessage from columnflow.columnar_util import Route +from columnflow.util import maybe_import, Derivable, DotDict, KeyValueMessage +from columnflow.types import TYPE_CHECKING, Any, Sequence -ak = maybe_import("awkward") +if TYPE_CHECKING: + ak = maybe_import("awkward") class MLModel(Derivable): diff --git a/columnflow/plotting/cmsGhent/plot_functions_1d.py b/columnflow/plotting/cmsGhent/plot_functions_1d.py index 589643ca1..e6e9d4b2d 100644 --- a/columnflow/plotting/cmsGhent/plot_functions_1d.py +++ b/columnflow/plotting/cmsGhent/plot_functions_1d.py @@ -1,5 +1,7 @@ from __future__ import annotations +import math + import order as od import law from collections import OrderedDict @@ -14,13 +16,12 @@ from columnflow.plotting.plot_all import plot_all from columnflow.plotting.cmsGhent.plot_util import cumulate +from columnflow.types import TYPE_CHECKING -plt = maybe_import("matplotlib.pyplot") np = maybe_import("numpy") -mtrans = maybe_import("matplotlib.transforms") -mplhep = maybe_import("mplhep") -math = maybe_import("math") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def plot_multi_variables( @@ -252,6 +253,8 @@ def plot_1d_line( """ TODO. """ + import hist + n_bins = math.prod([v.n_bins for v in variable_insts]) def flatten_data(data: hist.Hist | np.ndarray): diff --git a/columnflow/plotting/cmsGhent/plot_functions_2d.py b/columnflow/plotting/cmsGhent/plot_functions_2d.py index ed58a89c9..98ac2551a 100644 --- a/columnflow/plotting/cmsGhent/plot_functions_2d.py +++ b/columnflow/plotting/cmsGhent/plot_functions_2d.py @@ -1,15 +1,12 @@ +import order as od import law from collections import OrderedDict from columnflow.util import maybe_import +from unittest.mock import patch +from functools import partial -plt = maybe_import("matplotlib.pyplot") np = maybe_import("numpy") -od = maybe_import("order") -mtrans = maybe_import("matplotlib.transforms") -mplhep = maybe_import("mplhep") -hist = maybe_import("hist") -from columnflow.plotting.plot_all import make_plot_2d from columnflow.plotting.plot_util import ( apply_variable_settings, remove_residual_axis, @@ -23,6 +20,7 @@ def merge_migration_bins(h): """ binning both axes in equal bins """ + import hist x_edges = h.axes[0].edges y_edges = h.axes[1].edges @@ -97,6 +95,10 @@ def plot_migration_matrices( keep_bins_in_bkg: bool = False, **kwargs, ): + import mplhep + import matplotlib.transforms as mtrans + import matplotlib.pyplot as plt + plt.style.use(mplhep.style.CMS) fig, axes = plt.subplots( 2, 3, @@ -154,16 +156,44 @@ def plot_migration_matrices( style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + # # make main central migration plot - make_plot_2d(plot_config, style_config, figaxes=(fig, axes[0, 1])) + # + + central_ax = axes[0, 1] + + # apply style_config + if ax_cfg := style_config.get("ax_cfg", {}): + for tickname in ["xticks", "yticks"]: + ticks = ax_cfg.pop(tickname) + for ticksize in ["major", "minor"]: + if subticks := ticks.get(ticksize, {}): + getattr(central_ax, "set_" + tickname)(**subticks, minor=ticksize == "minor") + central_ax.set(**ax_cfg) + + if "legend_cfg" in style_config: + central_ax.legend(**style_config["legend_cfg"]) + + # annotation of category label + if annotate_kwargs := style_config.get("annotate_cfg", {}): + central_ax.annotate(**annotate_kwargs) + + if cms_label_kwargs := style_config.get("cms_label_cfg", {}): + mplhep.cms.label(ax=central_ax, **cms_label_kwargs) + + # call plot method, patching the colorbar function + # called internally by mplhep to draw the extension symbols + with patch.object(plt, "colorbar", partial(plt.colorbar, **plot_config.get("cbar_kwargs", {}))): + plot_config["hist"].plot2d(ax=central_ax, **plot_config.get("kwargs", {})) + if label_numbers: for i, x in enumerate(migrations_eq_ax.axes[0].centers): for j, y in enumerate(migrations_eq_ax.axes[1].centers): if abs(i - j) <= 1: lbl = f"{migrations_eq_ax.values()[i, j] * 100:.0f}" - axes[0, 1].text(x, y, lbl, ha="center", va="center", size="large") + central_ax.text(x, y, lbl, ha="center", va="center", size="large") - cbar = plt.colorbar(axes[0, 1].collections[0], **plot_config["cbar_kwargs"]) + cbar = plt.colorbar(central_ax.collections[0], **plot_config["cbar_kwargs"]) fix_cbar_minor_ticks(cbar) # set cbar range diff --git a/columnflow/plotting/cmsGhent/plot_util.py b/columnflow/plotting/cmsGhent/plot_util.py index 08b11643c..f3e83d8ce 100644 --- a/columnflow/plotting/cmsGhent/plot_util.py +++ b/columnflow/plotting/cmsGhent/plot_util.py @@ -1,12 +1,16 @@ from __future__ import annotations import order as od from columnflow.util import maybe_import +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") np = maybe_import("numpy") +if TYPE_CHECKING: + hist = maybe_import("hist") def cumulate(h: np.ndarray | hist.Hist, direction="below", axis: str | int | od.Variable = 0): + import hist + idx_slice = np.s_[::-1] if direction == "above" else np.s_[:] arr = h if isinstance(h, np.ndarray) else h.view(flow=False) if isinstance(axis, od.Variable): diff --git a/columnflow/plotting/cmsGhent/unrolled.py b/columnflow/plotting/cmsGhent/unrolled.py index 746f2bd0c..10b08e58e 100644 --- a/columnflow/plotting/cmsGhent/unrolled.py +++ b/columnflow/plotting/cmsGhent/unrolled.py @@ -33,6 +33,7 @@ from collections import OrderedDict import law +import order as od from columnflow.util import maybe_import from columnflow.plotting.plot_all import ( @@ -48,25 +49,25 @@ get_cms_label, get_position, ) +from columnflow.types import TYPE_CHECKING - -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") -od = maybe_import("order") -mticker = maybe_import("matplotlib.ticker") -colorsys = maybe_import("colorsys") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def change_saturation(hls, saturation_factor): + import colorsys + # Convert back to RGB new_rgb = colorsys.hls_to_rgb(hls[0], hls[1], saturation_factor) return new_rgb def get_new_colors(original_color, n_new_colors=2): + import colorsys + # Convert RGB to HLS hls = colorsys.rgb_to_hls(*original_color) @@ -178,6 +179,9 @@ def plot_unrolled( variable_settings: dict | None = None, **kwargs, ) -> plt.Figure: + import mplhep + import matplotlib as mpl + import matplotlib.pyplot as plt # remove shift axis from histograms if len(shift_insts) == 1: diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index 685f4627f..7f8d67ad6 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -10,7 +10,6 @@ import order as od -from columnflow.types import Sequence from columnflow.util import maybe_import, try_float from columnflow.config_util import group_shifts from columnflow.plotting.plot_util import ( @@ -21,12 +20,12 @@ apply_label_placeholders, calculate_stat_error, ) +from columnflow.types import TYPE_CHECKING, Sequence -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def draw_stat_error_bands( @@ -35,6 +34,8 @@ def draw_stat_error_bands( norm: float | Sequence | np.ndarray = 1.0, **kwargs, ) -> None: + import hist + assert len(h.axes) == 1 # compute relative statistical errors @@ -71,6 +72,8 @@ def draw_syst_error_bands( method: str = "quadratic_sum", **kwargs, ) -> None: + import hist + assert len(h.axes) == 1 assert method in ("quadratic_sum", "envelope") @@ -80,13 +83,14 @@ def draw_syst_error_bands( # create pairs of shifts mapping from up -> down and vice versa shift_pairs = {} + shift_pairs[nominal_shift] = nominal_shift # nominal shift maps to itself for up_shift, down_shift in shift_groups.values(): shift_pairs[up_shift] = down_shift shift_pairs[down_shift] = up_shift # stack histograms separately per shift, falling back to the nominal one when missing shift_stacks: dict[od.Shift, hist.Hist] = {} - for shift_inst in sum(shift_groups.values(), []): + for shift_inst in sum(shift_groups.values(), [nominal_shift]): for _h in syst_hists: # when the shift is present, the flipped shift must exist as well shift_ax = _h.axes["shift"] @@ -119,8 +123,8 @@ def draw_syst_error_bands( down_diffs = [] for source, (up_shift, down_shift) in shift_groups.items(): # get actual differences resulting from this shift - shift_up_diff = shift_stacks[up_shift].values()[b] - h.values()[b] - shift_down_diff = shift_stacks[down_shift].values()[b] - h.values()[b] + shift_up_diff = shift_stacks[up_shift].values()[b] - shift_stacks[nominal_shift].values()[b] + shift_down_diff = shift_stacks[down_shift].values()[b] - shift_stacks[nominal_shift].values()[b] # store them depending on whether they really increase or decrease the yield up_diffs.append(max(shift_up_diff, shift_down_diff, 0)) down_diffs.append(min(shift_up_diff, shift_down_diff, 0)) @@ -168,6 +172,8 @@ def draw_stack( norm: float | Sequence | np.ndarray = 1.0, **kwargs, ) -> None: + import hist + # check if norm is a number if try_float(norm): h = hist.Stack(*[i / norm for i in h]) @@ -201,6 +207,8 @@ def draw_hist( error_type: str = "variance", **kwargs, ) -> None: + import hist + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} if kwargs.get("color", "") is None: @@ -242,6 +250,8 @@ def draw_profile( """ Profiled histograms contains the storage type "Mean" and can therefore not be normalized """ + import hist + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} if kwargs.get("color", "") is None: @@ -271,6 +281,8 @@ def draw_errorbars( error_type: str = "poisson_unweighted", **kwargs, ) -> None: + import hist + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} values = h.values() / norm @@ -341,20 +353,31 @@ def plot_all( :param magnitudes: Optional float parameter that defines the displayed ymin when plotting with a logarithmic scale. :return: tuple of plot figure and axes """ + import matplotlib as mpl + import matplotlib.pyplot as plt + import mplhep + # general mplhep style plt.style.use(mplhep.style.CMS) + # use non-interactive Agg backend for plotting + mpl.use("Agg") + # setup figure and axes rax = None grid_spec = {"left": 0.15, "right": 0.95, "top": 0.95, "bottom": 0.1} grid_spec |= style_config.get("gridspec_cfg", {}) + + # Get figure size from style_config, with default values + subplots_cfg = style_config.get("subplots_cfg", {}) + if not skip_ratio: grid_spec = {"height_ratios": [3, 1], "hspace": 0, **grid_spec} - fig, axs = plt.subplots(2, 1, gridspec_kw=grid_spec, sharex=True) + fig, axs = plt.subplots(2, 1, gridspec_kw=grid_spec, sharex=True, **subplots_cfg) (ax, rax) = axs else: grid_spec.pop("height_ratios", None) - fig, ax = plt.subplots(gridspec_kw=grid_spec) + fig, ax = plt.subplots(gridspec_kw=grid_spec, **subplots_cfg) axs = (ax,) # invoke all plots methods diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 0ec3f3730..34b6d02a7 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -11,8 +11,8 @@ from collections import OrderedDict import law +import order as od -from columnflow.types import Iterable from columnflow.util import maybe_import from columnflow.plotting.plot_all import plot_all from columnflow.plotting.plot_util import ( @@ -27,17 +27,16 @@ get_position, get_profile_variations, blind_sensitive_bins, + remove_negative_contributions, join_labels, ) -from columnflow.hist_util import add_missing_shifts +from columnflow.hist_util import add_missing_shifts, sum_hists +from columnflow.types import TYPE_CHECKING, Iterable - -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") -od = maybe_import("order") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def plot_variable_stack( @@ -60,19 +59,24 @@ def plot_variable_stack( hists, process_style_config = apply_process_settings(hists, process_settings) # variable-based settings (rebinning, slicing, flow handling) hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) - # process scaling - hists = apply_process_scaling(hists) # remove data in bins where sensitivity exceeds some threshold blinding_threshold = kwargs.get("blinding_threshold", None) if blinding_threshold: hists = blind_sensitive_bins(hists, config_inst, blinding_threshold) + + # remove negative contributions per process if requested + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) + + # process scaling + hists = apply_process_scaling(hists) # density scaling per bin if density: hists = apply_density(hists, density) if len(shift_insts) == 1: # when there is exactly one shift bin, we can remove the shift axis - hists = remove_residual_axis(hists, "shift", select_value=shift_insts[0].name) + hists = remove_residual_axis(hists, "shift") else: # remove shift axis of histograms that are not to be stacked unstacked_hists = { @@ -99,6 +103,9 @@ def plot_variable_stack( shape_norm, yscale, ) + # additional, plot function specific changes + if shape_norm: + default_style_config["ax_cfg"]["ylabel"] = "Normalized entries" style_config = law.util.merge_dicts( default_style_config, process_style_config, @@ -107,13 +114,51 @@ def plot_variable_stack( deep=True, ) - # additional, plot function specific changes - if shape_norm: - style_config["ax_cfg"]["ylabel"] = "Normalized entries" - return plot_all(plot_config, style_config, **kwargs) +def plot_variable_efficiency( + hists: OrderedDict, + config_inst: od.Config, + category_inst: od.Category, + variable_insts: list[od.Variable], + shift_insts: list[od.Shift] | None, + style_config: dict | None = None, + shape_norm: bool = True, + cumsum_reverse: bool = True, + **kwargs, +): + """ + This plot function allows users to plot the efficiency of a cut on a variable as a function of the cut value. + Per default, each bin shows the efficiency of requiring value >= bin edge (cumsum_reverse=True). + Setting cumsum_reverse=False will instead show the efficiency of requiring value <= bin edge. + """ + for proc_inst, proc_hist in hists.items(): + if cumsum_reverse: + proc_hist.values()[...] = np.cumsum(proc_hist.values()[..., ::-1], axis=-1)[..., ::-1] + shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: h.values()[0] if shape_norm else 1) + else: + proc_hist.values()[...] = np.cumsum(proc_hist.values(), axis=-1) + shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: h.values()[-1] if shape_norm else 1) + + default_style_config = { + "ax_cfg": {"ylabel": "Efficiency" if shape_norm else "Cumulative entries"}, + } + style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + + return plot_variable_stack( + hists, + config_inst, + category_inst, + variable_insts, + shift_insts, + shape_norm=shape_norm, + shape_norm_func=shape_norm_func, + style_config=style_config, + **kwargs, + ) + + def plot_variable_variants( hists: OrderedDict, config_inst: od.Config, @@ -130,11 +175,12 @@ def plot_variable_variants( """ TODO. """ - hists = remove_residual_axis(hists, "shift") variable_inst = variable_insts[0] - hists = apply_variable_settings(hists, variable_insts, variable_settings)[0] + hists = apply_variable_settings(hists, variable_insts, variable_settings) + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) if density: hists = apply_density(hists, density) @@ -201,10 +247,14 @@ def plot_shifted_variable( """ TODO. """ + import hist + variable_inst = variable_insts[0] hists, process_style_config = apply_process_settings(hists, process_settings) hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) hists = apply_process_scaling(hists) if density: hists = apply_density(hists, density) @@ -215,7 +265,7 @@ def plot_shifted_variable( add_missing_shifts(h, all_shifts, str_axis="shift", nominal_bin="nominal") # create the sum of histograms over all processes - h_sum = sum(list(hists.values())[1:], list(hists.values())[0].copy()) + h_sum = sum_hists(hists.values()) # setup plotting configs plot_config = {} @@ -275,6 +325,8 @@ def plot_shifted_variable( default_style_config["rax_cfg"]["ylabel"] = "Ratio" if legend_title: default_style_config["legend_cfg"]["title"] = legend_title + if shape_norm: + style_config["ax_cfg"]["ylabel"] = "Normalized entries" style_config = law.util.merge_dicts( default_style_config, process_style_config, @@ -282,8 +334,6 @@ def plot_shifted_variable( style_config, deep=True, ) - if shape_norm: - style_config["ax_cfg"]["ylabel"] = "Normalized entries" return plot_all(plot_config, style_config, **kwargs) @@ -351,7 +401,7 @@ def plot_cutflow( }, "annotate_cfg": {"text": cat_label or ""}, "cms_label_cfg": { - "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 2), # /pb -> /fb + "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 1), # /pb -> /fb "com": config_inst.campaign.ecm, }, } @@ -401,6 +451,8 @@ def plot_profile( :param base_distribution_yscale: yscale of the base distributions :param skip_variations: whether to skip adding the up and down variation of the profile plot """ + import matplotlib.pyplot as plt + if len(variable_insts) != 2: raise Exception("The plot_profile function can only be used for 2-dimensional input histograms.") @@ -409,6 +461,8 @@ def plot_profile( hists, process_style_config = apply_process_settings(hists, process_settings) hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) hists = apply_process_scaling(hists) if density: hists = apply_density(hists, density) diff --git a/columnflow/plotting/plot_functions_2d.py b/columnflow/plotting/plot_functions_2d.py index c611f13f4..2009586fe 100644 --- a/columnflow/plotting/plot_functions_2d.py +++ b/columnflow/plotting/plot_functions_2d.py @@ -6,13 +6,17 @@ from __future__ import annotations +__all__ = [] + from collections import OrderedDict from functools import partial from unittest.mock import patch import law +import order as od from columnflow.util import maybe_import +from columnflow.hist_util import sum_hists from columnflow.plotting.plot_util import ( remove_residual_axis, apply_variable_settings, @@ -22,14 +26,11 @@ get_position, reduce_with, ) +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") -od = maybe_import("order") -mticker = maybe_import("matplotlib.ticker") +if TYPE_CHECKING: + plt = maybe_import("matplotlib.pyplot") def plot_2d( @@ -55,6 +56,10 @@ def plot_2d( variable_settings: dict | None = None, **kwargs, ) -> plt.Figure: + import matplotlib as mpl + import matplotlib.pyplot as plt + import mplhep + # remove shift axis from histograms hists = remove_residual_axis(hists, "shift") @@ -77,7 +82,7 @@ def plot_2d( extremes = "color" # add all processes into 1 histogram - h_sum = sum(list(hists.values())[1:], list(hists.values())[0].copy()) + h_sum = sum_hists(hists.values()) if shape_norm: h_sum = h_sum / h_sum.sum().value @@ -168,7 +173,8 @@ def plot_2d( "loc": "upper right", }, "cms_label_cfg": { - "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 2), # /pb -> /fb + "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 1), # /pb -> /fb + "com": config_inst.campaign.ecm, }, "plot2d_cfg": { "norm": cbar_norm, @@ -272,10 +278,10 @@ def plot_2d( _scale = cbar.ax.yaxis._scale _scale.subs = [2, 3, 4, 5, 6, 7, 8, 9] cbar.ax.yaxis.set_minor_locator( - mticker.SymmetricalLogLocator(_scale.get_transform(), subs=_scale.subs), + mpl.ticker.SymmetricalLogLocator(_scale.get_transform(), subs=_scale.subs), ) cbar.ax.yaxis.set_minor_formatter( - mticker.LogFormatterSciNotation(_scale.base), + mpl.ticker.LogFormatterSciNotation(_scale.base), ) plt.tight_layout() diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 26c4fe6cc..ec5e951c4 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -6,42 +6,45 @@ from __future__ import annotations +__all__ = [] + import re -from columnflow.types import Sequence +import order as od +import scinum + from columnflow.util import maybe_import from columnflow.plotting.plot_util import get_cms_label +from columnflow.types import TYPE_CHECKING, Sequence -ak = maybe_import("awkward") -od = maybe_import("order") np = maybe_import("numpy") -sci = maybe_import("scinum") -plt = maybe_import("matplotlib.pyplot") -hep = maybe_import("mplhep") -colors = maybe_import("matplotlib.colors") +ak = maybe_import("awkward") +if TYPE_CHECKING: + plt = maybe_import("matplotlib.pyplot") + # define a CF custom color maps cf_colors = { - "cf_green_cmap": colors.ListedColormap([ + "cf_green_cmap": [ "#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927", "#305126", "#325A25", "#356224", "#386B22", "#3B7520", "#3F7F1E", "#43891B", "#479418", "#4C9F14", "#52AA10", "#58B60C", "#5FC207", "#67cf02", - ]), - "cf_ygb_cmap": colors.ListedColormap([ + ], + "cf_ygb_cmap": [ "#003675", "#005B83", "#008490", "#009A83", "#00A368", "#00AC49", "#00B428", "#00BC06", "#0CC300", "#39C900", "#67cf02", "#72DB02", "#7EE605", "#8DF207", "#9CFD09", "#AEFF0B", "#C1FF0E", "#D5FF10", "#EBFF12", "#FFFF14", - ]), - "cf_cmap": colors.ListedColormap([ + ], + "cf_cmap": [ "#002C9C", "#00419F", "#0056A2", "#006BA4", "#0081A7", "#0098AA", "#00ADAB", "#00B099", "#00B287", "#00B574", "#00B860", "#00BB4C", "#00BD38", "#00C023", "#00C20D", "#06C500", "#1EC800", "#36CA00", "#4ECD01", "#67cf02", - ]), - "viridis": colors.ListedColormap([ + ], + "viridis": [ "#263DA8", "#1652CC", "#1063DB", "#1171D8", "#1380D5", "#0E8ED0", "#089DCC", "#0DA7C2", "#1DAFB3", "#2DB7A3", "#52BA91", "#73BD80", "#94BE71", "#B2BC65", "#D0BA59", "#E1BF4A", "#F4C53A", "#FCD12B", "#FAE61C", "#F9F90E", - ]), + ], } @@ -111,6 +114,10 @@ def plot_cm( is not *None* and its shape doesn't match *predictions*. :raises AssertionError: If *normalization* is not one of *None*, "row", "column". """ + import matplotlib as mpl + import matplotlib.pyplot as plt + import mplhep + # defining some useful properties and output shapes true_labels = list(events.keys()) pred_labels = [s.removeprefix("score_") for s in list(events.values())[0].fields] @@ -136,7 +143,7 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: counts[ind, index] += count if not skip_uncertainties: - vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count) if count else 0))) + vecNumber = np.vectorize(lambda n, count: scinum.Number(n, float(n / np.sqrt(count) if count else 0))) result = vecNumber(result, counts) # normalize Matrix if needed @@ -203,7 +210,7 @@ def get_errors(matrix): Useful for seperating the error from the data """ if matrix.dtype.name == "object": - get_errors_vec = np.vectorize(lambda x: x.get(sci.UP, unc=True)) + get_errors_vec = np.vectorize(lambda x: x.get(scinum.UP, unc=True)) return get_errors_vec(matrix) return np.zeros_like(matrix) @@ -219,13 +226,13 @@ def fmt(v): return "{}\n\u00B1{}".format(fmt(values[i][j]), fmt(np.nan_to_num(uncs[i][j]))) # create the plot - plt.style.use(hep.style.CMS) + plt.style.use(mplhep.style.CMS) fig, ax = plt.subplots(dpi=300) # some useful variables and functions n_processes = cm.shape[0] n_classes = cm.shape[1] - cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) + cmap = mpl.colors.ListedColormap(cf_colors.get(colormap, cf_colors["cf_cmap"])) x_labels = x_labels if x_labels else [f"out{i}" for i in range(n_classes)] y_labels = y_labels if y_labels else true_labels font_ax = 20 @@ -292,7 +299,7 @@ def fmt(v): if cms_llabel != "skip": cms_label_kwargs = get_cms_label(ax=ax, llabel=cms_llabel) cms_label_kwargs["rlabel"] = cms_rlabel - hep.cms.label(**cms_label_kwargs) + mplhep.cms.label(**cms_label_kwargs) plt.tight_layout() return fig @@ -349,6 +356,9 @@ def plot_roc( is not *None* and its shape doesn't match *predictions*. :raises ValueError: If *normalization* is not one of *None*, 'row', 'column'. """ + import matplotlib.pyplot as plt + import mplhep + # defining some useful properties and output shapes thresholds = np.linspace(0, 1, n_thresholds) weights = create_sample_weights(sample_weights, events, list(events.keys())) @@ -478,7 +488,7 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64: fpr = roc_data["fpr"] tpr = roc_data["tpr"] - plt.style.use(hep.style.CMS) + plt.style.use(mplhep.style.CMS) fig, ax = plt.subplots(dpi=300) ax.set_xlabel("FPR", loc="right", labelpad=10, fontsize=25) ax.set_ylabel("TPR", loc="top", labelpad=15, fontsize=25) @@ -499,7 +509,7 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64: if cms_llabel != "skip": cms_label_kwargs = get_cms_label(ax=ax, llabel=cms_llabel) cms_label_kwargs["rlabel"] = cms_rlabel - hep.cms.label(**cms_label_kwargs) + mplhep.cms.label(**cms_label_kwargs) plt.tight_layout() return fig diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index 64249bb96..c680cc46a 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -9,6 +9,7 @@ __all__ = [] import re +import math import operator import functools from collections import OrderedDict @@ -17,17 +18,14 @@ import order as od import scinum as sn -from columnflow.util import maybe_import, try_int, try_complex, UNSET -from columnflow.hist_util import copy_axis -from columnflow.types import Iterable, Any, Callable, Sequence, Hashable +from columnflow.util import maybe_import, try_int, try_complex, safe_div, UNSET +from columnflow.hist_util import copy_axis, sum_hists +from columnflow.types import TYPE_CHECKING, Iterable, Any, Callable, Sequence, Hashable -math = maybe_import("math") -hist = maybe_import("hist") np = maybe_import("numpy") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") -mpl = maybe_import("matplotlib") -mticker = maybe_import("matplotlib.ticker") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") logger = law.logger.get_logger(__name__) @@ -227,7 +225,7 @@ def get_stack_integral() -> float: if scale_factor == "stack": # compute the scale factor and round h_no_shift = remove_residual_axis_single(h, "shift", select_value="nominal") - scale_factor = round_dynamic(get_stack_integral() / h_no_shift.sum().value) or 1 + scale_factor = round_dynamic(safe_div(get_stack_integral(), h_no_shift.sum().value)) or 1 if try_int(scale_factor): scale_factor = int(scale_factor) hists[proc_inst] = h * scale_factor @@ -258,6 +256,8 @@ def apply_variable_settings( applies settings from *variable_settings* dictionary to the *variable_insts*; the *rebin*, *overflow*, *underflow*, and *slice* settings are directly applied to the histograms """ + import hist + # store info gathered along application of variable settings that can be inserted to the style config variable_style_config = {} @@ -276,6 +276,14 @@ def apply_variable_settings( h = h[{var_inst.name: hist.rebin(rebin_factor)}] hists[proc_inst] = h + # overflow and underflow bins + overflow = get_attr_or_aux(var_inst, "overflow", False) + underflow = get_attr_or_aux(var_inst, "underflow", False) + if overflow or underflow: + for proc_inst, h in list(hists.items()): + h = use_flow_bins(h, var_inst.name, underflow=underflow, overflow=overflow) + hists[proc_inst] = h + # slicing slices = get_attr_or_aux(var_inst, "slice", None) if ( @@ -288,14 +296,6 @@ def apply_variable_settings( h = h[{var_inst.name: slice(slice_0, slice_1)}] hists[proc_inst] = h - # overflow and underflow bins - overflow = get_attr_or_aux(var_inst, "overflow", False) - underflow = get_attr_or_aux(var_inst, "underflow", False) - if overflow or underflow: - for proc_inst, h in list(hists.items()): - h = use_flow_bins(h, var_inst.name, underflow=underflow, overflow=overflow) - hists[proc_inst] = h - # additional x axis transformations for trafo in law.util.make_list(get_attr_or_aux(var_inst, "x_transformations", None) or []): # forced representation into equal bins @@ -317,6 +317,15 @@ def apply_variable_settings( return hists, variable_style_config +def remove_negative_contributions(hists: dict[Hashable, hist.Hist]) -> dict[Hashable, hist.Hist]: + _hists = hists.copy() + for proc_inst, h in hists.items(): + h = h.copy() + h.view().value[h.view().value < 0] = 0 + _hists[proc_inst] = h + return _hists + + def use_flow_bins( h_in: hist.Hist, axis_name: str | int, @@ -376,12 +385,12 @@ def apply_density(hists: dict, density: bool = True) -> dict: if not density: return hists - for key, hist in hists.items(): + for key, h in hists.items(): # bin area safe for multi-dimensional histograms - area = functools.reduce(operator.mul, hist.axes.widths) + area = functools.reduce(operator.mul, h.axes.widths) # scale hist by bin area - hists[key] = hist / area + hists[key] = h / area return hists @@ -392,6 +401,8 @@ def remove_residual_axis_single( max_bins: int = 1, select_value: Any = None, ) -> hist.Hist: + import hist + # force always returning a copy h = h.copy() @@ -526,6 +537,8 @@ def prepare_stack_plot_config( backgrounds with uncertainty bands, unstacked processes as lines and data entrys with errorbars. """ + import hist + # separate histograms into stack, lines and data hists mc_hists, mc_colors, mc_edgecolors, mc_labels = [], [], [], [] mc_syst_hists = [] @@ -558,17 +571,21 @@ def prepare_stack_plot_config( h_data, h_mc, h_mc_stack = None, None, None if data_hists: - h_data = sum(data_hists[1:], data_hists[0].copy()) + h_data = sum_hists(data_hists) if mc_hists: - h_mc = sum(mc_hists[1:], mc_hists[0].copy()) + h_mc = sum_hists(mc_hists) h_mc_stack = hist.Stack(*mc_hists) # setup plotting configs plot_config = OrderedDict() + # take first (non-underflow) bin + # shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: h.values()[0] if shape_norm else 1) + shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: sum(h.values()) if shape_norm else 1) + # draw stack if h_mc_stack is not None: - mc_norm = sum(h_mc.values()) if shape_norm else 1 + mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_stack"] = { "method": "draw_stack", "hist": h_mc_stack, @@ -583,7 +600,7 @@ def prepare_stack_plot_config( # draw lines for i, h in enumerate(line_hists): - line_norm = sum(h.values()) if shape_norm else 1 + line_norm = shape_norm_func(h, shape_norm) plot_config[f"line_{i}"] = plot_cfg = { "method": "draw_hist", "hist": h, @@ -607,7 +624,7 @@ def prepare_stack_plot_config( # draw statistical error for stack if h_mc_stack is not None and not hide_stat_errors: - mc_norm = sum(h_mc.values()) if shape_norm else 1 + mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_stat_unc"] = { "method": "draw_stat_error_bands", "hist": h_mc, @@ -617,7 +634,7 @@ def prepare_stack_plot_config( # draw systematic error for stack if h_mc_stack is not None and mc_syst_hists: - mc_norm = sum(h_mc.values()) if shape_norm else 1 + mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_syst_unc"] = { "method": "draw_syst_error_bands", "hist": h_mc, @@ -636,7 +653,7 @@ def prepare_stack_plot_config( # draw data if data_hists: - data_norm = sum(h_data.values()) if shape_norm else 1 + data_norm = shape_norm_func(h_data, shape_norm) plot_config["data"] = plot_cfg = { "method": "draw_errorbars", "hist": h_data, @@ -886,7 +903,10 @@ def get_profile_variations(h_in: hist.Hist, axis: int = 1) -> dict[str, hist.His return {"nominal": h_nom, "up": h_up, "down": h_down} -def fix_cbar_minor_ticks(cbar: mpl.colorbar.Colorbar): +def fix_cbar_minor_ticks(cbar): + import matplotlib as mpl + import matplotlib.ticker as mticker + if isinstance(cbar.norm, mpl.colors.SymLogNorm): _scale = cbar.ax.yaxis._scale _scale.subs = [2, 3, 4, 5, 6, 7, 8, 9] @@ -910,6 +930,8 @@ def prepare_plot_config_2d( extreme_colors: tuple[str] | None = None, colormap: str = "", ): + import matplotlib as mpl + # add all processes into 1 histogram h_sum = sum(list(hists.values())[1:], list(hists.values())[0].copy()) if shape_norm: @@ -945,6 +967,7 @@ def prepare_plot_config_2d( # based on scale type and h_sum content # log scale (turning linear for low values) + if zscale == "log": # use SymLogNorm to correctly handle both positive and negative values cbar_norm = mpl.colors.SymLogNorm( @@ -1015,6 +1038,8 @@ def prepare_style_config_2d( variable_insts: list[od.Variable], cms_label: str = "", ) -> dict: + import matplotlib as mpl + # setup style config # TODO: some kind of z-label is still missing @@ -1113,7 +1138,7 @@ def blind_sensitive_bins( # set data points in masked region to zero for proc, h in data.items(): - h.values()[..., mask] = 0 + h.values()[..., mask] = -999 h.variances()[..., mask] = 0 # merge all histograms @@ -1135,6 +1160,8 @@ def rebin_equal_width( :param axis_name: Name of the axis to rebin. :return: Tuple of the rebinned histograms and the new bin edges. """ + import hist + # get the variable axis from the first histogram assert hists for var_index, var_axis in enumerate(list(hists.values())[0].axes): @@ -1230,29 +1257,30 @@ def remove_label_placeholders( return re.sub(f"__{sel}__", "", label) -def calculate_stat_error( - hist: hist.Hist, - error_type: str, -) -> dict: +def calculate_stat_error(h: hist.Hist, error_type: str) -> np.ndarray: """ - Calculate the error to be plotted for the given histogram *hist*. + Calculate the error to be plotted for the given histogram *h*. Supported error types are: - - 'variance': the plotted error is the square root of the variance for each bin - - 'poisson_unweighted': the plotted error is the poisson error for each bin - - 'poisson_weighted': the plotted error is the poisson error for each bin, weighted by the variance - """ + - "variance": the plotted error is the square root of the variance for each bin + - "poisson_unweighted": the plotted error is the poisson error for each bin + - "poisson_weighted": the plotted error is the poisson error for each bin, weighted by the variance + """ # determine the error type if error_type == "variance": - yerr = hist.view().variance ** 0.5 + yerr = h.view().variance ** 0.5 + elif error_type in {"poisson_unweighted", "poisson_weighted"}: # compute asymmetric poisson confidence interval from hist.intervals import poisson_interval - variances = hist.view().variance if error_type == "poisson_weighted" else None - values = hist.view().value + variances = h.view().variance if error_type == "poisson_weighted" else None + values = h.view().value confidence_interval = poisson_interval(values, variances) + # negative values are considerd as blinded bins -> set confidence interval to 0 + confidence_interval[:, values < 0] = 0 + if error_type == "poisson_weighted": # might happen if some bins are empty, see https://github.com/scikit-hep/hist/blob/5edbc25503f2cb8193cc5ff1eb71e1d8fa877e3e/src/hist/intervals.py#L74 # noqa: E501 confidence_interval[np.isnan(confidence_interval)] = 0 @@ -1260,19 +1288,14 @@ def calculate_stat_error( raise ValueError("Unweighted Poisson interval calculation returned NaN values, check Hist package") # calculate the error - # yerr_lower is the lower error yerr_lower = values - confidence_interval[0] - # yerr_upper is the upper error yerr_upper = confidence_interval[1] - values - # yerr is the size of the errorbars to be plotted yerr = np.array([yerr_lower, yerr_upper]) if np.any(yerr < 0): - logger.warning( - "yerr < 0, setting to 0. " - "This should not happen, please check your histogram.", - ) + logger.warning("found yerr < 0, forcing to 0; this should not happen, please check your histogram") yerr[yerr < 0] = 0 + else: raise ValueError(f"unknown error type '{error_type}'") diff --git a/columnflow/production/__init__.py b/columnflow/production/__init__.py index 529191cf3..03ff6faf9 100644 --- a/columnflow/production/__init__.py +++ b/columnflow/production/__init__.py @@ -8,18 +8,55 @@ import inspect -from columnflow.types import Callable +import law + from columnflow.util import DerivableMeta from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, Sequence, Any + + +class TaskArrayFunctionWithProducerRequirements(TaskArrayFunction): + + require_producers: Sequence[str] | set[str] | None = None + + def _req_producer(self, task: law.Task, producer: str) -> Any: + # hook to customize how required producers are requested + from columnflow.tasks.production import ProduceColumns + return ProduceColumns.req_other_producer(task, producer=producer) + def requires_func(self, task: law.Task, reqs: dict, **kwargs) -> None: + # no requirements for workflows in pilot mode + if callable(getattr(task, "is_workflow", None)) and task.is_workflow() and getattr(task, "pilot", False): + return -class Producer(TaskArrayFunction): + # add required producers when set + if (prods := self.require_producers): + reqs["required_producers"] = {prod: self._req_producer(task, prod) for prod in prods} + + def setup_func( + self, + task: law.Task, + reqs: dict, + inputs: dict, + reader_targets: law.util.InsertableDict, + **kwargs, + ) -> None: + if "required_producers" in inputs: + for prod, inp in inputs["required_producers"].items(): + reader_targets[f"required_producer_{prod}"] = inp["columns"] + + +class Producer(TaskArrayFunctionWithProducerRequirements): """ Base class for all producers. """ exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def producer( cls, @@ -27,6 +64,7 @@ def producer( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_producers: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -46,6 +84,7 @@ def producer( Monte Carlo simulation and skipped for real data. :param data_only: Boolean flag indicating that this :py:class:`Producer` should only run on real data and skipped for Monte Carlo simulation. + :param require_producers: Sequence of names of other producers to add to the requirements. :return: New :py:class:`Producer` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -55,6 +94,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_producers": require_producers, } # get the module name diff --git a/columnflow/production/categories.py b/columnflow/production/categories.py index e862249c8..d825caf45 100644 --- a/columnflow/production/categories.py +++ b/columnflow/production/categories.py @@ -6,6 +6,9 @@ from __future__ import annotations +import functools +import operator + import law from columnflow.categorization import Categorizer @@ -35,18 +38,20 @@ def category_ids( """ Assigns each event an array of category ids. """ - category_ids = [] - mask_cash = {} + # evaluate all unique categorizers, storing their returned masks + cat_masks = {} + for categorizer in self.unique_categorizers: + events, mask = self[categorizer](events, **kwargs) + cat_masks[categorizer] = mask + # loop through categories and construct mask over all categorizers + category_ids = [] for cat_inst, categorizers in self.categorizer_map.items(): - # start with a true mask - cat_mask = np.ones(len(events), dtype=bool) - - # loop through selectors - for categorizer in categorizers: - if categorizer not in mask_cash: - events, mask_cash[categorizer] = self[categorizer](events, **kwargs) - cat_mask = cat_mask & mask_cash[categorizer] + cat_mask = functools.reduce( + operator.and_, + (cat_masks[c] for c in categorizers), + np.ones(len(events), dtype=bool), + ) # covert to nullable array with the category ids or none, then apply ak.singletons ids = ak.where(cat_mask, np.float64(cat_inst.id), np.float64(np.nan)) @@ -75,7 +80,7 @@ def category_ids_init(self: Producer, **kwargs) -> None: continue # treat all selections as lists of categorizers - for sel in law.util.make_list(cat_inst.selection): + for sel in law.util.flatten(cat_inst.selection): if Categorizer.derived_by(sel): categorizer = sel elif Categorizer.has_cls(sel): @@ -98,3 +103,6 @@ def category_ids_init(self: Producer, **kwargs) -> None: self.produces.add(categorizer) self.categorizer_map.setdefault(cat_inst, []).append(categorizer) + + # store a list of unique categorizers + self.unique_categorizers = law.util.make_unique(sum(self.categorizer_map.values(), [])) diff --git a/columnflow/production/cms/btag.py b/columnflow/production/cms/btag.py index 36a88d2fb..27c6ab023 100644 --- a/columnflow/production/cms/btag.py +++ b/columnflow/production/cms/btag.py @@ -18,6 +18,7 @@ np = maybe_import("numpy") ak = maybe_import("awkward") + logger = law.logger.get_logger(__name__) @@ -251,18 +252,17 @@ def add_weight(syst_name, syst_direction, column_name): events = add_weight("central", None, self.weight_name) for syst_name, col_name in self.btag_uncs.items(): for direction in ["up", "down"]: - name = col_name.format(year=self.config_inst.campaign.x.year) events = add_weight( syst_name, direction, - f"{self.weight_name}_{name}_{direction}", + f"{self.weight_name}_{col_name}_{direction}", ) if syst_name in ["cferr1", "cferr2"]: # for c flavor uncertainties, multiply the uncertainty with the nominal btag weight events = set_ak_column( events, - f"{self.weight_name}_{name}_{direction}", - events[self.weight_name] * events[f"{self.weight_name}_{name}_{direction}"], + f"{self.weight_name}_{col_name}_{direction}", + events[self.weight_name] * events[f"{self.weight_name}_{col_name}_{direction}"], value_type=np.float32, ) elif self.shift_is_known_jec_source: @@ -290,7 +290,7 @@ def btag_weights_post_init(self: Producer, task: law.Task, **kwargs) -> None: # NOTE: we currently setup the produced columns only during the post_init. This means # that the `produces` of this Producer will be empty during task initialization, meaning - # that this Producer would be skipped if one would directly request it on command line + # that this Producer would be skipped if one would directly request it on the command line # gather info self.btag_config = self.get_btag_config() @@ -306,14 +306,14 @@ def btag_weights_post_init(self: Producer, task: law.Task, **kwargs) -> None: self.jec_source and btag_sf_jec_source in self.btag_config.jec_sources ) - # save names of method-intrinsic uncertainties + # names of method-intrinsic uncertainties, mapped to how they are namend in produced columns self.btag_uncs = { "hf": "hf", "lf": "lf", - "hfstats1": "hfstats1_{year}", - "hfstats2": "hfstats2_{year}", - "lfstats1": "lfstats1_{year}", - "lfstats2": "lfstats2_{year}", + "hfstats1": "hfstats1", + "hfstats2": "hfstats2", + "lfstats1": "lfstats1", + "lfstats2": "lfstats2", "cferr1": "cferr1", "cferr2": "cferr2", } @@ -324,9 +324,7 @@ def btag_weights_post_init(self: Producer, task: law.Task, **kwargs) -> None: self.produces.add(self.weight_name) # all varied columns for col_name in self.btag_uncs.values(): - name = col_name.format(year=self.config_inst.campaign.x.year) - for direction in ["up", "down"]: - self.produces.add(f"{self.weight_name}_{name}_{direction}") + self.produces.add(f"{self.weight_name}_{col_name}_{{up,down}}") elif self.shift_is_known_jec_source: # jec varied column self.produces.add(f"{self.weight_name}_jec_{self.jec_source}_{shift_inst.direction}") diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py index 52718f801..9e618c007 100644 --- a/columnflow/production/cms/dy.py +++ b/columnflow/production/cms/dy.py @@ -6,9 +6,9 @@ from __future__ import annotations -import law +import dataclasses -from dataclasses import dataclass +import law from columnflow.production import Producer, producer from columnflow.util import maybe_import, load_correction_set @@ -16,26 +16,34 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -vector = maybe_import("vector") + logger = law.logger.get_logger(__name__) -@dataclass +@dataclasses.dataclass class DrellYanConfig: + # era, e.g. "2022preEE" era: str - order: str + # correction set name correction: str - unc_correction: str + # uncertainty correction set name + unc_correction: str | None = None + # generator order + order: str | None = None + # list of systematics to be considered + systs: list[str] | None = None + # functions to get the number of jets and b-tagged jets from the events in case they should be used as inputs + get_njets: callable[["dy_weights", ak.Array], ak.Array] | None = None + get_nbtags: callable[["dy_weights", ak.Array], ak.Array] | None = None + # additional columns to be loaded, e.g. as needed for njets or nbtags + used_columns: set = dataclasses.field(default_factory=set) def __post_init__(self) -> None: - if ( - not self.era or - not self.order or - not self.correction or - not self.unc_correction - ): - raise ValueError("incomplete dy_weight_config: missing era, order, correction or unc_correction") + if not self.era or not self.correction: + raise ValueError(f"{self.__class__.__name__}: missing era or correction") + if self.unc_correction and not self.order: + raise ValueError(f"{self.__class__.__name__}: when unc_correction is defined, order must be set") @producer( @@ -58,7 +66,7 @@ def gen_dilepton(self, events: ak.Array, **kwargs) -> ak.Array: (status == 1) & events.GenPart.hasFlags("fromHardProcess") ) - # taus need to have status == 2, + # taus need to have status == 2 tau_mask = ( (pdg_id == 15) & (status == 2) & events.GenPart.hasFlags("fromHardProcess") ) @@ -136,7 +144,8 @@ def dy_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: *get_dy_weight_file* can be adapted in a subclass in case it is stored differently in the external files. - The campaign era and name of the correction set (see link above) should be given as an auxiliary entry in the config: + The analysis config should contain an auxiliary entry *dy_weight_config* pointing to a :py:class:`DrellYanConfig` + object: .. code-block:: python @@ -149,49 +158,71 @@ def dy_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: *get_dy_weight_config* can be adapted in a subclass in case it is stored differently in the config. """ - # map the input variable names from the corrector to our columns variable_map = { "era": self.dy_config.era, - "order": self.dy_config.order, "ptll": events.gen_dilepton_pt, } - # initializing the list of weight variations - weights_list = [("dy_weight", "nom")] - - # appending the respective number of uncertainties to the weight list - for i in range(self.n_unc): - for shift in ("up", "down"): - tmp_tuple = (f"dy_weight{i + 1}_{shift}", f"{shift}{i + 1}") - weights_list.append(tmp_tuple) + # optionals + if self.dy_config.order: + variable_map["order"] = self.dy_config.order + if callable(self.dy_config.get_njets): + variable_map["njets"] = self.dy_config.get_njets(self, events) + if callable(self.dy_config.get_nbtags): + variable_map["nbtags"] = self.dy_config.get_nbtags(self, events) + # for compatibility + variable_map["ntags"] = variable_map["nbtags"] + + # initializing the list of weight variations (called syst in the dy files) + systs = [("nom", "")] + + # add specific uncertainties or additional systs + if self.dy_config.unc_correction: + for i in range(self.n_unc): + for direction in ["up", "down"]: + systs.append((f"{direction}{i + 1}", f"_{direction}{i + 1}")) + elif self.dy_config.systs: + for syst in self.dy_config.systs: + systs.append((syst, f"_{syst}")) # preparing the input variables for the corrector - for column_name, syst in weights_list: - variable_map_syst = {**variable_map, "syst": syst} + for syst, postfix in systs: + _variable_map = {**variable_map, "syst": syst} # evaluating dy weights given a certain era, ptll array and sytematic shift - inputs = [variable_map_syst[inp.name] for inp in self.dy_corrector.inputs] + inputs = [_variable_map[inp.name] for inp in self.dy_corrector.inputs] dy_weight = self.dy_corrector.evaluate(*inputs) # save the weights in a new column - events = set_ak_column(events, column_name, dy_weight, value_type=np.float32) + events = set_ak_column(events, f"dy_weight{postfix}", dy_weight, value_type=np.float32) return events @dy_weights.init def dy_weights_init(self: Producer) -> None: - # the number of weights in partial run 3 is always 10 if self.config_inst.campaign.x.year not in {2022, 2023}: raise NotImplementedError( f"campaign year {self.config_inst.campaign.x.year} is not yet supported by {self.cls_name}", ) - self.n_unc = 10 - # register dynamically produced weight columns - for i in range(self.n_unc): - self.produces.add(f"dy_weight{i + 1}_{{up,down}}") + # get the dy weight config + self.dy_config: DrellYanConfig = self.get_dy_weight_config() + + # declare additional used columns + if self.dy_config.used_columns: + self.uses.update(self.dy_config.used_columns) + + # declare additional produced columns + if self.dy_config.unc_correction: + # the number should always be 10 + self.n_unc = 10 + for i in range(self.n_unc): + self.produces.add(f"dy_weight{i + 1}_{{up,down}}") + elif self.dy_config.systs: + for syst in self.dy_config.systs: + self.produces.add(f"dy_weight_{syst}") @dy_weights.requires @@ -215,31 +246,26 @@ def dy_weights_setup( reader_targets: law.util.InsertableDict, ) -> None: """ - Loads the Drell-Yan weight calculator from the external files bundle and saves them in the - py:attr:`dy_corrector` attribute for simpler access in the actual callable. The number of uncertainties - is calculated, per era, by another correcter in the external file and is saved in the - py:attr:`dy_unc_corrector` attribute. + Loads the Drell-Yan weight calculator from the external files bundle and saves them in the py:attr:`dy_corrector` + attribute for simpler access in the actual callable. The number of uncertainties is calculated, per era, by another + correcter in the external file and is saved in the py:attr:`dy_unc_corrector` attribute. """ bundle = reqs["external_files"] # import all correctors from the external file correction_set = load_correction_set(self.get_dy_weight_file(bundle.files)) - # check number of fetched correctors - if len(correction_set.keys()) != 2: - raise Exception("Expected exactly two types of Drell-Yan correction") - - # create the weight and uncertainty correctors - self.dy_config: DrellYanConfig = self.get_dy_weight_config() + # create the weight corrector self.dy_corrector = correction_set[self.dy_config.correction] - self.dy_unc_corrector = correction_set[self.dy_config.unc_correction] - dy_n_unc = int(self.dy_unc_corrector.evaluate(self.dy_config.order)) - - if dy_n_unc != self.n_unc: - raise ValueError( - f"Expected {self.n_unc} uncertainties, got {dy_n_unc}", - ) + # create the uncertainty corrector + if self.dy_config.unc_correction: + self.dy_unc_corrector = correction_set[self.dy_config.unc_correction] + dy_n_unc = int(self.dy_unc_corrector.evaluate(self.dy_config.order)) + if dy_n_unc != self.n_unc: + raise ValueError( + f"Expected {self.n_unc} uncertainties, got {dy_n_unc}", + ) @producer( @@ -247,8 +273,6 @@ def dy_weights_setup( # MET information # -> only Run 3 (PuppiMET) is supported "PuppiMET.{pt,phi}", - # Number of jets (as a per-event scalar) - "Jet.{pt,phi,eta,mass}", # Gen-level boson information (full boson momentum) # -> gen_dilepton_vis.pt, gen_dilepton_vis.phi, gen_dilepton_all.pt, gen_dilepton_all.phi gen_dilepton.PRODUCES, @@ -257,6 +281,8 @@ def dy_weights_setup( "RecoilCorrMET.{pt,phi}", "RecoilCorrMET.{pt,phi}_{recoilresp,recoilres}_{up,down}", }, + # custom njet column to be used to derive corrections + njet_column=None, mc_only=True, # function to determine the recoil correction file from external files get_dy_recoil_file=(lambda self, external_files: external_files.dy_recoil_sf), @@ -291,6 +317,8 @@ def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array *get_dy_recoil_config* can be adapted in a subclass in case it is stored differently in the config. """ + import vector + # steps: # 1) Build transverse vectors for MET and the generator-level boson (full and visible). # 2) Compute the recoil vector U = MET + vis - full in the transverse plane. @@ -331,12 +359,15 @@ def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array uperp = -u_x * full_unit_y + u_y * full_unit_x # Determine jet multiplicity for the event (jet selection as in original) - jet_selection = ( - ((events.Jet.pt > 30) & (np.abs(events.Jet.eta) < 2.5)) | - ((events.Jet.pt > 50) & (np.abs(events.Jet.eta) >= 2.5)) - ) - selected_jets = events.Jet[jet_selection] - njet = np.asarray(ak.num(selected_jets, axis=1), dtype=np.float32) + if self.njet_column: + njet = np.asarray(events[self.njet_column], dtype=np.float32) + else: + jet_selection = ( + ((events.Jet.pt > 30) & (np.abs(events.Jet.eta) < 2.5)) | + ((events.Jet.pt > 50) & (np.abs(events.Jet.eta) >= 2.5)) + ) + selected_jets = events.Jet[jet_selection] + njet = np.asarray(ak.num(selected_jets, axis=1), dtype=np.float32) # Apply nominal recoil correction on U components # (see here: https://cms-higgs-leprare.docs.cern.ch/htt-common/V_recoil/#example-snippet) @@ -418,6 +449,14 @@ def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array return events +@recoil_corrected_met.init +def recoil_corrected_met_init(self: Producer) -> None: + if self.njet_column: + self.uses.add(f"{self.njet_column}") + else: + self.uses.add("Jet.{pt,eta,phi,mass}") + + @recoil_corrected_met.requires def recoil_corrected_met_requires(self: Producer, task: law.Task, reqs: dict) -> None: # Ensure that external files are bundled. diff --git a/columnflow/production/cms/electron.py b/columnflow/production/cms/electron.py index e88d115e8..89b739daa 100644 --- a/columnflow/production/cms/electron.py +++ b/columnflow/production/cms/electron.py @@ -6,31 +6,35 @@ from __future__ import annotations -from dataclasses import dataclass +import dataclasses import law from columnflow.production import Producer, producer from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array -from columnflow.types import Any +from columnflow.columnar_util import set_ak_column, full_like, flat_np_view +from columnflow.types import Any, Callable np = maybe_import("numpy") ak = maybe_import("awkward") -@dataclass +@dataclasses.dataclass class ElectronSFConfig: correction: str campaign: str - working_point: str = "" + working_point: str | dict[str, Callable] = "" hlt_path: str = "" + min_pt: float = 0.0 + max_pt: float = 0.0 def __post_init__(self) -> None: if not self.working_point and not self.hlt_path: raise ValueError("either working_point or hlt_path must be set") if self.working_point and self.hlt_path: raise ValueError("only one of working_point or hlt_path must be set") + if 0.0 < self.max_pt <= self.min_pt: + raise ValueError(f"{self.__class__.__name__}: max_pt must be larger than min_pt") @classmethod def new(cls, obj: ElectronSFConfig | tuple[str, str, str]) -> ElectronSFConfig: @@ -52,11 +56,12 @@ def new(cls, obj: ElectronSFConfig | tuple[str, str, str]) -> ElectronSFConfig: # function to determine the correction file get_electron_file=(lambda self, external_files: external_files.electron_sf), # function to determine the electron weight config - get_electron_config=(lambda self: ElectronSFConfig.new(self.config_inst.x.electron_sf_names)), + get_electron_config=(lambda self: ElectronSFConfig.new(self.config_inst.x("electron_sf", self.config_inst.x("electron_sf_names", None)))), # noqa: E501 # choose if the eta variable should be the electron eta or the super cluster eta use_supercluster_eta=True, + # name of the saved weight column weight_name="electron_weight", - supported_versions=(1, 2, 3), + supported_versions={1, 2, 3}, ) def electron_weights( self: Producer, @@ -65,8 +70,7 @@ def electron_weights( **kwargs, ) -> ak.Array: """ - Creates electron weights using the correctionlib. Requires an external file in the config under - ``electron_sf``: + Creates electron weights using the correctionlib. Requires an external file in the config under ``electron_sf``: .. code-block:: python @@ -74,58 +78,96 @@ def electron_weights( "electron_sf": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/EGM/2017_UL/electron.json.gz", # noqa }) - *get_electron_file* can be adapted in a subclass in case it is stored differently in the - external files. + *get_electron_file* can be adapted in a subclass in case it is stored differently in the external files. - The name of the correction set, the year string for the weight evaluation, and the name of the - working point should be given as an auxiliary entry in the config: + The name of the correction set, the year string for the weight evaluation, and the name of the working point should + be given as an auxiliary entry in the config: .. code-block:: python - cfg.x.electron_sf_names = ElectronSFConfig( + cfg.x.electron_sf = ElectronSFConfig( correction="UL-Electron-ID-SF", campaign="2017", working_point="wp80iso", # for trigger weights use hlt_path instead ) - *get_electron_config* can be adapted in a subclass in case it is stored differently in the - config. + The *working_point* can also be a dictionary mapping working point names to functions that return a boolean mask for + the electrons. This is useful to compute scale factors for multiple working points at once, e.g. for the electron + reconstruction scale factors: - Optionally, an *electron_mask* can be supplied to compute the scale factor weight - based only on a subset of electrons. + .. code-block:: python + + cfg.x.electron_sf = ElectronSFConfig( + correction="Electron-ID-SF", + campaign="2022Re-recoE+PromptFG", + working_point={ + "RecoBelow20": lambda variable_map: variable_map["pt"] < 20.0, + "Reco20to75": lambda variable_map: (variable_map["pt"] >= 20.0) & (variable_map["pt"] < 75.0), + "RecoAbove75": lambda variable_map: variable_map["pt"] >= 75.0, + }, + ) + + *get_electron_config* can be adapted in a subclass in case it is stored differently in the config. + + Optionally, an *electron_mask* can be supplied to compute the scale factor weight based only on a subset of + electrons. """ - # flat super cluster eta/flat eta and pt views + # fold electron mask with pt cuts if given + if self.electron_config.min_pt > 0.0: + pt_mask = events.Electron.pt >= self.electron_config.min_pt + electron_mask = pt_mask if electron_mask is Ellipsis else (pt_mask & electron_mask) + if self.electron_config.max_pt > 0.0: + pt_mask = events.Electron.pt <= self.electron_config.max_pt + electron_mask = pt_mask if electron_mask is Ellipsis else (pt_mask & electron_mask) + + # prepare input variables + electrons = events.Electron[electron_mask] + eta = electrons.eta if self.use_supercluster_eta: - eta = flat_np_view(( - events.Electron.eta[electron_mask] + - events.Electron.deltaEtaSC[electron_mask] - ), axis=1) - else: - eta = flat_np_view(events.Electron.eta[electron_mask], axis=1) - pt = flat_np_view(events.Electron.pt[electron_mask], axis=1) - phi = flat_np_view(events.Electron.phi[electron_mask], axis=1) - + eta = ( + electrons.superclusterEta + if "superclusterEta" in electrons.fields + else electrons.eta + electrons.deltaEtaSC + ) variable_map = { "year": self.electron_config.campaign, - "WorkingPoint": self.electron_config.working_point, "Path": self.electron_config.hlt_path, - "pt": pt, + "pt": electrons.pt, + "phi": electrons.phi, "eta": eta, - "phi": phi, } # loop over systematics for syst, postfix in zip(self.sf_variations, ["", "_up", "_down"]): # get the inputs for this type of variation - variable_map_syst = { - **variable_map, - "ValType": syst, - } - inputs = [variable_map_syst[inp.name] for inp in self.electron_sf_corrector.inputs] - sf_flat = self.electron_sf_corrector(*inputs) - - # add the correct layout to it - sf = layout_ak_array(sf_flat, events.Electron.pt[electron_mask]) + variable_map_syst = variable_map | {"ValType": syst} + + # add working point + wp = self.electron_config.working_point + if isinstance(wp, str): + # single wp, just evaluate + variable_map_syst_wp = variable_map_syst | {"WorkingPoint": wp} + inputs = [variable_map_syst_wp[inp.name] for inp in self.electron_sf_corrector.inputs] + sf = self.electron_sf_corrector.evaluate(*inputs) + elif isinstance(wp, dict): + # mapping of wps to masks, evaluate per wp and combine + sf = full_like(eta, 1.0) + sf_flat = flat_np_view(sf) + for _wp, mask_fn in wp.items(): + mask = mask_fn(variable_map) + variable_map_syst_wp = variable_map_syst | {"WorkingPoint": _wp} + # call the corrector with the masked inputs + inputs = [ + ( + variable_map_syst_wp[inp.name][mask] + if isinstance(variable_map_syst_wp[inp.name], (np.ndarray, ak.Array)) + else variable_map_syst_wp[inp.name] + ) + for inp in self.electron_sf_corrector.inputs + ] + sf_flat[flat_np_view(mask)] = flat_np_view(self.electron_sf_corrector.evaluate(*inputs)) + else: + raise ValueError(f"unsupported working point type {type(variable_map['WorkingPoint'])}") # create the product over all electrons in one event weight = ak.prod(sf, axis=1, mask_identity=False) diff --git a/columnflow/production/cms/gen_particles.py b/columnflow/production/cms/gen_particles.py new file mode 100644 index 000000000..294af1a81 --- /dev/null +++ b/columnflow/production/cms/gen_particles.py @@ -0,0 +1,359 @@ +# coding: utf-8 + +""" +Producers that determine the generator-level particles and bring them into a structured format. This is most likely +useful for generator studies and truth definitions of physics objects. +""" + +from __future__ import annotations + +import law + +from columnflow.production import Producer, producer +from columnflow.columnar_util import set_ak_column +from columnflow.util import UNSET, maybe_import + +np = maybe_import("numpy") +ak = maybe_import("awkward") + + +logger = law.logger.get_logger(__name__) + +_keep_gen_part_fields = ["pt", "eta", "phi", "mass", "pdgId"] + + +# helper to transform generator particles by dropping / adding fields +def transform_gen_part(gen_parts: ak.Array, *, depth_limit: int, optional: bool = False) -> ak.Array: + # reduce down to relevant fields + arr = {} + for f in _keep_gen_part_fields: + if optional: + if (v := getattr(gen_parts, f, UNSET)) is not UNSET: + arr[f] = v + else: + arr[f] = getattr(gen_parts, f) + arr = ak.zip(arr, depth_limit=depth_limit) + + # remove parameters and add Lorentz vector behavior + arr = ak.without_parameters(arr) + arr = ak.with_name(arr, "PtEtaPhiMLorentzVector") + + return arr + + +@producer( + uses={ + "GenPart.{genPartIdxMother,status,statusFlags}", # required by the gen particle identification + f"GenPart.{{{','.join(_keep_gen_part_fields)}}}", # additional fields that should be read and added to gen_top + }, + produces={"gen_top.*.*"}, +) +def gen_top_lookup(self: Producer, events: ak.Array, strict: bool = True, **kwargs) -> ak.Array: + """ + Creates a new ragged column "gen_top" containing information about generator-level top quarks and their decay + products in a structured array with the following fields: + + - ``t``: list of all top quarks in the event, sorted such that top quarks precede anti-top quarks + - ``b``: list of bottom quarks from top quark decays, consistent ordering w.r.t. ``t`` (note that, in rare + cases, the decay into charm or down quarks is realized, and therefore stored in this field) + - ``w``: list of W bosons from top quark decays, consistent ordering w.r.t. ``t`` + - ``w_children``: list of W boson decay products, consistent ordering w.r.t. ``w``, the first entry is the + down-type quark or charged lepton, the second entry is the up-type quark or neutrino, and additional decay + products (e.g photons) are appended afterwards + - ``w_tau_children``: list of decay products from tau lepton decays stemming from W boson decays, however, + skipping the W boson from the tau lepton decay itself; the first entry is the tau neutrino, the second and + third entries are either the charged lepton and neutrino, or quarks or hadrons sorted by ascending absolute + pdg id; additional decay products (e.g photons) are appended afterwards + """ + # helper to extract unique values + unique_set = lambda a: set(np.unique(ak.flatten(a, axis=None))) + + # find hard top quarks + t = events.GenPart[abs(events.GenPart.pdgId) == 6] + t = t[t.hasFlags("isLastCopy")] # they are either fromHardProcess _or_ isLastCopy + + # sort them so that that top quarks come before anti-top quarks + t = t[ak.argsort(t.pdgId, axis=1, ascending=False)] + + # distinct top quark children + # (asking for isLastCopy leads to some tops that miss children, usually b's) + t_children = ak.drop_none(t.distinctChildren[t.distinctChildren.hasFlags("fromHardProcess", "isFirstCopy")]) + + # strict mode: check that there are exactly two children that are b and w + if strict: + if (tcn := unique_set(ak.num(t_children, axis=2))) != {2}: + raise Exception(f"found top quarks that have != 2 children: {tcn - {2}}") + if (tci := unique_set(abs(t_children.pdgId))) - {1, 3, 5, 24}: + raise Exception(f"found top quark children with unexpected pdgIds: {tci - {1, 3, 5, 24}}") + + # store b's (or s/d) and w's + abs_tc_ids = abs(t_children.pdgId) + b = ak.drop_none(ak.firsts(t_children[(abs_tc_ids == 1) | (abs_tc_ids == 3) | (abs_tc_ids == 5)], axis=2)) + w = ak.drop_none(ak.firsts(t_children[abs(t_children.pdgId) == 24], axis=2)) + + # distinct w children + w_children = ak.drop_none(w.distinctChildrenDeep) + + # distinguish into "hard" and additional ones + w_children_hard = w_children[(hard_mask := w_children.hasFlags("fromHardProcess"))] + w_children_rest = w_children[~hard_mask] + + # strict: check that there are exactly two hard children + if strict: + if (wcn := unique_set(ak.num(w_children_hard, axis=2))) != {2}: + raise Exception(f"found W bosons that have != 2 children: {wcn - {2}}") + + # sort them so that down-type quarks and charged leptons (odd pdgIds) come first, followed by up-type quarks and + # neutrinos (even pdgIds), then add back the remaining ones + w_children_hard = w_children_hard[ak.argsort(-(w_children_hard.pdgId % 2), axis=2)] + w_children = ak.concatenate([w_children_hard, w_children_rest], axis=2) + + # further distinguish tau decays in w_children + w_tau_children = ak.drop_none(w_children[abs(w_children.pdgId) == 15].distinctChildrenDeep) + # sort: nu tau first, photons last, rest in between sorted by ascending absolute pdgId + w_tau_nu_mask = abs(w_tau_children.pdgId) == 16 + w_tau_photon_mask = w_tau_children.pdgId == 22 + w_tau_rest = w_tau_children[~(w_tau_nu_mask | w_tau_photon_mask)] + w_tau_rest = w_tau_rest[ak.argsort(abs(w_tau_rest.pdgId), axis=3, ascending=True)] + w_tau_children = ak.concatenate( + [w_tau_children[w_tau_nu_mask], w_tau_rest, w_tau_children[w_tau_photon_mask]], + axis=3, + ) + + # zip into a single array with named fields + gen_top = ak.zip( + { + "t": transform_gen_part(t, depth_limit=2), + "b": transform_gen_part(b, depth_limit=2), + "w": transform_gen_part(w, depth_limit=2), + "w_children": transform_gen_part(w_children, depth_limit=3), + "w_tau_children": transform_gen_part(w_tau_children, depth_limit=4), + }, + depth_limit=1, + ) + + # save the column + events = set_ak_column(events, "gen_top", gen_top) + + return events + + +@producer( + uses={ + "GenPart.{genPartIdxMother,status,statusFlags}", # required by the gen particle identification + f"GenPart.{{{','.join(_keep_gen_part_fields)}}}", # additional fields that should be read and added to gen_top + }, + produces={"gen_higgs.*.*"}, +) +def gen_higgs_lookup(self: Producer, events: ak.Array, strict: bool = True, **kwargs) -> ak.Array: + """ + Creates a new ragged column "gen_higgs" containing information about generator-level Higgs bosons and their decay + products in a structured array with the following fields: + + - ``h``: list of all Higgs bosons in the event, sorted by the pdgId of their decay products such that Higgs + bosons decaying to quarks (b's) come first, followed by leptons, and then gauge bosons + - ``h_children``: list of direct Higgs boson children, consistent ordering w.r.t. ``h``, with the first entry + being the particle and the second one being the anti-particle; for Z bosons and (effective) gluons and + photons, no ordering is applied + - ``tau_children``: list of decay products from tau lepton decays coming from Higgs bosons, with the first entry + being the neutrino and the second one being the W boson + - ``tau_w_children``: list of the decay products from W boson decays from tau lepton decays, with the first + entry being the down-type quark or charged lepton, the second entry being the up-type quark or neutrino, and + additional decay products (e.g photons) are appended afterwards + - ``z_children``: not yet implemented + - ``w_children``: not yet implemented + """ + # helper to extract unique values + unique_set = lambda a: set(np.unique(ak.flatten(a, axis=None))) + + # find higgs + h = events.GenPart[events.GenPart.pdgId == 25] + h = h[h.hasFlags("fromHardProcess", "isLastCopy")] + + # sort them by increasing pdgId of their children (quarks, leptons, Z, W, effective gluons/photons) + h = h[ak.argsort(abs(ak.drop_none(ak.min(h.children.pdgId, axis=2))), axis=1, ascending=True)] + + # get distinct children + h_children = ak.drop_none(h.distinctChildren[h.distinctChildren.hasFlags("fromHardProcess", "isFirstCopy")]) + + # strict mode: check that there are exactly two children + if strict: + if (hcn := unique_set(ak.num(h_children, axis=2))) != {2}: + raise Exception(f"found Higgs bosons that have != 2 children: {hcn - {2}}") + + # sort them by decreasing pdgId + h_children = h_children[ak.argsort(h_children.pdgId, axis=2, ascending=False)] + # in strict mode, fix the children dimension to 2 + if strict: + h_children = h_children[:, :, [0, 1]] + + # further treatment of tau decays + tau_mask = h_children.pdgId[:, :, 0] == 15 + tau = ak.fill_none(h_children[ak.mask(tau_mask, tau_mask)], [], axis=1) + tau_children = tau.distinctChildrenDeep[tau.distinctChildrenDeep.hasFlags("isFirstCopy", "isTauDecayProduct")] + tau_children = ak.drop_none(tau_children) + # prepare neutrino and W boson handling + tau_nu_mask = abs(tau_children.pdgId) == 16 + tau_w_mask = abs(tau_children.pdgId) == 24 + tau_rest_mask = ~(tau_nu_mask | tau_w_mask) + tau_has_rest = ak.any(tau_rest_mask, axis=3) + # strict mode: there should always be a neutrino, and _either_ a W and nothing else _or_ no W at all + if strict: + if not ak.all(ak.any(tau_nu_mask[tau_mask], axis=3)): + raise Exception("found tau leptons without a tau neutrino among their children") + tau_has_w = ak.any(tau_w_mask, axis=3) + if not ak.all((tau_has_w ^ tau_has_rest)[tau_mask]): + raise Exception("found tau leptons with both W bosons and other decay products among their children") + # get the tau neutrino + tau_nu = tau_children[tau_nu_mask].sum(axis=3) + tau_nu = set_ak_column(tau_nu, "pdgId", ak.values_astype(16 * np.sign(tau.pdgId), np.int32)) + # get the W boson in case it is part of the tau children, otherwise build it from the sum of children + tau_w = tau_children[tau_w_mask].sum(axis=3) + if ak.any(tau_has_rest): + tau_w_rest = tau_children[tau_rest_mask].sum(axis=-1) + tau_w = ak.where(tau_has_rest, tau_w_rest, tau_w) + tau_w = set_ak_column(tau_w, "pdgId", ak.values_astype(-24 * np.sign(tau.pdgId), np.int32)) + # combine nu and w again + tau_nuw = ak.concatenate([tau_nu[..., None], tau_w[..., None]], axis=3) + # define w children + tau_w_children = ak.concatenate( + [tau_children[tau_rest_mask], ak.drop_none(ak.firsts(tau_children[tau_w_mask], axis=3).children)], + axis=2, + ) + + # children for decays other than taus are not yet implemented, so show a warning in case they are found + unhandled_ids = unique_set(abs(h_children.pdgId)) - set(range(1, 6 + 1)) - set(range(11, 16 + 1)) + if unhandled_ids: + logger.warning_once( + f"gen_higgs_undhandled_children_{'_'.join(map(str, sorted(unhandled_ids)))}", + f"found Higgs boson decays in the {self.cls_name} producer with pdgIds {unhandled_ids}, for which the " + "lookup of children is not yet implemented", + ) + + # zip into a single array with named fields + gen_higgs = ak.zip( + { + "h": transform_gen_part(h, depth_limit=2), + "h_children": transform_gen_part(h_children, depth_limit=3), + "tau_children": transform_gen_part(tau_nuw, depth_limit=4), + "tau_w_children": transform_gen_part(tau_w_children, depth_limit=4), + # "z_children": None, # not yet implemented + # "w_children": None, # not yet implemented + }, + depth_limit=1, + ) + + # save the column + events = set_ak_column(events, "gen_higgs", gen_higgs) + + return events + + +@producer( + uses={ + "GenPart.{genPartIdxMother,status,statusFlags}", # required by the gen particle identification + f"GenPart.{{{','.join(_keep_gen_part_fields)}}}", # additional fields that should be read and added to gen_top + }, + produces={"gen_dy.*.*"}, +) +def gen_dy_lookup(self: Producer, events: ak.Array, strict: bool = True, **kwargs) -> ak.Array: + """ + Creates a new ragged column "gen_dy" containing information about generator-level Z/g bosons and their decay + products in a structured array with the following fields: + + - ``z``: list of all Z/g bosons in the event, sorted by the pdgId of their decay products + - ``lep``: list of direct Z/g boson children, consistent ordering w.r.t. ``z``, with the first entry being the + lepton and the second one being the anti-lepton + - ``tau_children``: list of decay products from tau lepton decays coming from Z/g bosons, with the first entry + being the neutrino and the second one being the W boson + - ``tau_w_children``: list of the decay products from W boson decays from tau lepton decays, with the first + entry being the down-type quark or charged lepton, the second entry being the up-type quark or neutrino, and + additional decay products (e.g photons) are appended afterwards + """ + # note: in about 4% of DY events, the Z/g boson is missing, so this lookup starts at lepton level, see + # -> https://indico.cern.ch/event/1495537/contributions/6359516/attachments/3014424/5315938/HLepRare_25.02.14.pdf + # -> https://indico.cern.ch/event/1495537/contributions/6359516/attachments/3014424/5315938/HLepRare_25.02.14.pdf + + # helper to extract unique values + unique_set = lambda a: set(np.unique(ak.flatten(a, axis=None))) + + # get the e/mu and tau masks + abs_id = abs(events.GenPart.pdgId) + emu_mask = ( + ((abs_id == 11) | (abs_id == 13)) & + (events.GenPart.status == 1) & + events.GenPart.hasFlags("fromHardProcess") + ) + # taus need to have status == 2 + tau_mask = ( + (abs_id == 15) & + (events.GenPart.status == 2) & + events.GenPart.hasFlags("fromHardProcess") + ) + lep_mask = emu_mask | tau_mask + + # strict mode: there must be exactly two charged leptons per event + if strict: + if (nl := unique_set(ak.num(events.GenPart[lep_mask], axis=1))) - {2}: + raise Exception(f"found events that have != 2 charged leptons: {nl - {2}}") + + # get the leptons and sort by decreasing pdgId (lepton before anti-lepton) + lep = events.GenPart[lep_mask] + lep = lep[ak.argsort(lep.pdgId, axis=1, ascending=False)] + + # in strict mode, fix the lep dimension to 2 + if strict: + lep = lep[:, [0, 1]] + + # build the z from them + z = lep.sum(axis=-1) + z = set_ak_column(z, "pdgId", np.int32(23)) + + # further treatment of tau decays + tau = events.GenPart[tau_mask] + tau_children = tau.distinctChildren[tau.distinctChildren.hasFlags("isFirstCopy", "isTauDecayProduct")] + tau_children = ak.drop_none(tau_children) + # prepare neutrino and W boson handling + tau_nu_mask = abs(tau_children.pdgId) == 16 + tau_w_mask = abs(tau_children.pdgId) == 24 + tau_rest_mask = ~(tau_nu_mask | tau_w_mask) + tau_has_rest = ak.any(tau_rest_mask, axis=2) + # strict mode: there should always be a neutrino, and _either_ a W and nothing else _or_ no W at all + if strict: + if not ak.all(ak.any(tau_nu_mask, axis=2)): + raise Exception("found tau leptons without a tau neutrino among their children") + tau_has_w = ak.any(tau_w_mask, axis=2) + if not ak.all(tau_has_w ^ tau_has_rest): + raise Exception("found tau leptons with both W bosons and other decay products among their children") + # get the tau neutrino + tau_nu = tau_children[tau_nu_mask].sum(axis=2) + tau_nu = set_ak_column(tau_nu, "pdgId", ak.values_astype(16 * np.sign(tau.pdgId), np.int32)) + # get the W boson in case it is part of the tau children, otherwise build it from the sum of children + tau_w = tau_children[tau_w_mask].sum(axis=2) + if ak.any(tau_has_rest): + tau_w_rest = tau_children[tau_rest_mask].sum(axis=-1) + tau_w = ak.where(tau_has_rest, tau_w_rest, tau_w) + tau_w = set_ak_column(tau_w, "pdgId", ak.values_astype(-24 * np.sign(tau.pdgId), np.int32)) + # combine nu and w again + tau_nuw = ak.concatenate([tau_nu[..., None], tau_w[..., None]], axis=2) + # define w children + tau_w_children = ak.concatenate( + [tau_children[tau_rest_mask], ak.drop_none(ak.firsts(tau_children[tau_w_mask], axis=2).children)], + axis=1, + ) + + # zip into a single array with named fields + gen_dy = ak.zip( + { + "z": transform_gen_part(z, depth_limit=1), + "lep": transform_gen_part(lep, depth_limit=2), + "tau_children": transform_gen_part(tau_nuw, depth_limit=3), + "tau_w_children": transform_gen_part(tau_w_children, depth_limit=3), + }, + depth_limit=1, + ) + + # save the column + events = set_ak_column(events, "gen_dy", gen_dy) + + return events diff --git a/columnflow/production/cms/gen_top_decay.py b/columnflow/production/cms/gen_top_decay.py deleted file mode 100644 index 8e925aaa0..000000000 --- a/columnflow/production/cms/gen_top_decay.py +++ /dev/null @@ -1,90 +0,0 @@ -# coding: utf-8 - -""" -Producers that determine the generator-level particles related to a top quark decay. -""" - -from __future__ import annotations - -from columnflow.production import Producer, producer -from columnflow.util import maybe_import -from columnflow.columnar_util import set_ak_column - -ak = maybe_import("awkward") - - -@producer( - uses={"GenPart.{genPartIdxMother,pdgId,statusFlags}"}, - produces={"gen_top_decay"}, -) -def gen_top_decay_products(self: Producer, events: ak.Array, **kwargs) -> ak.Array: - """ - Creates a new ragged column "gen_top_decay" with one element per hard top quark. Each element is - a GenParticleArray with five or more objects in a distinct order: top quark, bottom quark, - W boson, down-type quark or charged lepton, up-type quark or neutrino, and any additional decay - produces of the W boson (if any, then most likly photon radiations). Per event, the structure - will be similar to: - - .. code-block:: python - - [ - # event 1 - [ - # top 1 - [t1, b1, W1, q1/l, q2/n(, additional_w_decay_products)], - # top 2 - [...], - ], - # event 2 - ... - ] - """ - # find hard top quarks - abs_id = abs(events.GenPart.pdgId) - t = events.GenPart[abs_id == 6] - t = t[t.hasFlags("isHardProcess")] - t = t[~ak.is_none(t, axis=1)] - - # distinct top quark children (b's and W's) - t_children = t.distinctChildrenDeep[t.distinctChildrenDeep.hasFlags("isHardProcess")] - - # get b's - b = t_children[abs(t_children.pdgId) == 5][:, :, 0] - - # get W's - w = t_children[abs(t_children.pdgId) == 24][:, :, 0] - - # distinct W children - w_children = w.distinctChildrenDeep[w.distinctChildrenDeep.hasFlags("isHardProcess")] - - # reorder the first two W children (leptons or quarks) so that the charged lepton / down-type - # quark is listed first (they have an odd pdgId) - w_children_firsttwo = w_children[:, :, :2] - w_children_firsttwo = w_children_firsttwo[(w_children_firsttwo.pdgId % 2 == 0) * 1] - w_children_rest = w_children[:, :, 2:] - - # concatenate to create the structure to return - groups = ak.concatenate( - [ - t[:, :, None], - b[:, :, None], - w[:, :, None], - w_children_firsttwo, - w_children_rest, - ], - axis=2, - ) - - # save the column - events = set_ak_column(events, "gen_top_decay", groups) - - return events - - -@gen_top_decay_products.skip -def gen_top_decay_products_skip(self: Producer, **kwargs) -> bool: - """ - Custom skip function that checks whether the dataset is a MC simulation containing top - quarks in the first place. - """ - return self.dataset_inst.is_data or not self.dataset_inst.has_tag("has_top") diff --git a/columnflow/production/cms/muon.py b/columnflow/production/cms/muon.py index 071b3122f..762e6b544 100644 --- a/columnflow/production/cms/muon.py +++ b/columnflow/production/cms/muon.py @@ -8,21 +8,23 @@ import law -from dataclasses import dataclass +import dataclasses from columnflow.production import Producer, producer from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array +from columnflow.columnar_util import set_ak_column from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") -@dataclass +@dataclasses.dataclass class MuonSFConfig: correction: str campaign: str = "" + min_pt: float = 0.0 + max_pt: float = 0.0 @classmethod def new(cls, obj: MuonSFConfig | tuple[str, str]) -> MuonSFConfig: @@ -37,18 +39,23 @@ def new(cls, obj: MuonSFConfig | tuple[str, str]) -> MuonSFConfig: return cls(**obj) raise ValueError(f"cannot convert {obj} to MuonSFConfig") + def __post_init__(self): + if 0.0 < self.max_pt <= self.min_pt: + raise ValueError(f"{self.__class__.__name__}: max_pt must be larger than min_pt") + @producer( uses={"Muon.{pt,eta}"}, - # produces in the init + # produces defined in init # only run on mc mc_only=True, # function to determine the correction file get_muon_file=(lambda self, external_files: external_files.muon_sf), # function to determine the muon weight config - get_muon_config=(lambda self: MuonSFConfig.new(self.config_inst.x.muon_sf_names)), + get_muon_config=(lambda self: MuonSFConfig.new(self.config_inst.x("muon_sf", self.config_inst.x("muon_sf_names", None)))), # noqa: E501 + # name of the saved weight column weight_name="muon_weight", - supported_versions=(1, 2), + supported_versions={1, 2}, ) def muon_weights( self: Producer, @@ -57,8 +64,7 @@ def muon_weights( **kwargs, ) -> ak.Array: """ - Creates muon weights using the correctionlib. Requires an external file in the config under - ``muon_sf``: + Creates muon weights using the correctionlib. Requires an external file in the config under ``muon_sf``: .. code-block:: python @@ -66,33 +72,37 @@ def muon_weights( "muon_sf": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/MUO/2017_UL/muon_z.json.gz", # noqa }) - *get_muon_file* can be adapted in a subclass in case it is stored differently in the external - files. + *get_muon_file* can be adapted in a subclass in case it is stored differently in the external files. - The name of the correction set and the year string for the weight evaluation should be given as - an auxiliary entry in the config: + The name of the correction set and the year string for the weight evaluation should be given as an auxiliary entry + in the config: .. code-block:: python - cfg.x.muon_sf_names = MuonSFConfig( + cfg.x.muon_sf = MuonSFConfig( correction="NUM_TightRelIso_DEN_TightIDandIPCut", campaign="2017_UL", ) *get_muon_config* can be adapted in a subclass in case it is stored differently in the config. - Optionally, a *muon_mask* can be supplied to compute the scale factor weight based only on a - subset of muons. + Optionally, a *muon_mask* can be supplied to compute the scale factor weight based only on a subset of muons. """ - # flat absolute eta and pt views - abs_eta = flat_np_view(abs(events.Muon["eta"][muon_mask]), axis=1) - pt = flat_np_view(events.Muon["pt"][muon_mask], axis=1) - + # fold muon mask with pt cuts if given + if self.muon_config.min_pt > 0.0: + pt_mask = events.Muon.pt >= self.muon_config.min_pt + muon_mask = pt_mask if muon_mask is Ellipsis else (pt_mask & muon_mask) + if self.muon_config.max_pt > 0.0: + pt_mask = events.Muon.pt <= self.muon_config.max_pt + muon_mask = pt_mask if muon_mask is Ellipsis else (pt_mask & muon_mask) + + # prepare input variables + muons = events.Muon[muon_mask] variable_map = { "year": self.muon_config.campaign, - "abseta": abs_eta, - "eta": abs_eta, - "pt": pt, + "eta": muons.eta, + "abseta": abs(muons.eta), + "pt": muons.pt, } # loop over systematics @@ -108,10 +118,7 @@ def muon_weights( "ValType": syst, # syst key in 2017 } inputs = [variable_map_syst[inp.name] for inp in self.muon_sf_corrector.inputs] - sf_flat = self.muon_sf_corrector(*inputs) - - # add the correct layout to it - sf = layout_ak_array(sf_flat, events.Muon["pt"][muon_mask]) + sf = self.muon_sf_corrector.evaluate(*inputs) # create the product over all muons in one event weight = ak.prod(sf, axis=1, mask_identity=False) diff --git a/columnflow/production/cms/scale.py b/columnflow/production/cms/scale.py index caa683566..0a44fefa2 100644 --- a/columnflow/production/cms/scale.py +++ b/columnflow/production/cms/scale.py @@ -220,10 +220,14 @@ def murmuf_envelope_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Ar # take the max/min value of all considered variations murf_weights = (events.LHEScaleWeight[non_zero_mask] / murf_nominal)[:, envelope_indices] + weights_up = np.ones(len(events), dtype=np.float32) + weights_down = np.ones(len(events), dtype=np.float32) + weights_up[non_zero_mask] = ak.max(murf_weights, axis=1) + weights_down[non_zero_mask] = ak.min(murf_weights, axis=1) # store columns events = set_ak_column_f32(events, "murmuf_envelope_weight", ones) - events = set_ak_column_f32(events, "murmuf_envelope_weight_down", ak.min(murf_weights, axis=1)) - events = set_ak_column_f32(events, "murmuf_envelope_weight_up", ak.max(murf_weights, axis=1)) + events = set_ak_column_f32(events, "murmuf_envelope_weight_down", weights_down) + events = set_ak_column_f32(events, "murmuf_envelope_weight_up", weights_up) return events diff --git a/columnflow/production/cms/seeds.py b/columnflow/production/cms/seeds.py index 09c84d8a9..f8834b8db 100644 --- a/columnflow/production/cms/seeds.py +++ b/columnflow/production/cms/seeds.py @@ -30,10 +30,6 @@ def create_seed(val: int, n_hex: int = 16) -> int: return int(hashlib.sha256(bytes(str(val), "utf-8")).hexdigest()[:-(n_hex + 1):-1], base=16) -# store a vectorized version (only interface, not actually simd'ing) -create_seed_vec = np.vectorize(create_seed, otypes=[np.uint64]) - - @producer( uses={ # global columns for event seed @@ -74,7 +70,7 @@ def deterministic_event_seeds(self, events: ak.Array, **kwargs) -> ak.Array: before invoking this producer. """ # started from an already hashed seed based on event, run and lumi info multiplied with primes - seed = create_seed_vec( + seed = self.create_seed_vec( np.asarray( self.primes[7] * ak.values_astype(events.event, np.uint64) + self.primes[5] * ak.values_astype(events.run, np.uint64) + @@ -125,7 +121,7 @@ def deterministic_event_seeds(self, events: ak.Array, **kwargs) -> ak.Array: seed = seed + primes * ak.values_astype(hashed, np.uint64) # create and store them - seed = ak.Array(create_seed_vec(np.asarray(seed))) + seed = ak.Array(self.create_seed_vec(np.asarray(seed))) events = set_ak_column(events, "deterministic_seed", seed, value_type=np.uint64) # uniqueness test across the chunk for debugging @@ -178,6 +174,9 @@ def apply_route(ak_array: ak.Array, route: Route) -> ak.Array | None: self.apply_route = apply_route + # store a vectorized version of the create_seed function (only interface, not actually simd'ing) + self.create_seed_vec = np.vectorize(create_seed, otypes=[np.uint64]) + class deterministic_object_seeds(Producer): @@ -217,7 +216,7 @@ def call_func(self, events: ak.Array, **kwargs) -> ak.Array: ) ) np_object_seed = np.asarray(ak.flatten(object_seed)) - np_object_seed[:] = create_seed_vec(np_object_seed) + np_object_seed[:] = self.create_seed_vec(np_object_seed) # store them events = set_ak_column(events, f"{self.object_field}.deterministic_seed", object_seed, value_type=np.uint64) @@ -253,6 +252,9 @@ def setup_func( # store primes in array self.primes = np.array(primes, dtype=np.uint64) + # store a vectorized version of the create_seed function (only interface, not actually simd'ing) + self.create_seed_vec = np.vectorize(create_seed, otypes=[np.uint64]) + deterministic_jet_seeds = deterministic_object_seeds.derive( "deterministic_jet_seeds", diff --git a/columnflow/production/cms/top_pt_weight.py b/columnflow/production/cms/top_pt_weight.py index bb1fb4c4e..8207414d2 100644 --- a/columnflow/production/cms/top_pt_weight.py +++ b/columnflow/production/cms/top_pt_weight.py @@ -6,13 +6,13 @@ from __future__ import annotations -from dataclasses import dataclass +import dataclasses import law from columnflow.production import Producer, producer from columnflow.util import maybe_import -from columnflow.columnar_util import set_ak_column +from columnflow.columnar_util import set_ak_column, full_like ak = maybe_import("awkward") np = maybe_import("numpy") @@ -21,134 +21,101 @@ logger = law.logger.get_logger(__name__) -@dataclass -class TopPtWeightConfig: - params: dict[str, float] - pt_max: float = 500.0 - - @classmethod - def new(cls, obj: TopPtWeightConfig | dict[str, float]) -> TopPtWeightConfig: - # backward compatibility only - if isinstance(obj, cls): - return obj - return cls(params=obj) - - -@producer( - uses={"GenPart.{pdgId,statusFlags}"}, - # requested GenPartonTop columns, passed to the *uses* and *produces* - produced_top_columns={"pt"}, - mc_only=True, - # skip the producer unless the datasets has this specified tag (no skip check performed when none) - require_dataset_tag="has_top", -) -def gen_parton_top(self: Producer, events: ak.Array, **kwargs) -> ak.Array: +@dataclasses.dataclass +class TopPtWeightFromDataConfig: """ - Produce parton-level top quarks (before showering and detector simulation). - Creates new collection named "GenPartonTop" - - *produced_top_columns* can be adapted to change the columns that will be produced - for the GenPartonTop collection. - - The function is skipped when the dataset is data or when it does not have the tag *has_top*. - - :param events: awkward array containing events to process + Container to configure the top pt reweighting parameters for the method based on fits to data. For more info, see + https://twiki.cern.ch/twiki/bin/viewauth/CMS/TopPtReweighting?rev=31#TOP_PAG_corrections_based_on_dat """ - # find parton-level top quarks - abs_id = abs(events.GenPart.pdgId) - t = events.GenPart[abs_id == 6] - t = t[t.hasFlags("isLastCopy")] - t = t[~ak.is_none(t, axis=1)] - - # save the column - events = set_ak_column(events, "GenPartonTop", t) - - return events - - -@gen_parton_top.init -def gen_parton_top_init(self: Producer, **kwargs) -> bool: - for col in self.produced_top_columns: - self.uses.add(f"GenPart.{col}") - self.produces.add(f"GenPartonTop.{col}") + params: dict[str, float] = dataclasses.field(default_factory=lambda: { + "a": 0.0615, + "a_up": 0.0615 * 1.5, + "a_down": 0.0615 * 0.5, + "b": -0.0005, + "b_up": -0.0005 * 1.5, + "b_down": -0.0005 * 0.5, + }) + pt_max: float = 500.0 -@gen_parton_top.skip -def gen_parton_top_skip(self: Producer, **kwargs) -> bool: +@dataclasses.dataclass +class TopPtWeightFromTheoryConfig: """ - Custom skip function that checks whether the dataset is a MC simulation containing top quarks in the first place - using the :py:attr:`require_dataset_tag` attribute. + Container to configure the top pt reweighting parameters for the theory-based method. For more info, see + https://twiki.cern.ch/twiki/bin/viewauth/CMS/TopPtReweighting?rev=31#TOP_PAG_corrections_based_on_the """ - # never skip if the tag is not set - if self.require_dataset_tag is None: - return False - - return self.dataset_inst.is_data or not self.dataset_inst.has_tag(self.require_dataset_tag) - - -def get_top_pt_weight_config(self: Producer) -> TopPtWeightConfig: - if self.config_inst.has_aux("top_pt_reweighting_params"): - logger.info_once( - "deprecated_top_pt_weight_config", - "the config aux field 'top_pt_reweighting_params' is deprecated and will be removed in " - "a future release, please use 'top_pt_weight' instead", + params: dict[str, float] = dataclasses.field(default_factory=lambda: { + "a": 0.103, + "b": -0.0118, + "c": -0.000134, + "d": 0.973, + }) + + +# for backward compatibility +class TopPtWeightConfig(TopPtWeightFromDataConfig): + + def __init__(self, *args, **kwargs): + logger.warning_once( + "TopPtWeightConfig is deprecated and will be removed in future versions, please use " + "TopPtWeightFromDataConfig instead to keep using the data-based method, or TopPtWeightFromTheoryConfig to " + "use the theory-based method", ) - params = self.config_inst.x.top_pt_reweighting_params - else: - params = self.config_inst.x.top_pt_weight - - return TopPtWeightConfig.new(params) + super().__init__(*args, **kwargs) @producer( - uses={"GenPartonTop.pt"}, + uses={"gen_top.t.pt"}, produces={"top_pt_weight{,_up,_down}"}, - get_top_pt_weight_config=get_top_pt_weight_config, - # skip the producer unless the datasets has this specified tag (no skip check performed when none) - require_dataset_tag="is_ttbar", + get_top_pt_weight_config=(lambda self: self.config_inst.x.top_pt_weight), ) def top_pt_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: - """ - Compute SF to be used for top pt reweighting. + r""" + Compute SF to be used for top pt reweighting, either with information from a fit to data or from theory. See https://twiki.cern.ch/twiki/bin/view/CMS/TopPtReweighting?rev=31 for more information. - The *GenPartonTop.pt* column can be produced with the :py:class:`gen_parton_top` Producer. The - SF should *only be applied in ttbar MC* as an event weight and is computed based on the - gen-level top quark transverse momenta. - - The top pt reweighting parameters should be given as an auxiliary entry in the config: + The method to be used depends on the config entry obtained with *get_top_pt_config* which should either be of + type :py:class:`TopPtWeightFromDataConfig` or :py:class:`TopPtWeightFromTheoryConfig`. - .. code-block:: python + - data-based: $SF(p_T)=e^{a + b \cdot p_T}$ + - theory-based: $SF(p_T)=a \cdot e^{b \cdot p_T} + c \cdot p_T + d$ - cfg.x.top_pt_reweighting_params = { - "a": 0.0615, - "a_up": 0.0615 * 1.5, - "a_down": 0.0615 * 0.5, - "b": -0.0005, - "b_up": -0.0005 * 1.5, - "b_down": -0.0005 * 0.5, - } + The *gen_top.t.pt* column can be produced with the :py:class:`gen_top_lookup` producer. The SF should *only be + applied in ttbar MC* as an event weight and is computed based on the gen-level top quark transverse momenta. + The top pt weight configuration should be given as an auxiliary entry "top_pt_weight" in the config. *get_top_pt_config* can be adapted in a subclass in case it is stored differently in the config. - - :param events: awkward array containing events to process """ # check the number of gen tops - if ak.any((n_tops := ak.num(events.GenPartonTop, axis=1)) != 2): + if ak.any((n_tops := ak.num(events.gen_top.t, axis=1)) != 2): raise Exception( - f"{self.cls_name} can only run on events with two generator top quarks, but found " - f"counts of {','.join(map(str, sorted(set(n_tops))))}", + f"{self.cls_name} can only run on events with two generator top quarks, but found counts of " + f"{','.join(map(str, sorted(set(n_tops))))}", ) - # clamp top pt - top_pt = events.GenPartonTop.pt - if self.cfg.pt_max >= 0.0: + # get top pt + top_pt = events.gen_top.t.pt + if not self.theory_method and self.cfg.pt_max >= 0.0: top_pt = ak.where(top_pt > self.cfg.pt_max, self.cfg.pt_max, top_pt) - for variation in ("", "_up", "_down"): - # evaluate SF function - sf = np.exp(self.cfg.params[f"a{variation}"] + self.cfg.params[f"b{variation}"] * top_pt) + for variation in ["", "_up", "_down"]: + # evaluate SF function, implementation is method dependent + if self.theory_method: + # up variation: apply twice the effect + # down variation: no weight at all + if variation != "_down": + sf = ( + self.cfg.params["a"] * np.exp(self.cfg.params["b"] * top_pt) + + self.cfg.params["c"] * top_pt + + self.cfg.params["d"] + ) + if variation == "_up": + sf = 1.0 + 2.0 * (sf - 1.0) + elif variation == "_down": + sf = full_like(top_pt, 1.0) + else: + sf = np.exp(self.cfg.params[f"a{variation}"] + self.cfg.params[f"b{variation}"] * top_pt) # compute weight from SF product for top and anti-top weight = np.sqrt(np.prod(sf, axis=1)) @@ -163,14 +130,9 @@ def top_pt_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: def top_pt_weight_init(self: Producer) -> None: # store the top pt weight config self.cfg = self.get_top_pt_weight_config() - - -@top_pt_weight.skip -def top_pt_weight_skip(self: Producer, **kwargs) -> bool: - """ - Skip if running on anything except ttbar MC simulation, evaluated via the :py:attr:`require_dataset_tag` attribute. - """ - if self.require_dataset_tag is None: - return self.dataset_inst.is_data - - return self.dataset_inst.is_data or not self.dataset_inst.has_tag("is_ttbar") + if not isinstance(self.cfg, (TopPtWeightFromDataConfig, TopPtWeightFromTheoryConfig)): + raise Exception( + f"{self.cls_name} expects the config entry obtained with get_top_pt_weight_config to be of type " + f"TopPtWeightFromDataConfig or TopPtWeightFromTheoryConfig, but got {type(self.cfg)}", + ) + self.theory_method = isinstance(self.cfg, TopPtWeightFromTheoryConfig) diff --git a/columnflow/production/cmsGhent/btag_weights.py b/columnflow/production/cmsGhent/btag_weights.py index 179ac438c..bb5112023 100644 --- a/columnflow/production/cmsGhent/btag_weights.py +++ b/columnflow/production/cmsGhent/btag_weights.py @@ -20,8 +20,6 @@ ak = maybe_import("awkward") np = maybe_import("numpy") -hist = maybe_import("hist") -correctionlib = maybe_import("correctionlib") logger = law.logger.get_logger(__name__) @@ -64,6 +62,8 @@ def init_btag(self: Producer, add_eff_vars=True): def setup_btag(self: Producer, task: law.Task, reqs: dict): + import correctionlib + bundle = reqs["external_files"] correction_set_btag_wp_corr = correctionlib.CorrectionSet.from_string( self.get_btag_sf(bundle.files).load(formatter="gzip").decode("utf-8"), @@ -140,6 +140,9 @@ def fixed_wp_btag_weights( # get the total number of jets in the chunk jets = events.Jet[jet_mask] if jet_mask is not None else events.Jet jets = set_ak_column(jets, "abseta", abs(jets.eta)) + # currently set hard max on pt for efficiency since overflow could not be changed in correctionlib + # (could also manually change the flow) + jets = set_ak_column(jets, "minpt", ak.where(jets.pt <= 999, jets.pt, 999)) # helper to create and store the weight def add_weight(flavour_group, systematic, variation=None): @@ -185,9 +188,7 @@ def sf_eff_wp(working_point, none_value=0.): ) eff = self.btag_eff_corrector( flat_input.hadronFlavour, - # currently set hard max on pt since overflow could not be changed in correctionlib - # (could also manually change the flow) - ak.min([flat_input.pt, 999 * ak.ones_like(flat_input.pt)], axis=0), + flat_input.minpt, flat_input.abseta, working_point, ) @@ -299,6 +300,8 @@ def fixed_wp_btag_weights_setup( inputs: dict, reader_targets: law.util.InsertableDict, ) -> None: + import correctionlib + correction_set_btag_wp_corr = setup_btag(self, task, reqs) # fix for change in nomenclature of deepJet scale factors for light hadronFlavour jets @@ -366,6 +369,7 @@ def btag_efficiency_hists( hists: DotDict | dict = None, **kwargs, ) -> ak.Array: + import hist if hists is None: return events diff --git a/columnflow/production/cmsGhent/gen_features.py b/columnflow/production/cmsGhent/gen_features.py index 2f170aef4..6667e6697 100644 --- a/columnflow/production/cmsGhent/gen_features.py +++ b/columnflow/production/cmsGhent/gen_features.py @@ -10,7 +10,6 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") def _geometric_matching(particles1: ak.Array, particles2: ak.Array) -> (ak.Array, ak.Array): diff --git a/columnflow/production/cmsGhent/lepton.py b/columnflow/production/cmsGhent/lepton.py index 051e4358e..31e080139 100644 --- a/columnflow/production/cmsGhent/lepton.py +++ b/columnflow/production/cmsGhent/lepton.py @@ -12,8 +12,6 @@ ak = maybe_import("awkward") np = maybe_import("numpy") -hist = maybe_import("hist") -correctionlib = maybe_import("correctionlib") logger = law.logger.get_logger(__name__) diff --git a/columnflow/production/cmsGhent/trigger/hist_producer.py b/columnflow/production/cmsGhent/trigger/hist_producer.py index 0f3abaa0b..7b0840689 100644 --- a/columnflow/production/cmsGhent/trigger/hist_producer.py +++ b/columnflow/production/cmsGhent/trigger/hist_producer.py @@ -15,10 +15,12 @@ import columnflow.production.cmsGhent.trigger.util as util from columnflow.selection import SelectionResult import order as od +from columnflow.types import TYPE_CHECKING np = maybe_import("numpy") ak = maybe_import("awkward") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") logger = law.logger.get_logger(__name__) @@ -34,6 +36,8 @@ def trigger_efficiency_hists( object_mask: dict = None, **kwargs, ) -> ak.Array: + import hist + if hists is None: logger.warning_once(self.cls_name + " did not get any histograms") return events diff --git a/columnflow/production/cmsGhent/trigger/sf_producer.py b/columnflow/production/cmsGhent/trigger/sf_producer.py index 82e320214..fbe5c03b1 100644 --- a/columnflow/production/cmsGhent/trigger/sf_producer.py +++ b/columnflow/production/cmsGhent/trigger/sf_producer.py @@ -12,10 +12,8 @@ from columnflow.columnar_util import set_ak_column, has_ak_column, Route import columnflow.production.cmsGhent.trigger.util as util - np = maybe_import("numpy") ak = maybe_import("awkward") -hist = maybe_import("hist") logger = law.logger.get_logger(__name__) diff --git a/columnflow/production/cmsGhent/trigger/uncertainties.py b/columnflow/production/cmsGhent/trigger/uncertainties.py index 6e0cec903..4ec341862 100644 --- a/columnflow/production/cmsGhent/trigger/uncertainties.py +++ b/columnflow/production/cmsGhent/trigger/uncertainties.py @@ -4,12 +4,12 @@ from columnflow.util import maybe_import from columnflow.production.cmsGhent.trigger.Koopman_test import koopman_confint import columnflow.production.cmsGhent.trigger.util as util +from columnflow.types import TYPE_CHECKING -import numpy as np - -hist = maybe_import("hist") - -Hist = hist.Hist +np = maybe_import("numpy") +if TYPE_CHECKING: + hist = maybe_import("hist") + Hist = hist.Hist def calc_stat( diff --git a/columnflow/production/cmsGhent/trigger/util.py b/columnflow/production/cmsGhent/trigger/util.py index 6054ecb1a..fc0963ec1 100644 --- a/columnflow/production/cmsGhent/trigger/util.py +++ b/columnflow/production/cmsGhent/trigger/util.py @@ -9,10 +9,13 @@ from columnflow.production import Producer from columnflow.util import maybe_import from columnflow.plotting.plot_util import use_flow_bins +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") -Hist = hist.Hist np = maybe_import("numpy") +ak = maybe_import("awkward") +if TYPE_CHECKING: + hist = maybe_import("hist") + Hist = hist.Hist logger = law.logger.get_logger(__name__) @@ -23,6 +26,8 @@ def reduce_hist( exclude: str | Collection[str] = tuple(), keepdims=True, ): + import hist + exclude = law.util.make_list(exclude) if reduce is Ellipsis: return histogram.project(*exclude) @@ -74,6 +79,8 @@ def syst_hist( syst_name: str = "", arrays: np.ndarray | tuple[np.ndarray, np.ndarray] = None, ) -> Hist: + import hist + if syst_name == "central": variations = [syst_name] else: diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 8d88985da..c118b3468 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -6,7 +6,10 @@ from __future__ import annotations -from collections import defaultdict +import copy +import itertools +import dataclasses +import collections import law import order as od @@ -15,156 +18,259 @@ from columnflow.production import Producer, producer from columnflow.util import maybe_import, DotDict from columnflow.columnar_util import set_ak_column -from columnflow.types import Any +from columnflow.types import Any, Sequence np = maybe_import("numpy") -sp = maybe_import("scipy") -maybe_import("scipy.sparse") ak = maybe_import("awkward") logger = law.logger.get_logger(__name__) -def get_inclusive_dataset(self: Producer) -> od.Dataset: +def get_stitching_datasets(self: Producer, debug: bool = False) -> tuple[od.Dataset, list[od.Dataset]]: """ - Helper function to obtain the inclusive dataset from a list of datasets that are required to stitch this - *dataset_inst*. - """ - process_map = {d.processes.get_first(): d for d in self.stitching_datasets} - - process_inst = self.dataset_inst.processes.get_first() - incl_dataset = None - while process_inst: - if process_inst in process_map: - incl_dataset = process_map[process_inst] - process_inst = process_inst.parent_processes.get_first(default=None) - - if not incl_dataset: - raise Exception("inclusive dataset not found") - - unmatched_processes = {p for p in process_map if not incl_dataset.has_process(p, deep=True)} - if unmatched_processes: - raise Exception(f"processes {unmatched_processes} not found in inclusive dataset") - - return incl_dataset + Helper function to obtain information about stitching datasets: - -def get_stitching_datasets(self: Producer) -> list[od.Dataset]: - """ - Helper function to obtain all datasets that are required to stitch this *dataset_inst*. + - the inclusive dataset, which is the dataset that contains all processes + - all datasets that are required to stitch this *dataset_inst* """ - stitching_datasets = { - d for d in self.config_inst.datasets + # first collect all datasets that are needed to stitch the current dataset + required_datasets = { + d + for d in self.config_inst.datasets if ( d.has_process(self.dataset_inst.processes.get_first(), deep=True) or self.dataset_inst.has_process(d.processes.get_first(), deep=True) ) } - return list(stitching_datasets) + + # determine the inclusive dataset + inclusive_dataset = None + for dataset_inst in required_datasets: + for other_dataset_inst in required_datasets: + if dataset_inst == other_dataset_inst: + continue + # check if the other dataset is a sub-dataset of the current one by comparing their leading process + if not dataset_inst.has_process(other_dataset_inst.processes.get_first(), deep=True): + break + else: + # if we did not break, the dataset is the inclusive one + inclusive_dataset = dataset_inst + break + if not inclusive_dataset: + raise Exception("inclusive dataset not found") + + if debug: + logger.info( + f"determined info for stitching content of dataset '{self.dataset_inst.name}':\n" + f" - inclusive dataset: {inclusive_dataset.name}\n" + f" - required datasets: {', '.join(d.name for d in required_datasets)}", + ) + + return inclusive_dataset, list(required_datasets) -def get_br_from_inclusive_dataset( +def get_br_from_inclusive_datasets( self: Producer, - inclusive_dataset: od.Dataset, - stats: dict, -) -> dict[int, float]: + process_insts: Sequence[od.Process] | set[od.Process], + dataset_selection_stats: dict[str, dict[str, float | dict[str, float]]], + merged_selection_stats: dict[str, float | dict[str, float]], + debug: bool = False, +) -> dict[od.Process, float]: """ - Helper function to compute the branching ratios from the inclusive sample. This is done with ratios of event weights - isolated per dataset and thus independent of the overall mc weight normalization. + Helper function to compute the branching ratios from sum of weights of inclusive samples. """ - # define helper variables and mapping between process ids and dataset names - proc_ds_map = { - d.processes.get_first().id: d - for d in self.config_inst.datasets - if d.name in stats.keys() - } - inclusive_proc = inclusive_dataset.processes.get_first() - N = lambda x: sn.Number(x, np.sqrt(x)) # alias for Number with counting error - - # create a dictionary "parent process id" -> {"child process id" -> "branching ratio", ...} - # each ratio is based on gen weight sums - child_brs: dict[int, dict[int, sn.Number]] = defaultdict(dict) - for proc, _, child_procs in inclusive_dataset.walk_processes(): - # the process must be covered by a dataset and should not be a leaf process - if proc.id not in proc_ds_map or proc.is_leaf_process: - continue - dataset_name = proc_ds_map[proc.id].name - - # get the mc weights for the "mother" dataset and add an entry for the process - sum_mc_weight: float = stats[dataset_name]["sum_mc_weight"] - sum_mc_weight_per_process: dict[str, float] = stats[dataset_name]["sum_mc_weight_per_process"] - # use the number of events to compute the error on the branching ratio - num_events: int = stats[dataset_name]["num_events"] - num_events_per_process: dict[str, int] = stats[dataset_name]["num_events_per_process"] - - # loop over all child processes - for child_proc in child_procs: - # skip processes that are not covered by any dataset or irrelevant for the used dataset - # (identified as leaf processes that have no occurrences in the stats - # or as non-leaf processes that are not in the stitching datasets) - is_leaf = child_proc.is_leaf_process - if ( - (is_leaf and str(child_proc.id) not in sum_mc_weight_per_process) or - (not is_leaf and child_proc.id not in proc_ds_map) - ): - + # step 1: per desired process, collect datasets that contain them + process_datasets = collections.defaultdict(set) + for process_inst in process_insts: + for dataset_name, dstats in dataset_selection_stats.items(): + if str(process_inst.id) in dstats["sum_mc_weight_per_process"]: + process_datasets[process_inst].add(self.config_inst.get_dataset(dataset_name)) + + # step 2: per dataset, collect all "lowest level" processes that are contained in them + dataset_processes = collections.defaultdict(set) + for dataset_name in dataset_selection_stats: + dataset_inst = self.config_inst.get_dataset(dataset_name) + dataset_process_inst = dataset_inst.processes.get_first() + for process_inst in process_insts: + if process_inst == dataset_process_inst or dataset_process_inst.has_process(process_inst, deep=True): + dataset_processes[dataset_inst].add(process_inst) + + # step 3: per process, structure the assigned datasets and corresponding processes in DAGs, from more inclusive down + # to more exclusive phase spaces; usually each DAG can contain multiple paths to compute the BR of a single + # process; this is resolved in step 4 + @dataclasses.dataclass + class Node: + process_inst: od.Process + dataset_inst: od.Dataset | None = None + next: set[Node] = dataclasses.field(default_factory=set) + + def __hash__(self) -> int: + return hash((self.process_inst, self.dataset_inst)) + + def str_lines(self) -> list[str]: + lines = [ + f"{self.__class__.__name__}(", + f" process={self.process_inst.name}({self.process_inst.id})", + f" dataset={self.dataset_inst.name if self.dataset_inst else 'None'}", + ] + if self.next: + lines.append(" next={") + for n in self.next: + lines.extend(f" {line}" for line in n.str_lines()) + lines.append(" }") + else: + lines.append(r" next={}") + lines.append(")") + return lines + + def __str__(self) -> str: + return "\n".join(self.str_lines()) + + process_dags = {} + for process_inst, dataset_insts in process_datasets.items(): + # first, per dataset, remember all sub (more exclusive) datasets + # (the O(n^2) is not necessarily optimal, but we are dealing with very small numbers here, thus acceptable) + sub_datasets = {} + for d_incl, d_excl in itertools.permutations(dataset_insts, 2): + if d_incl.processes.get_first().has_process(d_excl.processes.get_first(), deep=True): + sub_datasets.setdefault(d_incl, set()).add(d_excl) + # then, expand to a DAG structure + nodes = {} + excl_nodes = set() + for d_incl, d_excls in sub_datasets.items(): + for d_excl in d_excls: + if d_incl not in nodes: + nodes[d_incl] = Node(d_incl.processes.get_first(), d_incl) + if d_excl not in nodes: + nodes[d_excl] = Node(d_excl.processes.get_first(), d_excl) + nodes[d_incl].next.add(nodes[d_excl]) + excl_nodes.add(nodes[d_excl]) + # mark the root node as the head of the DAG + dag = (set(nodes.values()) - excl_nodes).pop() + # add another node to leaves that only contains the process instance + for node in excl_nodes: + if node.next or node.process_inst == process_inst: continue - - # determine relevant leaf processes that will be summed over - # (since the all stats are only derived for those) - leaf_proc_ids = ( - [child_proc.id] - if is_leaf or str(child_proc.id) in sum_mc_weight_per_process - else [ - p.id for p, _, _ in child_proc.walk_processes() - if str(p.id) in sum_mc_weight_per_process - ] - ) - - # compute the br and its uncertainty using the bare number of events - # NOTE: we assume that the uncertainty is independent of the mc weights, so we can use - # the same relative uncertainty; this is a simplification, but should be fine for most - # cases; we can improve this by switching from jsons to hists when storing sum of weights - leaf_sum = lambda d: sum(d.get(str(proc_id), 0) for proc_id in leaf_proc_ids) - br_nom = leaf_sum(sum_mc_weight_per_process) / sum_mc_weight - br_unc = N(leaf_sum(num_events_per_process)) / N(num_events) - child_brs[proc.id][child_proc.id] = sn.Number( - br_nom, - br_unc(sn.UP, unc=True, factor=True) * 1j, # same relative uncertainty + if process_inst not in nodes: + nodes[process_inst] = Node(process_inst) + node.next.add(nodes[process_inst]) + process_dags[process_inst] = dag + + # step 4: per process, compute the branching ratio for each possible path in the DAG, while keeping track of the + # statistical precision of each combination, evaluated based on the raw number of events; then pick the + # most precise path; again, there should usually be just a single path, but multiple ones are possible when + # datasets have complex overlap + def get_single_br(dataset_inst: od.Dataset, process_inst: od.Process) -> sn.Number | None: + # process_inst might refer to a mid-layer process, so check which lowest-layer processes it is made of + lowest_process_ids = ( + [process_inst.id] + if process_inst in process_insts + else [ + int(process_id_str) + for process_id_str in dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"] + if process_inst.has_process(int(process_id_str), deep=True) + ] + ) + # extract stats + process_sum_weights = sum( + dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"].get(str(process_id), 0.0) + for process_id in lowest_process_ids + ) + dataset_sum_weights = sum(dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"].values()) + process_num_events = sum( + dataset_selection_stats[dataset_inst.name]["num_events_per_process"].get(str(process_id), 0.0) + for process_id in lowest_process_ids + ) + dataset_num_events = sum(dataset_selection_stats[dataset_inst.name]["num_events_per_process"].values()) + # when there are no events, return None + if process_num_events == 0: + logger.warning( + f"found no events for process '{process_inst.name}' ({process_inst.id}) with subprocess ids " + f"'{','.join(map(str, lowest_process_ids))}' in selection stats of dataset {dataset_inst.name}", ) + return None + # compute the ratio of events, assuming correlated poisson counting errors since numbers come from the same + # dataset, then compute the relative uncertainty + num_ratio = ( + sn.Number(process_num_events, process_num_events**0.5) / + sn.Number(dataset_num_events, dataset_num_events**0.5) + ) + rel_unc = num_ratio(sn.UP, unc=True, factor=True) + # compute the branching ratio, using the same relative uncertainty and store using the dataset name to mark its + # limited statistics as the source of uncertainty which is important for consistent error propagation + br = sn.Number(process_sum_weights / dataset_sum_weights, {f"{dataset_inst.name}_stats": rel_unc * 1j}) + return br + + def path_repr(br_path: tuple[sn.Number, ...], dag_path: tuple[Node, ...]) -> str: + return " X ".join( + f"{node.process_inst.name} (br = {br.combine_uncertainties().str(format=3)})" + for br, node in zip(br_path, dag_path) + ) - # define actual per-process branching ratios - branching_ratios: dict[int, float] = {} - - def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: - """ - Recursively multiply the branching ratios from the nested dictionary. - """ - # when the br for proc_id can be created from sub processes, calculate it via product - if proc_id in child_brs: - for child_id, child_br in child_brs[proc_id].items(): - # multiply the branching ratios assuming no correlation - prod_br = child_br.mul(proc_br, rho=0, inplace=False) - multiply_branching_ratios(child_id, prod_br) - return - - # warn the user if the relative (statistical) error is large - rel_unc = proc_br(sn.UP, unc=True, factor=True) - if rel_unc > 0.05: + process_brs = {} + process_brs_debug = {} + for process_inst, dag in process_dags.items(): + brs = [] + queue = collections.deque([(dag, (br := sn.Number(1.0, 0.0)), (br,), (dag,))]) + while queue: + node, br, br_path, dag_path = queue.popleft() + if not node.next: + brs.append((br, br_path, dag_path)) + continue + for sub_node in node.next: + sub_br = get_single_br(node.dataset_inst, sub_node.process_inst) + if sub_br is not None: + queue.append((sub_node, br * sub_br, br_path + (sub_br,), dag_path + (sub_node,))) + # combine all uncertainties + brs = [(br.combine_uncertainties(), *paths) for br, *paths in brs] + # select the most certain one + brs.sort(key=lambda tpl: tpl[0](sn.UP, unc=True, factor=True)) + best_br, best_br_path, best_dag_path = brs[0] + process_brs[process_inst] = best_br.nominal + process_brs_debug[process_inst] = (best_br.nominal, best_br(sn.UP, unc=True, factor=True)) # value and % unc + # show a warning in case the relative uncertainty is large + if (rel_unc := best_br(sn.UP, unc=True, factor=True)) > 0.1: logger.warning( - f"large error on the branching ratio for process {inclusive_proc.get_process(proc_id).name} with " - f"process id {proc_id} ({rel_unc * 100:.2f}%)", + f"large error on the branching ratio of {rel_unc * 100:.2f}% for process '{process_inst.name}' " + f"({process_inst.id}), calculated along\n {path_repr(best_br_path, best_dag_path)}", ) + # in case there were multiple values, check their compatibility with the best one and warn if they diverge + for i, (br, br_path, dag_path) in enumerate(brs[1:], 2): + abs_diff = abs(best_br.n - br.n) + rel_diff = abs_diff / best_br.n + pull = abs(best_br.n - br.n) / (best_br.u(direction="up")**2 + br.u(direction="up")**2)**0.5 + if rel_diff > 0.1 and pull > 3: + logger.warning( + f"detected diverging branching ratios between the best and the one on position {i} for process " + f"'{process_inst.name}' (abs_diff={abs_diff:.4f}, rel_diff={rel_diff:.4f}, pull={pull:.2f} ):" + f"\nbest path: {best_br.str(format=3)} from {path_repr(best_br_path, best_dag_path)}" + f"\npath {i} : {br.str(format=3)} from {path_repr(br_path, dag_path)}", + ) - # just store the nominal value - branching_ratios[proc_id] = proc_br.nominal + if debug: + from tabulate import tabulate + header = ["process name", "process id", "branching ratio", "uncertainty (%)"] + rows = [ + [ + process_inst.name, process_inst.id, process_brs_debug[process_inst][0], + f"{process_brs_debug[process_inst][1] * 100:.4f}", + ] + for process_inst in sorted(process_brs_debug) + ] + logger.info(f"extracted branching ratios from process occurrence in datasets:\n{tabulate(rows, header)}") - # fill all branching ratios - for proc_id, br in child_brs[inclusive_proc.id].items(): - multiply_branching_ratios(proc_id, br) + return process_brs - return branching_ratios + +def update_dataset_selection_stats( + self: Producer, + dataset_selection_stats: dict[str, dict[str, float | dict[str, float]]], +) -> dict[str, dict[str, float | dict[str, float]]]: + """ + Hook to optionally update the per-dataset selection stats. + """ + return dataset_selection_stats @producer( @@ -173,12 +279,16 @@ def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: weight_name="normalization_weight", # which luminosity to apply, uses the value stored in the config when None luminosity=None, + # whether to normalize weights per dataset to the mean weight first (to cancel out numeric differences) + normalize_weights_per_dataset=True, # whether to allow stitching datasets allow_stitching=False, - get_xsecs_from_inclusive_dataset=False, + get_xsecs_from_inclusive_datasets=False, get_stitching_datasets=get_stitching_datasets, - get_inclusive_dataset=get_inclusive_dataset, - get_br_from_inclusive_dataset=get_br_from_inclusive_dataset, + get_br_from_inclusive_datasets=get_br_from_inclusive_datasets, + update_dataset_selection_stats=update_dataset_selection_stats, + update_dataset_selection_stats_br=None, + update_dataset_selection_stats_sum_weights=None, # only run on mc mc_only=True, ) @@ -207,33 +317,55 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra process_id = np.asarray(events.process_id) # ensure all ids were assigned a cross section - unique_process_ids = set(process_id) - invalid_ids = unique_process_ids - self.xs_process_ids + unique_process_ids = set(np.unique(process_id)) + invalid_ids = unique_process_ids - self.known_process_ids if invalid_ids: + invalid_names = [ + f"{self.config_inst.get_process(proc_id).name} ({proc_id})" + for proc_id in invalid_ids + ] raise Exception( - f"process_id field contains id(s) {invalid_ids} for which no cross sections were found; process ids with " - f"cross sections: {self.xs_process_ids}", + f"process_id field contains entries {', '.join(invalid_names)} for which no cross sections were found; " + f"process ids with cross sections: {self.known_process_ids}", ) # read the weight per process (defined as lumi * xsec / sum_weights) from the lookup table - process_weight = np.squeeze(np.asarray(self.process_weight_table[0, process_id].todense())) + process_weight = np.squeeze(np.asarray(self.process_weight_table[process_id].todense()), axis=-1) # compute the weight and store it norm_weight = events.mc_weight * process_weight events = set_ak_column(events, self.weight_name, norm_weight, value_type=np.float32) - # if we are stitching, we also compute the inclusive weight for debugging purposes - if ( - self.allow_stitching and - self.get_xsecs_from_inclusive_dataset and - self.dataset_inst == self.inclusive_dataset - ): + # when stitching, also compute the inclusive-only weight + if self.allow_stitching and self.dataset_inst == self.inclusive_dataset: incl_norm_weight = events.mc_weight * self.inclusive_weight events = set_ak_column(events, self.weight_name_incl, incl_norm_weight, value_type=np.float32) return events +@normalization_weights.init +def normalization_weights_init(self: Producer, **kwargs) -> None: + """ + Initializes the normalization weights producer by setting up the normalization weight column. + """ + # declare the weight name to be a produced column + self.produces.add(self.weight_name) + + # when stitching is enabled, store specific information + if self.allow_stitching: + # remember the inclusive dataset and all datasets needed to determine the weights of processes in _this_ dataset + self.inclusive_dataset, self.required_datasets = self.get_stitching_datasets() + + # potentially also store the weight needed for only using the inclusive dataset + if self.dataset_inst == self.inclusive_dataset: + self.weight_name_incl = f"{self.weight_name}_inclusive" + self.produces.add(self.weight_name_incl) + else: + self.inclusive_dataset = self.dataset_inst + self.required_datasets = [self.dataset_inst] + + @normalization_weights.requires def normalization_weights_requires( self: Producer, @@ -245,7 +377,7 @@ def normalization_weights_requires( Adds the requirements needed by the underlying py:attr:`task` to access selection stats into *reqs*. """ # check that all datasets are known - for dataset in self.stitching_datasets: + for dataset in self.required_datasets: if not self.config_inst.has_dataset(dataset): raise Exception(f"unknown dataset '{dataset}' required for normalization weights computation") @@ -256,7 +388,7 @@ def normalization_weights_requires( dataset=dataset.name, branch=-1 if task.is_workflow() else 0, ) - for dataset in self.stitching_datasets + for dataset in self.required_datasets } return reqs @@ -274,127 +406,167 @@ def normalization_weights_setup( """ Sets up objects required by the computation of normalization weights and stores them as instance attributes: + - py: attr: `process_weight_table`: A sparse array serving as a lookup table for the calculated process weights. + This weight is defined as the product of the luminosity, the cross section, divided by the sum of event + weights per process. + - py: attr: `known_process_ids`: A set of all process ids that are known by the lookup table. - py: attr: `process_weight_table`: A sparse array serving as a lookup table for the calculated process weights. This weight is defined as the product of the luminosity, the cross section, divided by the sum of event weights per process. """ + import scipy.sparse + # load the selection stats - selection_stats = { - dataset: task.cached_value( + dataset_selection_stats = { + dataset: copy.deepcopy(task.cached_value( key=f"selection_stats_{dataset}", func=lambda: inp["stats"].load(formatter="json"), - ) + )) for dataset, inp in inputs["selection_stats"].items() } - # if necessary, merge the selection stats across datasets - if len(selection_stats) > 1: - from columnflow.tasks.selection import MergeSelectionStats - merged_selection_stats = defaultdict(float) - for stats in selection_stats.values(): - MergeSelectionStats.merge_counts(merged_selection_stats, stats) - else: - merged_selection_stats = selection_stats[self.dataset_inst.name] - # determine all proceses at any depth in the stitching datasets - process_insts = { - process_inst - for dataset_inst in self.stitching_datasets - for process_inst, _, _ in dataset_inst.walk_processes() + # optionally normalize weights per dataset to their mean, to potentially align different numeric domains + norm_factor = 1.0 + if self.normalize_weights_per_dataset: + for dataset, stats in dataset_selection_stats.items(): + dataset_mean_weight = ( + sum(stats["sum_mc_weight_per_process"].values()) / + sum(stats["num_events_per_process"].values()) + ) + for process_id_str in stats["sum_mc_weight_per_process"]: + stats["sum_mc_weight_per_process"][process_id_str] /= dataset_mean_weight + if dataset == self.dataset_inst.name: + norm_factor = 1.0 / dataset_mean_weight + + # drop unused stats + dataset_selection_stats = { + dataset: {field: stats[field] for field in ["num_events_per_process", "sum_mc_weight_per_process"]} + for dataset, stats in dataset_selection_stats.items() } - # determine ids of processes that were identified in the selection stats - allowed_ids = set(map(int, merged_selection_stats["sum_mc_weight_per_process"])) - - # complain if there are processes seen/id'ed during selection that are not part of the datasets - unknown_process_ids = allowed_ids - {p.id for p in process_insts} - if unknown_process_ids: + # separately treat stats for extracting BRs and sum of mc weights + def extract_stats(*update_funcs): + # create copy + stats = copy.deepcopy(dataset_selection_stats) + # update through one of the functions + for update_func in update_funcs: + if callable(update_func): + stats = update_func(stats) + break + # merge + if len(stats) > 1: + from columnflow.tasks.selection import MergeSelectionStats + merged_stats = collections.defaultdict(float) + for _stats in stats.values(): + MergeSelectionStats.merge_counts(merged_stats, _stats) + else: + merged_stats = stats[self.dataset_inst.name] + return stats, merged_stats + + dataset_selection_stats_br, merged_selection_stats_br = extract_stats( + self.update_dataset_selection_stats_br, + self.update_dataset_selection_stats, + ) + _, merged_selection_stats_sum_weights = extract_stats( + self.update_dataset_selection_stats_sum_weights, + self.update_dataset_selection_stats, + ) + + # get all process ids and instances seen and assigned during selection of this dataset + # (i.e., all possible processes that might be encountered during event processing) + process_ids = set(map(int, dataset_selection_stats_br[self.dataset_inst.name]["sum_mc_weight_per_process"])) + process_insts = set(map(self.config_inst.get_process, process_ids)) + + # consistency check: when the main process of the current dataset is part of these "lowest level" processes, + # there should only be this single process, otherwise the manual (sub) process assignment does not match the + # general dataset -> main process info + if self.dataset_inst.processes.get_first() in process_insts and len(process_insts) > 1: raise Exception( - f"selection stats contain ids of processes that were not previously registered to the config " - f"'{self.config_inst.name}': {', '.join(map(str, unknown_process_ids))}", + f"dataset '{self.dataset_inst.name}' has main process '{self.dataset_inst.processes.get_first().name}' " + "assigned to it (likely as per cmsdb), but the dataset selection stats for this dataset contain multiple " + "sub processes, which is likely a misconfiguration of the manual sub process assignment upstream; found " + f"sub processes: {', '.join(f'{process_inst.name} ({process_inst.id})' for process_inst in process_insts)}", ) - # likewise, drop processes that were not seen during selection - process_insts = {p for p in process_insts if p.id in allowed_ids} - max_id = max(process_inst.id for process_inst in process_insts) + # setup the event weight lookup table + process_weight_table = scipy.sparse.lil_matrix((max(process_ids) + 1, 1), dtype=np.float32) - # get the luminosity - lumi = self.config_inst.x.luminosity if self.luminosity is None else self.luminosity - lumi = lumi.nominal if isinstance(lumi, sn.Number) else float(lumi) - - # create a event weight lookup table - process_weight_table = sp.sparse.lil_matrix((1, max_id + 1), dtype=np.float32) - if self.allow_stitching and self.get_xsecs_from_inclusive_dataset: - inclusive_dataset = self.inclusive_dataset - logger.debug(f"using inclusive dataset {inclusive_dataset.name} for cross section lookup") - - # extract branching ratios from inclusive dataset(s) - inclusive_proc = inclusive_dataset.processes.get_first() - if self.dataset_inst == inclusive_dataset and process_insts == {inclusive_proc}: - branching_ratios = {inclusive_proc.id: 1.0} - else: - branching_ratios = self.get_br_from_inclusive_dataset( - inclusive_dataset=inclusive_dataset, - stats=selection_stats, + def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) -> None: + if sum_weights == 0: + logger.warning( + f"zero sum of weights found for computing normalization weight for process '{process_inst.name}' " + f"({process_inst.id}) in dataset '{self.dataset_inst.name}', going to use weight of 0.0", ) - if not branching_ratios: - raise Exception( - f"no branching ratios could be computed based on the inclusive dataset {inclusive_dataset}", - ) + weight = 0.0 + else: + weight = norm_factor * xsec * lumi / sum_weights + process_weight_table[process_inst.id, 0] = weight - # compute the weight the inclusive dataset would have on its own without stitching - inclusive_xsec = inclusive_proc.get_xsec(self.config_inst.campaign.ecm).nominal - self.inclusive_weight = ( - lumi * inclusive_xsec / selection_stats[inclusive_dataset.name]["sum_mc_weight"] - if self.dataset_inst == inclusive_dataset - else 0 + # get the luminosity + lumi = float(self.config_inst.x.luminosity if self.luminosity is None else self.luminosity) + + # prepare info for the inclusive dataset + inclusive_proc = self.inclusive_dataset.processes.get_first() + inclusive_xsec = inclusive_proc.get_xsec(self.config_inst.campaign.ecm).nominal + + # compute the weight the inclusive dataset would have on its own without stitching + if self.allow_stitching and self.dataset_inst == self.inclusive_dataset: + inclusive_sum_weights = sum( + dataset_selection_stats[self.inclusive_dataset.name]["sum_mc_weight_per_process"].values(), + ) + self.inclusive_weight = norm_factor * inclusive_xsec * lumi / inclusive_sum_weights + + # fill weights into the lut, depending on whether stitching is allowed / needed or not + do_stitch = ( + self.allow_stitching and + self.get_xsecs_from_inclusive_datasets and + (len(process_insts) > 1 or len(self.required_datasets) > 1) + ) + if do_stitch: + logger.debug( + f"using inclusive dataset '{self.inclusive_dataset.name}' and process '{inclusive_proc.name}' for cross " + "section lookup", + ) + + # optionally run the dataset lookup again in debug mode when stitching + is_first_branch = getattr(task, "branch", None) == 0 + if is_first_branch: + self.get_stitching_datasets(debug=True) + + # extract branching ratios + branching_ratios = self.get_br_from_inclusive_datasets( + process_insts, + dataset_selection_stats_br, + merged_selection_stats_br, + debug=is_first_branch, ) # fill the process weight table - for proc_id, br in branching_ratios.items(): - sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(proc_id)] - process_weight_table[0, proc_id] = lumi * inclusive_xsec * br / sum_weights + for process_inst, br in branching_ratios.items(): + sum_weights = merged_selection_stats_sum_weights["sum_mc_weight_per_process"][str(process_inst.id)] + fill_weight_table(process_inst, br * inclusive_xsec, sum_weights) else: # fill the process weight table with per-process cross sections for process_inst in process_insts: - if self.config_inst.campaign.ecm not in process_inst.xsecs.keys(): + if self.config_inst.campaign.ecm not in process_inst.xsecs: raise KeyError( f"no cross section registered for process {process_inst} for center-of-mass energy of " f"{self.config_inst.campaign.ecm}", ) - sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(process_inst.id)] xsec = process_inst.get_xsec(self.config_inst.campaign.ecm).nominal - process_weight_table[0, process_inst.id] = lumi * xsec / sum_weights + sum_weights = merged_selection_stats_sum_weights["sum_mc_weight_per_process"][str(process_inst.id)] + fill_weight_table(process_inst, xsec, sum_weights) + # store lookup table and known process ids self.process_weight_table = process_weight_table - self.xs_process_ids = set(self.process_weight_table.rows[0]) - - -@normalization_weights.init -def normalization_weights_init(self: Producer, **kwargs) -> None: - """ - Initializes the normalization weights producer by setting up the normalization weight column. - """ - self.produces.add(self.weight_name) - if self.allow_stitching: - self.stitching_datasets = self.get_stitching_datasets() - self.inclusive_dataset = self.get_inclusive_dataset() - else: - self.stitching_datasets = [self.dataset_inst] - - if ( - self.allow_stitching and - self.get_xsecs_from_inclusive_dataset and - self.dataset_inst == self.inclusive_dataset - ): - self.weight_name_incl = f"{self.weight_name}_inclusive" - self.produces.add(self.weight_name_incl) + self.known_process_ids = process_ids stitched_normalization_weights = normalization_weights.derive( "stitched_normalization_weights", cls_dict={ "weight_name": "normalization_weight", - "get_xsecs_from_inclusive_dataset": True, + "get_xsecs_from_inclusive_datasets": True, "allow_stitching": True, }, ) @@ -402,6 +574,6 @@ def normalization_weights_init(self: Producer, **kwargs) -> None: stitched_normalization_weights_brs_from_processes = stitched_normalization_weights.derive( "stitched_normalization_weights_brs_from_processes", cls_dict={ - "get_xsecs_from_inclusive_dataset": False, + "get_xsecs_from_inclusive_datasets": False, }, ) diff --git a/columnflow/production/util.py b/columnflow/production/util.py index 5c3df8fa0..938876282 100644 --- a/columnflow/production/util.py +++ b/columnflow/production/util.py @@ -3,9 +3,10 @@ """ General producers that might be utilized in various places. """ + from __future__ import annotations -from functools import partial +import functools from columnflow.types import Iterable, Sequence, Union from columnflow.production import Producer, producer @@ -13,7 +14,6 @@ from columnflow.columnar_util import attach_coffea_behavior as attach_coffea_behavior_fn ak = maybe_import("awkward") -coffea = maybe_import("coffea") @producer(call_force=True) @@ -47,11 +47,14 @@ def attach_coffea_behavior( # general awkward array functions # -def ak_extract_fields(arr: ak.Array, fields: list[str], **kwargs): +def ak_extract_fields(arr: ak.Array, fields: list[str], optional_fields: list[str] | None = None, **kwargs): """ Build an array containing only certain `fields` of an input array `arr`, preserving behaviors. """ + if optional_fields is None: + optional_fields = [] + # reattach behavior if "behavior" not in kwargs: kwargs["behavior"] = arr.behavior @@ -60,6 +63,10 @@ def ak_extract_fields(arr: ak.Array, fields: list[str], **kwargs): { field: getattr(arr, field) for field in fields + } | { + field: getattr(arr, field) + for field in optional_fields + if field in arr.fields }, **kwargs, ) @@ -69,15 +76,21 @@ def ak_extract_fields(arr: ak.Array, fields: list[str], **kwargs): # functions for operating on lorentz vectors # -_lv_base = partial(ak_extract_fields, behavior=coffea.nanoevents.methods.nanoaod.behavior) +def _lv_base(*args, **kwargs): + # scoped partial to defer coffea import + import coffea.nanoevents + import coffea.nanoevents.methods.nanoaod + kwargs["behavior"] = coffea.nanoevents.methods.nanoaod.behavior + return ak_extract_fields(*args, **kwargs) + -lv_xyzt = partial(_lv_base, fields=["x", "y", "z", "t"], with_name="LorentzVector") +lv_xyzt = functools.partial(_lv_base, fields=["x", "y", "z", "t"], with_name="LorentzVector") lv_xyzt.__doc__ = """Construct a `LorentzVectorArray` from an input array.""" -lv_mass = partial(_lv_base, fields=["pt", "eta", "phi", "mass"], with_name="PtEtaPhiMLorentzVector") +lv_mass = functools.partial(_lv_base, fields=["pt", "eta", "phi", "mass"], with_name="PtEtaPhiMLorentzVector") lv_mass.__doc__ = """Construct a `PtEtaPhiMLorentzVectorArray` from an input array.""" -lv_energy = partial(_lv_base, fields=["pt", "eta", "phi", "energy"], with_name="PtEtaPhiELorentzVector") +lv_energy = functools.partial(_lv_base, fields=["pt", "eta", "phi", "energy"], with_name="PtEtaPhiELorentzVector") lv_energy.__doc__ = """Construct a `PtEtaPhiELorentzVectorArray` from an input array.""" diff --git a/columnflow/reduction/__init__.py b/columnflow/reduction/__init__.py index c35975c2e..e58c3ba61 100644 --- a/columnflow/reduction/__init__.py +++ b/columnflow/reduction/__init__.py @@ -8,18 +8,22 @@ import inspect -from columnflow.types import Callable +from columnflow.calibration import TaskArrayFunctionWithCalibratorRequirements from columnflow.util import DerivableMeta -from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, Sequence -class Reducer(TaskArrayFunction): +class Reducer(TaskArrayFunctionWithCalibratorRequirements): """ Base class for all reducers. """ exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def reducer( cls, @@ -27,6 +31,7 @@ def reducer( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_calibrators: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -45,6 +50,7 @@ def reducer( for real data. :param data_only: Boolean flag indicating that this reducer should only run on real data and skipped for Monte Carlo simulation. + :param require_calibrators: Sequence of names of calibrators to add to the requirements. :return: New reducer subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -54,6 +60,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_calibrators": require_calibrators, } # get the module name diff --git a/columnflow/reduction/util.py b/columnflow/reduction/util.py index 4133f7a31..6dd5c7637 100644 --- a/columnflow/reduction/util.py +++ b/columnflow/reduction/util.py @@ -89,16 +89,21 @@ def create_collections_from_masks( # add collections for dst_name in dst_names: - object_mask = object_masks[src_name, dst_name] - dst_collection = events[src_name][object_mask] + object_mask = ak.drop_none(object_masks[src_name, dst_name]) + try: + dst_collection = events[src_name][object_mask] + except ValueError as e: + # check f the object mask refers to an option type + mask_type = getattr(getattr(ak.type(object_mask), "content", None), "cotent", None) + if isinstance(mask_type, ak.types.OptionType): + msg = ( + f"object mask to create dst collection '{dst_name}' from src collection '{src_name}' uses an " + f"option type '{object_mask.typestr}' which is not supported; please adjust your mask to not " + "contain missing values (most likely by using ak.drop_none() in your event selection)" + ) + raise ValueError(msg) from e + # no further custom handling, re-raise + raise e events = set_ak_column(events, dst_name, dst_collection) return events - - -def masked_sorted_indices(mask: ak.Array, sort_var: ak.Array, ascending: bool = False) -> ak.Array: - """ - Helper function to obtain the correct indices of an object mask - """ - indices = ak.argsort(sort_var, axis=-1, ascending=ascending) - return indices[mask[indices]] diff --git a/columnflow/selection/__init__.py b/columnflow/selection/__init__.py index 5f0af3ce7..1f4368fc0 100644 --- a/columnflow/selection/__init__.py +++ b/columnflow/selection/__init__.py @@ -12,9 +12,9 @@ import law import order as od -from columnflow.types import Callable, T +from columnflow.calibration import TaskArrayFunctionWithCalibratorRequirements from columnflow.util import maybe_import, DotDict, DerivableMeta -from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, T, Sequence ak = maybe_import("awkward") np = maybe_import("numpy") @@ -22,13 +22,17 @@ logger = law.logger.get_logger(__name__) -class Selector(TaskArrayFunction): +class Selector(TaskArrayFunctionWithCalibratorRequirements): """ Base class for all selectors. """ exposed = False + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + def __init__(self: Selector, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -44,25 +48,26 @@ def selector( bases=(), mc_only: bool = False, data_only: bool = False, + require_calibrators: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ - Decorator for creating a new :py:class:`~.Selector` subclass with additional, optional - *bases* and attaching the decorated function to it as ``call_func``. + Decorator for creating a new :py:class:`~.Selector` subclass with additional, optional *bases* and attaching the + decorated function to it as ``call_func``. - When *mc_only* (*data_only*) is *True*, the selector is skipped and not considered by - other calibrators, selectors and producers in case they are evaluated on a - :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` - (``is_data``) attribute is *False*. + When *mc_only* (*data_only*) is *True*, the selector is skipped and not considered by other calibrators, + selectors and producers in case they are evaluated on a :py:class:`order.Dataset` (using the + :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Selector` class. :param bases: Additional bases for the new :py:class:`Selector`. - :param mc_only: Boolean flag indicating that this :py:class:`Selector` should only run on - Monte Carlo simulation and skipped for real data. - :param data_only: Boolean flag indicating that this :py:class:`Selector` should only run on - real data and skipped for Monte Carlo simulation. + :param mc_only: Boolean flag indicating that this :py:class:`Selector` should only run on Monte Carlo simulation + and skipped for real data. + :param data_only: Boolean flag indicating that this :py:class:`Selector` should only run on real data and + skipped for Monte Carlo simulation. + :param require_calibrators: Sequence of names of calibrators to add to the requirements. :return: New :py:class:`Selector` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -72,6 +77,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_calibrators": require_calibrators, } # get the module name diff --git a/columnflow/selection/cms/json_filter.py b/columnflow/selection/cms/json_filter.py index 2b750a563..6eddb84d1 100644 --- a/columnflow/selection/cms/json_filter.py +++ b/columnflow/selection/cms/json_filter.py @@ -14,8 +14,6 @@ ak = maybe_import("awkward") np = maybe_import("numpy") -sp = maybe_import("scipy") -maybe_import("scipy.sparse") def get_lumi_file_default(self, external_files: DotDict) -> str: @@ -124,6 +122,8 @@ def json_filter_setup( :param inputs: Additional inputs, currently not used :param reader_targets: Additional targets, currently not used """ + import scipy.sparse + bundle = reqs["external_files"] # import the correction sets from the external file @@ -134,7 +134,7 @@ def json_filter_setup( max_run = max(map(int, json.keys())) # build lookup table - self.run_ls_lookup = sp.sparse.lil_matrix((max_run + 1, max_ls + 1), dtype=bool) + self.run_ls_lookup = scipy.sparse.lil_matrix((max_run + 1, max_ls + 1), dtype=bool) for run, ls_ranges in json.items(): run = int(run) for ls_range in ls_ranges: diff --git a/columnflow/selection/cmsGhent/lepton_mva_cuts.py b/columnflow/selection/cmsGhent/lepton_mva_cuts.py index a1e238530..df3772cec 100644 --- a/columnflow/selection/cmsGhent/lepton_mva_cuts.py +++ b/columnflow/selection/cmsGhent/lepton_mva_cuts.py @@ -13,7 +13,7 @@ from columnflow.columnar_util import set_ak_column, optional_column # from columnflow.production.util import attach_coffea_behavior from columnflow.selection import Selector, SelectionResult, selector -from columnflow.reduction.util import masked_sorted_indices +from columnflow.columnar_util import sorted_indices_from_mask ak = maybe_import("awkward") @@ -83,7 +83,7 @@ def lepton_mva_object( steps={}, objects={ lep: - {lep: masked_sorted_indices(events[lep][working_point[lep]], events[lep].pt)} + {lep: sorted_indices_from_mask(events[lep][working_point[lep]], events[lep].pt)} for lep in ["Muon", "Electron"] }, ) diff --git a/columnflow/selection/empty.py b/columnflow/selection/empty.py index 0be227402..563846e47 100644 --- a/columnflow/selection/empty.py +++ b/columnflow/selection/empty.py @@ -61,7 +61,7 @@ def empty( events = set_ak_column(events, "category_ids", category_ids) # empty selection result with a trivial event mask - results = SelectionResult(event=ak.Array(np.ones(len(events), dtype=np.bool_))) + results = SelectionResult(event=ak.Array(np.ones(len(events), dtype=bool))) # increment stats weight_map = { diff --git a/columnflow/tasks/calibration.py b/columnflow/tasks/calibration.py index 5842906e8..8b58be61f 100644 --- a/columnflow/tasks/calibration.py +++ b/columnflow/tasks/calibration.py @@ -18,7 +18,6 @@ class _CalibrateEvents( - # ParamsCacheMixin, CalibratorMixin, ChunkedIOMixin, law.LocalWorkflow, @@ -28,8 +27,6 @@ class _CalibrateEvents( Base classes for :py:class:`CalibrateEvents`. """ - # cache_param_sep = ["calibrator"] - class CalibrateEvents(_CalibrateEvents): """ @@ -63,9 +60,7 @@ def workflow_requires(self) -> dict: reqs["lfns"] = self.reqs.GetDatasetLFNs.req(self) # add calibrator dependent requirements - reqs["calibrator"] = law.util.make_unique(law.util.flatten( - self.calibrator_inst.run_requires(task=self), - )) + reqs["calibrator"] = law.util.make_unique(law.util.flatten(self.calibrator_inst.run_requires(task=self))) return reqs @@ -177,7 +172,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["columns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["columns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) diff --git a/columnflow/tasks/cms/external.py b/columnflow/tasks/cms/external.py index 03eb98220..89d527244 100644 --- a/columnflow/tasks/cms/external.py +++ b/columnflow/tasks/cms/external.py @@ -6,6 +6,11 @@ from __future__ import annotations +__all__ = [] + +import os +import glob + import luigi import law @@ -20,6 +25,8 @@ class CreatePileupWeights(ConfigTask): + task_namespace = "cf.cms" + single_config = True data_mode = luigi.ChoiceParameter( @@ -60,7 +67,7 @@ def run(self): # since this tasks uses stage-in into and stage-out from the sandbox, # prepare external files with the staged-in inputs - externals.get_files(self.input()) + externals.get_files_collection(self.input()) # read the mc profile mc_profile = self.read_mc_profile_from_cfg(externals.files.pu.mc_profile) @@ -162,3 +169,73 @@ def normalize_values(cls, values: Sequence[float]) -> list[float]: enable=["configs", "skip_configs"], attributes={"version": None}, ) + + +class CheckCATUpdates(ConfigTask, law.tasks.RunOnceTask): + """ + CMS specific task that checks for updates in the metadata managed and stored by the CAT group. See + https://cms-analysis-corrections.docs.cern.ch for more info. + + To function correctly, this task requires an auxiliary entry ``cat_info`` in the analysis config, pointing to a + :py:class:`columnflow.cms_util.CATInfo` instance that defines the era information and the current POG correction + timestamps. The task will then check in the CAT metadata structure if newer timestamps are available. + """ + + task_namespace = "cf.cms" + + version = None + + single_config = False + + def run(self): + # helpers to convert date strings to tuples for numeric comparisons + decode_date_str = lambda s: tuple(map(int, s.split("-"))) + + # loop through configs + for config_inst in self.config_insts: + with self.publish_step( + f"checking CAT metadata updates for config '{law.util.colored(config_inst.name, style='bright')}' in " + f"{config_inst.x.cat_info.metadata_root}", + ): + newest_dates = {} + updated_any = False + for pog, date_str in config_inst.x.cat_info.snapshot.items(): + if not date_str: + continue + + # get all versions in the cat directory, split by date numbers + pog_era_dir = os.path.join( + config_inst.x.cat_info.metadata_root, + pog.upper(), + config_inst.x.cat_info.get_era_directory(pog), + ) + if not os.path.isdir(pog_era_dir): + self.logger.warning(f"CAT metadata directory '{pog_era_dir}' does not exist, skipping") + continue + dates = [ + os.path.basename(path) + for path in glob.glob(os.path.join(pog_era_dir, "*-*-*")) + ] + if not dates: + raise ValueError(f"no CAT snapshots found in '{pog_era_dir}'") + + # compare with current date + latest_date_str = max(dates, key=decode_date_str) + if date_str == "latest" or decode_date_str(date_str) < decode_date_str(latest_date_str): + newest_dates[pog] = latest_date_str + updated_any = True + self.publish_message( + f"found newer {law.util.colored(pog.upper(), color='cyan')} snapshot: {date_str} -> " + f"{latest_date_str} ({os.path.join(pog_era_dir, latest_date_str)})", + ) + else: + newest_dates[pog] = date_str + + # print a new CATSnapshot line that can be copy-pasted into the config + if updated_any: + args_str = ", ".join(f"{pog}=\"{date_str}\"" for pog, date_str in newest_dates.items() if date_str) + self.publish_message( + f"{law.util.colored('new CATSnapshot line ->', style='bright')} CATSnapshot({args_str})\n", + ) + else: + self.publish_message("no updates found\n") diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index f0bfae242..abf8ec2ec 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -6,128 +6,216 @@ from __future__ import annotations +import collections + import law import order as od from columnflow.tasks.framework.base import AnalysisTask, wrapper_factory from columnflow.tasks.framework.inference import SerializeInferenceModelBase from columnflow.tasks.histograms import MergeHistograms +from columnflow.inference.cms.datacard import DatacardWriter +from columnflow.types import TYPE_CHECKING + +if TYPE_CHECKING: + from columnflow.inference.cms.datacard import DatacardHists, ShiftHists class CreateDatacards(SerializeInferenceModelBase): resolution_task_cls = MergeHistograms + datacard_writer_cls = DatacardWriter def output(self): - hooks_repr = self.hist_hooks_repr - cat_obj = self.branch_data - - def basename(name: str, ext: str) -> str: + def basename(cat_obj, name, ext): parts = [name, cat_obj.name] - if hooks_repr: + if (hooks_repr := self.hist_hooks_repr): parts.append(f"hooks_{hooks_repr}") if cat_obj.postfix is not None: parts.append(cat_obj.postfix) return f"{'__'.join(map(str, parts))}.{ext}" - return { - "card": self.target(basename("datacard", "txt")), - "shapes": self.target(basename("shapes", "root")), - } + return law.SiblingFileCollection({ + cat_obj.name: { + "card": self.target(basename(cat_obj, "datacard", "txt")), + "shapes": self.target(basename(cat_obj, "shapes", "root")), + } + for cat_obj in self.inference_model_inst.categories + }) @law.decorator.log @law.decorator.safe_output def run(self): import hist - from columnflow.inference.cms.datacard import DatacardHists, ShiftHists, DatacardWriter - # prepare inputs + # prepare inputs and outputs inputs = self.input() + outputs = self.output() - # loop over all configs required by the datacard category and gather histograms - cat_obj = self.branch_data - datacard_hists: DatacardHists = {cat_obj.name: {}} - - # step 1: gather histograms per process for each config - input_hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} - for config_inst in self.config_insts: - # skip configs that are not required - if not cat_obj.config_data.get(config_inst.name): - continue - # load them - input_hists[config_inst] = self.load_process_hists(inputs, cat_obj, config_inst) - - # step 2: apply hist hooks - input_hists = self.invoke_hist_hooks(input_hists) - - # step 3: transform to nested histogram as expected by the datacard writer - for config_inst in input_hists.keys(): - config_data = cat_obj.config_data.get(config_inst.name) - - # determine leaf categories to gather - category_inst = config_inst.get_category(config_data.category) - leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] - - # start the transformation - proc_objs = list(cat_obj.processes) - if config_data.data_datasets and not cat_obj.data_from_processes: - proc_objs.append(self.inference_model_inst.process_spec(name="data")) - for proc_obj in proc_objs: - # get the corresponding process instance - if proc_obj.name == "data": - process_inst = config_inst.get_process("data") - elif config_inst.name in proc_obj.config_data: - process_inst = config_inst.get_process(proc_obj.config_data[config_inst.name].process) - else: - # skip process objects that rely on data from a different config - continue - - # extract the histogram for the process - if not (h_proc := input_hists[config_inst].get(process_inst, None)): - self.logger.warning( - f"found no histogram to model datacard process '{proc_obj.name}', please check your " - f"inference model '{self.inference_model}'", - ) + # overall strategy to load data efficiently and to write datacards: + # 1) determine which variables have to be loaded for which config (stored in a map), then loop over variables + # 2) load all histograms per config + # 3) start datacard writing by looping over datacard categories that use the specific variable + # 4) apply hist hooks + # 5) prepare histogram in the nested format expected by the datacard writer and write the card + + # step 1: gather variable info, then loop + variable_data = collections.defaultdict(set) + for config_inst, data in self.combined_config_data.items(): + for variable in data["variables"]: + variable_data[variable].add(config_inst) + + for variable, variable_config_insts in variable_data.items(): + # step 2 + input_hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} + for config_inst in variable_config_insts: + data = self.combined_config_data[config_inst] + input_hists[config_inst] = self.load_process_hists( + config_inst, + { + dataset_name: list(data["mc_datasets"][dataset_name]["proc_names"]) + for dataset_name in data["mc_datasets"] + } | { + dataset_name: ["data"] + for dataset_name in data["data_datasets"] + }, + variable, + inputs[config_inst.name], + ) + + # step 3 + for cat_obj in self.inference_model_inst.categories: + # skip if the variable is not used in this category + if not any(d.variable == variable for d in cat_obj.config_data.values()): continue - - # select relevant categories - h_proc = h_proc[{ - "category": [ - hist.loc(c.name) - for c in leaf_category_insts - if c.name in h_proc.axes["category"] - ], - }][{"category": sum}] - - # create the nominal hist - datacard_hists[cat_obj.name].setdefault(proc_obj.name, {}).setdefault(config_inst.name, {}) - shift_hists: ShiftHists = datacard_hists[cat_obj.name][proc_obj.name][config_inst.name] - shift_hists["nominal"] = h_proc[{ - - "shift": hist.loc(config_inst.get_shift("nominal").name), - }] - - # no additional shifts need to be created for data - if proc_obj.name == "data": + # cross check that all configs use the same variable (should already be guarded by the model validation) + assert all(d.variable == variable for d in cat_obj.config_data.values()) + + # check which configs contribute to this category + config_insts = [ + config_inst for config_inst in self.config_insts + if config_inst.name in cat_obj.config_data + ] + if not config_insts: continue + self.publish_message(f"processing inputs for category '{cat_obj.name}' with variable '{variable}'") + + # get config-based category name + category = cat_obj.config_data[config_insts[0].name].category + + # step 4: hist hooks + _input_hists = self.invoke_hist_hooks( + {config_inst: input_hists[config_inst].copy() for config_inst in config_insts}, + hook_kwargs={"variable_name": variable, "category_name": category}, + ) + + # step 5: transform to datacard format + datacard_hists: DatacardHists = {cat_obj.name: {}} + for config_inst in _input_hists.keys(): + config_data = cat_obj.config_data.get(config_inst.name) + + # determine leaf categories to gather + category_inst = config_inst.get_category(category) + leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] + + # eagerly remove data histograms in case data is supposed to be faked from mc processes + if cat_obj.data_from_processes: + for process_inst in list(_input_hists[config_inst]): + if process_inst.is_data: + del _input_hists[config_inst][process_inst] + + # start the transformation + proc_objs = list(cat_obj.processes) + if config_data.data_datasets and not cat_obj.data_from_processes: + proc_objs.append(self.inference_model_inst.process_spec(name="data")) + for proc_obj in proc_objs: + # skip the process objects if it does not contribute to this config_inst + if config_inst.name not in proc_obj.config_data and proc_obj.name != "data": + continue + + # get all process instances (keys in _input_hists) to be combined + if proc_obj.is_dynamic: + if not (process_name := proc_obj.config_data[config_inst.name].get("process", None)): + raise ValueError( + f"dynamic datacard process object misses 'process' entry in config data for " + f"'{config_inst.name}': {proc_obj}", + ) + process_inst = config_inst.get_process(process_name) + else: + process_inst = config_inst.get_process( + proc_obj.name + if proc_obj.name == "data" + else proc_obj.config_data[config_inst.name].process, + ) + + # extract the histogram for the process + # (removed from hists to eagerly cleanup memory) + h_proc = _input_hists[config_inst].get(process_inst, None) + if h_proc is None: + self.logger.error( + f"found no histogram to model datacard process '{proc_obj.name}', please check your " + f"inference model '{self.inference_model}'", + ) + continue + + # select relevant categories + h_proc = h_proc[{ + "category": [ + hist.loc(c.name) + for c in leaf_category_insts + if c.name in h_proc.axes["category"] + ], + }] + h_proc = h_proc[{"category": sum}] - # create histograms per shift - for param_obj in proc_obj.parameters: - # skip the parameter when varied hists are not needed - if not self.inference_model_inst.require_shapes_for_parameter(param_obj): - continue - # store the varied hists - shift_source = param_obj.config_data[config_inst.name].shift_source - for d in ["up", "down"]: - shift_hists[(param_obj.name, d)] = h_proc[{ - "shift": hist.loc(config_inst.get_shift(f"{shift_source}_{d}").name), + # create the nominal hist + datacard_hists[cat_obj.name].setdefault(proc_obj.name, {}).setdefault(config_inst.name, {}) + shift_hists: ShiftHists = datacard_hists[cat_obj.name][proc_obj.name][config_inst.name] + shift_hists["nominal"] = h_proc[{ + "shift": hist.loc(config_inst.get_shift("nominal").name), }] - # forward objects to the datacard writer - outputs = self.output() - writer = DatacardWriter(self.inference_model_inst, datacard_hists) - with outputs["card"].localize("w") as tmp_card, outputs["shapes"].localize("w") as tmp_shapes: - writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outputs["shapes"].basename) + # no additional shifts need to be created for data + if proc_obj.name == "data": + continue + + # create histograms per shape shift + for param_obj in proc_obj.parameters: + # skip the parameter when varied hists are not needed + need_shapes = ( + (param_obj.type.is_shape and not param_obj.transformations.any_from_rate) or + (param_obj.type.is_rate and param_obj.transformations.any_from_shape) + ) + if not need_shapes: + continue + # store the varied hists + shift_source = ( + param_obj.config_data[config_inst.name].shift_source + if config_inst.name in param_obj.config_data + else None + ) + for d in ["up", "down"]: + if shift_source and f"{shift_source}_{d}" not in h_proc.axes["shift"]: + raise ValueError( + f"cannot find '{shift_source}_{d}' in shift axis of histogram for process " + f"'{proc_obj.name}' in config '{config_inst.name}' while handling parameter " + f"'{param_obj.name}' in datacard category '{cat_obj.name}', available shifts " + f"are: {list(h_proc.axes['shift'])}", + ) + shift_hists[(param_obj.name, d)] = h_proc[{ + "shift": hist.loc(f"{shift_source}_{d}" if shift_source else "nominal"), + }] + + # forward objects to the datacard writer + outp = outputs[cat_obj.name] + writer = self.datacard_writer_cls(self.inference_model_inst, datacard_hists) + with outp["card"].localize("w") as tmp_card, outp["shapes"].localize("w") as tmp_shapes: + writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outp["shapes"].basename) + self.publish_message(f"datacard written to {outp['card'].abspath}") + + # eager cleanup + del _input_hists + del input_hists CreateDatacardsWrapper = wrapper_factory( diff --git a/columnflow/tasks/cmsGhent/btagefficiency.py b/columnflow/tasks/cmsGhent/btagefficiency.py index 32c07cc8b..333ff05ea 100644 --- a/columnflow/tasks/cmsGhent/btagefficiency.py +++ b/columnflow/tasks/cmsGhent/btagefficiency.py @@ -17,8 +17,11 @@ from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.util import dev_sandbox, dict_add_strict, DotDict, maybe_import +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") + +if TYPE_CHECKING: + hist = maybe_import("hist") class BTagEfficiencyBase: diff --git a/columnflow/tasks/cmsGhent/selection_hists.py b/columnflow/tasks/cmsGhent/selection_hists.py index 13db8a6ed..4176a7fa7 100644 --- a/columnflow/tasks/cmsGhent/selection_hists.py +++ b/columnflow/tasks/cmsGhent/selection_hists.py @@ -14,10 +14,10 @@ PlotBase, PlotBase1D, VariablePlotSettingMixin, ProcessPlotSettingMixin, ) -from columnflow.types import Any +from columnflow.types import TYPE_CHECKING, Any - -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") class CustomDefaultVariablesMixin( @@ -189,6 +189,7 @@ def efficiency(cls, selected_counts: hist.Hist, incl: hist.Hist, **kwargs) -> hi @param incl: histogram with event counts before selection @param kwargs: keyword arguments passed to **proportion_confint** """ + import hist from statsmodels.stats.proportion import proportion_confint efficiency = selected_counts / incl.values() eff_sample_size_corr = incl.values() / incl.variances() diff --git a/columnflow/tasks/cmsGhent/trigger_scale_factors.py b/columnflow/tasks/cmsGhent/trigger_scale_factors.py index 412760b0f..006f4c044 100644 --- a/columnflow/tasks/cmsGhent/trigger_scale_factors.py +++ b/columnflow/tasks/cmsGhent/trigger_scale_factors.py @@ -6,7 +6,6 @@ from itertools import product import luigi -from columnflow.types import Any from columnflow.tasks.framework.base import Requirements, ConfigTask from columnflow.tasks.framework.mixins import ( CalibratorClassesMixin, SelectorClassMixin, DatasetsMixin, @@ -19,10 +18,11 @@ import columnflow.production.cmsGhent.trigger.util as util from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.util import dev_sandbox, dict_add_strict, maybe_import - +from columnflow.types import TYPE_CHECKING, Any np = maybe_import("numpy") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") logger = law.logger.get_logger(__name__) diff --git a/columnflow/tasks/cutflow.py b/columnflow/tasks/cutflow.py index f134edc67..6a515c07e 100644 --- a/columnflow/tasks/cutflow.py +++ b/columnflow/tasks/cutflow.py @@ -187,6 +187,12 @@ def run(self): for cat in self.config_inst.get_leaf_categories() } + # get IDs and names of all leaf categories + leaf_category_map = { + cat.id: cat.name + for cat in self.config_inst.get_leaf_categories() + } + # create a temp dir for saving intermediate files tmp_dir = law.LocalDirectoryTarget(is_tmp=True) tmp_dir.touch() diff --git a/columnflow/tasks/external.py b/columnflow/tasks/external.py index 8f37ede77..0d901bc54 100644 --- a/columnflow/tasks/external.py +++ b/columnflow/tasks/external.py @@ -17,6 +17,7 @@ import law import order as od +from columnflow import env_is_local from columnflow.tasks.framework.base import AnalysisTask, ConfigTask, DatasetTask, wrapper_factory from columnflow.tasks.framework.parameters import user_parameter_inst from columnflow.tasks.framework.decorators import only_local_env @@ -404,7 +405,7 @@ class ExternalFile: """ location: str - subpaths: dict[str, str] = field(default_factory=str) + subpaths: dict[str, str] = field(default_factory=dict) version: str = "v1" def __str__(self) -> str: @@ -428,6 +429,11 @@ def new(cls, resource: ExternalFile | str | tuple[str] | tuple[str, str]) -> Ext return cls(location=resource[0], version=resource[1]) raise ValueError(f"invalid resource type and format: {resource}") + def __getattr__(self, attr: str) -> str: + if attr in self.subpaths: + return self.subpaths[attr] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") + class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): """ @@ -436,8 +442,8 @@ class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): This task is intended to download source files for other tasks, such as files containing corrections for objects, the "golden" json files, source files for the calculation of pileup weights, and others. - All information about the relevant external files is extracted from the given ``config_inst``, which must contain - the keyword ``external_files`` in the auxiliary information. This can look like this: + All information about the relevant external files is extracted from the given ``config_inst``, which must contain an + auxiliary field ``external_files`` like the following (all entries are optional and user-defined): .. code-block:: python @@ -452,7 +458,7 @@ class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): "electron_sf": ExternalFile(f"{SOURCE_URL}/POG/EGM/{year}{corr_postfix}_UL/electron.json.gz", version="v1"), }) - The entries in this DotDict should be :py:class:`ExternalFile` instances. + All entries should be :py:class:`ExternalFile` instances. """ single_config = True @@ -461,6 +467,11 @@ class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): default=5, description="number of replicas to generate; default: 5", ) + recreate = luigi.BoolParameter( + default=False, + significant=False, + description="when True, forces the recreation of the bundle even if it exists; default: False", + ) user = user_parameter_inst version = None @@ -477,18 +488,10 @@ def __init__(self, *args, **kwargs): self._file_names = None # cached dict for lazy access to files in fetched bundle - self.files_dir = None - self._files = None + self._files_collection = None @classmethod def create_unique_basename(cls, path: str | ExternalFile) -> str | dict[str, str]: - """ - Create a unique basename for a given path. When *path* is an :py:class:`ExternalFile` with one or more subpaths - defined, a dictionary mapping subpaths to unique basenames is returned. - - :param path: path or external file object. - :return: Unique basename(s). - """ if isinstance(path, str): return f"{law.util.create_hash(path)}_{os.path.basename(path)}" @@ -503,11 +506,6 @@ def create_unique_basename(cls, path: str | ExternalFile) -> str | dict[str, str @property def files_hash(self) -> str: - """ - Create a hash based on all external files. - - :return: Hash based on the flattened list of external files in the current config instance. - """ if self._files_hash is None: # take the external files and flatten them into a deterministic order, then hash def deterministic_flatten(d): @@ -522,117 +520,156 @@ def deterministic_flatten(d): @property def file_names(self) -> DotDict: - """ - Create a unique basename for each external file. - - :return: DotDict of same shape as ``external_files`` DotDict with unique basenames. - """ if self._file_names is None: self._file_names = law.util.map_struct(self.create_unique_basename, self.ext_files) return self._file_names - def get_files(self, output=None): - if self._files is None: + def get_files_collection(self, output=None) -> law.SiblingFileCollection: + if self._files_collection is None: # get the output if not output: output = self.output() - if not output.exists(): + if not output["local_files"].exists(): raise Exception( - f"accessing external files from the bundle requires the output of {self} to " - "exist, but it appears to be missing", + f"accessing external files from the bundle requires the output of {self} to exist, but it appears " + "to be missing", ) - if isinstance(output, law.FileCollection): - output = output.random_target() - self.files_dir = law.LocalDirectoryTarget(is_tmp=True) - output.load(self.files_dir, formatter="tar") + self._files_collection = output["local_files"] - # resolve basenames in the bundle directory and map to local targets - def resolve_basename(unique_basename): - return self.files_dir.child(unique_basename) + return self._files_collection - self._files = law.util.map_struct(resolve_basename, self.file_names) - - return self._files + @property + def files(self) -> DotDict: + return self.get_files_collection().targets @property - def files(self): - return self.get_files() + def files_dir(self) -> law.LocalDirectoryTarget: + return self.get_files_collection().dir def single_output(self): # required by law.tasks.TransferLocalFile return self.target(f"externals_{self.files_hash}.tgz") - @only_local_env + def output(self): + def local_target(basename): + path = os.path.join(f"externals_{self.files_hash}", basename) + is_dir = "." not in basename # simple heuristic, but type actually checked after unpacking below + return self.local_target(path, dir=is_dir) + + return DotDict( + bundle=super().output(), + local_files=law.SiblingFileCollection(law.util.map_struct(local_target, self.file_names)), + ) + + def trace_transfer_output(self, output): + return output["bundle"] + @law.decorator.notify @law.decorator.log @law.decorator.safe_output def run(self): - # create a tmp dir to work in - tmp_dir = law.LocalDirectoryTarget(is_tmp=True) - tmp_dir.touch() - - # create a scratch directory for temporary downloads that will not be bundled - scratch_dir = tmp_dir.child("scratch", type="d") - scratch_dir.touch() - - # progress callback - progress = self.create_progress_callback(len(law.util.flatten(self.ext_files))) - - # helper to fetch a single src to dst - def fetch(src, dst): - if src.startswith(("http://", "https://")): - # download via wget - wget(src, dst) - elif os.path.isfile(src): - # copy local file - shutil.copy2(src, dst) - elif os.path.isdir(src): - # copy local dir - shutil.copytree(src, dst) - else: - raise NotImplementedError(f"fetching {src} is not supported") - - # helper function to fetch generic files - def fetch_file(ext_file, counter=[0]): - if ext_file.subpaths: - # copy to scratch dir first in case a subpath is requested - basename = self.create_unique_basename(ext_file.location) - scratch_dst = os.path.join(scratch_dir.abspath, basename) - fetch(ext_file.location, scratch_dst) - # when not a directory, assume the file is an archive and unpack it - if not os.path.isdir(scratch_dst): - arc_dir = scratch_dir.child(basename.split(".")[0] + "_unpacked", type="d") - self.publish_message(f"unpacking {scratch_dst}") - law.LocalFileTarget(scratch_dst).load(arc_dir) - scratch_src = arc_dir.abspath + outputs = self.output() + + # remove the bundle if recreating + if outputs["bundle"].exists() and self.recreate: + outputs["bundle"].remove() + + # bundle only if needed + if not outputs["bundle"].exists(): + if not env_is_local: + raise RuntimeError( + f"the output bundle {outputs['bundle'].basename} is missing, but cannot be created in non-local " + "environments", + ) + + # create a tmp dir to work in + tmp_dir = law.LocalDirectoryTarget(is_tmp=True) + tmp_dir.touch() + + # create a scratch directory for temporary downloads that will not be bundled + scratch_dir = tmp_dir.child("scratch", type="d") + scratch_dir.touch() + + # progress callback + progress = self.create_progress_callback(len(law.util.flatten(self.ext_files))) + + # helper to fetch a single src to dst + def fetch(src, dst): + if src.startswith(("http://", "https://")): + # download via wget + wget(src, dst) + elif os.path.isfile(src): + # copy local file + shutil.copy2(src, dst) + elif os.path.isdir(src): + # copy local dir + shutil.copytree(src, dst) else: - scratch_src = scratch_dst - # copy all subpaths - basenames = self.create_unique_basename(ext_file) - for name, subpath in ext_file.subpaths.items(): - fetch(os.path.join(scratch_src, subpath), os.path.join(tmp_dir.abspath, basenames[name])) - else: - # copy directly to the bundle dir - src = ext_file.location - dst = os.path.join(tmp_dir.abspath, self.create_unique_basename(ext_file.location)) - fetch(src, dst) - # log - self.publish_message(f"fetched {ext_file}") - progress(counter[0]) - counter[0] += 1 - - # fetch all files and cleanup scratch dir - law.util.map_struct(fetch_file, self.ext_files) - scratch_dir.remove() - - # create the bundle - tmp = law.LocalFileTarget(is_tmp="tgz") - tmp.dump(tmp_dir, formatter="tar") - - # log the file size - bundle_size = law.util.human_bytes(tmp.stat().st_size, fmt=True) - self.publish_message(f"bundle size is {bundle_size}") - - # transfer the result - self.transfer(tmp) + err = f"cannot fetch {src}" + if src.startswith("/") and os.path.isdir("/".join(src.split("/", 2)[:2])): + err += ", file or directory does not exist" + else: + err += ", resource type is not supported" + raise NotImplementedError(err) + + # helper function to fetch generic files + def fetch_file(ext_file, counter=[0]): + if ext_file.subpaths: + # copy to scratch dir first in case a subpath is requested + basename = self.create_unique_basename(ext_file.location) + scratch_dst = os.path.join(scratch_dir.abspath, basename) + fetch(ext_file.location, scratch_dst) + # when not a directory, assume the file is an archive and unpack it + if not os.path.isdir(scratch_dst): + arc_dir = scratch_dir.child(basename.split(".")[0] + "_unpacked", type="d") + self.publish_message(f"unpacking {scratch_dst}") + law.LocalFileTarget(scratch_dst).load(arc_dir) + scratch_src = arc_dir.abspath + else: + scratch_src = scratch_dst + # copy all subpaths + basenames = self.create_unique_basename(ext_file) + for name, subpath in ext_file.subpaths.items(): + fetch(os.path.join(scratch_src, subpath), os.path.join(tmp_dir.abspath, basenames[name])) + else: + # copy directly to the bundle dir + src = ext_file.location + dst = os.path.join(tmp_dir.abspath, self.create_unique_basename(ext_file.location)) + fetch(src, dst) + # log + self.publish_message(f"fetched {ext_file}") + progress(counter[0]) + counter[0] += 1 + + # fetch all files and cleanup scratch dir + law.util.map_struct(fetch_file, self.ext_files) + scratch_dir.remove() + + # create the bundle + tmp = law.LocalFileTarget(is_tmp="tgz") + tmp.dump(tmp_dir, formatter="tar") + + # log the file size + bundle_size = law.util.human_bytes(tmp.stat().st_size, fmt=True) + self.publish_message(f"bundle size is {bundle_size}") + + # transfer the result + self.transfer(tmp, outputs["bundle"]) + + # unpack the bundle to have local files available + with self.publish_step(f"unpacking to {outputs['local_files'].dir.abspath} ..."): + outputs["local_files"].dir.remove() + bundle = outputs["bundle"] + if isinstance(bundle, law.FileCollection): + bundle = bundle.random_target() + bundle.load(outputs["local_files"].dir, formatter="tar") + + # check if unpacked files/directories are described by the correct target class + for target in outputs["local_files"]._flat_target_list: + mismatch = ( + (isinstance(target, law.FileSystemFileTarget) and not os.path.isfile(target.abspath)) or + (isinstance(target, law.FileSystemDirectoryTarget) and not os.path.isdir(target.abspath)) + ) + if mismatch: + raise Exception(f"mismatching file/directory type of unpacked target {target!r}") diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index b6128196f..4315af4e4 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -23,7 +23,7 @@ import order as od from columnflow.columnar_util import mandatory_coffea_columns, Route, ColumnCollection -from columnflow.util import is_regex, prettify, DotDict +from columnflow.util import get_docs_url, is_regex, prettify, DotDict, freeze from columnflow.types import Sequence, Callable, Any, T @@ -38,6 +38,11 @@ default_repr_max_count = law.config.get_expanded_int("analysis", "repr_max_count") default_repr_hash_len = law.config.get_expanded_int("analysis", "repr_hash_len") +# cached and parsed sections of the law config for faster lookup +_cfg_outputs_dict = None +_cfg_versions_dict = None +_cfg_resources_dict = None + # placeholder to denote a default value that is resolved dynamically RESOLVE_DEFAULT = "DEFAULT" @@ -80,6 +85,9 @@ class TaskShifts: local: set[str] = field(default_factory=set) upstream: set[str] = field(default_factory=set) + def __hash__(self) -> int: + return hash((frozenset(self.local), frozenset(self.upstream))) + class BaseTask(law.Task): @@ -127,11 +135,6 @@ class AnalysisTask(BaseTask, law.SandboxTask): exclude_params_branch = {"user"} exclude_params_workflow = {"user", "notify_slack", "notify_mattermost", "notify_custom"} - # cached and parsed sections of the law config for faster lookup - _cfg_outputs_dict = None - _cfg_versions_dict = None - _cfg_resources_dict = None - @classmethod def modify_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().modify_param_values(params) @@ -176,20 +179,26 @@ def req_params(cls, inst: AnalysisTask, **kwargs) -> dict[str, Any]: _prefer_cli = law.util.make_set(kwargs.get("_prefer_cli", [])) | { "version", "workflow", "job_workers", "poll_interval", "walltime", "max_runtime", "retries", "acceptance", "tolerance", "parallel_jobs", "shuffle_jobs", "htcondor_cpus", - "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_pool", "pilot", + "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_pool", "pilot", "remote_claw_sandbox", } kwargs["_prefer_cli"] = _prefer_cli # build the params params = super().req_params(inst, **kwargs) - # when not explicitly set in kwargs and no global value was defined on the cli for the task - # family, evaluate and use the default value + # evaluate and use the default version in case + # - "version" is an actual parameter object of cls, and + # - "version" is not explicitly set in kwargs, and + # - no global value was defined on the cli for the task family, and + # - if cls and inst belong to the same family, they differ in the keys used for the config lookup if ( isinstance(getattr(cls, "version", None), luigi.Parameter) and "version" not in kwargs and not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") and - cls.task_family != law.parser.root_task_cls().task_family + ( + cls.task_family != inst.task_family or + freeze(cls.get_config_lookup_keys(params)) != freeze(inst.get_config_lookup_keys(inst)) + ) ): default_version = cls.get_default_version(inst, params) if default_version and default_version != law.NO_STR: @@ -224,17 +233,19 @@ def _structure_cfg_items(cls, items: list[tuple[str, Any]]) -> dict[str, Any]: d[part] = {"*": d[part]} d = d[part] else: - # assign value to the last nesting level - if part in d and isinstance(d[part], dict): - d[part]["*"] = value - else: + # assign value to the last nesting level, do not overwrite + if part not in d: d[part] = value + elif isinstance(d[part], dict): + d[part]["*"] = value return items_dict @classmethod def _get_cfg_outputs_dict(cls) -> dict[str, Any]: - if cls._cfg_outputs_dict is None and law.config.has_section("outputs"): + global _cfg_outputs_dict + + if _cfg_outputs_dict is None and law.config.has_section("outputs"): # collect config item pairs skip_keys = {"wlcg_file_systems", "lfn_sources"} items = [ @@ -242,26 +253,30 @@ def _get_cfg_outputs_dict(cls) -> dict[str, Any]: for key, value in law.config.items("outputs") if value and key not in skip_keys ] - cls._cfg_outputs_dict = cls._structure_cfg_items(items) + _cfg_outputs_dict = cls._structure_cfg_items(items) - return cls._cfg_outputs_dict + return _cfg_outputs_dict @classmethod def _get_cfg_versions_dict(cls) -> dict[str, Any]: - if cls._cfg_versions_dict is None and law.config.has_section("versions"): + global _cfg_versions_dict + + if _cfg_versions_dict is None and law.config.has_section("versions"): # collect config item pairs items = [ (key, value) for key, value in law.config.items("versions") if value ] - cls._cfg_versions_dict = cls._structure_cfg_items(items) + _cfg_versions_dict = cls._structure_cfg_items(items) - return cls._cfg_versions_dict + return _cfg_versions_dict @classmethod def _get_cfg_resources_dict(cls) -> dict[str, Any]: - if cls._cfg_resources_dict is None and law.config.has_section("resources"): + global _cfg_resources_dict + + if _cfg_resources_dict is None and law.config.has_section("resources"): # helper to split resource values into key-value pairs themselves def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: params = [] @@ -285,9 +300,9 @@ def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: for key, value in law.config.items("resources") if value and not key.startswith("_") ] - cls._cfg_resources_dict = cls._structure_cfg_items(items) + _cfg_resources_dict = cls._structure_cfg_items(items) - return cls._cfg_resources_dict + return _cfg_resources_dict @classmethod def get_default_version(cls, inst: AnalysisTask, params: dict[str, Any]) -> str | None: @@ -354,10 +369,16 @@ def get_config_lookup_keys( else getattr(inst_or_params, "analysis", None) ) if analysis not in {law.NO_STR, None, ""}: - keys["analysis"] = analysis + prefix = "ana" + keys[prefix] = f"{prefix}_{analysis}" # add the task family - keys["task_family"] = cls.task_family + prefix = "task" + keys[prefix] = f"{prefix}_{cls.task_family}" + + # for backwards compatibility, add the task family again without the prefix + # (TODO: this should be removed in the future) + keys[f"{prefix}_compat"] = cls.task_family return keys @@ -375,7 +396,7 @@ def _dfs_key_lookup( return empty_value # the keys to use for the lookup are the flattened values of the keys dict - flat_keys = collections.deque(law.util.flatten(keys.values() if isinstance(keys, dict) else keys)) + flat_keys = law.util.flatten(keys.values() if isinstance(keys, dict) else keys) # start tree traversal using a queue lookup consisting of names and values of tree nodes, # as well as the remaining keys (as a deferred function) to compare for that particular path @@ -389,9 +410,27 @@ def _dfs_key_lookup( # check if the pattern matches any key regex = is_regex(pattern) - while _keys: - key = _keys.popleft() + for i, key in enumerate(_keys): if law.util.multi_match(key, pattern, regex=regex): + # for a limited time, show a deprecation warning when the old task family key was matched + # (old = no "task_" prefix) + # TODO: remove once deprecated + if "task_compat" in keys and key == keys["task_compat"]: + docs_url = get_docs_url( + "user_guide", + "best_practices.html", + anchor="selecting-output-locations", + ) + logger.warning_once( + "dfs_lookup_old_task_key", + f"during the lookup of a pinned location, version or resource value of a '{cls.__name__}' " + f"task, an entry matched based on the task family '{key}' that misses the new 'task_' " + "prefix; please update the pinned entries in your law.cfg file by adding the 'task_' " + f"prefix to entries that contain the task family, e.g. 'task_{key}: VALUE'; support for " + f"missing prefixes will be removed in a future version; see {docs_url} for more info", + ) + # remove the matched key from remaining lookup keys + _keys.pop(i) # when obj is not a dict, we found the value if not isinstance(obj, dict): return obj @@ -1237,6 +1276,17 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params["config_insts"] = [params["config_inst"]] else: if "config_insts" not in params and "configs" in params: + # custom pattern matching + matched_config_names = [] + for pattern in params["configs"]: + matched_config_names.extend( + config_name for config_name in analysis_inst.configs.names() + if law.util.multi_match(config_name, pattern) + ) + matched_config_names = law.util.make_unique(matched_config_names) + if matched_config_names: + params["configs"] = matched_config_names + # load config instances params["config_insts"] = list(map(analysis_inst.get_config, params["configs"])) # resolving of parameters that is required before ArrayFunctions etc. can be initialized @@ -1351,14 +1401,17 @@ def get_known_shifts( resolution_task_cls = None @classmethod - def req_params(cls, inst: law.Task, *args, **kwargs) -> dict[str, Any]: - params = super().req_params(inst, *args, **kwargs) - + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: # manually add known shifts between workflows and branches - if isinstance(inst, law.BaseWorkflow) and inst.__class__ == cls and getattr(inst, "known_shifts", None): - params["known_shifts"] = inst.known_shifts + if ( + "known_shifts" not in kwargs and + isinstance(inst, law.BaseWorkflow) and + inst.__class__ == cls and + getattr(inst, "known_shifts", None) + ): + kwargs["known_shifts"] = inst.known_shifts - return params + return super().req_params(inst, **kwargs) @classmethod def _multi_sequence_repr( @@ -1449,7 +1502,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "config", None) ) if config not in {law.NO_STR, None, ""}: - keys.insert_before("task_family", "config", config) + prefix = "cfg" + keys.insert_before("task", prefix, f"{prefix}_{config}") return keys @@ -1479,7 +1533,7 @@ def __init__(self, *args, **kwargs) -> None: @property def config_repr(self) -> str: - return "__".join(config_inst.name for config_inst in self.config_insts) + return "__".join(config_inst.name for config_inst in sorted(self.config_insts, key=lambda c: c.id)) def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() @@ -1629,7 +1683,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "shift", None) ) if shift not in (law.NO_STR, None, ""): - keys["shift"] = shift + prefix = "shift" + keys[prefix] = f"{prefix}_{shift}" return keys @@ -1688,6 +1743,21 @@ def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any] return params + @classmethod + def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values(params) + + # also add a reference to the info instance when a global shift is defined + if "dataset_inst" in params and "global_shift_inst" in params: + shift_name = params["global_shift_inst"].name + params["dataset_info_inst"] = ( + params["dataset_inst"].get_info(shift_name) + if shift_name in params["dataset_inst"].info + else params["dataset_inst"].get_info("nominal") + ) + + return params + @classmethod def get_known_shifts( cls, @@ -1720,7 +1790,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "dataset", None) ) if dataset not in {law.NO_STR, None, ""}: - keys.insert_before("shift", "dataset", dataset) + prefix = "dataset" + keys.insert_before("shift", prefix, f"{prefix}_{dataset}") return keys diff --git a/columnflow/tasks/framework/histograms.py b/columnflow/tasks/framework/histograms.py index a0ecdaa2e..81a8682b1 100644 --- a/columnflow/tasks/framework/histograms.py +++ b/columnflow/tasks/framework/histograms.py @@ -16,8 +16,10 @@ ) from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms from columnflow.util import dev_sandbox, maybe_import +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") class HistogramsUserBase( @@ -31,7 +33,7 @@ class HistogramsUserBase( CategoriesMixin, VariablesMixin, ): - single_config = True + single_config = False sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -42,29 +44,49 @@ def store_parts(self) -> law.util.InsertableDict: def load_histogram( self, + inputs: dict, + config: str | od.Config, dataset: str | od.Dataset, variable: str | od.Variable, + update_label: bool = True, ) -> hist.Hist: """ Helper function to load the histogram from the input for a given dataset and variable. + :param inputs: The inputs dictionary containing the histograms. + :param config: The config name or instance. :param dataset: The dataset name or instance. :param variable: The variable name or instance. + :param update_label: Whether to update the label of the variable axis in the histogram. + If True, the label will be updated based on the first config instance's variable label. :return: The loaded histogram. """ + if isinstance(dataset, od.Dataset): dataset = dataset.name if isinstance(variable, od.Variable): variable = variable.name - histogram = self.input()[dataset]["collection"][0]["hists"].targets[variable].load(formatter="pickle") + if isinstance(config, od.Config): + config = config.name + histogram = inputs[config][dataset]["collection"][0]["hists"].targets[variable].load(formatter="pickle") + + if update_label: + # get variable label from first config instance + for var_name in variable.split("-"): + label = self.config_insts[0].get_variable(var_name).x_title + ax_names = [ax.name for ax in histogram.axes] + if var_name in ax_names: + # update the label of the variable axis + histogram.axes[var_name].label = label return histogram def slice_histogram( self, histogram: hist.Hist, - processes: str | list[str], - categories: str | list[str], - shifts: str | list[str], + config_inst: od.Config, + processes: str | list[str] | None = None, + categories: str | list[str] | None = None, + shifts: str | list[str] | None = None, reduce_axes: bool = False, ) -> hist.Hist: """ @@ -85,49 +107,56 @@ def slice_histogram( def flatten_nested_list(nested_list): return [item for sublist in nested_list for item in sublist] - # transform into lists if necessary - processes = law.util.make_list(processes) - categories = law.util.make_list(categories) - shifts = law.util.make_list(shifts) - - # get all leaf categories - category_insts = list(map(self.config_inst.get_category, categories)) - leaf_category_insts = set(flatten_nested_list([ - category_inst.get_leaf_categories() or [category_inst] - for category_inst in category_insts - ])) + selection_dict = {} - # get all sub processes - process_insts = list(map(self.config_inst.get_process, processes)) - sub_process_insts = set(flatten_nested_list([ - [sub for sub, _, _ in proc.walk_processes(include_self=True)] - for proc in process_insts - ])) + if processes: + # transform into lists if necessary + processes = law.util.make_list(processes) + # get all sub processes - # get all shift instances - shift_insts = [self.config_inst.get_shift(shift) for shift in shifts] - - # work on a copy - h = histogram.copy() - - # axis selections - h = h[{ - "process": [ + process_insts = list(map(config_inst.get_process, processes)) + sub_process_insts = set(flatten_nested_list([ + [sub for sub, _, _ in proc.walk_processes(include_self=True)] + for proc in process_insts + ])) + selection_dict["process"] = [ hist.loc(p.name) for p in sub_process_insts - if p.name in h.axes["process"] - ], - "category": [ + if p.name in histogram.axes["process"] + ] + if categories: + # transform into lists if necessary + categories = law.util.make_list(categories) + + # get all leaf categories + category_insts = list(map(config_inst.get_category, categories)) + leaf_category_insts = set(flatten_nested_list([ + category_inst.get_leaf_categories() or [category_inst] + for category_inst in category_insts + ])) + selection_dict["category"] = [ hist.loc(c.name) for c in leaf_category_insts - if c.name in h.axes["category"] - ], - "shift": [ + if c.name in histogram.axes["category"] + ] + + if shifts: + # transform into lists if necessary + shifts = law.util.make_list(shifts) + + # get all shift instances + shift_insts = [config_inst.get_shift(shift) for shift in shifts] + selection_dict["shift"] = [ hist.loc(s.name) for s in shift_insts - if s.name in h.axes["shift"] - ], - }] + if s.name in histogram.axes["shift"] + ] + + # work on a copy + h = histogram.copy() + + # axis selections + h = h[selection_dict] if reduce_axes: # axis reductions @@ -154,14 +183,20 @@ def workflow_requires(self): return reqs def requires(self): + datasets = [self.datasets] if self.single_config else self.datasets return { - d: self.reqs.MergeHistograms.req_different_branching( - self, - dataset=d, - branch=-1, - _prefer_cli={"variables"}, - ) - for d in self.datasets + config_inst.name: { + d: self.reqs.MergeHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=d, + branch=-1, + _prefer_cli={"variables"}, + ) + for d in datasets[i] + if config_inst.has_dataset(d) + } + for i, config_inst in enumerate(self.config_insts) } @@ -183,12 +218,17 @@ def workflow_requires(self): return reqs def requires(self): + datasets = [self.datasets] if self.single_config else self.datasets return { - d: self.reqs.MergeShiftedHistograms.req_different_branching( - self, - dataset=d, - branch=-1, - _prefer_cli={"variables"}, - ) - for d in self.datasets + config_inst.name: { + d: self.reqs.MergeShiftedHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=d, + branch=-1, + _prefer_cli={"variables"}, + ) + for d in datasets[i] + } + for i, config_inst in enumerate(self.config_insts) } diff --git a/columnflow/tasks/framework/inference.py b/columnflow/tasks/framework/inference.py index e742ad0d0..508b67e17 100644 --- a/columnflow/tasks/framework/inference.py +++ b/columnflow/tasks/framework/inference.py @@ -6,6 +6,8 @@ from __future__ import annotations +import pickle + import law import order as od @@ -15,11 +17,13 @@ InferenceModelMixin, HistHookMixin, MLModelsMixin, ) from columnflow.tasks.framework.remote import RemoteWorkflow -from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms -from columnflow.util import dev_sandbox, DotDict, maybe_import +from columnflow.tasks.histograms import MergeShiftedHistograms from columnflow.config_util import get_datasets_from_process +from columnflow.util import dev_sandbox, DotDict, maybe_import +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") class SerializeInferenceModelBase( @@ -42,7 +46,6 @@ class SerializeInferenceModelBase( # upstream requirements reqs = Requirements( RemoteWorkflow.reqs, - MergeHistograms=MergeHistograms, MergeShiftedHistograms=MergeShiftedHistograms, ) @@ -106,139 +109,198 @@ def get_data_datasets(cls, config_inst: od.Config, cat_obj: DotDict) -> list[str ) ] + @law.workflow_property(cache=True) + def combined_config_data(self) -> dict[od.ConfigInst, dict[str, dict | set]]: + # prepare data extracted from the inference model + config_data = { + config_inst: { + # all variables used in this config in any datacard category + "variables": set(), + # plain set of names of real data datasets + "data_datasets": set(), + # per mc dataset name, the set of shift sources and the names processes to be extracted from them + "mc_datasets": {}, + } + for config_inst in self.config_insts + } + + # iterate over all model categories + for cat_obj in self.inference_model_inst.categories: + # keep track of per-category information for consistency checks + variables = set() + categories = set() + + # iterate over configs relevant for this category + config_insts = [config_inst for config_inst in self.config_insts if config_inst.name in cat_obj.config_data] + for config_inst in config_insts: + data = config_data[config_inst] + + # variables + data["variables"].add(cat_obj.config_data[config_inst.name].variable) + + # data datasets, but only if + # - data in that category is not faked from mc processes, or + # - at least one process object is dynamic (that usually means data-driven) + if not cat_obj.data_from_processes or any(proc_obj.is_dynamic for proc_obj in cat_obj.processes): + data["data_datasets"].update(self.get_data_datasets(config_inst, cat_obj)) + + # mc datasets over all process objects + # - the process is not dynamic + for proc_obj in cat_obj.processes: + mc_datasets = self.get_mc_datasets(config_inst, proc_obj) + for dataset_name in mc_datasets: + if dataset_name not in data["mc_datasets"]: + data["mc_datasets"][dataset_name] = { + "shift_sources": set(), + "proc_names": set(), + } + data["mc_datasets"][dataset_name]["proc_names"].add( + proc_obj.config_data[config_inst.name].process, + ) + + # shift sources + for param_obj in proc_obj.parameters: + if config_inst.name not in param_obj.config_data: + continue + # only add if a shift is required for this parameter + if ( + (param_obj.type.is_shape and not param_obj.transformations.any_from_rate) or + (param_obj.type.is_rate and param_obj.transformations.any_from_shape) + ): + shift_source = param_obj.config_data[config_inst.name].shift_source + for mc_dataset in mc_datasets: + data["mc_datasets"][mc_dataset]["shift_sources"].add(shift_source) + + # for consistency checks later + variables.add(cat_obj.config_data[config_inst.name].variable) + categories.add(cat_obj.config_data[config_inst.name].category) + + # consistency checks: the config-based variable and category names must be identical + if len(variables) != 1: + raise ValueError( + f"found diverging variables to be used in datacard category '{cat_obj.name}' across configs " + f"{', '.join(c.name for c in config_insts)}: {variables}", + ) + if len(categories) != 1: + raise ValueError( + f"found diverging categories to be used in datacard category '{cat_obj.name}' across configs " + f"{', '.join(c.name for c in config_insts)}: {categories}", + ) + + return config_data + def create_branch_map(self): - return list(self.inference_model_inst.categories) + # dummy branch map + return {0: None} - def _requires_cat_obj(self, cat_obj: DotDict, **req_kwargs): + def _hist_requirements(self, **kwargs): + # gather data from inference model to define requirements in the structure + # config_name -> dataset_name -> MergeHistogramsTask reqs = {} - for config_inst in self.config_insts: - if not (config_data := cat_obj.config_data.get(config_inst.name)): - continue - - # add merged shifted histograms for mc - reqs[config_inst.name] = { - proc_obj.name: { - dataset: self.reqs.MergeShiftedHistograms.req_different_branching( - self, - config=config_inst.name, - dataset=dataset, - shift_sources=tuple( - param_obj.config_data[config_inst.name].shift_source - for param_obj in proc_obj.parameters - if ( - config_inst.name in param_obj.config_data and - self.inference_model_inst.require_shapes_for_parameter(param_obj) - ) - ), - variables=(config_data.variable,), - **req_kwargs, - ) - for dataset in self.get_mc_datasets(config_inst, proc_obj) - } - for proc_obj in cat_obj.processes - if config_inst.name in proc_obj.config_data and not proc_obj.is_dynamic - } - # add merged histograms for data, but only if - # - data in that category is not faked from mc, or - # - at least one process object is dynamic (that usually means data-driven) - if ( - (not cat_obj.data_from_processes or any(proc_obj.is_dynamic for proc_obj in cat_obj.processes)) and - (data_datasets := self.get_data_datasets(config_inst, cat_obj)) - ): - reqs[config_inst.name]["data"] = { - dataset: self.reqs.MergeHistograms.req_different_branching( - self, - config=config_inst.name, - dataset=dataset, - variables=(config_data.variable,), - **req_kwargs, - ) - for dataset in data_datasets - } + for config_inst, data in self.combined_config_data.items(): + reqs[config_inst.name] = {} + # mc datasets + for dataset_name in sorted(data["mc_datasets"]): + reqs[config_inst.name][dataset_name] = self.reqs.MergeShiftedHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=dataset_name, + shift_sources=tuple(sorted(data["mc_datasets"][dataset_name]["shift_sources"])), + variables=tuple(sorted(data["variables"])), + **kwargs, + ) + # data datasets + for dataset_name in sorted(data["data_datasets"]): + reqs[config_inst.name][dataset_name] = self.reqs.MergeShiftedHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=dataset_name, + shift_sources=(), + variables=tuple(sorted(data["variables"])), + **kwargs, + ) return reqs def workflow_requires(self): reqs = super().workflow_requires() - - reqs["merged_hists"] = hist_reqs = {} - for cat_obj in self.branch_map.values(): - cat_reqs = self._requires_cat_obj(cat_obj) - for config_name, proc_reqs in cat_reqs.items(): - hist_reqs.setdefault(config_name, {}) - for proc_name, dataset_reqs in proc_reqs.items(): - hist_reqs[config_name].setdefault(proc_name, {}) - for dataset_name, task in dataset_reqs.items(): - hist_reqs[config_name][proc_name].setdefault(dataset_name, set()).add(task) - + reqs["merged_hists"] = self._hist_requirements() return reqs def requires(self): - cat_obj = self.branch_data - return self._requires_cat_obj(cat_obj, branch=-1, workflow="local") + return self._hist_requirements(branch=-1, workflow="local") def load_process_hists( self, - inputs: dict, - cat_obj: DotDict, config_inst: od.Config, - ) -> dict[od.Process, hist.Hist]: - # loop over all configs required by the datacard category and gather histograms - config_data = cat_obj.config_data.get(config_inst.name) + dataset_processes: dict[str, list[str]], + variable: str, + inputs: dict, + ) -> dict[str, dict[od.Process, hist.Hist]]: + import hist - # collect histograms per config process + # collect histograms per variable and process hists: dict[od.Process, hist.Hist] = {} - with self.publish_step( - f"extracting {config_data.variable} in {config_data.category} for config {config_inst.name}...", - ): - for proc_obj_name, inp in inputs[config_inst.name].items(): - if proc_obj_name == "data": - process_inst = config_inst.get_process("data") - else: - proc_obj = self.inference_model_inst.get_process(proc_obj_name, category=cat_obj.name) - process_inst = config_inst.get_process(proc_obj.config_data[config_inst.name].process) - sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] - - # loop over per-dataset inputs and extract histograms containing the process - h_proc = None - for dataset_name, _inp in inp.items(): - dataset_inst = config_inst.get_dataset(dataset_name) - - # skip when the dataset is already known to not contain any sub process - if not any(map(dataset_inst.has_process, sub_process_insts)): - self.logger.warning( - f"dataset '{dataset_name}' does not contain process '{process_inst.name}' or any of " - "its subprocesses which indicates a misconfiguration in the inference model " - f"'{self.inference_model}'", - ) - continue - - # open the histogram and work on a copy - h = _inp["collection"][0]["hists"][config_data.variable].load(formatter="pickle").copy() - - # axis selections - h = h[{ - "process": [ - hist.loc(p.name) - for p in sub_process_insts - if p.name in h.axes["process"] - ], - }] - # axis reductions - h = h[{"process": sum}] + with self.publish_step(f"extracting '{variable}' for config {config_inst.name} ..."): + for dataset_name, process_names in dataset_processes.items(): + # open the histogram and work on a copy + inp = inputs[dataset_name]["collection"][0]["hists"][variable] + try: + h = inp.load(formatter="pickle").copy() + except pickle.UnpicklingError as e: + raise Exception( + f"failed to load '{variable}' histogram for dataset '{dataset_name}' in config " + f"'{config_inst.name}' from {inp.abspath}", + ) from e + + # determine processes to extract + process_insts = [config_inst.get_process(name) for name in process_names] + + # loop over all proceses assigned to this dataset + for process_inst in process_insts: + # gather all subprocesses for a full query later + sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] + + # there must be at least one matching sub process + if not any(p.name in h.axes["process"] for p in sub_process_insts): + raise Exception(f"no '{variable}' histograms found for process '{process_inst.name}'") + + # select and reduce over relevant processes + h_proc = h[{ + "process": [hist.loc(p.name) for p in sub_process_insts if p.name in h.axes["process"]], + }] + h_proc = h_proc[{"process": sum}] + + # additional custom reductions + h_proc = self.modify_process_hist( + config_inst=config_inst, + process_inst=process_inst, + variable=variable, + h=h_proc, + ) - # add the histogram for this dataset - if h_proc is None: - h_proc = h + # store it + if process_inst in hists: + hists[process_inst] += h_proc else: - h_proc += h + hists[process_inst] = h_proc - # there must be a histogram - if h_proc is None: - raise Exception(f"no histograms found for process '{process_inst.name}'") + return hists - # save histograms mapped to processes - hists[process_inst] = h_proc + def modify_process_hist( + self, + config_inst: od.Config, + process_inst: od.Process, + variable: str, + h: hist.Hist, + ) -> hist.Hist: + """ + Hook to modify a process histogram after it has been loaded. This can be helpful to reduce memory early on. - return hists + :param config_inst: The config instance the histogram belongs to. + :param process_inst: The process instance the histogram belongs to. + :param h: The histogram to modify. + :return: The modified histogram. + """ + return h diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index bb020887f..36b99d5ad 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -8,7 +8,7 @@ import time import itertools -from collections import Counter +from collections import Counter, defaultdict import luigi import law @@ -29,6 +29,7 @@ from columnflow.types import Callable from columnflow.timing import Timer +np = maybe_import("numpy") ak = maybe_import("awkward") @@ -92,21 +93,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrator"} return super().req_params(inst, **kwargs) - @property - def calibrator_repr(self) -> str: - """ - Return a string representation of the calibrator class. - """ - return self.build_repr(self.array_function_cls_repr(self.calibrator)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "calibrator", f"calib__{self.calibrator_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -121,10 +107,26 @@ def get_config_lookup_keys( else getattr(inst_or_params, "calibrator", None) ) if calibrator not in (law.NO_STR, None, ""): - keys["calibrator"] = f"calib_{calibrator}" + prefix = "calib" + keys[prefix] = f"{prefix}_{calibrator}" return keys + @property + def calibrator_repr(self) -> str: + """ + Return a string representation of the calibrator class. + """ + return self.build_repr(self.array_function_cls_repr(self.calibrator)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "calibrator", f"calib__{self.calibrator_repr}") + return parts + class CalibratorMixin(ArrayFunctionInstanceMixin, CalibratorClassMixin): """ @@ -198,6 +200,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_calibrator(cls, inst: CalibratorMixin, **kwargs) -> CalibratorMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + calibrator instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # calibrator_inst and known_shifts must be set to None to by-pass calibrator instance cache lookup and thus, + # also full parameter resolution + kwargs.setdefault("calibrator_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -209,9 +228,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.calibrator_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_calibrator_inst(self) -> None: + def teardown_calibrator_inst(self, **kwargs) -> None: if self.calibrator_inst: - self.calibrator_inst.run_teardown(task=self) + self.calibrator_inst.run_teardown(task=self, **kwargs) @property def calibrator_repr(self) -> str: @@ -305,7 +324,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "calibrators", None) ) if calibrators not in {law.NO_STR, None, "", ()}: - keys["calibrators"] = [f"calib_{calibrator}" for calibrator in calibrators] + prefix = "calib" + keys[prefix] = [f"{prefix}_{calibrator}" for calibrator in calibrators] return keys @@ -476,6 +496,25 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: } return super().req_params(inst, **kwargs) + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: SelectorClassMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the selector name + selector = ( + inst_or_params.get("selector") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "selector", None) + ) + if selector not in (law.NO_STR, None, ""): + prefix = "sel" + keys[prefix] = f"{prefix}_{selector}" + + return keys + @property def selector_repr(self) -> str: """ @@ -497,24 +536,6 @@ def store_parts(self) -> law.util.InsertableDict: parts.insert_after(self.config_store_anchor, "selector", f"sel__{self.selector_repr}") return parts - @classmethod - def get_config_lookup_keys( - cls, - inst_or_params: SelectorClassMixin | dict[str, Any], - ) -> law.util.InsertiableDict: - keys = super().get_config_lookup_keys(inst_or_params) - - # add the selector name - selector = ( - inst_or_params.get("selector") - if isinstance(inst_or_params, dict) - else getattr(inst_or_params, "selector", None) - ) - if selector not in (law.NO_STR, None, ""): - keys["selector"] = f"sel_{selector}" - - return keys - class SelectorMixin(ArrayFunctionInstanceMixin, SelectorClassMixin): """ @@ -583,6 +604,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_selector(cls, inst: SelectorMixin, **kwargs) -> SelectorMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + selector instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # selector_inst and known_shifts must be set to None to by-pass selector instance cache lookup and thus, also + # full parameter resolution + kwargs.setdefault("selector_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -594,9 +632,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.selector_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_selector_inst(self) -> None: + def teardown_selector_inst(self, **kwargs) -> None: if self.selector_inst: - self.selector_inst.run_teardown(task=self) + self.selector_inst.run_teardown(task=self, **kwargs) @property def selector_repr(self) -> str: @@ -674,21 +712,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"reducer"} return super().req_params(inst, **kwargs) - @property - def reducer_repr(self) -> str: - """ - Return a string representation of the reducer class. - """ - return self.build_repr(self.array_function_cls_repr(self.reducer)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "reducer", f"red__{self.reducer_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -703,10 +726,26 @@ def get_config_lookup_keys( else getattr(inst_or_params, "reducer", None) ) if reducer not in (law.NO_STR, None, ""): - keys["reducer"] = f"red_{reducer}" + prefix = "red" + keys[prefix] = f"{prefix}_{reducer}" return keys + @property + def reducer_repr(self) -> str: + """ + Return a string representation of the reducer class. + """ + return self.build_repr(self.array_function_cls_repr(self.reducer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "reducer", f"red__{self.reducer_repr}") + return parts + class ReducerMixin(ArrayFunctionInstanceMixin, ReducerClassMixin): """ @@ -780,6 +819,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_reducer(cls, inst: ReducerMixin, **kwargs) -> ReducerMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + reducer instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # reducer_inst and known_shifts must be set to None to by-pass reducer instance cache lookup and thus, also full + # parameter resolution + kwargs.setdefault("reducer_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -791,9 +847,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.reducer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_reducer_inst(self) -> None: + def teardown_reducer_inst(self, **kwargs) -> None: if self.reducer_inst: - self.reducer_inst.run_teardown(task=self) + self.reducer_inst.run_teardown(task=self, **kwargs) @property def reducer_repr(self) -> str: @@ -849,21 +905,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producer"} return super().req_params(inst, **kwargs) - @property - def producer_repr(self) -> str: - """ - Return a string representation of the producer class. - """ - return self.build_repr(self.array_function_cls_repr(self.producer)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "producer", f"prod__{self.producer_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -878,10 +919,26 @@ def get_config_lookup_keys( else getattr(inst_or_params, "producer", None) ) if producer not in (law.NO_STR, None, ""): - keys["producer"] = f"prod_{producer}" + prefix = "prod" + keys[prefix] = f"{prefix}_{producer}" return keys + @property + def producer_repr(self) -> str: + """ + Return a string representation of the producer class. + """ + return self.build_repr(self.array_function_cls_repr(self.producer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "producer", f"prod__{self.producer_repr}") + return parts + class ProducerMixin(ArrayFunctionInstanceMixin, ProducerClassMixin): """ @@ -955,6 +1012,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_producer(cls, inst: ProducerMixin, **kwargs) -> ProducerMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + producer instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # producer_inst and known_shifts must be set to None to by-pass producer instance cache lookup and thus, also + # full parameter resolution + kwargs.setdefault("producer_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -966,9 +1040,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.producer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_producer_inst(self) -> None: + def teardown_producer_inst(self, **kwargs) -> None: if self.producer_inst: - self.producer_inst.run_teardown(task=self) + self.producer_inst.run_teardown(task=self, **kwargs) @property def producer_repr(self) -> str: @@ -1062,7 +1136,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "producers", None) ) if producers not in {law.NO_STR, None, "", ()}: - keys["producers"] = [f"prod_{producer}" for producer in producers] + prefix = "prod" + keys[prefix] = [f"{prefix}_{producer}" for producer in producers] return keys @@ -1206,21 +1281,8 @@ def ml_model_repr(self) -> str: @classmethod def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: - """ - Get the required parameters for the task, preferring the ``--ml-model`` set on task-level - via CLI. - - This method first checks if the ``--ml-model`` parameter is set at the task-level via the command line. If it - is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. The method then - calls the 'req_params' method of the superclass with the updated kwargs. - - :param inst: The current task instance. - :param kwargs: Additional keyword arguments that may contain parameters for the task. - :return: A dictionary of parameters required for the task. - """ # prefer --ml-model set on task-level via cli kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"ml_model"} - return super().req_params(inst, **kwargs) @classmethod @@ -1278,6 +1340,25 @@ def events_used_in_training( not shift_inst.has_tag("disjoint_from_nominal") ) + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: MLModelMixinBase | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the ml model name + ml_model = ( + inst_or_params.get("ml_model") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "ml_model", None) + ) + if ml_model not in (law.NO_STR, None, ""): + prefix = "ml" + keys[prefix] = f"{prefix}_{ml_model}" + + return keys + class MLModelTrainingMixin( MLModelMixinBase, @@ -1474,9 +1555,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.preparation_producer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_preparation_producer_inst(self) -> None: + def teardown_preparation_producer_inst(self, **kwargs) -> None: if self.preparation_producer_inst: - self.preparation_producer_inst.run_teardown(task=self) + self.preparation_producer_inst.run_teardown(task=self, **kwargs) @classmethod def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: @@ -1564,7 +1645,6 @@ def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any] def req_params(cls, inst: law.Task, **kwargs) -> dict: # prefer --ml-models set on task-level via cli kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"ml_models"} - return super().req_params(inst, **kwargs) @property @@ -1605,6 +1685,25 @@ def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: return columns + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: MLModelsMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the ml model names + ml_models = ( + inst_or_params.get("ml_models") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "ml_models", None) + ) + if ml_models not in {law.NO_STR, None, "", ()}: + prefix = "ml" + keys[prefix] = [f"{prefix}_{ml_model}" for ml_model in ml_models] + + return keys + class HistProducerClassMixin(ArrayFunctionClassMixin): """ @@ -1651,21 +1750,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"hist_producer"} return super().req_params(inst, **kwargs) - @property - def hist_producer_repr(self) -> str: - """ - Return a string representation of the hist producer class. - """ - return self.build_repr(self.array_function_cls_repr(self.hist_producer)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "hist_producer", f"hist__{self.hist_producer_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -1680,10 +1764,26 @@ def get_config_lookup_keys( else getattr(inst_or_params, "hist_producer", None) ) if producer not in (law.NO_STR, None, ""): - keys["hist_producer"] = f"hist_{producer}" + prefix = "hist" + keys[prefix] = f"{prefix}_{producer}" return keys + @property + def hist_producer_repr(self) -> str: + """ + Return a string representation of the hist producer class. + """ + return self.build_repr(self.array_function_cls_repr(self.hist_producer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "hist_producer", f"hist__{self.hist_producer_repr}") + return parts + class HistProducerMixin(ArrayFunctionInstanceMixin, HistProducerClassMixin): """ @@ -1760,6 +1860,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_hist_producer(cls, inst: HistProducerMixin, **kwargs) -> HistProducerMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + hist producer instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # hist_producer_inst and known_shifts must be set to None to by-pass hist producer instance cache lookup and + # thus, also full parameter resolution + kwargs.setdefault("hist_producer_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -1771,9 +1888,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.hist_producer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_hist_producer_inst(self) -> None: + def teardown_hist_producer_inst(self, **kwargs) -> None: if self.hist_producer_inst: - self.hist_producer_inst.run_teardown(task=self) + self.hist_producer_inst.run_teardown(task=self, **kwargs) @property def hist_producer_repr(self) -> str: @@ -1808,10 +1925,9 @@ def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any] return params @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict: + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: # prefer --inference-model set on task-level via cli kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"inference_model"} - return super().req_params(inst, **kwargs) @property @@ -2267,7 +2383,24 @@ def resolve(config_inst: od.Config, processes: Any, datasets: Any) -> tuple[list deep=True, ) else: - processes = config_inst.processes.names() + processes = list(config_inst.processes.names()) + # protect against overlap between top-level processes + to_remove = defaultdict(set) + for process_name in processes: + process = config_inst.get_process(process_name) + # check any remaining process for overlap + for child_process_name in processes: + if child_process_name == process_name: + continue + if process.has_process(child_process_name, deep=True): + to_remove[child_process_name].add(process_name) + if to_remove: + processes = [process_name for process_name in processes if process_name not in to_remove] + for removed, reasons in to_remove.items(): + reasons = ", ".join(map("'{}'".format, reasons)) + logger.warning( + f"removed '{removed}' from selected processes due to overlap with {reasons}", + ) if not processes and not cls.allow_empty_processes: raise ValueError(f"no processes found matching {processes_orig}") if datasets != law.no_value: @@ -2489,6 +2622,9 @@ class ChunkedIOMixin(ConfigTask): description="when True, checks whether columns if input arrays overlap in at least one field", ) + # number of events per row group in the merged file + merging_row_group_size = law.config.get_expanded_int("analysis", "merging_row_group_size", 50_000) + exclude_params_req = {"check_finite_output", "check_overlapping_inputs"} # define default chunk and pool sizes that can be adjusted per inheriting task @@ -2498,18 +2634,25 @@ class ChunkedIOMixin(ConfigTask): @classmethod def raise_if_not_finite(cls, ak_array: ak.Array) -> None: """ - Checks whether all values in array *ak_array* are finite. + Checks whether values of all columns in *ak_array* are finite. String and bytestring types are skipped. The check is performed using the :external+numpy:py:func:`numpy.isfinite` function. - :param ak_array: Array with events to check. + :param ak_array: Array with columns to check. :raises ValueError: If any value in *ak_array* is not finite. """ - import numpy as np from columnflow.columnar_util import get_ak_routes for route in get_ak_routes(ak_array): - if ak.any(~np.isfinite(ak.flatten(route.apply(ak_array), axis=None))): + # flatten + flat = ak.flatten(route.apply(ak_array), axis=None) + # perform parameter dependent checks + if isinstance((params := getattr(getattr(flat, "layout", None), "parameters", None)), dict): + # skip string and bytestring arrays + if params.get("__array__") in {"string", "bytestring"}: + continue + # check finiteness + if ak.any(~np.isfinite(flat)): raise ValueError(f"found one or more non-finite values in column '{route.column}' of array {ak_array}") @classmethod @@ -2613,6 +2756,7 @@ def _get_hist_hook(self, name: str) -> Callable: def invoke_hist_hooks( self, hists: dict[od.Config, dict[od.Process, Any]], + hook_kwargs: dict | None = None, ) -> dict[od.Config, dict[od.Process, Any]]: """ Invoke hooks to modify histograms before further processing such as plotting. @@ -2634,7 +2778,7 @@ def invoke_hist_hooks( # invoke it self.publish_message(f"invoking hist hook '{hook}'") - hists = func(self, hists) + hists = func(self, hists, **(hook_kwargs or {})) return hists diff --git a/columnflow/tasks/framework/parameters.py b/columnflow/tasks/framework/parameters.py index c7c70318f..782b81562 100644 --- a/columnflow/tasks/framework/parameters.py +++ b/columnflow/tasks/framework/parameters.py @@ -14,6 +14,7 @@ from columnflow.util import try_float, try_complex, DotDict, Derivable from columnflow.types import Iterable, Any + user_parameter_inst = luigi.Parameter( default=getpass.getuser(), description="the user running the current task, mainly for central schedulers to distinguish " diff --git a/columnflow/tasks/framework/remote.py b/columnflow/tasks/framework/remote.py index fd719087b..da2c59c2c 100644 --- a/columnflow/tasks/framework/remote.py +++ b/columnflow/tasks/framework/remote.py @@ -48,6 +48,10 @@ class BundleRepo(AnalysisTask, law.git.BundleGitRepository, law.tasks.TransferLo os.environ["CF_CONDA_BASE"], ] + include_files = [ + "law_user.cfg", + ] + def get_repo_path(self): # required by BundleGitRepository return os.environ["CF_REPO_BASE"] @@ -343,13 +347,25 @@ def __post_init__(self): ) +_default_remove_claw_sandbox = law.config.get_expanded("analysis", "default_remote_claw_sandbox", None) or law.NO_STR + + class RemoteWorkflowMixin(AnalysisTask): """ Mixin class for custom remote workflows adding common functionality. """ + remote_claw_sandbox = luigi.Parameter( + default=_default_remove_claw_sandbox, + significant=False, + description="the name of a non-dev sandbox to use in remote jobs for the 'claw' executable rather than using " + f"using 'law' directly; not used when empty; default: {_default_remove_claw_sandbox}", + ) + skip_destination_info: bool = False + exclude_params_req = {"remote_claw_sandbox"} + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -577,6 +593,14 @@ def add_common_configs( render=False, ) + # claw sandbox + if self.remote_claw_sandbox not in {None, "", law.NO_STR}: + if self.remote_claw_sandbox.endswith("_dev"): + raise ValueError( + f"remote_claw_sandbox must not refer to a dev sandbox, got '{self.remote_claw_sandbox}'", + ) + config.render_variables["law_exe"] = f"CLAW_SANDBOX='{self.remote_claw_sandbox}' claw" + def common_destination_info(self, info: dict[str, str]) -> dict[str, str]: """ Hook to modify the additional info printed along logs of the workflow. @@ -641,6 +665,11 @@ def handle_scheduler_message(self, msg, _attr_value=None): input_unit="GB", unit="GB", ) +_default_htcondor_runtime = law.util.parse_duration( + law.config.get_expanded("analysis", "htcondor_runtime", 3.0), + input_unit="h", + unit="h", +) class HTCondorWorkflow(RemoteWorkflowMixin, law.htcondor.HTCondorWorkflow): @@ -650,11 +679,11 @@ class HTCondorWorkflow(RemoteWorkflowMixin, law.htcondor.HTCondorWorkflow): significant=False, description="transfer job logs to the output directory; default: True", ) - max_runtime = law.DurationParameter( - default=2.0, + htcondor_runtime = law.DurationParameter( + default=_default_htcondor_runtime, unit="h", significant=False, - description="maximum runtime; default unit is hours; default: 2", + description=f"maximum runtime; default unit is hours; default: {_default_htcondor_runtime}", ) htcondor_logs = luigi.BoolParameter( default=False, @@ -708,12 +737,12 @@ class HTCondorWorkflow(RemoteWorkflowMixin, law.htcondor.HTCondorWorkflow): # parameters that should not be passed to a workflow required upstream exclude_params_req_set = { - "max_runtime", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", + "htcondor_runtime", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", } # parameters that should not be passed from workflow to branches exclude_params_branch = { - "max_runtime", "htcondor_logs", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", + "htcondor_runtime", "htcondor_logs", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_flavor", "htcondor_share_software", } @@ -743,7 +772,7 @@ def __init__(self, *args, **kwargs): self.bundle_repo_req = self.reqs.BundleRepo.req(self) # add scheduler message handlers - self.add_message_handler("max_runtime") + self.add_message_handler("htcondor_runtime") self.add_message_handler("htcondor_logs") self.add_message_handler("htcondor_cpus") self.add_message_handler("htcondor_gpus") @@ -801,6 +830,8 @@ def htcondor_job_config(self, config, job_num, branches): batch_name += f"_{info['config']}" if "dataset" in info: batch_name += f"_{info['dataset']}" + if "shift" in info: + batch_name += f"_{info['shift']}" config.custom_content.append(("batch_name", batch_name)) # CERN settings, https://batchdocs.web.cern.ch/local/submit.html#os-selection-via-containers @@ -821,8 +852,8 @@ def htcondor_job_config(self, config, job_num, branches): config.custom_content.append(("Request_OpSysAndVer", "\"RedHat9\"")) # maximum runtime, compatible with multiple batch systems - if self.max_runtime is not None and self.max_runtime > 0: - max_runtime = int(math.floor(self.max_runtime * 3600)) - 1 + if self.htcondor_runtime is not None and self.htcondor_runtime > 0: + max_runtime = int(math.floor(self.htcondor_runtime * 3600)) - 1 config.custom_content.append(("+MaxRuntime", max_runtime)) config.custom_content.append(("+RequestRuntime", max_runtime)) @@ -876,6 +907,11 @@ def htcondor_destination_info(self, info: dict[str, str]) -> dict[str, str]: _default_slurm_flavor = law.config.get_expanded("analysis", "slurm_flavor", "maxwell") _default_slurm_partition = law.config.get_expanded("analysis", "slurm_partition", "cms-uhh") +_default_slurm_runtime = law.util.parse_duration( + law.config.get_expanded("analysis", "slurm_runtime", 3.0), + input_unit="h", + unit="h", +) class SlurmWorkflow(RemoteWorkflowMixin, law.slurm.SlurmWorkflow): @@ -885,11 +921,11 @@ class SlurmWorkflow(RemoteWorkflowMixin, law.slurm.SlurmWorkflow): significant=False, description="transfer job logs to the output directory; default: True", ) - max_runtime = law.DurationParameter( - default=2.0, + slurm_runtime = law.DurationParameter( + default=_default_slurm_runtime, unit="h", significant=False, - description="maximum runtime; default unit is hours; default: 2", + description=f"maximum runtime; default unit is hours; default: {_default_slurm_runtime}", ) slurm_partition = luigi.Parameter( default=_default_slurm_partition, @@ -905,10 +941,10 @@ class SlurmWorkflow(RemoteWorkflowMixin, law.slurm.SlurmWorkflow): ) # parameters that should not be passed to a workflow required upstream - exclude_params_req_set = {"max_runtime"} + exclude_params_req_set = {"slurm_runtime"} # parameters that should not be passed from workflow to branches - exclude_params_branch = {"max_runtime", "slurm_partition", "slurm_flavor"} + exclude_params_branch = {"slurm_runtime", "slurm_partition", "slurm_flavor"} # mapping of environment variables to render variables that are forwarded slurm_forward_env_variables = { @@ -960,9 +996,9 @@ def slurm_job_config(self, config, job_num, branches): ) # set job time - if self.max_runtime is not None: + if self.slurm_runtime is not None and self.slurm_runtime > 0: job_time = law.util.human_duration( - seconds=int(math.floor(self.max_runtime * 3600)) - 1, + seconds=int(math.floor(self.slurm_runtime * 3600)) - 1, colon_format=True, ) config.custom_content.append(("time", job_time)) diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index 238243f69..f57e18b14 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -14,7 +14,6 @@ CalibratorClassesMixin, CalibratorsMixin, SelectorClassMixin, SelectorMixin, ReducerClassMixin, ReducerMixin, ProducerClassesMixin, ProducersMixin, VariablesMixin, DatasetShiftSourcesMixin, HistProducerClassMixin, HistProducerMixin, ChunkedIOMixin, MLModelsMixin, - # ParamsCacheMixin, ) from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.framework.parameters import last_edge_inclusive_inst @@ -22,19 +21,27 @@ from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.tasks.ml import MLEvaluation +from columnflow.hist_util import update_ax_labels, sum_hists from columnflow.util import dev_sandbox +class VariablesMixinWorkflow( + VariablesMixin, + law.LocalWorkflow, + RemoteWorkflow, +): + + def control_output_postfix(self) -> str: + return f"{super().control_output_postfix()}__vars_{self.variables_repr}" + + class _CreateHistograms( - # ParamsCacheMixin ReducedEventsUser, ProducersMixin, MLModelsMixin, HistProducerMixin, ChunkedIOMixin, - VariablesMixin, - law.LocalWorkflow, - RemoteWorkflow, + VariablesMixinWorkflow, ): """ Base classes for :py:class:`CreateHistograms`. @@ -101,11 +108,13 @@ def workflow_requires(self): self.reqs.MLEvaluation.req(self, ml_model=ml_model_inst.cls_name) for ml_model_inst in self.ml_model_insts ] + elif self.producer_insts: + # pass-through pilot workflow requirements of upstream task + t = self.reqs.ProduceColumns.req(self) + law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) - # add hist_producer dependent requirements - reqs["hist_producer"] = law.util.make_unique(law.util.flatten( - self.hist_producer_inst.run_requires(task=self), - )) + # add hist producer dependent requirements + reqs["hist_producer"] = law.util.make_unique(law.util.flatten(self.hist_producer_inst.run_requires(task=self))) return reqs @@ -242,6 +251,10 @@ def run(self): events = attach_coffea_behavior(events) events, weight = self.hist_producer_inst(events, task=self) + if len(events) == 0: + self.publish_message(f"no events found in chunk {pos}") + continue + # merge category ids and check that they are defined as leaf categories category_ids = ak.concatenate( [Route(c).apply(events) for c in self.category_id_columns], @@ -341,9 +354,7 @@ class _MergeHistograms( ProducersMixin, MLModelsMixin, HistProducerMixin, - VariablesMixin, - law.LocalWorkflow, - RemoteWorkflow, + VariablesMixinWorkflow, ): """ Base classes for :py:class:`MergeHistograms`. @@ -421,10 +432,12 @@ def requires(self): ) def output(self): - return {"hists": law.SiblingFileCollection({ - variable_name: self.target(f"hist__var_{variable_name}.pickle") - for variable_name in self.variables - })} + return { + "hists": law.SiblingFileCollection({ + variable_name: self.target(f"hist__var_{variable_name}.pickle") + for variable_name in self.variables + }), + } @law.decorator.notify @law.decorator.log @@ -443,10 +456,13 @@ def run(self): variable_names = list(hists[0].keys()) for variable_name in self.iter_progress(variable_names, len(variable_names), reach=(50, 100)): self.publish_message(f"merging histograms for '{variable_name}'") + variable_hists = [h[variable_name] for h in hists] + + # update axis labels from variable insts for consistency + update_ax_labels(variable_hists, self.config_inst, variable_name) # merge them - variable_hists = [h[variable_name] for h in hists] - merged = sum(variable_hists[1:], variable_hists[0].copy()) + merged = sum_hists(variable_hists) # post-process the merged histogram merged = self.hist_producer_inst.run_post_process_merged_hist(merged, task=self) @@ -478,9 +494,7 @@ class _MergeShiftedHistograms( ProducerClassesMixin, MLModelsMixin, HistProducerClassMixin, - VariablesMixin, - law.LocalWorkflow, - RemoteWorkflow, + VariablesMixinWorkflow, ): """ Base classes for :py:class:`MergeShiftedHistograms`. @@ -507,9 +521,10 @@ def create_branch_map(self): def workflow_requires(self): reqs = super().workflow_requires() - # add nominal and both directions per shift source - for shift in ["nominal"] + self.shifts: - reqs[shift] = self.reqs.MergeHistograms.req(self, shift=shift, _prefer_cli={"variables"}) + if not self.pilot: + # add nominal and both directions per shift source + for shift in ["nominal"] + self.shifts: + reqs[shift] = self.reqs.MergeHistograms.req(self, shift=shift, _prefer_cli={"variables"}) return reqs @@ -535,17 +550,19 @@ def run(self): outputs = self.output()["hists"].targets for variable_name, outp in self.iter_progress(outputs.items(), len(outputs)): - self.publish_message(f"merging histograms for '{variable_name}'") + with self.publish_step(f"merging histograms for '{variable_name}' ..."): + # load hists + variable_hists = [ + coll["hists"].targets[variable_name].load(formatter="pickle") + for coll in inputs.values() + ] - # load hists - variable_hists = [ - coll["hists"].targets[variable_name].load(formatter="pickle") - for coll in inputs.values() - ] + # update axis labels from variable insts for consistency + update_ax_labels(variable_hists, self.config_inst, variable_name) - # merge and write the output - merged = sum(variable_hists[1:], variable_hists[0].copy()) - outp.dump(merged, formatter="pickle") + # merge and write the output + merged = sum_hists(variable_hists) + outp.dump(merged, formatter="pickle") MergeShiftedHistogramsWrapper = wrapper_factory( diff --git a/columnflow/tasks/inspection.py b/columnflow/tasks/inspection.py index 3d9a1ce3b..11b472577 100644 --- a/columnflow/tasks/inspection.py +++ b/columnflow/tasks/inspection.py @@ -26,17 +26,42 @@ def output(self): return {"always_incomplete_dummy": self.target("dummy.txt")} def run(self): + """ + Loads histograms for all configs, variables, and datasets, + sums them up for each variable and + slices them according to the processes, categories, and shift, + The resulting histograms are stored in a dictionary with variable names as keys. + If `debugger` is set to True, an IPython debugger session is started for + interactive inspection of the histograms. + """ + inputs = self.input() + shifts = {self.shift, "nominal"} hists = {} - for dataset in self.datasets: - for variable in self.variables: - h_in = self.load_histogram(dataset, variable) - h_in = self.slice_histogram(h_in, self.processes, self.categories, self.shift) + for variable in self.variables: + for i, config_inst in enumerate(self.config_insts): + hist_per_config = None + sub_processes = self.processes[i] + for dataset in self.datasets[i]: + # sum over all histograms of the same variable and config + if hist_per_config is None: + hist_per_config = self.load_histogram(inputs, config_inst, dataset, variable) + else: + hist_per_config += self.load_histogram(inputs, config_inst, dataset, variable) + + # slice histogram per config according to the sub_processes and categories + hist_per_config = self.slice_histogram( + histogram=hist_per_config, + config_inst=config_inst, + processes=sub_processes, + categories=self.categories, + shifts=shifts, + ) if variable in hists.keys(): - hists[variable] += h_in + hists[variable] += hist_per_config else: - hists[variable] = h_in + hists[variable] = hist_per_config if self.debugger: from IPython import embed @@ -57,18 +82,42 @@ def output(self): return {"always_incomplete_dummy": self.target("dummy.txt")} def run(self): + """ + Loads histograms for all configs, variables, and datasets, + sums them up for each variable and + slices them according to the processes, categories, and shift, + The resulting histograms are stored in a dictionary with variable names as keys. + If `debugger` is set to True, an IPython debugger session is started for + interactive inspection of the histograms. + """ + inputs = self.input() shifts = ["nominal"] + self.shifts hists = {} - for dataset in self.datasets: - for variable in self.variables: - h_in = self.load_histogram(dataset, variable) - h_in = self.slice_histogram(h_in, self.processes, self.categories, shifts) + for variable in self.variables: + for i, config_inst in enumerate(self.config_insts): + hist_per_config = None + sub_processes = self.processes[i] + for dataset in self.datasets[i]: + # sum over all histograms of the same variable and config + if hist_per_config is None: + hist_per_config = self.load_histogram(inputs, config_inst, dataset, variable) + else: + hist_per_config += self.load_histogram(inputs, config_inst, dataset, variable) + + # slice histogram per config according to the sub_processes and categories + hist_per_config = self.slice_histogram( + histogram=hist_per_config, + config_inst=config_inst, + processes=sub_processes, + categories=self.categories, + shifts=shifts, + ) if variable in hists.keys(): - hists[variable] += h_in + hists[variable] += hist_per_config else: - hists[variable] = h_in + hists[variable] = hist_per_config if self.debugger: from IPython import embed diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 92cf13b9f..5ae06ec18 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -35,7 +35,6 @@ from columnflow.util import dev_sandbox, safe_div, DotDict, maybe_import from columnflow.columnar_util import set_ak_column - ak = maybe_import("awkward") @@ -808,7 +807,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["mlcolumns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["mlcolumns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index 5bdf58e7e..ba91c0492 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -18,7 +18,6 @@ from columnflow.tasks.framework.mixins import ( CalibratorClassesMixin, SelectorClassMixin, ReducerClassMixin, ProducerClassesMixin, HistProducerClassMixin, CategoriesMixin, ShiftSourcesMixin, HistHookMixin, MLModelsMixin, - # ParamsCacheMixin, ) from columnflow.tasks.framework.plotting import ( PlotBase, PlotBase1D, PlotBase2D, ProcessPlotSettingMixin, VariablePlotSettingMixin, @@ -32,7 +31,6 @@ class _PlotVariablesBase( - # ParamsCacheMixin CalibratorClassesMixin, SelectorClassMixin, ReducerClassMixin, @@ -52,10 +50,20 @@ class _PlotVariablesBase( class PlotVariablesBase(_PlotVariablesBase): + + bypass_branch_requirements = luigi.BoolParameter( + default=False, + description="whether to skip branch requirements and only use that of the workflow; default: False", + ) + single_config = False sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + exclude_params_repr = {"bypass_branch_requirements"} + exclude_params_index = {"bypass_branch_requirements"} + exclude_params_repr = {"bypass_branch_requirements"} + exclude_index = True def store_parts(self) -> law.util.InsertableDict: @@ -75,6 +83,13 @@ def workflow_requires(self): reqs["merged_hists"] = self.requires_from_branch() return reqs + def local_workflow_pre_run(self): + # when branches are cached, reinitiate the branch tasks with dropped branch level requirements since this + # method is called from a context where the identical workflow level requirements are already resolved + if self.cache_branch_map: + self._branch_tasks = None + self.get_branch_tasks(bypass_branch_requirements=True) + @abstractmethod def get_plot_shifts(self): return @@ -94,7 +109,7 @@ def get_config_process_map(self) -> tuple[dict[od.Config, dict[od.Process, dict[ dictionaries containing the dataset-process mapping and the shifts to be considered, and a dictionary mapping process names to the shifts to be considered. """ - reqs = self.requires() + reqs = self.requires() or self.as_workflow().requires().merged_hists config_process_map = {config_inst: {} for config_inst in self.config_insts} process_shift_map = defaultdict(set) @@ -161,12 +176,13 @@ def run(self): plot_shift_names = set(shift_inst.name for shift_inst in plot_shifts) # get assignment of processes to datasets and shifts - config_process_map, _ = self.get_config_process_map() + config_process_map, process_shift_map = self.get_config_process_map() # histogram data per process copy hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} with self.publish_step(f"plotting {self.branch_data.variable} in {self.branch_data.category}"): - for i, (config, dataset_dict) in enumerate(self.input().items()): + inputs = self.input() or self.workflow_input().merged_hists + for i, (config, dataset_dict) in enumerate(inputs.items()): config_inst = self.config_insts[i] category_inst = config_inst.get_category(self.branch_data.category) leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] @@ -194,7 +210,10 @@ def run(self): h = h[{"process": sum}] # create expected shift bins and fill them with the nominal histogram - add_missing_shifts(h, plot_shift_names, str_axis="shift", nominal_bin="nominal") + # change Ghent: replace all expected shifts with nominal. + # not preffered by columnflow: https://github.com/columnflow/columnflow/pull/692 + expected_shifts = plot_shift_names # & process_shift_map[process_inst.name] + add_missing_shifts(h, expected_shifts, str_axis="shift", nominal_bin="nominal") # add the histogram if process_inst in hists_config: @@ -219,7 +238,10 @@ def run(self): ) # update histograms using custom hooks - hists = self.invoke_hist_hooks(hists) + hists = self.invoke_hist_hooks( + hists, + hook_kwargs={"category_name": self.branch_data.category, "variable_name": self.branch_data.variable}, + ) # merge configs if len(self.config_insts) != 1: @@ -243,6 +265,15 @@ def run(self): _hists = OrderedDict() for process_inst in hists.keys(): h = hists[process_inst] + # determine expected shifts from the intersection of requested shifts and those known for the process + process_shifts = ( + process_shift_map[process_inst.name] + if process_inst.name in process_shift_map + else {"nominal"} + ) + expected_shifts = (process_shifts & plot_shift_names) or (process_shifts & {"nominal"}) + if not expected_shifts: + raise Exception(f"no shifts to plot found for process {process_inst.name}") # selections h = h[{ "category": [ @@ -252,7 +283,7 @@ def run(self): ], "shift": [ hist.loc(s_name) - for s_name in plot_shift_names + for s_name in expected_shifts if s_name in h.axes["shift"] ], }] @@ -300,6 +331,7 @@ class PlotVariablesBaseSingleShift( ): # use the MergeHistograms task to trigger upstream TaskArrayFunction initialization resolution_task_cls = MergeHistograms + exclude_index = True reqs = Requirements( @@ -314,28 +346,27 @@ def create_branch_map(self): for cat_name in sorted(self.categories) ] - def workflow_requires(self): - reqs = super().workflow_requires() - return reqs - def requires(self): - req = {} + reqs = {} - for i, config_inst in enumerate(self.config_insts): - sub_datasets = self.datasets[i] - req[config_inst.name] = {} - for d in sub_datasets: - if d in config_inst.datasets.names(): - req[config_inst.name][d] = self.reqs.MergeHistograms.req( - self, - config=config_inst.name, - shift=self.global_shift_insts[config_inst].name, - dataset=d, - branch=-1, - _exclude={"branches"}, - _prefer_cli={"variables"}, - ) - return req + if self.is_branch() and self.bypass_branch_requirements: + return reqs + + for config_inst, datasets in zip(self.config_insts, self.datasets): + reqs[config_inst.name] = {} + for d in datasets: + if d not in config_inst.datasets: + continue + reqs[config_inst.name][d] = self.reqs.MergeHistograms.req_different_branching( + self, + config=config_inst.name, + shift=self.global_shift_insts[config_inst].name, + dataset=d, + branch=-1, + _prefer_cli={"variables"}, + ) + + return reqs def plot_parts(self) -> law.util.InsertableDict: parts = super().plot_parts() @@ -474,26 +505,32 @@ def create_branch_map(self) -> list[DotDict]: return [DotDict(zip(keys, vals)) for vals in itertools.product(*seqs)] def requires(self): - req_cls = lambda dataset_name: ( + reqs = {} + + if self.is_branch() and self.bypass_branch_requirements: + return reqs + + req_cls = lambda dataset_name, config_inst: ( self.reqs.MergeShiftedHistograms - if self.config_inst.get_dataset(dataset_name).is_mc + if config_inst.get_dataset(dataset_name).is_mc else self.reqs.MergeHistograms ) - req = {} - for i, config_inst in enumerate(self.config_insts): - req[config_inst.name] = {} - for dataset_name in self.datasets[i]: - if dataset_name in config_inst.datasets: - req[config_inst.name][dataset_name] = req_cls(dataset_name).req( - self, - config=config_inst.name, - dataset=dataset_name, - branch=-1, - _exclude={"branches"}, - _prefer_cli={"variables"}, - ) - return req + for config_inst, datasets in zip(self.config_insts, self.datasets): + reqs[config_inst.name] = {} + for d in datasets: + if d not in config_inst.datasets: + continue + reqs[config_inst.name][d] = req_cls(d, config_inst).req( + self, + config=config_inst.name, + dataset=d, + branch=-1, + _exclude={"branches"}, + _prefer_cli={"variables"}, + ) + + return reqs def plot_parts(self) -> law.util.InsertableDict: parts = super().plot_parts() @@ -565,8 +602,8 @@ class PlotShiftedVariablesPerShift1D( class PlotShiftedVariablesPerConfig1D( - law.WrapperTask, PlotShiftedVariables1D, + law.WrapperTask, ): # force this one to be a local workflow workflow = "local" diff --git a/columnflow/tasks/production.py b/columnflow/tasks/production.py index cee2d1311..bb5f8d090 100644 --- a/columnflow/tasks/production.py +++ b/columnflow/tasks/production.py @@ -50,9 +50,7 @@ def workflow_requires(self): reqs["events"] = self.reqs.ProvideReducedEvents.req(self) # add producer dependent requirements - reqs["producer"] = law.util.make_unique(law.util.flatten( - self.producer_inst.run_requires(task=self), - )) + reqs["producer"] = law.util.make_unique(law.util.flatten(self.producer_inst.run_requires(task=self))) return reqs @@ -168,7 +166,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["columns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["columns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) @@ -199,6 +202,7 @@ class ProduceColumnsWrapper(_ProduceColumnsWrapperBase): producers = law.CSVParameter( default=(), description="names of producers to use; if empty, the default producer is used", + brace_expand=True, ) def __init__(self, *args, **kwargs): diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 54098b8ab..1d62a5d71 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -13,10 +13,7 @@ import luigi from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory -from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorMixin, ReducerMixin, ChunkedIOMixin, - # ParamsCacheMixin, -) +from columnflow.tasks.framework.mixins import CalibratorsMixin, SelectorMixin, ReducerMixin, ChunkedIOMixin from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.framework.decorators import on_failure from columnflow.tasks.external import GetDatasetLFNs @@ -32,7 +29,6 @@ class _ReduceEvents( - # ParamsCacheMixin, CalibratorsMixin, SelectorMixin, ReducerMixin, @@ -78,14 +74,13 @@ def workflow_requires(self): if calibrator_inst.produced_columns ] reqs["selection"] = self.reqs.SelectEvents.req(self) - # reducer dependent requirements - reqs["reducer"] = law.util.make_unique(law.util.flatten( - self.reducer_inst.run_requires(task=self), - )) else: # pass-through pilot workflow requirements of upstream task t = self.reqs.SelectEvents.req(self) - reqs = law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) + law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) + + # add reducer dependent requirements + reqs["reducer"] = law.util.make_unique(law.util.flatten(self.reducer_inst.run_requires(task=self))) return reqs @@ -199,6 +194,7 @@ def run(self): [inp.abspath for inp in inps], source_type=["coffea_root"] + (len(inps) - 1) * ["awkward_parquet"], read_columns=[read_columns, read_sel_columns] + (len(inps) - 2) * [read_columns], + chunk_size=self.reducer_inst.get_min_chunk_size(), ): # optional check for overlapping inputs within diffs if self.check_overlapping_inputs: @@ -216,12 +212,16 @@ def run(self): ) # invoke the reducer - if len(events): + if len(events) > 0: n_all += len(events) events = attach_coffea_behavior(events) events = self.reducer_inst(events, selection=sel, task=self) n_reduced += len(events) + # no need to proceed when no events are left (except for the last chunk to create empty output) + if len(events) == 0 and (output_chunks or pos.index < pos.n_chunks - 1): + continue + # remove columns events = route_filter(events) @@ -243,7 +243,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["events"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["events"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) @@ -284,16 +289,15 @@ class MergeReductionStats(_MergeReductionStats): n_inputs = luigi.IntParameter( default=10, significant=True, - description="minimal number of input files for sufficient statistics to infer merging " - "factors; default: 10", + description="minimal number of input files to infer merging factors with sufficient statistics; default: 10", ) merged_size = law.BytesParameter( default=law.NO_FLOAT, unit="MB", significant=False, - description="the maximum file size of merged files; default unit is MB; when 0, the " - "merging factor is not actually calculated from input files, but it is assumed to be 1 " - "(= no merging); default: config value 'reduced_file_size' or 512MB'", + description="the maximum file size of merged files; default unit is MB; when 0, the merging factor is not " + "actually calculated from input files, but it is assumed to be 1 (= no merging); default: config value " + "'reduced_file_size' or 512MB", ) # upstream requirements @@ -306,6 +310,12 @@ class MergeReductionStats(_MergeReductionStats): def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().resolve_param_values(params) + # cap n_inputs + if "n_inputs" in params and (dataset_info_inst := params.get("dataset_info_inst")): + n_files = dataset_info_inst.n_files + if params["n_inputs"] < 0 or params["n_inputs"] > n_files: + params["n_inputs"] = n_files + # check for the default merged size if "merged_size" in params: if params["merged_size"] in {None, law.NO_FLOAT}: @@ -411,7 +421,7 @@ def get_avg_std(values): self.publish_message(f" stats of {n} input files ".center(40, "-")) self.publish_message(f"average size: {law.util.human_bytes(stats['avg_size'], fmt=True)}") deviation = stats["std_size"] / stats["avg_size"] - self.publish_message(f"deviation : {deviation * 100:.2f} % (std / avg)") + self.publish_message(f"deviation : {deviation * 100:.2f}% (std/avg)") self.publish_message(" merging info ".center(40, "-")) self.publish_message(f"target size : {self.merged_size} MB") self.publish_message(f"merging : {stats['merge_factor']} into 1") @@ -457,6 +467,9 @@ class MergeReducedEvents(_MergeReducedEvents): ReduceEvents=ReduceEvents, ) + # number of events per row group in the merged file + merging_row_group_size = law.config.get_expanded_int("analysis", "merging_row_group_size", 50_000) + @law.workflow_property(setter=True, cache=True, empty_value=0) def file_merging(self): # check if the merging stats are present @@ -476,7 +489,10 @@ def create_branch_map(self): def workflow_requires(self): reqs = super().workflow_requires() reqs["stats"] = self.reqs.MergeReductionStats.req_different_branching(self) - reqs["events"] = self.reqs.ReduceEvents.req_different_branching(self, branches=((0, -1),)) + reqs["events"] = self.reqs.ReduceEvents.req_different_branching( + self, + branches=((0, self.dataset_info_inst.n_files),), + ) return reqs def requires(self): @@ -501,8 +517,14 @@ def run(self): inputs = [inp["events"] for inp in self.input()["events"].collection.targets.values()] output = self.output()["events"] + # merge law.pyarrow.merge_parquet_task( - self, inputs, output, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=inputs, + output=output, + callback=self.create_progress_callback(len(inputs)), + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) # optionally remove initial inputs diff --git a/columnflow/tasks/selection.py b/columnflow/tasks/selection.py index 9e8d5dab7..cc2b32fae 100644 --- a/columnflow/tasks/selection.py +++ b/columnflow/tasks/selection.py @@ -13,10 +13,7 @@ from columnflow.types import Any from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory -from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorMixin, ChunkedIOMixin, ProducerMixin, - # ParamsCacheMixin, -) +from columnflow.tasks.framework.mixins import CalibratorsMixin, SelectorMixin, ChunkedIOMixin, ProducerMixin from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.framework.decorators import on_failure from columnflow.tasks.external import GetDatasetLFNs @@ -25,7 +22,6 @@ from columnflow.tasks.framework.parameters import DerivableInstParameter from columnflow.production import Producer - np = maybe_import("numpy") ak = maybe_import("awkward") @@ -94,12 +90,10 @@ def workflow_requires(self): elif self.calibrator_insts: # pass-through pilot workflow requirements of upstream task t = self.reqs.CalibrateEvents.req(self) - reqs = law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) + law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) # add selector dependent requirements - reqs["selector"] = law.util.make_unique(law.util.flatten( - self.selector_inst.run_requires(task=self), - )) + reqs["selector"] = law.util.make_unique(law.util.flatten(self.selector_inst.run_requires(task=self))) return reqs @@ -190,6 +184,7 @@ def run(self): # get shift dependent aliases aliases = self.local_shift_inst.x("column_aliases", {}) + # define columns that need to be read read_columns = set(map(Route, mandatory_coffea_columns)) read_columns |= self.selector_inst.used_columns @@ -287,14 +282,24 @@ def run(self): sorted_chunks = [result_chunks[key] for key in sorted(result_chunks)] writer_opts_masks = self.get_parquet_writer_opts(repeating_values=True) law.pyarrow.merge_parquet_task( - self, sorted_chunks, outputs["results"], local=True, writer_opts=writer_opts_masks, + task=self, + inputs=sorted_chunks, + output=outputs["results"], + local=True, + writer_opts=writer_opts_masks, + target_row_group_size=self.merging_row_group_size, ) # merge the column files if write_columns: sorted_chunks = [column_chunks[key] for key in sorted(column_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, outputs["columns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=outputs["columns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) # save stats diff --git a/columnflow/tasks/union.py b/columnflow/tasks/union.py index 5b52d22b3..e73784175 100644 --- a/columnflow/tasks/union.py +++ b/columnflow/tasks/union.py @@ -4,6 +4,8 @@ Task to unite columns horizontally into a single file for further, possibly external processing. """ +from __future__ import annotations + import luigi import law @@ -13,7 +15,9 @@ from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.tasks.ml import MLEvaluation +from columnflow.columnar_util import Route from columnflow.util import dev_sandbox +from columnflow.types import Callable class _UniteColumns( @@ -33,6 +37,12 @@ class UniteColumns(_UniteColumns): sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + keep_columns_key = luigi.Parameter( + default=law.NO_STR, + description="if the 'keep_columns' auxiliary config entry for the task family 'cf.UniteColumns' is defined as " + "a dictionary, this key can selects which of the entries of columns to use; uses all columns when empty; " + "default: empty", + ) file_type = luigi.ChoiceParameter( default="parquet", choices=("parquet", "root"), @@ -47,6 +57,9 @@ class UniteColumns(_UniteColumns): MLEvaluation=MLEvaluation, ) + # a column that is evaluated to decide whether to keep or drop an event before writing + filter_events: str | Route | Callable | None = None + def workflow_requires(self): reqs = super().workflow_requires() @@ -97,7 +110,8 @@ def requires(self): @workflow_condition.output def output(self): - return {"events": self.target(f"data_{self.branch}.{self.file_type}")} + key_postfix = "" if self.keep_columns_key in {law.NO_STR, "", None} else f"_{self.keep_columns_key}" + return {"events": self.target(f"data_{self.branch}{key_postfix}.{self.file_type}")} @law.decorator.notify @law.decorator.log @@ -105,8 +119,7 @@ def output(self): @law.decorator.safe_output def run(self): from columnflow.columnar_util import ( - Route, RouteFilter, mandatory_coffea_columns, update_ak_array, sorted_ak_to_parquet, - sorted_ak_to_root, + RouteFilter, mandatory_coffea_columns, update_ak_array, sorted_ak_to_parquet, sorted_ak_to_root, ) # prepare inputs and outputs @@ -121,7 +134,18 @@ def run(self): # define columns that will be written write_columns: set[Route] = set() skip_columns: set[Route] = set() - for c in self.config_inst.x.keep_columns.get(self.task_family, ["*"]): + keep_struct = self.config_inst.x.keep_columns.get(self.task_family, ["*"]) + if isinstance(keep_struct, dict): + if self.keep_columns_key not in {law.NO_STR, "", None}: + if self.keep_columns_key not in keep_struct: + raise KeyError( + f"keep_columns_key '{self.keep_columns_key}' not found in keep_columns config entry for " + f"task family '{self.task_family}', existing keys: {list(keep_struct.keys())}", + ) + keep_struct = keep_struct[self.keep_columns_key] + else: + keep_struct = law.util.flatten(keep_struct.values()) + for c in law.util.make_unique(keep_struct): for r in self._expand_keep_column(c): if r.has_tag("skip"): skip_columns.add(r) @@ -152,6 +176,16 @@ def run(self): # add additional columns events = update_ak_array(events, *columns) + # optionally filter events + if self.filter_events: + if callable(self.filter_events): + filter_func = self.filter_events + else: + r = Route(self.filter_events) + filter_func = r.apply + mask = filter_func(events) + events = events[mask] + # remove columns events = route_filter(events) @@ -174,7 +208,7 @@ def run(self): self, sorted_chunks, output["events"], local=True, writer_opts=self.get_parquet_writer_opts(), ) else: # root - law.root.hadd_task(self, sorted_chunks, output["events"], local=True) + law.root.hadd_task(self, sorted_chunks, output["events"], local=True, hadd_args=["-O", "-f501"]) # overwrite class defaults diff --git a/columnflow/types.py b/columnflow/types.py index cfe437207..05d783daa 100644 --- a/columnflow/types.py +++ b/columnflow/types.py @@ -22,8 +22,8 @@ from collections.abc import KeysView, ValuesView # noqa from types import ModuleType, GeneratorType, GenericAlias # noqa from typing import ( # noqa - Any, Union, TypeVar, ClassVar, Sequence, Callable, Generator, TextIO, Iterable, Hashable, - Type, + TYPE_CHECKING, Any, Union, TypeVar, ClassVar, Sequence, Callable, Generator, TextIO, Iterable, Hashable, Type, + Literal, ) from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType, TypeAlias # noqa diff --git a/columnflow/util.py b/columnflow/util.py index 97d4c3b73..2db95eaf8 100644 --- a/columnflow/util.py +++ b/columnflow/util.py @@ -25,13 +25,12 @@ from functools import wraps from collections import OrderedDict from typing import Iterable -# from typing import Hashable, Callable import law import luigi from columnflow import env_is_dev, env_is_remote, docs_url, github_url -from columnflow.types import Callable, Any, Sequence, Union, ModuleType +from columnflow.types import Callable, Any, Sequence, Union, ModuleType, Type, T, Hashable #: Placeholder for an unset value. @@ -996,6 +995,31 @@ def derived_by(cls, other: DerivableMeta) -> bool: return isinstance(other, DerivableMeta) and issubclass(other, cls) +class CachedDerivableMeta(DerivableMeta): + + def __new__(metacls, cls_name: str, bases: tuple, cls_dict: dict) -> CachedDerivableMeta: + # add an instance cache if not disabled + cls_dict.setdefault("cache_instances", True) + cls_dict["_instances"] = {} if cls_dict["cache_instances"] else None + + return super().__new__(metacls, cls_name, bases, cls_dict) + + def __call__(cls: Type[T], *args, **kwargs) -> T: + # when not caching instances, return right away + if not cls.cache_instances: + return super().__call__(*args, **kwargs) + + # build the cache key from the inst_dict in kwargs + key = cls._get_inst_cache_key(args, kwargs) + if key not in cls._instances: + cls._instances[key] = super().__call__(*args, **kwargs) + + return cls._instances[key] + + def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: + raise NotImplementedError("__get_inst_cache_key method must be implemented by the derived meta class") + + class Derivable(object, metaclass=DerivableMeta): """ Derivable base class with features provided by the meta :py:class:`DerivableMeta`. diff --git a/docs/user_guide/best_practices.md b/docs/user_guide/best_practices.md index 4390319ab..150c09bc5 100644 --- a/docs/user_guide/best_practices.md +++ b/docs/user_guide/best_practices.md @@ -22,38 +22,41 @@ For convenience, if no file system with that name was defined, `LOCAL_FS_NAME` i - `wlcg, WLCG_FS_NAME` refers to a specific remote storage system named `WLCG_FS_NAME` that should be defined in the `law.cfg` file. `TASK_IDENTIFIER` identifies the task the location should apply to. -It can be a simple task family such as `cf.CalibrateEvents`, but for larger analyses a more fine grained selection is required. +It can be a simple task family such as `task_cf.CalibrateEvents` (see the format below), but for larger analyses a more fine grained selection is required. For this purpose, `TASK_IDENTIFIER` can be a `__`-separated sequence of so-called lookup keys, e.g. ```ini [outputs] -run3_23__cf.CalibrateEvents__nominal: wlcg, wlcg_fs_run3_23 +cfg_run3_23__task_cf.CalibrateEvents__shift_nominal: wlcg, wlcg_fs_run3_23 ``` Here, three keys are defined, making use of the config name, the task family, and the name of a systematic shift. The exact selection of possible keys and their resolution order is defined by the task itself in {py:meth}:`~columnflow.tasks.framework.base.AnalysisTask.get_config_lookup_keys` (and subclasses). Most tasks, however, define their lookup keys as: -1. analysis name -2. config name -3. task family -4. dataset name -5. shift name +1. analysis name, prefixed by `ana_` +2. config name, prefixed by `cfg_` +3. task family, prefixed by `task_` +4. dataset name, prefixed by `dataset_` +5. shift name, prefixed by `shift_` 6. calibrator name, prefixed by `calib_` 7. selector name, prefixed by `sel_` -8. producer name, prefixed by `prod_` +8. reducer name, prefixed by `red_` +9. producer name, prefixed by `prod_` +10. ml model name, prefixed by `ml_` +11. hist producer name, prefixed by `hist_` When defining `TASK_IDENTIFIER`'s, not all keys need to be specified, and patterns or regular expressions (`^EXPR$`) can be used. -The definition order is **important** as the first matching definition is used. +The definition order in the config file is **important** as the first matching definition is used. This way, output locations are highly customizable. ```ini [outputs] # store all run3 outputs on a specific fs, and all other outputs locally -run3_*__cf.CalibrateEvents: wlcg, wlcg_fs_run3 -cf.CalibrateEvents: local +cfg_run3_*__task_cf.CalibrateEvents: wlcg, wlcg_fs_run3 +task_cf.CalibrateEvents: local ``` ## Controlling versions of upstream tasks @@ -90,18 +93,18 @@ Consider the following two examples for defining versions, one via auxiliary con ```python cfg.x.versions = { - "run3_*": { - "cf.CalibrateEvents": "v2", + "cfg_run3_*": { + "task_cf.CalibrateEvents": "v2", }, - "cf.CalibrateEvents": "v1", + "task_cf.CalibrateEvents": "v1", } ``` ```ini [versions] -run3_*__cf.CalibrateEvents: v2 -cf.CalibrateEvents: v1 +cfg_run3_*__task_cf.CalibrateEvents: v2 +task_cf.CalibrateEvents: v1 ``` They are **equivalent** since the `__`-separated `TASK_IDENTIFIER`'s in the `law.cfg` are internallly converted to the same nested dictionary structure. diff --git a/docs/user_guide/examples.md b/docs/user_guide/examples.md index 7d9ea5022..09849eeab 100644 --- a/docs/user_guide/examples.md +++ b/docs/user_guide/examples.md @@ -267,11 +267,10 @@ lfn_sources: local_dcache # output locations per task family # for local targets : "local[, STORE_PATH]" # for remote targets: "wlcg[, WLCG_FS_NAME]" -cf.Task1: local -cf.Task2: local, /shared/path/to/store/output -cf.Task3: /shared/path/to/store/output +task_cf.Task1: local +task_cf.Task2: local, /shared/path/to/store/output +task_cf.Task3: /shared/path/to/store/output ... - ``` It is important to redirect the setup to the custom config file by setting the ```LAW_CONFIG_FILE``` environment variable in the `setup.sh` file to the path of the custom config file as follows: diff --git a/law.cfg b/law.cfg index 95010ce73..2ba8f2ba3 100644 --- a/law.cfg +++ b/law.cfg @@ -51,6 +51,9 @@ default_create_selection_hists: True # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False +# the name of a sandbox to use for tasks in remote jobs initially (invoked with claw when set) +default_remote_claw_sandbox: None + # some remote workflow parameter defaults # (resources like memory and disk can also be set in [resources] with more granularity) htcondor_flavor: $CF_HTCONDOR_FLAVOR @@ -65,6 +68,9 @@ chunked_io_chunk_size: 100000 chunked_io_pool_size: 2 chunked_io_debug: True +# settings for merging parquet files in several locations +merging_row_group_size: 50000 + # csv list of task families that inherit from ChunkedReaderMixin and whose output arrays should be # checked (raising an exception) for non-finite values before saving them to disk # supported tasks are: cf.CalibrateEvents, cf.SelectEvents, cf.ReduceEvents, cf.ProduceColumns, @@ -269,4 +275,4 @@ wait_interval: 20 check_unfulfilled_deps: False cache_task_completion: True keep_alive: $CF_WORKER_KEEP_ALIVE -force_multiprocessing: False +force_multiprocessing: $CF_REMOTE_ENV diff --git a/modules/law b/modules/law index a02aeb3c2..73ff4fd52 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit a02aeb3c2cf7cb460e52f67490f10c50055c6606 +Subproject commit 73ff4fd520ddecff5288ee3804aa4b4c8d929858 diff --git a/sandboxes/_setup_cmssw.sh b/sandboxes/_setup_cmssw.sh index 8e874b60c..f872a5e34 100644 --- a/sandboxes/_setup_cmssw.sh +++ b/sandboxes/_setup_cmssw.sh @@ -234,12 +234,14 @@ setup_cmssw() { if command -v cf_cmssw_custom_install &> /dev/null; then echo -e "\nrunning cf_cmssw_custom_install" cf_cmssw_custom_install && - cd "${install_path}/src" && + source "/cvmfs/cms.cern.ch/cmsset_default.sh" "" && + cd "${install_path}/src" && scram b elif [ ! -z "${cf_cmssw_custom_install}" ] && [ -f "${cf_cmssw_custom_install}" ]; then echo -e "\nsourcing cf_cmssw_custom_install file" source "${cf_cmssw_custom_install}" "" && - cd "${install_path}/src" && + source "/cvmfs/cms.cern.ch/cmsset_default.sh" "" && + cd "${install_path}/src" && scram b fi ) diff --git a/sandboxes/_setup_venv.sh b/sandboxes/_setup_venv.sh index 31e8e8fd4..fc45460e1 100644 --- a/sandboxes/_setup_venv.sh +++ b/sandboxes/_setup_venv.sh @@ -248,7 +248,9 @@ setup_venv() { # install if not existing if [ ! -f "${CF_SANDBOX_FLAG_FILE}" ]; then - cf_color cyan "installing venv ${CF_VENV_NAME} from ${sandbox_file} at ${install_path}" + echo -n "$( cf_color cyan "installing venv" )" + echo -n " $( cf_color cyan_bright "${CF_VENV_NAME}" )" + echo " $( cf_color cyan "from ${sandbox_file} at ${install_path}" )" rm -rf "${install_path}" cf_create_venv "${venv_name_hashed}" diff --git a/sandboxes/cf.txt b/sandboxes/cf.txt index 46861d8c1..7c910e0ef 100644 --- a/sandboxes/cf.txt +++ b/sandboxes/cf.txt @@ -1,8 +1,7 @@ -# version 14 +# version 15 luigi~=3.6.0 -scinum~=2.2.0 +scinum~=2.2.1 six~=1.17.0 pyyaml~=6.0.2 -typing_extensions~=4.13.0 -tabulate~=0.9.0 +typing_extensions~=4.12.2 diff --git a/sandboxes/columnar.txt b/sandboxes/columnar.txt index 8d6293ecc..36cbda25c 100644 --- a/sandboxes/columnar.txt +++ b/sandboxes/columnar.txt @@ -1,14 +1,13 @@ -# version 17 +# version 18 # exact versions for core array packages -awkward==2.8.1 -uproot==5.6.0 -pyarrow==19.0.1 -dask-awkward==2025.3.0 -correctionlib==2.6.4 -coffea==2024.11.0 +awkward==2.8.9 +uproot==5.6.6 +pyarrow==21.0.0 +correctionlib==2.7.0 +coffea==2025.9.0 # minimum versions for general packages -zstandard~=0.23.0 -lz4~=4.4.3 +zstandard~=0.25.0 +lz4~=4.4.4 xxhash~=3.5.0 diff --git a/sandboxes/dev.txt b/sandboxes/dev.txt index 73ab68008..16fd687de 100644 --- a/sandboxes/dev.txt +++ b/sandboxes/dev.txt @@ -1,12 +1,13 @@ -# version 11 +# version 12 # last version to support python 3.9 ipython~=8.18.1 -pytest~=8.3.5 -pytest-cov~=6.0.0 -flake8~=7.1.2 +pytest~=8.4.2 +pytest-cov~=7.0.0 +flake8~=7.3.0 flake8-commas~=4.0.0 flake8-quotes~=3.4.0 -pipdeptree~=2.26.0 -pymarkdownlnt~=0.9.29 -uniplot~=0.17.1 +pymarkdownlnt~=0.9.32 +uniplot~=0.21.4 +pipdeptree~=2.28.0 +mermaidmro~=0.2.1 diff --git a/sandboxes/ml_tf.txt b/sandboxes/ml_tf.txt index 382f89151..ce18e6828 100644 --- a/sandboxes/ml_tf.txt +++ b/sandboxes/ml_tf.txt @@ -1,4 +1,4 @@ -# version 11 +# version 12 # use packages from columnar sandbox as baseline -r columnar.txt diff --git a/setup.sh b/setup.sh index 4d30d2664..bdcf9337b 100644 --- a/setup.sh +++ b/setup.sh @@ -376,6 +376,7 @@ cf_setup_interactive_common_variables() { query CF_VENV_SETUP_MODE_UPDATE "Automatically update virtual envs if needed" "false" [ "${CF_VENV_SETUP_MODE_UPDATE}" != "true" ] && export_and_save CF_VENV_SETUP_MODE "update" unset CF_VENV_SETUP_MODE_UPDATE + query CF_INTERACTIVE_VENV_FILE "Custom venv setup fill to use for interactive work instead of 'cf_dev'" "" "''" query CF_LOCAL_SCHEDULER "Use a local scheduler for law tasks" "true" if [ "${CF_LOCAL_SCHEDULER}" != "true" ]; then @@ -530,8 +531,12 @@ cf_setup_software_stack() { # Optional environments variables: # CF_REMOTE_ENV # When true-ish, the software stack is sourced but not built. - # CF_CI_ENV - # When true-ish, the "cf" venv is skipped and only the "cf_dev" env is built. + # CF_LOCAL_ENV + # When not true-ish, the context is not meant for local development and only the "cf_dev" venv is built. + # CF_INTERACTIVE_VENV_FILE is ignored in this case. + # CF_INTERACTIVE_VENV_FILE + # IF CF_LOCAL_ENV is true-ish, the venv setup of this file is sourced to start the interactive shell. When + # empty, defaults to ${CF_BASE}/sandboxes/cf_dev.sh. # CF_REINSTALL_SOFTWARE # When true-ish, any existing software stack is removed and freshly installed. # CF_CONDA_ARCH @@ -620,8 +625,8 @@ cf_setup_software_stack() { # conda / micromamba setup # - # not needed in CI or RTD jobs - if ! ${CF_CI_ENV} && ! ${CF_RTD_ENV}; then + # only needed in local envs + if ${CF_LOCAL_ENV}; then # base environment local conda_missing="$( [ -d "${CF_CONDA_BASE}" ] && echo "false" || echo "true" )" if ${conda_missing}; then @@ -658,6 +663,7 @@ EOF echo cf_color cyan "setting up conda / micromamba environment" micromamba install \ + gcc \ libgcc \ bash \ zsh \ @@ -691,32 +697,61 @@ EOF # - "cf" : contains the minimal stack to run tasks and is sent alongside jobs # - "cf_dev" : "cf" + additional python tools for local development (e.g. ipython) + # - custom : when CF_INTERACTIVE_VENV_FILE is set, source the venv setup from there + + source_venv() { + # all parameters must be given + local venv_file="$1" + local venv_name="$2" + # must be true or false + local use_subshell="$3" + + # source the file and catch the return code + local ret="0" + if ${use_subshell}; then + ( source "${venv_file}" "" "silent" ) + ret="$?" + else + source "${venv_file}" "" "silent" + ret="$?" + fi - show_version_warning() { - >&2 echo - >&2 echo "WARNING: your venv '$1' is not up to date, please consider updating it in a new shell with" - >&2 echo "WARNING: > CF_REINSTALL_SOFTWARE=1 source setup.sh $( ${setup_is_default} || echo "${setup_name}" )" - >&2 echo - } - - # source the production sandbox, potentially skipped in CI and RTD jobs - if ! ${CF_CI_ENV} && ! ${CF_RTD_ENV}; then - ( source "${CF_BASE}/sandboxes/cf.sh" "" "silent" ) - ret="$?" + # code 21 means "version outdated", all others are as usual if [ "${ret}" = "21" ]; then - show_version_warning "cf" + >&2 echo + >&2 echo "WARNING: your venv '${venv_name}' is not up to date, please consider updating it in a new shell with" + >&2 echo "WARNING: > CF_REINSTALL_SOFTWARE=1 source setup.sh $( ${setup_is_default} || echo "${setup_name}" )" + >&2 echo elif [ "${ret}" != "0" ]; then return "${ret}" fi + + return "0" + } + + # build the production sandbox in a subshell, only in local envs + if ${CF_LOCAL_ENV}; then + source_venv "${CF_BASE}/sandboxes/cf.sh" "cf" true || return "$?" + fi + + # check if a custom interactive venv should be used, check the file, but only in local envs + local use_custom_interactive_venv="false" + if ${CF_LOCAL_ENV} && [ ! -z "${CF_INTERACTIVE_VENV_FILE}" ]; then + # check existence + if [ ! -f "${CF_INTERACTIVE_VENV_FILE}" ]; then + >&2 echo "the interactive venv setup file ${CF_INTERACTIVE_VENV_FILE} does not exist" + return "2" + fi + use_custom_interactive_venv="true" fi - # source the dev sandbox - source "${CF_BASE}/sandboxes/cf_dev.sh" "" "silent" - ret="$?" - if [ "${ret}" = "21" ]; then - show_version_warning "cf_dev" - elif [ "${ret}" != "0" ]; then - return "${ret}" + # build the dev sandbox, using a subshell if a custom venv is given that should be sourced afterwards + source_venv "${CF_BASE}/sandboxes/cf_dev.sh" "cf_dev" "${use_custom_interactive_venv}" || return "$?" + + # source the custom interactive venv setup file if given + if ${use_custom_interactive_venv}; then + echo "activating custom interactive venv from $( cf_color magenta "${CF_INTERACTIVE_VENV_FILE}" )" + source_venv "${CF_INTERACTIVE_VENV_FILE}" "$( basename "${CF_INTERACTIVE_VENV_FILE%.*}" )" false || return "$?" fi # initialize submodules @@ -757,13 +792,21 @@ cf_setup_post_install() { # Should be true or false, indicating if the setup is run in a local environment. # CF_REPO_BASE # The base directory of the analysis repository, which is used to determine the law home and config file. + # + # Optional environment variables: + # CF_SKIP_SETUP_GIT_HOOKS + # When set to true, the setup of git hooks is skipped. + # CF_SKIP_LAW_INDEX + # When set to true, the initial indexing of law tasks is skipped. + # CF_SKIP_CHECK_TMP_DIR + # When set to true, the check of the size of the target tmp directory is skipped. # # git hooks # # only in local env - if ${CF_LOCAL_ENV}; then + if ! ${CF_SKIP_SETUP_GIT_HOOKS} && ${CF_LOCAL_ENV}; then cf_setup_git_hooks || return "$?" fi @@ -783,7 +826,9 @@ cf_setup_post_install() { complete -o bashdefault -o default -F _law_complete claw # silently index - law index -q + if ! ${CF_SKIP_LAW_INDEX}; then + law index -q + fi fi fi @@ -791,7 +836,7 @@ cf_setup_post_install() { # check the tmp directory size # - if ${CF_LOCAL_ENV} && which law &> /dev/null; then + if ! ${CF_SKIP_CHECK_TMP_DIR} && ${CF_LOCAL_ENV} && which law &> /dev/null; then cf_check_tmp_dir fi @@ -817,12 +862,16 @@ cf_check_tmp_dir() { >&2 cf_color "red" "cf_check_tmp_dir: 'law config target.tmp_dir' must not be empty" return "2" elif [ ! -d "${tmp_dir}" ]; then - >&2 cf_color "red" "cf_check_tmp_dir: 'law config target.tmp_dir' is not a directory" - return "3" + # nothing to do + return "0" fi - # compute the size - local tmp_size="$( find "${tmp_dir}" -maxdepth 1 -name "*" -user "$( id -u )" -exec du -cb {} + | grep 'total$' | cut -d $'\t' -f 1 )" + # compute the size, with a notification shown if it takes too long + ( sleep 5 && cf_color yellow "computing the size of your files in ${tmp_dir} ..." ) & + local msg_pid="$!" + local tmp_size="$( find "${tmp_dir}" -maxdepth 1 -user "$( id -u )" -exec du -cb {} + | grep 'total$' | cut -d $'\t' -f 1 | sort | head -n 1 )" + kill "${msg_pid}" 2> /dev/null + wait "${msg_pid}" 2> /dev/null # warn above 1GB with color changing when above 2GB local thresh1="1073741824" @@ -1103,6 +1152,9 @@ for flag_name in \ CF_REINSTALL_SOFTWARE \ CF_REINSTALL_HOOKS \ CF_SKIP_BANNER \ + CF_SKIP_SETUP_GIT_HOOKS \ + CF_SKIP_LAW_INDEX \ + CF_SKIP_CHECK_TMP_DIR \ CF_ON_HTCONDOR \ CF_ON_SLURM \ CF_ON_GRID \ diff --git a/tests/test_columnar_util.py b/tests/test_columnar_util.py index d06ac9882..952e27c3a 100644 --- a/tests/test_columnar_util.py +++ b/tests/test_columnar_util.py @@ -16,6 +16,7 @@ ak = maybe_import("awkward") dak = maybe_import("dask_awkward") coffea = maybe_import("coffea") +maybe_import("coffea.nanoevents") class RouteTest(unittest.TestCase): diff --git a/tests/test_inference.py b/tests/test_inference.py index 5f1bde4b4..16e4161f9 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -106,7 +106,7 @@ def test_parameter_spec(self): # Test data name = "test_parameter" type = ParameterType.rate_gauss - transformations = [ParameterTransformation.centralize, ParameterTransformation.symmetrize] + transformations = [ParameterTransformation.symmetrize] config_name = "test_config" config_shift_source = "test_shift_source" effect = 1.5 @@ -177,34 +177,3 @@ def test_parameter_group_spec_with_no_parameter_names(self): ) self.assertDictEqual(result, expected_result) - - def test_require_shapes_for_parameter_shape(self): - # No shape is required if the parameter type is a rate - types = [ParameterType.rate_gauss, ParameterType.rate_uniform, ParameterType.rate_unconstrained] - for t in types: - with self.subTest(t=t): - param_obj = DotDict( - type=t, - transformations=ParameterTransformations([ParameterTransformation.effect_from_rate]), - name="test_param", - ) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertFalse(result) - - # if the transformation is shape-based expect True - param_obj.transformations = ParameterTransformations([ParameterTransformation.effect_from_shape]) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertTrue(result) - - # No shape is required if the transformation is from a rate - param_obj = DotDict( - type=ParameterType.shape, - transformations=ParameterTransformations([ParameterTransformation.effect_from_rate]), - name="test_param", - ) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertFalse(result) - - param_obj.transformations = ParameterTransformations([ParameterTransformation.effect_from_shape]) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertTrue(result)