From 1ebac8ced16e9bfe59aa79c750dc5c7cea593a5d Mon Sep 17 00:00:00 2001 From: JulesVandenbroeck <93740577+JulesVandenbroeck@users.noreply.github.com> Date: Mon, 28 Apr 2025 15:01:13 +0200 Subject: [PATCH 001/123] Fix to jer application on jec variations (#665) * initial commit, fix to jer application to jec variations * change smearing factor to a callable function and calculate smearing factor for each jec variation. * update jets and mets definitions to ensure deep copy of original event array is taken * add jec-specfic columns to uses * Vectorized jer application over jec variations (#92) * Simplify jer init. * Overhaul vectorized jer processing. * Minor sources fix in jec. * move jec_variations, jer_variations, and postfixes to jer_init. Also include jec_ prefix to jec_variations as jec_variations is only used for registering uses and produces and storing jer variations in a dictionary. * change jer_random_normal variable name to random_normal --------- Co-authored-by: juvanden Co-authored-by: Marcel Rieger --- columnflow/calibration/cms/jets.py | 256 +++++++++++++++++------------ columnflow/calibration/util.py | 18 ++ 2 files changed, 169 insertions(+), 105 deletions(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index bd910264b..e7eb6b330 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -10,7 +10,7 @@ from columnflow.types import Any from columnflow.calibration import Calibrator, calibrator -from columnflow.calibration.util import ak_random, propagate_met +from columnflow.calibration.util import ak_random, propagate_met, sum_transverse from columnflow.production.util import attach_coffea_behavior from columnflow.util import maybe_import, InsertableDict, DotDict from columnflow.columnar_util import set_ak_column, layout_ak_array, optional_column as optional @@ -134,7 +134,7 @@ def get_jerc_file_default(self: Calibrator, external_files: DotDict) -> str: :param external_files: Dictionary containing the information about the file location :return: path or url to correction file(s) - """ # noqa + """ # noqa # get config try_attrs = ("get_jec_config", "get_jer_config") @@ -320,7 +320,7 @@ def jec( values to the missing transverse energy (MET) using :py:func:`~columnflow.calibration.util.propagate_met` for events where ``met.eta > *min_eta_met_prop*``. - """ # noqa + """ # noqa # use local variable for convenience jet_name = self.jet_name @@ -481,7 +481,8 @@ def jec_init(self: Calibrator) -> None: sources = self.uncertainty_sources if sources is None: - sources = jec_cfg.uncertainty_sources + sources = jec_cfg.uncertainty_sources or [] + self.uncertainty_sources = sources # register used jet columns self.uses.add(f"{self.jet_name}.{{pt,eta,phi,mass,area,rawFactor}}") @@ -597,20 +598,15 @@ def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): for name in names ] - # take sources from constructor or config - sources = self.uncertainty_sources - if sources is None: - sources = jec_cfg.uncertainty_sources - jec_keys = make_jme_keys(jec_cfg.levels) jec_keys_subset_type1_met = make_jme_keys(jec_cfg.levels_for_type1_met) - junc_keys = make_jme_keys(sources, is_data=False) # uncertainties only stored as MC keys + junc_keys = make_jme_keys(self.uncertainty_sources, is_data=False) # uncertainties only stored as MC keys # store the evaluators self.evaluators = { "jec": get_evaluators(correction_set, jec_keys), "jec_subset_type1_met": get_evaluators(correction_set, jec_keys_subset_type1_met), - "junc": dict(zip(sources, get_evaluators(correction_set, junc_keys))), + "junc": dict(zip(self.uncertainty_sources, get_evaluators(correction_set, junc_keys))), } @@ -698,6 +694,12 @@ def get_jer_config_default(self: Calibrator) -> DotDict: get_jer_file=get_jerc_file_default, # function to determine the jer configuration dict get_jer_config=get_jer_config_default, + # function to determine the jec configuration dict + get_jec_config=get_jec_config_default, + # jec uncertainty sources to propagate jer to, defaults to config when empty + jec_uncertainty_sources=None, + # whether gen jet matching should be performed relative to the nominal jet pt, or the jec varied values + gen_jet_matching_nominal=False, ) def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ @@ -742,22 +744,41 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: *get_jer_config* can be adapted in a subclass in case it is stored differently in the config. + The nominal JER smearing is performed on nominal jets as well as those varied as a result of jet energy corrections. + For this purpose, *get_jec_config* and *jec_uncertainty_sources* can be defined to control the considered + variations. Consequently, the matching of jets to gen jets which depends on pt values of the former is subject to a + choice regarding which pt values to use. If *gen_jet_matching_nominal* is *True*, the nominal pt values are used, + and the jec varied pt values otherwise. + Throws an error if running on data. :param events: awkward array containing events to process - """ # noqa + """ # noqa # use local variables for convenience jet_name = self.jet_name gen_jet_name = self.gen_jet_name + met_name = self.met_name # fail when running on data if self.dataset_inst.is_data: raise ValueError("attempt to apply jet energy resolution smearing in data") + # prepare variations + jer_nom, jer_up, jer_down = self.jer_variations + # save the unsmeared properties in case they are needed later events = set_ak_column_f32(events, f"{jet_name}.pt_unsmeared", events[jet_name].pt) events = set_ak_column_f32(events, f"{jet_name}.mass_unsmeared", events[jet_name].mass) + # normally distributed random numbers per jet for use in stochastic smearing below + random_normal = ( + ak_random(0, 1, events[jet_name].deterministic_seed, rand_func=self.deterministic_normal) + if self.deterministic_seed_index >= 0 + else ak_random(0, 1, rand_func=np.random.Generator( + np.random.SFC64(events.event.to_list())).normal, + ) + ) + # obtain rho, which might be located at different routes, depending on the nano version rho = ( events.fixedGridRhoFastjetAll @@ -765,50 +786,60 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: events.Rho.fixedGridRhoFastjetAll ) - # variable naming convention + # prepare evaluator variables variable_map = { "JetEta": events[jet_name].eta, "JetPt": events[jet_name].pt, "Rho": rho, + "systematic": jer_nom, } - # pt resolution + # extract nominal pt resolution inputs = [variable_map[inp.name] for inp in self.evaluators["jer"].inputs] - jer = ak_evaluate(self.evaluators["jer"], *inputs) + jerpt = {jer_nom: ak_evaluate(self.evaluators["jer"], *inputs)} - # JER scale factors and systematic variations + # for simplifications below, use the same values for jer variations + jerpt[jer_up] = jerpt[jer_nom] + jerpt[jer_down] = jerpt[jer_nom] + + # extract pt resolutions evaluted for jec uncertainties + for jec_var in self.jec_variations: + _variable_map = variable_map | {"JetPt": events[jet_name][f"pt_{jec_var}"]} + inputs = [_variable_map[inp.name] for inp in self.evaluators["jer"].inputs] + jerpt[jec_var] = ak_evaluate(self.evaluators["jer"], *inputs) + + # extract scale factors jersf = {} - for syst in ("nom", "up", "down"): - variable_map_syst = dict(variable_map, systematic=syst) - inputs = [variable_map_syst[inp.name] for inp in self.evaluators["sf"].inputs] - jersf[syst] = ak_evaluate(self.evaluators["sf"], *inputs) + for jer_var in self.jer_variations: + _variable_map = variable_map | {"systematic": jer_var} + inputs = [_variable_map[inp.name] for inp in self.evaluators["sf"].inputs] + jersf[jer_var] = ak_evaluate(self.evaluators["sf"], *inputs) + + # extract scale factors for jec uncertainties + for jec_var in self.jec_variations: + _variable_map = variable_map | {"JetPt": events[jet_name][f"pt_{jec_var}"]} + inputs = [_variable_map[inp.name] for inp in self.evaluators["sf"].inputs] + jersf[jec_var] = ak_evaluate(self.evaluators["sf"], *inputs) # array with all JER scale factor variations as an additional axis # (note: axis needs to be regular for broadcasting to work correctly) + jerpt = ak.concatenate( + [jerpt[v][..., None] for v in self.jer_variations + self.jec_variations], + axis=-1, + ) jersf = ak.concatenate( - [jersf[syst][..., None] for syst in ("nom", "up", "down")], + [jersf[v][..., None] for v in self.jer_variations + self.jec_variations], axis=-1, ) # -- stochastic smearing - # normally distributed random numbers according to JER - jer_random_normal = ( - ak_random(0, jer, events[jet_name].deterministic_seed, rand_func=self.deterministic_normal) - if self.deterministic_seed_index >= 0 - else ak_random(0, jer, rand_func=np.random.Generator( - np.random.SFC64(events.event.to_list())).normal, - ) - ) # scale random numbers according to JER SF - jersf2_m1 = jersf ** 2 - 1 + jersf2_m1 = jersf**2 - 1 add_smear = np.sqrt(ak.where(jersf2_m1 < 0, 0, jersf2_m1)) - # broadcast over JER SF variations - jer_random_normal, jersf_z = ak.broadcast_arrays(jer_random_normal, add_smear) - # compute smearing factors (stochastic method) - smear_factors_stochastic = 1.0 + jer_random_normal * add_smear + smear_factors_stochastic = 1.0 + random_normal * jerpt * add_smear # -- scaling method (using gen match) @@ -824,124 +855,139 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: ) # gen jets that match the reconstructed jets - matched_gen_jets = padded_gen_jets[valid_gen_jet_idxs] + matched_gen_jet = padded_gen_jets[valid_gen_jet_idxs] # compute the relative (reco - gen) pt difference - pt_relative_diff = (events[jet_name].pt - matched_gen_jets.pt) / events[jet_name].pt + if self.gen_jet_matching_nominal: + # use nominal pt for matching + match_pt = events[jet_name].pt + else: + # concatenate varied pt values for broadcasting + pt_names = ["pt" for _ in self.jer_variations] + [f"pt_{jec_var}" for jec_var in self.jec_variations] + match_pt = ak.concatenate([events[jet_name][pt_name][..., None] for pt_name in pt_names], axis=-1) + pt_relative_diff = 1 - matched_gen_jet.pt / match_pt # test if matched gen jets are within 3 * resolution - is_matched_pt = np.abs(pt_relative_diff) < 3 * jer + # (no check for Delta-R matching criterion; we assume this was done during nanoAOD production to get the genJetIdx) + is_matched_pt = np.abs(pt_relative_diff) < 3 * jerpt is_matched_pt = ak.fill_none(is_matched_pt, False) # masked values = no gen match - # (no check for Delta-R matching criterion; we assume this was done during - # nanoAOD production to get the `genJetIdx`) - - # broadcast over JER SF variations - pt_relative_diff, jersf = ak.broadcast_arrays(pt_relative_diff, jersf) - # compute smearing factors (scaling method) smear_factors_scaling = 1.0 + (jersf - 1.0) * pt_relative_diff # -- hybrid smearing: take smear factors from scaling if there was a match, # otherwise take the stochastic ones - smear_factors = ak.where( - is_matched_pt[:, :, None], - smear_factors_scaling, - smear_factors_stochastic, - ) + smear_factors = ak.where(is_matched_pt, smear_factors_scaling, smear_factors_stochastic) # ensure array is not nullable (avoid ambiguity on Arrow/Parquet conversion) smear_factors = ak.fill_none(smear_factors, 0.0) - # store pt and phi of the full jet system + # to allow for code simplifications below, store the reference pt and mass columns upon which the smearing is based + # into the events array for cases where it shouldn't already exist + for direction in ["up", "down"]: + events = set_ak_column_f32(events, f"{jet_name}.pt_jer_{direction}", events[jet_name].pt) + events = set_ak_column_f32(events, f"{jet_name}.mass_jer_{direction}", events[jet_name].mass) + # when propagating met, do the same for respective columns + if self.propagate_met: + events = set_ak_column_f32(events, f"{met_name}.pt_jer_{direction}", events[met_name].pt) + events = set_ak_column_f32(events, f"{met_name}.phi_jer_{direction}", events[met_name].phi) + + # when propagating met, before smearing is applied, store pt and phi of the full jet system for all variations using + # string postfixes as keys if self.propagate_met: - jetsum = events[jet_name].sum(axis=1) - jetsum_pt_before = jetsum.pt - jetsum_phi_before = jetsum.phi - - # apply the smearing factors to the pt and mass - # (note: apply variations first since they refer to the original pt) - events = set_ak_column_f32(events, f"{jet_name}.pt_jer_up", events[jet_name].pt * smear_factors[:, :, 1]) - events = set_ak_column_f32(events, f"{jet_name}.mass_jer_up", events[jet_name].mass * smear_factors[:, :, 1]) - events = set_ak_column_f32(events, f"{jet_name}.pt_jer_down", events[jet_name].pt * smear_factors[:, :, 2]) - events = set_ak_column_f32(events, f"{jet_name}.mass_jer_down", events[jet_name].mass * smear_factors[:, :, 2]) - events = set_ak_column_f32(events, f"{jet_name}.pt", events[jet_name].pt * smear_factors[:, :, 0]) - events = set_ak_column_f32(events, f"{jet_name}.mass", events[jet_name].mass * smear_factors[:, :, 0]) + jetsum_pt_before = {} + jetsum_phi_before = {} + for postfix in self.postfixes: + jetsum_pt_before[postfix], jetsum_phi_before[postfix] = sum_transverse( + events[jet_name][f"pt{postfix}"], + events[jet_name].phi, + ) + + # apply the smearing + # (note: this requires that postfixes and smear_factors have the same order, but this should be the case) + for i, postfix in enumerate(self.postfixes): + pt_name = f"pt{postfix}" + m_name = f"mass{postfix}" + events = set_ak_column_f32(events, f"{jet_name}.{pt_name}", events[jet_name][pt_name] * smear_factors[..., i]) + events = set_ak_column_f32(events, f"{jet_name}.{m_name}", events[jet_name][m_name] * smear_factors[..., i]) # recover coffea behavior events = self[attach_coffea_behavior](events, collections=[jet_name], **kwargs) # met propagation if self.propagate_met: - # save unsmeared quantities - events = set_ak_column_f32(events, f"{self.met_name}.pt_unsmeared", events[self.met_name].pt) - events = set_ak_column_f32(events, f"{self.met_name}.phi_unsmeared", events[self.met_name].phi) - - # get pt and phi of all jets after correcting - jetsum = events[jet_name].sum(axis=1) - jetsum_pt_after = jetsum.pt - jetsum_phi_after = jetsum.phi - - # propagate changes to MET - met_pt, met_phi = propagate_met( - jetsum_pt_before, - jetsum_phi_before, - jetsum_pt_after, - jetsum_phi_after, - events[self.met_name].pt, - events[self.met_name].phi, - ) - events = set_ak_column_f32(events, f"{self.met_name}.pt", met_pt) - events = set_ak_column_f32(events, f"{self.met_name}.phi", met_phi) + events = set_ak_column_f32(events, f"{met_name}.pt_unsmeared", events[met_name].pt) + events = set_ak_column_f32(events, f"{met_name}.phi_unsmeared", events[met_name].phi) + + # propagate per variation + for postfix in self.postfixes: + # get pt and phi of all jets after correcting + jetsum_pt_after, jetsum_phi_after = sum_transverse( + events[jet_name][f"pt{postfix}"], + events[jet_name].phi, + ) - # syst variations on top of corrected MET - met_pt_up, met_phi_up = propagate_met( - jetsum_pt_after, - jetsum_phi_after, - events[jet_name].pt_jer_up, - events[jet_name].phi, - met_pt, - met_phi, - ) - met_pt_down, met_phi_down = propagate_met( - jetsum_pt_after, - jetsum_phi_after, - events[jet_name].pt_jer_down, - events[jet_name].phi, - met_pt, - met_phi, - ) - events = set_ak_column_f32(events, f"{self.met_name}.pt_jer_up", met_pt_up) - events = set_ak_column_f32(events, f"{self.met_name}.pt_jer_down", met_pt_down) - events = set_ak_column_f32(events, f"{self.met_name}.phi_jer_up", met_phi_up) - events = set_ak_column_f32(events, f"{self.met_name}.phi_jer_down", met_phi_down) + # propagate changes to MET + met_pt, met_phi = propagate_met( + jetsum_pt_before[postfix], + jetsum_phi_before[postfix], + jetsum_pt_after, + jetsum_phi_after, + events[met_name][f"pt{postfix}"], + events[met_name][f"phi{postfix}"], + ) + events = set_ak_column_f32(events, f"{met_name}.pt{postfix}", met_pt) + events = set_ak_column_f32(events, f"{met_name}.phi{postfix}", met_phi) return events @jer.init def jer_init(self: Calibrator) -> None: + # add jec_cfg for applying nominal smearing to jec variations + jec_cfg = self.get_jec_config() + jec_sources = self.jec_uncertainty_sources + if jec_sources is None: + jec_sources = jec_cfg.uncertainty_sources or [] + self.jec_uncertainty_sources = jec_sources + + # prepare jec variations + self.jec_variations = sum(([f"jec_{unc}_up", f"jec_{unc}_down"] for unc in self.jec_uncertainty_sources), []) + + jet_jec_columns = {f"{self.jet_name}.{{pt,mass}}_{jec_source}" for jec_source in self.jec_variations} + met_jec_columns = {f"{self.met_name}.{{pt,phi}}_{jec_source}" for jec_source in self.jec_variations} + # determine gen-level jet index column lower_first = lambda s: s[0].lower() + s[1:] if s else s self.gen_jet_idx_column = lower_first(self.gen_jet_name) + "Idx" + # prepare jer variations and postfixes + self.jer_variations = ["nom", "up", "down"] + self.postfixes = ["", "_jer_up", "_jer_down"] + [f"_{jec_var}" for jec_var in self.jec_variations] + # register used jet columns self.uses.add(f"{self.jet_name}.{{pt,eta,phi,mass,{self.gen_jet_idx_column}}}") - - # register used gen jet columns self.uses.add(f"{self.gen_jet_name}.{{pt,eta,phi}}") + if jec_sources: + self.uses |= jet_jec_columns # register produced jet columns self.produces.add(f"{self.jet_name}.{{pt,mass}}{{,_unsmeared,_jer_up,_jer_down}}") + if jec_sources: + self.produces |= jet_jec_columns - # register produced MET columns + # additional columns when propagating MET if self.propagate_met: # register used MET columns self.uses.add(f"{self.met_name}.{{pt,phi}}") + if jec_sources: + self.uses |= met_jec_columns # register produced MET columns self.produces.add(f"{self.met_name}.{{pt,phi}}{{,_jer_up,_jer_down,_unsmeared}}") + if jec_sources: + self.produces |= met_jec_columns @jer.requires diff --git a/columnflow/calibration/util.py b/columnflow/calibration/util.py index ac20de9bb..b95e0b20a 100644 --- a/columnflow/calibration/util.py +++ b/columnflow/calibration/util.py @@ -39,6 +39,24 @@ def ak_random(*args, rand_func: Callable) -> ak.Array: return ak.from_numpy(np_randvals) +def sum_transverse(pt: ak.Array, phi: ak.Array) -> tuple[ak.Array, ak.Array]: + """ + Helper function to compute the sum of transverse vectors given their pt and phi values. + + :param pt: Transverse momentum of the vector(s). + :param phi: Azimuthal angle of the vector(s). + :return: Tuple containing the transverse momentum and azimuthal angle of the sum of the vectors. + """ + px_sum = ak.sum(pt * np.cos(phi), axis=-1) + py_sum = ak.sum(pt * np.sin(phi), axis=-1) + + # compute new components + pt_sum = (px_sum**2.0 + py_sum**2.0)**0.5 + phi_sum = np.arctan2(py_sum, px_sum) + + return pt_sum, phi_sum + + def propagate_met( jet_pt1: (ak.Array), jet_phi1: ak.Array, From ab1372be66040767d14426fe5ecce192e8b90f9e Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Tue, 27 May 2025 12:46:30 +0200 Subject: [PATCH 002/123] Refactoring for 0.3 release (#628) * Fix task inehritance, adjust store parts. * Typo. * Revert stray changes. * Add store_part_anchor. * Re-purpose store part anchor for config part only. * Define config_store_anchor on ConfigTask for subclasses. * Fix inheritance order in datacard task. * TAF init refactoring draft. * Adapt template analysis. * Add comment. * Add review comments by @mafrahm. * Start. * Minor cleanup. * Port ConfigTask and ShiftTask. * Propagate ConfigTask changes to mixins and other tasks. * Update inference interface and tasks for multi config inputs. * Update hist hook handling. * Fix hist hook lookup. * Typo. * Update docstring. * Add union and intersection modes to default resolution. * Overhaul find_config_objects. * Update columnflow/tasks/framework/base.py Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> * Port config lookup to variables, review comment. * Typo. * Use dicts. * Update readme, fix typo in config task. * Update DatasetsProcessesMixin and ShiftSourcesMixin. * Improve loop over configs in shift validation. * Merge MultiConfigPlotting into refactor/taf_init (#630) * implement MultiConfigTask * disable TaskArrayFunction init when there is not config inst * fix MLEvaluation reqs * update template and CSPM parameter description * fix tests * add warning when cspm defaults are set in config_inst * tmp * fixes of PlotShiftMixin task * fixes in PlotShiftMixin * modify PlotVariables1d run method for multi config * reintroduce MLModelsMixin to plotting * add analysis_inst instead of config_inst in preparation_producer * resolve processes and variables per config * add PlotVariablesPerConfig wrapper tasks * split between DatasetsProcessesMixin and MultiConfigDatasetsProcessesMixin + cleanup * make ShiftSourcesMixin work with multiple configs * add PlotShiftedVariablesPerConfig1D mixin * fix bug when dataset is missing in first config * fix mixins from CreateDatacards * add default config to MultiConfigTask * move defaults to analysis_inst in AnalysisTask * move process and variable settings back to config inst * fixes to the two previous commits... * move multi config resolving function to AnalysisTask * review comments * set weight_producer in analysis inst in template * fix lint in template * decouple ShiftTask and ConfigTask and remove PlotShiftMixin * remove checking shifts for all reqs * move default categories, variables, and inference model to analysis inst * cleanup in VariablesMixin * remove config_inst from get_known_shifts signature * resolve shifts per config * store shift and category names instead of ids in histograms * fix hist tests * fix shifted variable plot func * handle missing shift bins in plot_shifted_variables * fix missing shift bins in plotting task * fill category as int and transform to str later * add growth to translated axis * fix and extend hist_util tests * loop over variables when switching to strcat * same for cutflow (+lint) * cleanup * fix resolving of ml model insts * fix process order in plotting * fix HistogramsUser task class inheritance * allow resolving selector step defaults * move selector_steps default to analysis_inst * fix bug in obtaining unique category ids * fir MRO in CreateHistograms * Update columnflow/tasks/plotting.py Co-authored-by: Philip Keicher <26219567+pkausw@users.noreply.github.com> * feature/MultiConfigPlotting * cleanup and reintroduce ML mixins from MultiConfigPlotting * fix murmuf_envelope Producer * remove config_inst from get_known_shifts signature * fix WeightProducerClassMixin inheritance and add MultiConfigDatasetsProcessesShiftSourcesMixin * decouple ShiftTask from ConfigTask * bugfixes and linting * fix PlotShiftedVariables * centralize definition of CSPW representations * fixes in ML tasks * fix inconsistencies after merging * add tests for default and group resolving * remove single_config tag from VariablesMixin * move CSPM groups to analysis inst in template * fix category/variable resolving and add resolving tests * streamline tests * cleanup and fix param resolving * add tests for process resolving * extend resolving tests and fix dataset/process resolving * remove duplicate lines * include shift inst tests --------- Co-authored-by: localusers user Co-authored-by: Philip Keicher <26219567+pkausw@users.noreply.github.com> * Refactor/taf init simplified shift validation (#641) * revert changes to ShiftSourcesMixin * simplify shift resolving as much as possible * streamline resolve_shifts function * Refactor/taf init (reorganized resolution + fix ml pipeline) (#643) * first draft for reordered TAF initialization and param resolution * remove shift bins that were not requested from branch map * fix single config tasks (yields) and cleanup * reintroduce ML training pipeline * switch to DatasetsProcessesMixin * recreate dependencies in run_post_init * Cleanup * perform shift resolution only if not yet done * fix single config datasets/processes resolving * move DatasetsProcessesMixin to PlotProcessBase and fallback to nominal shift in reqs * move logger messages into debug mode * revert 5c515bb076db1d601a472b6c70b93fccaeab74df * fallback branch to -1 if not existent * move default CSPs back to cofig inst * fix param resolving in wrapper_factory * update resolution class and function names * fix bug (pass shift name instead of inst in reqs) * Apply suggestions from code review Co-authored-by: Marcel Rieger * add resolve_instances for InferenceModelMixin * minor refactoring --------- Co-authored-by: Marcel Rieger * Improve known_shifts caching between workflow and branches. * Fixes edge cases. * Fix default resolution. * Refactor/taf init fixes (#645) * add missing MLEvaluation reqs * add producer_inst to ProduceColumns.reqs in ML pipeline * load ML columns in histograms and union tasks * locate shift name instead of id in histograms * Typo. * Adjust inference model tests. (#646) * Fix TAF post init order (#647) * Correct taf post init order. * Fix selector steps default. * Fix typo. * Add reducer interface. (#648) * Add reducer interface. * Additional reducer fallback to cf_default. * Add hist prodcer interface. (#650) * Cleanup top pt weight producer. (#625) * Cleanup top pt weight producer. * Add TopPtWeightConfig. * Update columnflow/production/cms/top_pt_weight.py Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> --------- Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> * Documentation update for refactoring (#652) * Start docs update. * Update README. * Add TAF docs. * Finish TAF docs, start transition. * Finish tafs in transition guide. * Finish changed task names docs. * Add multi-config update instructions. * Finish transition guide for reducers. * Finish inference model transition docs. * Finish transition docs. * Lint. * Systematic shift plotting (#649) * Update shift plots. * Fix id/name handling. * Address review comments by @mafrahm. * Update variable names, add comments. * Update sandboxes. * Update law. * Code harmonization * Apply review comment. * Use process names in hist axes. (#657) * Use process names in hist axes. * Apply axes conversion to remaining spots. * Add configurable string representations. * Add missing docstring. * Optimize hist filling, code alignment. * Feature: Add mechanism to transform hist into version with equally spaced b ins (#627) * added mechanism to transform hist into version with equally spaced bins, also added keyword to rotate xticks label * linter * added forgotten keyword argument in the preration of the config * correct typo, add new arguments to kwargs and change default x_ticks * Refactor rebinning function. * Simplify axis settings. * Feedback process and variable updates to style config. * Move x axis transformations to 'apply_variable_settings'. --------- Co-authored-by: Nathan Prouvost Co-authored-by: Marcel Rieger Co-authored-by: Marcel R. * Add and use only_local_env decorator. * Make lumi in normalization weight producer configurable. * Minor fixes and consistency. * Fix config lookup for taf classes mixins. (#669) * CMS jet id producer (#661) * Add cms-related jet id producer. * Fix bit check. * Allow subpaths in external files. (#663) * Allow subpaths in external files. * Minor de-nesting. * Maintain subpaths type. * Eager taf teardown when call function fails. (#662) * Eager taf teardown when call function fails. * Gracefully trigger teardown via decorator. * minor fixes and streamlining (#671) * Unambiguous hashing. * fix plotting with single varied shift (#672) * remove flag from MergeHistograms * fix plotting single varied shift in PlotVariables1D * Update columnflow/tasks/histograms.py --------- Co-authored-by: Marcel Rieger * Update law. * adding dy weights producer (#622) * adding dy weights producer * redifining masks and adding dy_weights_init * adding dy_order input * adding order to DrellYanConfig * adding order to DrellYanConfig * adding check for dy order in cfg * add missing self.dy_unc_corrector * update dy weight producer * linting dy recoil producer * remove duplicate entry in dy recoil weights * fix logic in DY recoil vis dilepton selection * format with black * passed flake8 * linting * update recoil corrections by removing helper functions * fix linter * fix bug with import InsertableDict * Apply suggestions from code review Co-authored-by: Marcel Rieger * add suggestions from review to DY producer --------- Co-authored-by: Paul Philipp Gadow Co-authored-by: philippgadow Co-authored-by: Marcel Rieger * fix PlotCutflow task and requirements * Update columnflow/tasks/selection.py Co-authored-by: Marcel Rieger * Shift-conform dy outputs. * fix ml_model repr * Rename recoil_corrections to recoil_corrected_met. * Apply new recommendation for egamma calibration (#674) * added more kwargs for config, that are necessary to handle run2 and run3 recommendation at the same time. * added more variables for variable maps, switched application of smearing to a standardized version, that results in the same result but is more robust. * removed comments * removed version check * change rand_func to separate normal_up, down variant * rewrap docstring and point to EGammaPog recommendation and example file * switched to concrete arguments in config and feedforward this change * add example into docstring about how to use the calibrator in combination with the config * Apply suggestions from code review --------- Co-authored-by: Marcel Rieger * append scale label when not passing placeholder * implement own errorbar calculation (#675) * implement own errorbar calculation * make poisson error calculation independent of histogram shape * Apply suggestions from code review Co-authored-by: Marcel Rieger * change function name * Apply comments from review Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> --------- Co-authored-by: Marcel Rieger Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> * Fix flow handling for fake data in datacards. * Allow skipping histogram checks. * Fix used columns of btag weight producer. * minor plotting fixes * Fix norm_weight_producer_inst in MergeSelectionMasks. * Improve transition guide. * parton shower weights (#676) * init commit. To see commit history check scalefactor-development branch in GhentAnalysis fork * remove the cmsGhent folder and add parton_shower.py to production/cms * Update columnflow/production/cms/parton_shower.py Co-authored-by: Marcel Rieger * Update columnflow/production/cms/parton_shower.py Co-authored-by: Marcel Rieger * add parton_shower to columnflow-cms specific production modules --------- Co-authored-by: juvanden Co-authored-by: Marcel Rieger * Minor alignment. * Minor cleanup of electron code. * Fix typos in egamma calibrators. * Enable jet_id producer for data. * Hotfix ps weights when variations are missing. * Add cf_remove_tmp tool. * fix typo. * Fix shift selection for plotting. * Fixes for docs (pdf figures not displayed) (#679) * docs: evince-previewer -> evince evince-previewer is the print preview of evince * added filter to upload svg files to lfs * docs: converted all pdf plots to (additional) svg using `for f in *.pdf; do pdf2svg $f ${f%.pdf}.svg; done` uploaded to lfs * docs: using wildcard extensions for plot file names such that the html generation uses svg, while others (e.g. latex) may still use pdf Before, the image display in the browser was broken and only a link to the pdf file was shown (supposedly the alt text). * Rename histograming -> histogramming. (#680) * Add missing local_env check to BundleExternalFiles task. * allow diverging producers in MLEvaluation (#681) * hotfix: update producer_insts based on evaluation_producers * hotfix: update hists with remove_residual_axis function * allow passing mask to apply JER smearing only to a subset of jets * update faulty import in cms_minimal template * hotfix: allow running ml pipeline without preparation_producer * fix padding when ak.max returns None * update jer_horn_handling calibrator * cast undefined_category_ids to str before raising the error to avoid TypeError * Improve tmp file removal. * Update law. * Add preparation producer post init. * Avoid full config copy in plotting. * More verbose leaf category check exceptions. * set scale_factor to 1 instead of 0 (#685) * allow skipping preparation_producer in MLEvaluation (#686) Co-authored-by: Marcel Rieger * Save lepton pair pdg id in gen_dilepton producer. * Add structure for category groups. * Add warning. * Typo. * Add warn flag to CategoryGroup. --------- Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Co-authored-by: localusers user Co-authored-by: Philip Keicher <26219567+pkausw@users.noreply.github.com> Co-authored-by: Bogdan-Wiederspan <79155113+Bogdan-Wiederspan@users.noreply.github.com> Co-authored-by: Nathan Prouvost Co-authored-by: Ana Andrade <99343616+aalvesan@users.noreply.github.com> Co-authored-by: Paul Philipp Gadow Co-authored-by: philippgadow Co-authored-by: Mathis Frahm Co-authored-by: Nathan Prouvost <49162277+nprouvost@users.noreply.github.com> Co-authored-by: JulesVandenbroeck <93740577+JulesVandenbroeck@users.noreply.github.com> Co-authored-by: juvanden Co-authored-by: Johannes Lange Co-authored-by: Philip Daniel Keicher --- .flake8 | 4 +- .gitattributes | 1 + .markdownlint | 12 +- README.md | 43 +- .../__cf_module_name__/calibration/example.py | 12 +- .../config/analysis___cf_short_name_lc__.py | 15 +- .../{weight => histogramming}/__init__.py | 0 .../histogramming/example.py | 43 + .../__cf_module_name__/plotting/example.py | 4 +- .../__cf_module_name__/production/example.py | 26 +- .../__cf_module_name__/reduction/__init__.py | 1 + .../__cf_module_name__/reduction/example.py | 27 + .../__cf_module_name__/selection/example.py | 5 +- .../__cf_module_name__/weight/example.py | 43 - analysis_templates/cms_minimal/law.cfg | 9 +- bin/cf_inspect.py | 22 +- bin/cf_remove_tmp | 46 + bin/cf_sandbox | 9 +- columnflow/__init__.py | 36 +- columnflow/calibration/__init__.py | 54 +- columnflow/calibration/cms/egamma.py | 275 +- columnflow/calibration/cms/jets.py | 91 +- columnflow/calibration/cms/jets_coffea.py | 728 ---- columnflow/calibration/cms/met.py | 34 +- columnflow/calibration/cms/tau.py | 57 +- columnflow/columnar_util.py | 538 ++- columnflow/config_util.py | 211 +- columnflow/hist_util.py | 188 +- columnflow/histogramming/__init__.py | 261 ++ columnflow/histogramming/default.py | 153 + columnflow/inference/__init__.py | 581 ++- columnflow/inference/cms/datacard.py | 167 +- columnflow/ml/__init__.py | 53 +- columnflow/plotting/plot_all.py | 289 +- columnflow/plotting/plot_functions_1d.py | 173 +- columnflow/plotting/plot_functions_2d.py | 25 +- columnflow/plotting/plot_util.py | 502 ++- columnflow/production/__init__.py | 54 +- columnflow/production/categories.py | 18 +- columnflow/production/cms/btag.py | 57 +- columnflow/production/cms/dy.py | 446 +++ columnflow/production/cms/electron.py | 48 +- columnflow/production/cms/gen_top_decay.py | 8 +- columnflow/production/cms/jet.py | 187 +- columnflow/production/cms/muon.py | 29 +- columnflow/production/cms/parton_shower.py | 91 + columnflow/production/cms/pdf.py | 2 +- columnflow/production/cms/pileup.py | 37 +- columnflow/production/cms/scale.py | 30 +- columnflow/production/cms/seeds.py | 23 +- columnflow/production/cms/supercluster_eta.py | 34 - columnflow/production/cms/top_pt_weight.py | 106 +- columnflow/production/normalization.py | 114 +- columnflow/reduction/__init__.py | 105 + columnflow/reduction/default.py | 92 + columnflow/{selection => reduction}/util.py | 43 +- columnflow/selection/__init__.py | 54 +- columnflow/selection/cms/jets.py | 29 +- columnflow/selection/cms/json_filter.py | 22 +- columnflow/selection/cms/met_filters.py | 5 +- columnflow/selection/empty.py | 2 +- columnflow/selection/matching.py | 169 - columnflow/selection/stats.py | 23 +- columnflow/tasks/calibration.py | 41 +- columnflow/tasks/cms/external.py | 4 +- columnflow/tasks/cms/inference.py | 325 +- columnflow/tasks/cutflow.py | 180 +- columnflow/tasks/external.py | 168 +- columnflow/tasks/framework/base.py | 1462 +++++--- columnflow/tasks/framework/decorators.py | 104 +- columnflow/tasks/framework/histograms.py | 42 +- columnflow/tasks/framework/inference.py | 244 ++ columnflow/tasks/framework/mixins.py | 3127 +++++++++-------- columnflow/tasks/framework/parameters.py | 35 +- columnflow/tasks/framework/plotting.py | 41 +- columnflow/tasks/framework/remote.py | 25 + .../tasks/framework/remote_bootstrap.sh | 2 + columnflow/tasks/histograms.py | 262 +- columnflow/tasks/ml.py | 261 +- columnflow/tasks/plotting.py | 527 ++- columnflow/tasks/production.py | 46 +- columnflow/tasks/reduction.py | 219 +- columnflow/tasks/selection.py | 150 +- columnflow/tasks/union.py | 37 +- columnflow/tasks/yields.py | 39 +- columnflow/types.py | 2 +- columnflow/util.py | 70 +- columnflow/weight/__init__.py | 119 - columnflow/weight/all_weights.py | 82 - columnflow/weight/empty.py | 17 - docs/Makefile | 2 +- docs/api/calibration/cms/index.rst | 11 +- docs/api/calibration/cms/jets_coffea.rst | 44 - docs/api/histogramming/default.rst | 9 + docs/api/histogramming/index.rst | 14 + docs/api/index.rst | 7 +- docs/api/reduction/default.rst | 9 + docs/api/reduction/index.rst | 14 + docs/api/selection/index.rst | 4 +- docs/api/selection/matching.rst | 9 - docs/api/selection/util.rst | 9 - docs/api/types.rst | 3 +- docs/api/weights/all_weights.rst | 9 - docs/api/weights/empty.rst | 9 - docs/api/weights/index.rst | 14 - docs/conf.py | 16 +- ...lot__proc_st__cat_incl__var_cf_jet1_pt.svg | 3 + ...lot__proc_tt__cat_incl__var_cf_jet1_pt.svg | 3 + ...2_a2211e799f__cat_incl__var_cf_jet1_pt.svg | 3 + ...2_a2211e799f__cat_incl__var_cf_jet1_pt.svg | 3 + ...2_a2211e799f__cat_incl__var_cf_jet1_pt.svg | 3 + ..._analy__1__12a17bf79c__cutflow__cat_2j.svg | 3 + ...naly__1__12a17bf79c__cutflow__cat_incl.svg | 3 + ...11e799f__unc_mu__cat_incl__var_jet1_pt.svg | 3 + ...2211e799f__unc_mu__cat_incl__var_n_jet.svg | 3 + ..._a2211e799f__cat_incl__var_jet1_pt__c1.svg | 3 + ..._2_a2211e799f__cat_incl__var_n_jet__c1.svg | 3 + ...__proc_3_7727a49dc2__cat_2j__var_n_jet.svg | 3 + ...proc_3_7727a49dc2__cat_incl__var_n_jet.svg | 3 + ...proc_3_7727a49dc2__cat_incl__var_n_jet.svg | 3 + ..._a2211e799f__cat_incl__var_jet1_pt__c3.svg | 3 + ..._2_a2211e799f__cat_incl__var_n_jet__c3.svg | 3 + ..._a2211e799f__cat_incl__var_jet1_pt__c2.svg | 3 + ..._2_a2211e799f__cat_incl__var_n_jet__c2.svg | 3 + ...2211e799f__cat_incl__var_jet1_pt-n_jet.svg | 3 + ...2211e799f__cat_incl__var_n_jet-jet1_pt.svg | 3 + docs/requirements.txt | 18 +- docs/user_guide/02_03_transition.md | 291 ++ .../building_blocks/hist_producers.md | 5 + docs/user_guide/building_blocks/index.rst | 2 + docs/user_guide/building_blocks/reducers.md | 5 + docs/user_guide/debugging.md | 32 +- docs/user_guide/index.rst | 4 +- docs/user_guide/plotting.md | 44 +- docs/user_guide/structure.md | 8 +- docs/user_guide/task_array_functions.md | 230 ++ law.cfg | 19 +- modules/law | 2 +- sandboxes/_setup_cmssw.sh | 2 +- sandboxes/_setup_venv.sh | 2 +- sandboxes/cf.txt | 10 +- sandboxes/columnar.txt | 13 +- sandboxes/dev.txt | 12 +- sandboxes/ml_tf.txt | 4 +- setup.sh | 12 +- tests/run_tests | 9 + tests/test_base_tasks.py | 391 +++ tests/test_hist_util.py | 56 +- tests/test_inference.py | 214 +- 149 files changed, 9846 insertions(+), 5979 deletions(-) rename analysis_templates/cms_minimal/__cf_module_name__/{weight => histogramming}/__init__.py (100%) create mode 100644 analysis_templates/cms_minimal/__cf_module_name__/histogramming/example.py create mode 100644 analysis_templates/cms_minimal/__cf_module_name__/reduction/__init__.py create mode 100644 analysis_templates/cms_minimal/__cf_module_name__/reduction/example.py delete mode 100644 analysis_templates/cms_minimal/__cf_module_name__/weight/example.py create mode 100755 bin/cf_remove_tmp delete mode 100644 columnflow/calibration/cms/jets_coffea.py create mode 100644 columnflow/histogramming/__init__.py create mode 100644 columnflow/histogramming/default.py create mode 100644 columnflow/production/cms/dy.py create mode 100644 columnflow/production/cms/parton_shower.py delete mode 100644 columnflow/production/cms/supercluster_eta.py create mode 100644 columnflow/reduction/__init__.py create mode 100644 columnflow/reduction/default.py rename columnflow/{selection => reduction}/util.py (55%) delete mode 100644 columnflow/selection/matching.py create mode 100644 columnflow/tasks/framework/inference.py delete mode 100644 columnflow/weight/__init__.py delete mode 100644 columnflow/weight/all_weights.py delete mode 100644 columnflow/weight/empty.py delete mode 100644 docs/api/calibration/cms/jets_coffea.rst create mode 100644 docs/api/histogramming/default.rst create mode 100644 docs/api/histogramming/index.rst create mode 100644 docs/api/reduction/default.rst create mode 100644 docs/api/reduction/index.rst delete mode 100644 docs/api/selection/matching.rst delete mode 100644 docs/api/selection/util.rst delete mode 100644 docs/api/weights/all_weights.rst delete mode 100644 docs/api/weights/empty.rst delete mode 100644 docs/api/weights/index.rst create mode 100644 docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_st__cat_incl__var_cf_jet1_pt.svg create mode 100644 docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_tt__cat_incl__var_cf_jet1_pt.svg create mode 100644 docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step0_Initial__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg create mode 100644 docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step1_jet__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg create mode 100644 docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step2_muon__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg create mode 100644 docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_2j.svg create mode 100644 docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_incl.svg create mode 100644 docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_jet1_pt.svg create mode 100644 docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_n_jet.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c1.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c1.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_2j__var_n_jet.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__4601e8554b__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c3.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c3.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c2.svg create mode 100644 docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c2.svg create mode 100644 docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt-n_jet.svg create mode 100644 docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_n_jet-jet1_pt.svg create mode 100644 docs/user_guide/02_03_transition.md create mode 100644 docs/user_guide/building_blocks/hist_producers.md create mode 100644 docs/user_guide/building_blocks/reducers.md create mode 100644 docs/user_guide/task_array_functions.md create mode 100644 tests/test_base_tasks.py diff --git a/.flake8 b/.flake8 index f9ba7fdc2..4e6d17f02 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] -# line length of 100 is recommended, but set it to a forgiving value -max-line-length = 120 +# line length of 120 is recommended, but set it to a forgiving value +max-line-length = 121 # codes of errors to ignore ignore = E128, E306, E402, E722, E731, E741, W504, Q003 diff --git a/.gitattributes b/.gitattributes index 541e4940c..a9dceeaab 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,3 +4,4 @@ *.jpeg filter=lfs diff=lfs merge=lfs -text *.root filter=lfs diff=lfs merge=lfs -text *.ico filter=lfs diff=lfs merge=lfs -text +*.svg filter=lfs diff=lfs merge=lfs -text diff --git a/.markdownlint b/.markdownlint index 3ac4e55de..b87f1b383 100644 --- a/.markdownlint +++ b/.markdownlint @@ -1,10 +1,12 @@ plugins: - # disable max line length + md001: + enabled: False md013: + # disable max line length enabled: False - md033: - allowed_elements: "!--,![CDATA[,!DOCTYPE,table,h1,p,img" + md024: + siblings_only: True md026: punctuation: ".,;!。,;!" - md001: - enabled: False + md033: + allowed_elements: "!--,![CDATA[,!DOCTYPE,table,h1,a,p,img,div" diff --git a/README.md b/README.md index ffd97c418..a7921d8e8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

- - + light logo + dark logo

@@ -35,25 +35,39 @@ Original source hosted at [GitHub](https://github.com/columnflow/columnflow). -## Note on current development +## ❗️ Note on v0.2 → v0.3 transition -This project is currently in a beta phase. +The 0.3 release introduces many performance fixes and new features such as + +- a new interface for all *task array functions* (calibrators, selectors, producers, etc.), +- support for plotting data of multiple data taking campaigns at once, +- a simplified machine learning interface, and +- statistical inference models with support for merging data of different campaigns. + +However, some of these changes are potentially breaking existing code. +Checkout the [v0.2 → v0.3 transition guide](https://columnflow.readthedocs.io/en/latest/user_guide/02_03_transition.html) as well as the [release notes](https://github.com/columnflow/columnflow/releases/tag/v0.3.0) for a detailed overview of the changes and how to adapt your code. + +Version 0.2 continues to be available via the [`legacy/v0.2`](https://github.com/columnflow/columnflow/tree/legacy/v0.2) branch, with the latest release being [v0.2.5](https://github.com/columnflow/columnflow/releases/tag/v0.2.5). + +## 🚧 Note on current development + +This project is in an advanced beta phase. The project setup, suggested workflows, definitions of particular tasks, and the signatures of various helper classes and functions are mostly frozen but could still be subject to changes in the near future. -At this point (July 2024), various large-scale analyses based upon columnflow are being developed, and in the process, help test and verify various aspects of its core. -The first major release with a largely frozen API is expected in the fall of 2024. -However, if you would like to join early on, contribute or just give it a spin, feel free to get in touch! +Various large-scale analyses based upon columnflow have been performed, others are being developed, and in the process, help test and verify various aspects of the framework. -![Columnflow analytics](https://repobeats.axiom.co/api/embed/b6ebc5ba41019de55eb48e195eecb438890442c8.svg "Columnflow analytics") +
+ Columnflow analytics +
-## Quickstart +## ⏩ Quickstart To create an analysis using columnflow, it is recommended to start from a predefined template (located in [analysis_templates](https://github.com/columnflow/columnflow/tree/master/analysis_templates)). The following command (no previous git clone required) interactively asks for a handful of names and settings, and creates a minimal, yet fully functioning project structure for you! @@ -103,15 +117,18 @@ Setup successfull! The next steps are: For a better overview of the tasks that are triggered by the commands below, checkout the current (yet stylized) [task graph](https://github.com/columnflow/columnflow/wiki#default-task-graph). -## Projects using columnflow +## 💯 Projects using columnflow - [hh2bbtautau](https://github.com/uhh-cms/hh2bbtautau): HH → bb𝜏𝜏 analysis with CMS. - [hh2bbww](https://github.com/uhh-cms/hh2bbww): HH → bbWW analysis with CMS. - [topmass](https://github.com/uhh-cms/topmass): Top quark mass measurement with CMS. - [mttbar](https://github.com/uhh-cms/mttbar): Search for heavy resonances in ttbar events with CMS. -- [analysis playground](https://github.com/uhh-cms/analysis_playground): A testing playground for HEP analyses. +- [analysis playground](https://github.com/uhh-cms/AZH2inv): TODO +- [topsf](https://github.com/uhh-cms/topsf): Top tagging scale factor measurement. +- [hto4l](https://github.com/uhh-cms/hto4l): H → ZZ → 4l analysis with CMS. +- [DiJetJERC](https://github.com/uhh-cms/DiJetJERC): Di-jet analysis with CMS. -## Contributors +## 🙏 Contributors @@ -119,7 +136,7 @@ For a better overview of the tasks that are triggered by the commands below, che - + diff --git a/analysis_templates/cms_minimal/__cf_module_name__/calibration/example.py b/analysis_templates/cms_minimal/__cf_module_name__/calibration/example.py index 2daf545f8..3c2b451d2 100644 --- a/analysis_templates/cms_minimal/__cf_module_name__/calibration/example.py +++ b/analysis_templates/cms_minimal/__cf_module_name__/calibration/example.py @@ -14,16 +14,8 @@ @calibrator( - uses={ - deterministic_seeds, - "Jet.{pt,eta,phi,mass}", - }, - produces={ - deterministic_seeds, - "Jet.pt", "Jet.mass", - "Jet.pt_jec_up", "Jet.mass_jec_up", - "Jet.pt_jec_down", "Jet.mass_jec_down", - }, + uses={deterministic_seeds, "Jet.{pt,eta,phi,mass}"}, + produces={deterministic_seeds, "Jet.{pt,mass}{,_jec_up,_jec_down}"}, ) def example(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # a) "correct" Jet.pt by scaling four momenta by 1.1 (pt<30) or 0.9 (pt<=30) diff --git a/analysis_templates/cms_minimal/__cf_module_name__/config/analysis___cf_short_name_lc__.py b/analysis_templates/cms_minimal/__cf_module_name__/config/analysis___cf_short_name_lc__.py index 75ea9b6e9..5fbc11d9b 100644 --- a/analysis_templates/cms_minimal/__cf_module_name__/config/analysis___cf_short_name_lc__.py +++ b/analysis_templates/cms_minimal/__cf_module_name__/config/analysis___cf_short_name_lc__.py @@ -50,6 +50,9 @@ # named function hooks that can modify store_parts of task outputs if needed ana.x.store_parts_modifiers = {} +# histogramming hooks, invoked before creating plots when --hist-hook parameter set +ana.x.hist_hooks = {} + # # setup configs @@ -108,11 +111,13 @@ # verify that the root process of all datasets is part of any of the registered processes verify_config_processes(cfg, warn=True) -# default objects, such as calibrator, selector, producer, ml model, inference model, etc +# default objects, such as calibrator, selector, reducer, producer, ml model, inference model, etc cfg.x.default_calibrator = "example" cfg.x.default_selector = "example" +cfg.x.default_selector_steps = [] +cfg.x.default_reducer = "cf_default" cfg.x.default_producer = "example" -cfg.x.default_weight_producer = "example" +cfg.x.default_hist_producer = "cf_default" cfg.x.default_ml_model = None cfg.x.default_inference_model = "example" cfg.x.default_categories = ("incl",) @@ -165,15 +170,15 @@ # calibrator groups for conveniently looping over certain calibrators # (used during calibration) -cfg.x.calibrator_groups = {} +ana.x.calibrator_groups = {} # producer groups for conveniently looping over certain producers # (used during the ProduceColumns task) -cfg.x.producer_groups = {} +ana.x.producer_groups = {} # ml_model groups for conveniently looping over certain ml_models # (used during the machine learning tasks) -cfg.x.ml_model_groups = {} +ana.x.ml_model_groups = {} # custom method and sandbox for determining dataset lfns cfg.x.get_dataset_lfns = None diff --git a/analysis_templates/cms_minimal/__cf_module_name__/weight/__init__.py b/analysis_templates/cms_minimal/__cf_module_name__/histogramming/__init__.py similarity index 100% rename from analysis_templates/cms_minimal/__cf_module_name__/weight/__init__.py rename to analysis_templates/cms_minimal/__cf_module_name__/histogramming/__init__.py diff --git a/analysis_templates/cms_minimal/__cf_module_name__/histogramming/example.py b/analysis_templates/cms_minimal/__cf_module_name__/histogramming/example.py new file mode 100644 index 000000000..1190acb92 --- /dev/null +++ b/analysis_templates/cms_minimal/__cf_module_name__/histogramming/example.py @@ -0,0 +1,43 @@ +# coding: utf-8 + +""" +Example histogram producer. +""" + +from columnflow.histogramming import HistProducer +from columnflow.histogramming.default import cf_default +from columnflow.util import maybe_import +from columnflow.config_util import get_shifts_from_sources +from columnflow.columnar_util import Route + +ak = maybe_import("awkward") +np = maybe_import("numpy") + + +# extend columnflow's default hist producer +@cf_default.hist_producer() +def example(self: HistProducer, events: ak.Array, **kwargs) -> ak.Array: + # build the full event weight + weight = ak.Array(np.ones(len(events), dtype=np.float32)) + + if self.dataset_inst.is_mc and len(events): + for column in self.weight_columns: + weight = weight * Route(column).apply(events) + + return events, weight + + +@example.init +def example_init(self: HistProducer) -> None: + self.weight_columns = {} + + if self.dataset_inst.is_data: + return + + # store column names referring to weights to multiply + self.weight_columns |= {"normalization_weight", "muon_weight"} + self.uses |= self.weight_columns + + # declare shifts that the produced event weight depends on + shift_sources = {"mu"} + self.shifts |= set(get_shifts_from_sources(self.config_inst, *shift_sources)) diff --git a/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py b/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py index 6044a6e3c..da7e34817 100644 --- a/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py +++ b/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py @@ -47,10 +47,10 @@ def my_plot1d_func( """ # we can add arbitrary parameters via the `general_settings` parameter to access them in the # plotting function. They are automatically parsed either to a bool, float, or string - print(f"The example_param has been set to '{example_param}' (type: {type(example_param)})") + print(f"the example_param has been set to '{example_param}' (type: {type(example_param)})") # call helper function to remove shift axis from histogram - remove_residual_axis(hists, "shift") + hists = remove_residual_axis(hists, "shift") # call helper functions to apply the variable_settings and process_settings variable_inst = variable_insts[0] diff --git a/analysis_templates/cms_minimal/__cf_module_name__/production/example.py b/analysis_templates/cms_minimal/__cf_module_name__/production/example.py index 59b83b819..b00a9e26a 100644 --- a/analysis_templates/cms_minimal/__cf_module_name__/production/example.py +++ b/analysis_templates/cms_minimal/__cf_module_name__/production/example.py @@ -12,7 +12,7 @@ from columnflow.production.cms.seeds import deterministic_seeds from columnflow.production.cms.mc_weight import mc_weight from columnflow.production.cms.muon import muon_weights -from columnflow.selection.util import create_collections_from_masks +from columnflow.reduction.util import create_collections_from_masks from columnflow.util import maybe_import from columnflow.columnar_util import EMPTY_FLOAT, Route, set_ak_column from columnflow.production.util import attach_coffea_behavior @@ -26,14 +26,8 @@ @producer( - uses={ - # nano columns - attach_coffea_behavior, "Jet.{pt,eta,phi,mass}", - }, - produces={ - # new columns - attach_coffea_behavior, "ht", "n_jet", "dijet.{pt,mass,dr}", - }, + uses={"Jet.{pt,eta,phi,mass}"}, + produces={"ht", "n_jet", "dijet.{pt,mass,dr}"}, ) def jet_features(self: Producer, events: ak.Array, **kwargs) -> ak.Array: @@ -42,7 +36,7 @@ def jet_features(self: Producer, events: ak.Array, **kwargs) -> ak.Array: events = set_ak_column(events, "n_jet", ak.num(events.Jet.pt, axis=1), value_type=np.int32) # attach coffea behaviour - events = self[attach_coffea_behavior](events, collections={}, **kwargs) + events = attach_coffea_behavior(events, collections={}, **kwargs) # object padding (Note that after padding, ak.num(events.Jet.pt, axis=1) would always be >= 2) events = set_ak_column(events, "Jet", ak.pad_none(events.Jet, 2)) @@ -57,16 +51,8 @@ def jet_features(self: Producer, events: ak.Array, **kwargs) -> ak.Array: @producer( - uses={ - mc_weight, category_ids, - # nano columns - "Jet.pt", - }, - produces={ - mc_weight, category_ids, - # new columns - "cutflow.jet1_pt", - }, + uses={mc_weight, category_ids, "Jet.{pt,phi}"}, + produces={mc_weight, category_ids, "cutflow.jet1_pt"}, ) def cutflow_features( self: Producer, diff --git a/analysis_templates/cms_minimal/__cf_module_name__/reduction/__init__.py b/analysis_templates/cms_minimal/__cf_module_name__/reduction/__init__.py new file mode 100644 index 000000000..57d631c3f --- /dev/null +++ b/analysis_templates/cms_minimal/__cf_module_name__/reduction/__init__.py @@ -0,0 +1 @@ +# coding: utf-8 diff --git a/analysis_templates/cms_minimal/__cf_module_name__/reduction/example.py b/analysis_templates/cms_minimal/__cf_module_name__/reduction/example.py new file mode 100644 index 000000000..a09906527 --- /dev/null +++ b/analysis_templates/cms_minimal/__cf_module_name__/reduction/example.py @@ -0,0 +1,27 @@ +# coding: utf-8 + +""" +Exemplary reduction methods that can run on-top of columnflow's default reduction. +""" + +from columnflow.reduction import Reducer, reducer +from columnflow.reduction.default import cf_default +from columnflow.util import maybe_import +from columnflow.columnar_util import set_ak_column + +ak = maybe_import("awkward") + + +@reducer( + uses={cf_default, "Jet.hadronFlavour"}, + produces={cf_default, "Jet.from_b_hadron"}, +) +def example(self: Reducer, events: ak.Array, selection: ak.Array, **kwargs) -> ak.Array: + # run cf's default reduction which handles event selection and collection creation + events = self[cf_default](events, selection, **kwargs) + + # compute and store additional columns after the default reduction + # (so only on a subset of the events and objects which might be computationally lighter) + events = set_ak_column(events, "Jet.from_b_hadron", abs(events.Jet.hadronFlavour) == 5, value_type=bool) + + return events diff --git a/analysis_templates/cms_minimal/__cf_module_name__/selection/example.py b/analysis_templates/cms_minimal/__cf_module_name__/selection/example.py index dd427317c..736fdefd8 100644 --- a/analysis_templates/cms_minimal/__cf_module_name__/selection/example.py +++ b/analysis_templates/cms_minimal/__cf_module_name__/selection/example.py @@ -24,7 +24,6 @@ # (not selectable from the command line but used by other, exposed selectors) # - @selector( uses={"Muon.{pt,eta,phi,mass}"}, ) @@ -91,6 +90,7 @@ def jet_selection_init(self: Selector) -> None: if shift_inst.has_tag(("jec", "jer")) } + # # exposed selectors # (those that can be invoked from the command line) @@ -100,8 +100,7 @@ def jet_selection_init(self: Selector) -> None: @selector( uses={ # selectors / producers called within _this_ selector - mc_weight, cutflow_features, process_ids, muon_selection, jet_selection, - increment_stats, + mc_weight, cutflow_features, process_ids, muon_selection, jet_selection, increment_stats, }, produces={ # selectors / producers whose newly created columns should be kept diff --git a/analysis_templates/cms_minimal/__cf_module_name__/weight/example.py b/analysis_templates/cms_minimal/__cf_module_name__/weight/example.py deleted file mode 100644 index be7158119..000000000 --- a/analysis_templates/cms_minimal/__cf_module_name__/weight/example.py +++ /dev/null @@ -1,43 +0,0 @@ -# coding: utf-8 - -""" -Example event weight producer. -""" - -from columnflow.weight import WeightProducer, weight_producer -from columnflow.util import maybe_import -from columnflow.config_util import get_shifts_from_sources -from columnflow.columnar_util import Route - -ak = maybe_import("awkward") -np = maybe_import("numpy") - - -@weight_producer( - # both used columns and dependent shifts are defined in init below - # only run on mc - mc_only=True, -) -def example(self: WeightProducer, events: ak.Array, **kwargs) -> ak.Array: - # build the full event weight - weight = ak.Array(np.ones(len(events), dtype=np.float32)) - for column in self.weight_columns: - weight = weight * Route(column).apply(events) - - return events, weight - - -@example.init -def example_init(self: WeightProducer) -> None: - # store column names referring to weights to multiply - self.weight_columns = { - "normalization_weight", - "muon_weight", - } - self.uses |= self.weight_columns - - # declare shifts that the produced event weight depends on - shift_sources = { - "mu", - } - self.shifts |= set(get_shifts_from_sources(self.config_inst, *shift_sources)) diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index 55eef2720..8ce50143d 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -28,10 +28,11 @@ default_config: run2_2017_nano_v9 default_dataset: st_tchannel_t_4f_powheg calibration_modules: columnflow.calibration.cms.{jets,met,tau}, __cf_module_name__.calibration.example -selection_modules: columnflow.selection.{empty}, columnflow.selection.cms.{json_filter,met_filters}, __cf_module_name__.selection.example -production_modules: columnflow.production.{categories,matching,normalization,processes}, columnflow.production.cms.{btag,electron,jet,matching,mc_weight,muon,pdf,pileup,scale,seeds}, __cf_module_name__.production.example +selection_modules: columnflow.selection.empty, columnflow.selection.cms.{json_filter,met_filters}, __cf_module_name__.selection.example +reduction_modules: columnflow.reduction.default, __cf_module_name__.reduction.example +production_modules: columnflow.production.{categories,matching,normalization,processes}, columnflow.production.cms.{btag,electron,jet,matching,mc_weight,muon,pdf,pileup,scale,parton_shower,seeds}, __cf_module_name__.production.example categorization_modules: __cf_module_name__.categorization.example -weight_production_modules: columnflow.weight.{empty,all_weights}, __cf_module_name__.weight.example +hist_production_modules: columnflow.histogramming.default, __cf_module_name__.histogramming.example ml_modules: columnflow.ml, __cf_module_name__.ml.example inference_modules: columnflow.inference, __cf_module_name__.inference.example @@ -71,7 +72,7 @@ chunked_io_debug: False # csv list of task families that inherit from ChunkedReaderMixin and whose output arrays should be # checked (raising an exception) for non-finite values before saving them to disk -check_finite_output: cf.CalibrateEvents, cf.SelectEvents, cf.ProduceColumns +check_finite_output: cf.CalibrateEvents, cf.SelectEvents, cf.ReduceEvents, cf.ProduceColumns # how to treat inexistent selector steps passed to cf.CreateCutflowHistograms: throw an error, # silently skip them, or add a dummy step to the output (allowed values: raise, ignore, dummy) diff --git a/bin/cf_inspect.py b/bin/cf_inspect.py index df13ccfaf..fb080b8e9 100644 --- a/bin/cf_inspect.py +++ b/bin/cf_inspect.py @@ -74,6 +74,7 @@ def list_content(data: Any) -> None: import argparse ap = argparse.ArgumentParser( + add_help=False, description=( "Utility script for quickly loading event arrays, histograms or other supported " "objects from files for interactive processing.\n\n" @@ -85,9 +86,12 @@ def list_content(data: Any) -> None: ), ) + ap.register("action", "help", argparse._HelpAction) ap.add_argument("files", metavar="FILE", nargs="+", help="one or more supported files") ap.add_argument("--events", "-e", action="store_true", help="assume files to contain event info") + ap.add_argument("--hists", "-h", action="store_true", help="assume files to contain histograms") ap.add_argument("--list", "-l", action="store_true", help="list contents of the loaded file") + ap.add_argument("--help", action="help", help="show this help message and exit") args = ap.parse_args() @@ -99,14 +103,24 @@ def list_content(data: Any) -> None: # interpret data intepreted = objects if args.events: - events = objects - interpreted = events - print("events loaded from objects[0] into variable 'events'") - # preload common packages import awkward as ak # noqa import numpy as np # noqa + events = interpreted = objects + print("events loaded from objects[0] into variable 'events'") + + elif args.hists: + # preload common packages + import hist # noqa + + if isinstance(objects, hist.Hist): + h = interpreted = objects + print("histogram loaded from objects[0] into variable 'h'") + else: + hists = interpreted = objects + print("histograms loaded from objects[0] into variable 'hists'") + # list content if args.list: list_content(interpreted) diff --git a/bin/cf_remove_tmp b/bin/cf_remove_tmp new file mode 100755 index 000000000..0d9ba39f3 --- /dev/null +++ b/bin/cf_remove_tmp @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +cf_remove_tmp() { + # Removes all files owned by the user in the directory referred to by `law config target.tmp_dir` (usually identical + # to $LAW_TARGET_TMP_DIR). + # + # Arguments: + # 1. mode: optional, when "all" files are removed rather than just files starting with luigi-tmp-*. + + # zsh options + local shell_is_zsh="$( [ -z "${ZSH_VERSION}" ] && echo "false" || echo "true" )" + if ${shell_is_zsh}; then + emulate -L bash + setopt globdots + fi + + # get the mode + local mode="$1" + if [ ! -z "${mode}" ]; then + if [ "${mode}" != "all" ]; then + >&2 echo "invalid mode '${mode}', use 'all' or leave empty" + return "1" + fi + fi + + # get the directory + local tmp_dir="$( law config target.tmp_dir )" + local ret="$?" + if [ "${ret}" != "0" ]; then + >&2 echo "'law config target.tmp_dir' failed with error code ${ret}" + return "${ret}" + elif [ -z "${tmp_dir}" ]; then + >&2 echo "'law config target.tmp_dir' must not be empty" + return "2" + elif [ ! -d "${tmp_dir}" ]; then + >&2 echo "'law config target.tmp_dir' is not a directory" + return "3" + fi + + # remove all files and directories in tmp_dir owned by the user + local pattern="luigi-tmp-*" + [ "${mode}" = "all" ] && pattern="*" + find "${tmp_dir}" -maxdepth 1 -name "${pattern}" -user "$( id -u )" -exec rm -r "{}" \; +} + +cf_remove_tmp "$@" diff --git a/bin/cf_sandbox b/bin/cf_sandbox index 00a0682f5..10e3c8c2f 100755 --- a/bin/cf_sandbox +++ b/bin/cf_sandbox @@ -74,12 +74,19 @@ cf_sandbox() { local interactive="false" [ -z "${cmd}" ] && cmd="bash -l" && interactive="true" + # escape some characters + cmd="${cmd//\{/\\\{}" + cmd="${cmd//\}/\\\}}" + + # create an unespaced representation + local cmd_repr="${cmd//\\/}" + # load tools CF_SKIP_SETUP="1" source "${CF_BASE}/setup.sh" "" || return "$?" # run it echo "$( cf_color green "sandbox" ): ${sandbox_file_repr}" - echo "$( cf_color green "command" ): ${cmd}" + echo "$( cf_color green "command" ): ${cmd_repr}" cf_color magenta "--- entering sandbox -----------------------------" ${interactive} && echo "(ctrl+d or type 'exit' to leave)" diff --git a/columnflow/__init__.py b/columnflow/__init__.py index e1369bb60..dda0895c7 100644 --- a/columnflow/__init__.py +++ b/columnflow/__init__.py @@ -23,6 +23,12 @@ m = re.match(r"^(\d+)\.(\d+)\.(\d+)(-.+)?$", __version__) version = tuple(map(int, m.groups()[:3])) + (m.group(4),) +#: Location of the documentation. +docs_url = os.getenv("CF_DOCS_URL", "https://columnflow.readthedocs.io/en/latest") + +#: Location of the repository on github. +github_url = os.getenv("CF_GITHUB_URL", "https://github.com/columnflow/columnflow") + #: Boolean denoting whether the environment is in a local environment (based on ``CF_LOCAL_ENV``). env_is_local = law.util.flag_to_bool(os.getenv("CF_LOCAL_ENV", "false")) @@ -76,18 +82,6 @@ # initialize producers, calibrators, selectors, categorizers, ml models and stat models from columnflow.util import maybe_import - import columnflow.production # noqa - if law.config.has_option("analysis", "production_modules"): - for m in law.config.get_expanded("analysis", "production_modules", [], split_csv=True): - logger.debug(f"loading production module '{m}'") - maybe_import(m.strip()) - - import columnflow.weight # noqa - if law.config.has_option("analysis", "weight_production_modules"): - for m in law.config.get_expanded("analysis", "weight_production_modules", [], split_csv=True): - logger.debug(f"loading weight production module '{m}'") - maybe_import(m.strip()) - import columnflow.calibration # noqa if law.config.has_option("analysis", "calibration_modules"): for m in law.config.get_expanded("analysis", "calibration_modules", [], split_csv=True): @@ -100,6 +94,24 @@ logger.debug(f"loading selection module '{m}'") maybe_import(m.strip()) + import columnflow.reduction # noqa + if law.config.has_option("analysis", "reduction_modules"): + for m in law.config.get_expanded("analysis", "reduction_modules", [], split_csv=True): + logger.debug(f"loading reduction module '{m}'") + maybe_import(m.strip()) + + import columnflow.production # noqa + if law.config.has_option("analysis", "production_modules"): + for m in law.config.get_expanded("analysis", "production_modules", [], split_csv=True): + logger.debug(f"loading production module '{m}'") + maybe_import(m.strip()) + + import columnflow.histogramming # noqa + if law.config.has_option("analysis", "hist_production_modules"): + for m in law.config.get_expanded("analysis", "hist_production_modules", [], split_csv=True): + logger.debug(f"loading hist production module '{m}'") + maybe_import(m.strip()) + import columnflow.categorization # noqa if law.config.has_option("analysis", "categorization_modules"): for m in law.config.get_expanded("analysis", "categorization_modules", [], split_csv=True): diff --git a/columnflow/calibration/__init__.py b/columnflow/calibration/__init__.py index 22d793bf4..276e22c6d 100644 --- a/columnflow/calibration/__init__.py +++ b/columnflow/calibration/__init__.py @@ -8,10 +8,9 @@ import inspect -from columnflow.types import Callable, Sequence +from columnflow.types import Callable from columnflow.util import DerivableMeta from columnflow.columnar_util import TaskArrayFunction -from columnflow.config_util import expand_shift_sources class Calibrator(TaskArrayFunction): @@ -28,8 +27,6 @@ def calibrator( bases: tuple = (), mc_only: bool = False, data_only: bool = False, - nominal_only: bool = False, - shifts_only: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -41,11 +38,6 @@ def calibrator( :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. - When *nominal_only* is *True* or *shifts_only* is set, the calibrator is skipped and not - considered by other calibrators, selectors and producers in case they are evalauted on a - :py:class:`order.Shift` (using the :py:attr:`global_shift_inst` attribute) whose name does - not match. - All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Calibrator` class. @@ -54,10 +46,6 @@ def calibrator( Monte Carlo simulation and skipped for real data. :param data_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on real data and skipped for Monte Carlo simulation. - :param nominal_only: Boolean flag indicating that this :py:class:`Calibrator` should only - run on the nominal shift and skipped on any other shifts. - :param shifts_only: Shift names that this :py:class:`Calibrator` should only run on, - skipping all other shifts. :return: New :py:class:`Calibrator` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -67,8 +55,6 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, - "nominal_only": nominal_only, - "shifts_only": shifts_only, } # get the module name @@ -82,45 +68,23 @@ def decorator(func: Callable) -> DerivableMeta: def update_cls_dict(cls_name, cls_dict, get_attr): mc_only = get_attr("mc_only") data_only = get_attr("data_only") - nominal_only = get_attr("nominal_only") - shifts_only = get_attr("shifts_only") - - # prepare shifts_only - if shifts_only: - shifts_only_expanded = set(expand_shift_sources(shifts_only)) - if shifts_only_expanded != shifts_only: - shifts_only = shifts_only_expanded - cls_dict["shifts_only"] = shifts_only # optionally add skip function if mc_only and data_only: raise Exception(f"calibrator {cls_name} received both mc_only and data_only") - if nominal_only and shifts_only: + if (mc_only or data_only) and cls_dict.get("skip_func"): raise Exception( - f"calibrator {cls_name} received both nominal_only and shifts_only", + f"calibrator {cls_name} received custom skip_func, but either mc_only or " + "data_only are set", ) - if mc_only or data_only or nominal_only or shifts_only: - if cls_dict.get("skip_func"): - raise Exception( - f"calibrator {cls_name} received custom skip_func, but either mc_only, " - "data_only, nominal_only or shifts_only are set", - ) if "skip_func" not in cls_dict: - def skip_func(self): + def skip_func(self, **kwargs) -> bool: # check mc_only and data_only - if getattr(self, "dataset_inst", None): - if mc_only and not self.dataset_inst.is_mc: - return True - if data_only and not self.dataset_inst.is_data: - return True - - # check nominal_only and shifts_only - if getattr(self, "global_shift_inst", None): - if nominal_only and not self.global_shift_inst.is_nominal: - return True - if shifts_only and self.global_shift_inst.name not in shifts_only: - return True + if mc_only and not self.dataset_inst.is_mc: + return True + if data_only and not self.dataset_inst.is_data: + return True # in all other cases, do not skip return False diff --git a/columnflow/calibration/cms/egamma.py b/columnflow/calibration/cms/egamma.py index 5b0bf7ed7..b879e300c 100644 --- a/columnflow/calibration/cms/egamma.py +++ b/columnflow/calibration/cms/egamma.py @@ -14,10 +14,8 @@ from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random -from columnflow.util import maybe_import, InsertableDict -from columnflow.columnar_util import ( - set_ak_column, flat_np_view, ak_copy, optional_column, -) +from columnflow.util import maybe_import, load_correction_set, DotDict +from columnflow.columnar_util import set_ak_column, flat_np_view, ak_copy, optional_column from columnflow.types import Any ak = maybe_import("awkward") @@ -30,7 +28,10 @@ @dataclass class EGammaCorrectionConfig: - correction_set: str = "Scale" + correction_set: str + value_type: str + uncertainty_type: str + compound: bool = False corrector_kwargs: dict[str, Any] = field(default_factory=dict) @@ -46,7 +47,7 @@ def source_field(self) -> str: ... @abc.abstractmethod - def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFile: + def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFileTarget: """Function to retrieve the correction file from the external files. :param external_files: File target containing the files as requested @@ -59,30 +60,49 @@ def get_scale_config(self) -> EGammaCorrectionConfig: """Function to retrieve the configuration for the photon energy correction.""" ... - def call_func( - self, - events: ak.Array, - **kwargs, - ) -> ak.Array: + def call_func(self, events: ak.Array, **kwargs) -> ak.Array: """ - Apply energy corrections to EGamma objects in the events array. + Apply energy corrections to EGamma objects in the events array. There are two types of implementations: standard + and Et dependent. + For Run2 the standard implementation is used, while for Run3 the Et dependent is recommended by the EGammaPog: + https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3?rev=41 + The Et dependendent recipe follows the example given in: + https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/66f581d0549e8d67fc55420d8bba15c9369fff7c/examples/egmScaleAndSmearingExample.py - This implementation follows the recommendations from the EGamma POG: - https://twiki.cern.ch/twiki/bin/view/CMS/EgammSFandSSRun3#Scale_And_Smearings_Example + Requires an external file in the config under ``electron_ss``. Example: - Derivatives of this base class require additional member variables and - functions: + .. code-block:: python - - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` or `Photon`). - - *get_correction_file*: Function to retrieve the correction file, e.g. - from the list of external files in the current `config_inst`. + cfg.x.external_files = DotDict.wrap({ + "electron_ss": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-120c4271/POG/EGM/2022_Summer22//electronSS_EtDependent.json.gz", # noqa + }) + + The pairs of correction set, value and uncertainty type names, and if a compound method is used should be configured using the :py:class:`EGammaCorrectionConfig` as an + auxiliary entry in the config: + + .. code-block:: python + + cfg.x.eec = EGammaCorrectionConfig( + correction_set="EGMScale_Compound_Ele_2022preEE", + value_type="scale", + uncertainty_type="escale", + compound=True, + ) + + Derivatives of this base class require additional member variables and functions: + + - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` + or `Photon`). + - *get_correction_file*: Function to retrieve the correction file, e.g.from + the list, of external files in the current `config_inst`. - *get_scale_config*: Function to retrieve the configuration for the energy correction. - This config must be an instance of :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. + This config must be an instance of + :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. - If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. - The correction tool only supports flat arrays, so inputs are converted to a flat numpy view first. - Corrections are always applied to the raw pt, which is important if more than one correction is applied in a - row. The final corrections must be applied to the current pt. + If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. The + correction tool only supports flat arrays, so inputs are converted to a flat numpy view + first. Corrections are always applied to the raw pt, which is important if more than one + correction is applied in a row. The final corrections must be applied to the current pt. If :py:attr:`with_uncertainties` is set to `True`, the scale uncertainties are calculated. The scale uncertainties are only available for simulated data. @@ -93,15 +113,13 @@ def call_func( :notes: - Varied corrections are only applied to Monte Carlo (MC) data. - EGamma energy correction is only applied to real data. - - Changes are applied to the views and directly propagate to the original awkward arrays. + - Changes are applied to the views and directly propagate to the original awkward + arrays. """ - # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt - if "rawPt" not in events[self.source_field].fields: - events = set_ak_column_f32( - events, f"{self.source_field}.rawPt", events[self.source_field].pt, - ) + events = set_ak_column_f32(events, f"{self.source_field}.rawPt", events[self.source_field].pt) + # the correction tool only supports flat arrays, so convert inputs to flat np view first # corrections are always applied to the raw pt - this is important if more than # one correction is applied in a row @@ -110,16 +128,13 @@ def call_func( # the final corrections must be applied to the current pt though pt_application = flat_np_view(events[self.source_field].pt, axis=1) - broadcasted_run = ak.broadcast_arrays( - events[self.source_field].pt, events.run, - ) + broadcasted_run = ak.broadcast_arrays(events[self.source_field].pt, events.run) run = flat_np_view(broadcasted_run[1], axis=1) gain = flat_np_view(events[self.source_field].seedGain, axis=1) sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) r9 = flat_np_view(events[self.source_field].r9, axis=1) # prepare arguments - # we use pt as et since there depends in linear (following the recoomendations) # (energy is part of the LorentzVector behavior) variable_map = { "et": pt_eval, @@ -127,6 +142,10 @@ def call_func( "gain": gain, "r9": r9, "run": run, + "seedGain": gain, + "pt": pt_eval, + "AbsScEta": np.abs(sceta), + "ScEta": sceta, **self.scale_config.corrector_kwargs, } args = tuple( @@ -136,7 +155,7 @@ def call_func( # varied corrections are only applied to MC if self.with_uncertainties and self.dataset_inst.is_mc: - scale_uncertainties = self.scale_corrector("total_uncertainty", *args) + scale_uncertainties = self.scale_corrector.evaluate(self.scale_config.uncertainty_type, *args) scales_up = (1 + scale_uncertainties) scales_down = (1 - scale_uncertainties) @@ -150,21 +169,19 @@ def call_func( # save columns postfix = f"scale_{direction}" - events = set_ak_column_f32( - events, f"{self.source_field}.pt_{postfix}", pt_varied, - ) + events = set_ak_column_f32(events, f"{self.source_field}.pt_{postfix}", pt_varied) # apply the nominal correction # note: changes are applied to the views and directly propagate to the original ak arrays # and do not need to be inserted into the events chunk again # EGamma energy correction is ONLY applied to DATA if self.dataset_inst.is_data: - scales_nom = self.scale_corrector("total_correction", *args) + scales_nom = self.scale_corrector.evaluate(self.scale_config.value_type, *args) pt_application *= scales_nom return events - def init_func(self) -> None: + def init_func(self, **kwargs) -> None: """Function to initialize the calibrator. Sets the required and produced columns for the calibrator. @@ -186,11 +203,10 @@ def init_func(self) -> None: # add columns with unceratinties if requested # photon scale _uncertainties_ are only available for MC - if self.with_uncertainties and getattr(self, "dataset_inst", None): - if self.dataset_inst.is_mc: - self.produces |= {f"{self.source_field}.pt_scale_{{up,down}}"} + if self.with_uncertainties and self.dataset_inst.is_mc: + self.produces |= {f"{self.source_field}.pt_scale_{{up,down}}"} - def requires_func(self, reqs: dict) -> None: + def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: """Function to add necessary requirements. This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` @@ -198,14 +214,19 @@ def requires_func(self, reqs: dict) -> None: :param reqs: Dictionary of requirements. """ + if "external_files" in reqs: + return + from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) def setup_func( self, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: """Setup function before event chunk loop. @@ -215,21 +236,15 @@ def setup_func( :param reqs: Dictionary with resolved requirements. :param inputs: Dictionary with inputs (not used). :param reader_targets: Dictionary for optional additional columns to load - (not used). """ - bundle = reqs["external_files"] self.scale_config = self.get_scale_config() - # create the egamma corrector - import correctionlib - correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate - correction_set = correctionlib.CorrectionSet.from_string( - self.get_correction_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) - self.scale_corrector = correction_set[self.scale_config.correction_set] - - # check versions - assert self.scale_corrector.version in [0, 1, 2] + corr_file = self.get_correction_file(reqs["external_files"].files) + # init and extend the correction set + corr_set = load_correction_set(corr_file) + if self.scale_config.compound: + corr_set = corr_set.compound + self.scale_corrector = corr_set[self.scale_config.correction_set] class egamma_resolution_corrector(Calibrator): @@ -266,19 +281,36 @@ def get_resolution_config(self) -> EGammaCorrectionConfig: """Function to retrieve the configuration for the photon energy correction.""" ... - def call_func( - self, - events: ak.Array, - **kwargs, - ) -> ak.Array: + def call_func(self, events: ak.Array, **kwargs) -> ak.Array: """ Apply energy resolution corrections to EGamma objects in the events array. - This implementation follows the recommendations from the EGamma POG: - https://twiki.cern.ch/twiki/bin/view/CMS/EgammSFandSSRun3#Scale_And_Smearings_Example + There are two types of implementations: standard and Et dependent. For Run2 the standard + implementation is used, while for Run3 the Et dependent is recommended by the EGammaPog: + https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3?rev=41 The Et dependendent + recipe follows the example given in: + https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/66f581d0549e8d67fc55420d8bba15c9369fff7c/examples/egmScaleAndSmearingExample.py + + Requires an external file in the config under ``electron_ss``. Example: + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "electron_ss": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-120c4271/POG/EGM/2022_Summer22/electronSS_EtDependent.json.gz", # noqa + }) + + The pairs of correction set, value and uncertainty type names, and if a compound method is used should be configured using the :py:class:`EGammaCorrectionConfig` as an + auxiliary entry in the config: + + .. code-block:: python - Derivatives of this base class require additional member variables and - functions: + cfg.x.eec = EGammaCorrectionConfig( + correction_set="EGMSmearAndSyst_ElePTsplit_2022preEE", + value_type="smear", + uncertainty_type="esmear", + ) + + Derivatives of this base class require additional member variables and functions: - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` or `Photon`). - *get_correction_file*: Function to retrieve the correction file, e.g. @@ -308,57 +340,60 @@ def call_func( # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt if "rawPt" not in events[self.source_field].fields: - events = set_ak_column_f32( - events, f"{self.source_field}.rawPt", ak_copy(events[self.source_field].pt), - ) + events = set_ak_column_f32(events, f"{self.source_field}.rawPt", ak_copy(events[self.source_field].pt)) # the correction tool only supports flat arrays, so convert inputs to flat np view first - sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) r9 = flat_np_view(events[self.source_field].r9, axis=1) flat_seeds = flat_np_view(events[self.source_field].deterministic_seed, axis=1) + pt = flat_np_view(events[self.source_field].rawPt, axis=1) # prepare arguments - # we use pt as et since there depends in linear (following the recoomendations) - # (energy is part of the LorentzVector behavior) variable_map = { + "AbsScEta": np.abs(sceta), "eta": sceta, "r9": r9, - **self.resolution_config.corrector_kwargs, + "pt": pt, + **self.resolution_cfg.corrector_kwargs, } + args = tuple( - variable_map[inp.name] for inp in self.resolution_corrector.inputs + variable_map[inp.name] + for inp in self.resolution_corrector.inputs if inp.name in variable_map ) # calculate the smearing scale - rho = self.resolution_corrector("rho", *args) - - # -- stochastic smearing - # normally distributed random numbers according to EGamma resolution + # as mentioned in the example above, allows us to apply them directly to the MC simulation. + rho = self.resolution_corrector.evaluate(self.resolution_cfg.value_type, *args) # varied corrections if self.with_uncertainties and self.dataset_inst.is_mc: - rho_unc = self.resolution_corrector("err_rho", *args) + rho_unc = self.resolution_corrector.evaluate(self.resolution_cfg.uncertainty_type, *args) + random_normal_number = functools.partial(ak_random, 0, 1) + smearing_func = lambda rng_array, variation: rng_array * variation + 1 + smearing_up = ( - ak_random( - 1, rho + rho_unc, flat_seeds, - rand_func=self.deterministic_normal_up, + smearing_func( + random_normal_number(flat_seeds, rand_func=self.deterministic_normal_up), + rho + rho_unc, ) if self.deterministic_seed_index >= 0 - else ak_random(1, rho + rho_unc, rand_func=np.random.Generator( - np.random.SFC64(events.event.to_list())).normal, + else smearing_func( + random_normal_number(rand_func=np.random.Generator(np.random.SFC64(events.event.to_list())).normal), + rho + rho_unc, ) ) smearing_down = ( - ak_random( - 1, rho - rho_unc, flat_seeds, - rand_func=self.deterministic_normal_down, + smearing_func( + random_normal_number(flat_seeds, rand_func=self.deterministic_normal_down), + rho - rho_unc, ) if self.deterministic_seed_index >= 0 - else ak_random(1, rho - rho_unc, rand_func=np.random.Generator( - np.random.SFC64(events.event.to_list())).normal, + else smearing_func( + random_normal_number(rand_func=np.random.Generator(np.random.SFC64(events.event.to_list())).normal), + rho - rho_unc, ) ) @@ -373,9 +408,7 @@ def call_func( # save columns postfix = f"res_{direction}" - events = set_ak_column_f32( - events, f"{self.source_field}.pt_{postfix}", pt_varied, - ) + events = set_ak_column_f32(events, f"{self.source_field}.pt_{postfix}", pt_varied) # apply the nominal correction # note: changes are applied to the views and directly propagate to the original ak arrays @@ -395,7 +428,7 @@ def call_func( return events - def init_func(self) -> None: + def init_func(self, **kwargs) -> None: """Function to initialize the calibrator. Sets the required and produced columns for the calibrator. @@ -411,11 +444,10 @@ def init_func(self) -> None: } # add columns with unceratinties if requested - if self.with_uncertainties and getattr(self, "dataset_inst", None): - if self.dataset_inst.is_mc: - self.produces |= {f"{self.source_field}.pt_res_{{up,down}}"} + if self.with_uncertainties and self.dataset_inst.is_mc: + self.produces |= {f"{self.source_field}.pt_res_{{up,down}}"} - def requires_func(self, reqs: dict) -> None: + def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: """Function to add necessary requirements. This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` @@ -423,10 +455,20 @@ def requires_func(self, reqs: dict) -> None: :param reqs: Dictionary of requirements. """ + if "external_files" in reqs: + return + from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) - def setup_func(self, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: + def setup_func( + self, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, + ) -> None: """Setup function before event chunk loop. This function loads the correction file and sets up the correction tool. @@ -439,24 +481,19 @@ def setup_func(self, reqs: dict, inputs: dict, reader_targets: InsertableDict) - :param reader_targets: Dictionary for optional additional columns to load (not used). """ - bundle = reqs["external_files"] - self.resolution_config = self.get_resolution_config() - + self.resolution_cfg = self.get_resolution_config() # create the egamma corrector - import correctionlib - correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate - correction_set = correctionlib.CorrectionSet.from_string( - self.get_correction_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) - self.resolution_corrector = correction_set[self.resolution_config.correction_set] - - # check versions - assert self.resolution_corrector.version in [0, 1, 2] + corr_file = self.get_correction_file(reqs["external_files"].files) + corr_set = load_correction_set(corr_file) + if self.resolution_cfg.compound: + corr_set = corr_set.compound + self.resolution_corrector = corr_set[self.resolution_cfg.correction_set] # use deterministic seeds for random smearing if requested if self.deterministic_seed_index >= 0: idx = self.deterministic_seed_index bit_generator = np.random.SFC64 + def deterministic_normal(loc, scale, seed, idx_offset=0): return np.asarray([ np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1 + idx_offset)[-1] @@ -513,10 +550,9 @@ def photons(self, events: ak.Array, **kwargs) -> ak.Array: return events -@photons.init -def photons_init(self) -> None: +@photons.pre_init +def photons_pre_init(self, **kwargs) -> None: # forward argument to the producers - if pec not in self.deps_kwargs: self.deps_kwargs[pec] = dict() if per not in self.deps_kwargs: @@ -588,10 +624,9 @@ def electrons(self, events: ak.Array, **kwargs) -> ak.Array: return events -@electrons.init -def electrons_init(self) -> None: +@electrons.pre_init +def electrons_pre_init(self, **kwargs) -> None: # forward argument to the producers - if eec not in self.deps_kwargs: self.deps_kwargs[eec] = dict() if eer not in self.deps_kwargs: diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index e7eb6b330..24462732f 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -12,7 +12,7 @@ from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random, propagate_met, sum_transverse from columnflow.production.util import attach_coffea_behavior -from columnflow.util import maybe_import, InsertableDict, DotDict +from columnflow.util import maybe_import, DotDict, load_correction_set from columnflow.columnar_util import set_ak_column, layout_ak_array, optional_column as optional np = maybe_import("numpy") @@ -476,7 +476,7 @@ def correct_jets(*, pt, eta, phi, area, rho, evaluator_key="jec"): @jec.init -def jec_init(self: Calibrator) -> None: +def jec_init(self: Calibrator, **kwargs) -> None: jec_cfg = self.get_jec_config() sources = self.uncertainty_sources @@ -513,16 +513,28 @@ def jec_init(self: Calibrator) -> None: @jec.requires -def jec_requires(self: Calibrator, reqs: dict) -> None: +def jec_requires( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @jec.setup -def jec_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: +def jec_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: """ Load the correct jec files using the :py:func:`from_string` method of the :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` @@ -573,13 +585,9 @@ def jec_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: Insert :param inputs: Additional inputs, currently not used :param reader_targets: TODO: add documentation """ - bundle = reqs["external_files"] - # import the correction sets from the external file - import correctionlib - correction_set = correctionlib.CorrectionSet.from_string( - self.get_jec_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) + jec_file = self.get_jec_file(reqs["external_files"].files) + correction_set = load_correction_set(jec_file) # compute JEC keys from config information jec_cfg = self.get_jec_config() @@ -700,6 +708,8 @@ def get_jer_config_default(self: Calibrator) -> DotDict: jec_uncertainty_sources=None, # whether gen jet matching should be performed relative to the nominal jet pt, or the jec varied values gen_jet_matching_nominal=False, + # regions where stochastic smearing is applied + stochastic_smearing_mask=lambda self, jets: ak.ones_like(jets.pt, dtype=np.bool), ) def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ @@ -796,17 +806,17 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # extract nominal pt resolution inputs = [variable_map[inp.name] for inp in self.evaluators["jer"].inputs] - jerpt = {jer_nom: ak_evaluate(self.evaluators["jer"], *inputs)} + jer = {jer_nom: ak_evaluate(self.evaluators["jer"], *inputs)} # for simplifications below, use the same values for jer variations - jerpt[jer_up] = jerpt[jer_nom] - jerpt[jer_down] = jerpt[jer_nom] + jer[jer_up] = jer[jer_nom] + jer[jer_down] = jer[jer_nom] # extract pt resolutions evaluted for jec uncertainties for jec_var in self.jec_variations: _variable_map = variable_map | {"JetPt": events[jet_name][f"pt_{jec_var}"]} inputs = [_variable_map[inp.name] for inp in self.evaluators["jer"].inputs] - jerpt[jec_var] = ak_evaluate(self.evaluators["jer"], *inputs) + jer[jec_var] = ak_evaluate(self.evaluators["jer"], *inputs) # extract scale factors jersf = {} @@ -823,8 +833,8 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # array with all JER scale factor variations as an additional axis # (note: axis needs to be regular for broadcasting to work correctly) - jerpt = ak.concatenate( - [jerpt[v][..., None] for v in self.jer_variations + self.jec_variations], + jer = ak.concatenate( + [jer[v][..., None] for v in self.jer_variations + self.jec_variations], axis=-1, ) jersf = ak.concatenate( @@ -839,7 +849,11 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: add_smear = np.sqrt(ak.where(jersf2_m1 < 0, 0, jersf2_m1)) # compute smearing factors (stochastic method) - smear_factors_stochastic = 1.0 + random_normal * jerpt * add_smear + smear_factors_stochastic = ak.where( + self.stochastic_smearing_mask(events[jet_name]), + 1.0 + random_normal * jer * add_smear, + 1.0, + ) # -- scaling method (using gen match) @@ -869,7 +883,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # test if matched gen jets are within 3 * resolution # (no check for Delta-R matching criterion; we assume this was done during nanoAOD production to get the genJetIdx) - is_matched_pt = np.abs(pt_relative_diff) < 3 * jerpt + is_matched_pt = np.abs(pt_relative_diff) < 3 * jer is_matched_pt = ak.fill_none(is_matched_pt, False) # masked values = no gen match # compute smearing factors (scaling method) @@ -943,8 +957,14 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: return events +jer_horn_handling = jer.derive("jer_horn_handling", cls_dict={ + # source: https://cms-jerc.web.cern.ch/Recommendations/#note-25eta30 + "stochastic_smearing_mask": lambda self, jets: (abs(jets.eta) < 2.5) | (abs(jets.eta) > 3.0), +}) + + @jer.init -def jer_init(self: Calibrator) -> None: +def jer_init(self: Calibrator, **kwargs) -> None: # add jec_cfg for applying nominal smearing to jec variations jec_cfg = self.get_jec_config() jec_sources = self.jec_uncertainty_sources @@ -991,16 +1011,28 @@ def jer_init(self: Calibrator) -> None: @jer.requires -def jer_requires(self: Calibrator, reqs: dict) -> None: +def jer_requires( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @jer.setup -def jer_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: +def jer_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: """ Load the correct jer files using the :py:func:`from_string` method of the :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` function and apply the @@ -1033,13 +1065,9 @@ def jer_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: Insert :param inputs: Additional inputs, currently not used. :param reader_targets: TODO: add documentation. """ - bundle = reqs["external_files"] - # import the correction sets from the external file - import correctionlib - correction_set = correctionlib.CorrectionSet.from_string( - self.get_jer_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) + jer_file = self.get_jer_file(reqs["external_files"].files) + correction_set = load_correction_set(jer_file) # compute JER keys from config information jer_cfg = self.get_jer_config() @@ -1058,6 +1086,7 @@ def jer_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: Insert if self.deterministic_seed_index >= 0: idx = self.deterministic_seed_index bit_generator = np.random.SFC64 + def deterministic_normal(loc, scale, seed): return np.asarray([ np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1)[-1] @@ -1107,8 +1136,8 @@ def jets(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: return events -@jets.init -def jets_init(self: Calibrator) -> None: +@jets.pre_init +def jets_pre_init(self: Calibrator, **kwargs) -> None: # forward argument to the producers self.deps_kwargs[jec]["jet_name"] = self.jet_name self.deps_kwargs[jer]["jet_name"] = self.jet_name diff --git a/columnflow/calibration/cms/jets_coffea.py b/columnflow/calibration/cms/jets_coffea.py deleted file mode 100644 index e37b9ac18..000000000 --- a/columnflow/calibration/cms/jets_coffea.py +++ /dev/null @@ -1,728 +0,0 @@ -# coding: utf-8 - -""" -Calibration methods for jets using :external+coffea:doc:`coffea ` -functions. -""" - -import os -import functools - -import law - -from columnflow.util import maybe_import, memoize, InsertableDict -from columnflow.calibration import Calibrator, calibrator -from columnflow.calibration.util import propagate_met, ak_random -from columnflow.production.util import attach_coffea_behavior -from columnflow.columnar_util import set_ak_column -from typing import Iterable, Callable, Type - -np = maybe_import("numpy") -ak = maybe_import("awkward") - -coffea_extractor = maybe_import("coffea.lookup_tools.extractor") -coffea_jetmet_tools = maybe_import("coffea.jetmet_tools") -coffea_txt_converters = maybe_import("coffea.lookup_tools.txt_converters") - - -# -# first, some utility functions -# - -set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) - - -def get_basenames(struct: Iterable) -> Iterable: - """ - Replace full file paths in an arbitrary struct by the file basenames. - - The function loops through the structure and extracts the base name using a combination of - :py:func:`os.path.splitext` and :py:func:`os.path.basename`. The loop itself is done using the - :external+law:py:func:`law.util.map_struct` function. - - :param struct: Iterable of arbitrary nested structure containing full file paths - - :return: Iterable of same structure as *struct* containing only basenames of paths. - """ - return law.util.map_struct( - lambda p: os.path.splitext(os.path.basename(p[0] if isinstance(p, tuple) else p))[0], - struct, - ) - - -@memoize -def get_lookup_provider( - files: list, - conversion_func: Callable, - provider_cls: Type, - names: list[str or tuple[str, str]] = None, -) -> Type: - """ - Create a coffea helper object for looking up information in files of various formats. - - This function reads in the *files* containing lookup tables (e.g. JEC text files), extracts the - table of values ("weights") using the conversion function *conversion_func* implemented in - coffea, and uses them to construct a helper object of type *provider_cls* that can be passed - event data to yield the lookup values (e.g. a - :external+coffea:py:class:`~coffea.jetmet_tools.FactorizedJetCorrector` or - :external+coffea:py:class:`~coffea.jetmet_tools.JetCorrectionUncertainty`). - - Optionally, a list of *names* can be supplied to select only a subset of weight tables for - constructing the provider object (the default is to use all of them). This is intended to be - useful for e.g. selecting only a particular set of jet energy uncertainties from an - "UncertaintySources" file. By convention, the *names* always start with the basename of the file - that contains the corresponding weight table. - - Entries in *names* may also be tuples of the form (*src_name*, *dst_name*), in which case the - *src_name* will be replaced by *dst_name* when passing the names to the *provider_cls*. - - The user must ensure that the *files* can be parsed by the *conversion_func* supplied, and that - the information contained in the files is meaningful in connection with the *provider_cls*. - - :param files: List of files containing lookup tables (e.g. JEC text files). - :param conversion_func: ``Callable`` that extracts the table of weights from the files in - *files*. Must return an *Iterable* that provides a :py:meth:`items` method that returns a - structure like (name, type), value - :param provider_cls: Class method that is used to construct the *provider* instance that finally - provides the weights for the events. Examples: - :external+coffea:py:class:`~coffea.jetmet_tools.FactorizedJetCorrector`, - :external+coffea:py:class:`~coffea.jetmet_tools.JetCorrectionUncertainty` - :param names: Optional list of weight names to include, see text above. - :raises ValueError: If *names* contains weight names that are not present in the source file - :return: helper class that provides the weights for the events of same type as *provider_cls* - (e.g. - :external+coffea:py:class:`~coffea.jetmet_tools.FactorizedJetCorrector`, - :external+coffea:py:class:`~coffea.jetmet_tools.JetCorrectionUncertainty`) - """ - # the extractor reads the information contained in the files - extractor = coffea_extractor.extractor() - - # files contain one or more lookup tables, each identified by a name - all_names = [] - for file_ in files: - # the actual file parsing is done here - weights = conversion_func(file_) - for (name, type_), value in weights.items(): - extractor.add_weight_set(name, type_, value) - all_names.append(name) - - extractor.finalize() - - # if user provided explicit names, check that corresponding - # weight tables have been read - if names is not None: - src_dst_names = [n if isinstance(n, tuple) else (n, n) for n in names] - unknown_names = set(src_name for src_name, _ in src_dst_names) - set(all_names) - if unknown_names: - unknown_names = ", ".join(sorted(list(unknown_names))) - available = ", ".join(sorted(list(all_names))) - raise ValueError( - f"no weight tables found for the following names: {unknown_names}, " - f"available: {available}", - ) - # TODO: I don't think the code works correctly if *names* is a list of - # strings, since further down below the code explicitly needs a tuple - # structure. We will probably need something like the following here - - # names = src_dst_names - else: - names = [(n, n) for n in all_names] - - # the evaluator does the actual lookup for each separate name - evaluator = extractor.make_evaluator() - - # the provider combines lookup results from multiple names - provider = provider_cls(**{ - dst_name: evaluator[src_name] - for src_name, dst_name in names - }) - - return provider - - -# -# Jet energy corrections -# - -@calibrator( - uses={ - "Jet.pt", "Jet.eta", "Jet.phi", "Jet.mass", "Jet.area", "Jet.rawFactor", - "Jet.jetId", - "Rho.fixedGridRhoFastjetAll", "fixedGridRhoFastjetAll", - attach_coffea_behavior, - }, - produces={ - "Jet.pt", "Jet.mass", "Jet.rawFactor", - }, - # custom uncertainty sources, defaults to config when empty - uncertainty_sources=None, - # toggle for propagation to MET - propagate_met=True, -) -def jec_coffea( - self: Calibrator, - events: ak.Array, - min_pt_met_prop: float = 15.0, - max_eta_met_prop: float = 5.2, - **kwargs, -) -> ak.Array: - """ - Apply jet energy corrections and calculate shifts for jet energy uncertainty sources. - - :param events: awkward array containing events to process - :param min_pt_met_prop: If *propagate_met* variable is ``True`` propagate the updated jet values - to the missing transverse energy (MET) using - :py:func:`~columnflow.calibration.util.propagate_met` for events where - ``met.pt > min_pt_met_prop``. - :param max_eta_met_prop: If *propagate_met* variable is ``True`` propagate the updated jet - values to the missing transverse energy (MET) using - :py:func:`~columnflow.calibration.util.propagate_met` for events where - ``met.eta > min_eta_met_prop``. - """ - # calculate uncorrected pt, mass - events = set_ak_column_f32(events, "Jet.pt_raw", events.Jet.pt * (1 - events.Jet.rawFactor)) - events = set_ak_column_f32(events, "Jet.mass_raw", events.Jet.mass * (1 - events.Jet.rawFactor)) - - # build/retrieve lookup providers for JECs and uncertainties - # NOTE: could also be moved to `jec_setup`, but keep here in case the provider ever needs - # to change based on the event content (JEC change in the middle of a run) - jec_provider = get_lookup_provider( - self.jec_files, - coffea_txt_converters.convert_jec_txt_file, - coffea_jetmet_tools.FactorizedJetCorrector, - names=self.jec_names, - ) - jec_provider_only_l1 = get_lookup_provider( - self.jec_files_only_l1, - coffea_txt_converters.convert_jec_txt_file, - coffea_jetmet_tools.FactorizedJetCorrector, - names=self.jec_names_only_l1, - ) - if self.junc_names: - junc_provider = get_lookup_provider( - self.junc_files, - coffea_txt_converters.convert_junc_txt_file, - coffea_jetmet_tools.JetCorrectionUncertainty, - names=self.junc_names, - ) - - # obtain rho, which might be located at different routes, depending on the nano version - rho = ( - events.fixedGridRhoFastjetAll - if "fixedGridRhoFastjetAll" in events.fields else - events.Rho.fixedGridRhoFastjetAll - ) - - # look up JEC correction factors - jec_factors = jec_provider.getCorrection( - JetEta=events.Jet.eta, - JetPt=events.Jet.pt_raw, - JetA=events.Jet.area, - Rho=rho, - ) - jec_factors_only_l1 = jec_provider_only_l1.getCorrection( - JetEta=events.Jet.eta, - JetPt=events.Jet.pt_raw, - JetA=events.Jet.area, - Rho=rho, - ) - - # apply the new factors with only L1 corrections - events = set_ak_column_f32(events, "Jet.pt", events.Jet.pt_raw * jec_factors_only_l1) - events = set_ak_column_f32(events, "Jet.mass", events.Jet.mass_raw * jec_factors_only_l1) - events = self[attach_coffea_behavior](events, collections=["Jet"], **kwargs) - - # store pt and phi of the full jet system for MET propagation, including a selection in raw info - # see https://twiki.cern.ch/twiki/bin/view/CMS/JECAnalysesRecommendations?rev=19#Minimum_jet_selection_cuts - if self.propagate_met: - met_prop_mask = (events.Jet.pt_raw > min_pt_met_prop) & (abs(events.Jet.eta) < max_eta_met_prop) - jetsum = events.Jet[met_prop_mask].sum(axis=1) - jet_pt_only_l1 = jetsum.pt - jet_phi_only_l1 = jetsum.phi - - # full jet correction with all levels - events = set_ak_column_f32(events, "Jet.pt", events.Jet.pt_raw * jec_factors) - events = set_ak_column_f32(events, "Jet.mass", events.Jet.mass_raw * jec_factors) - events = set_ak_column_f32(events, "Jet.rawFactor", (1 - events.Jet.pt_raw / events.Jet.pt)) - events = self[attach_coffea_behavior](events, collections=["Jet"], **kwargs) - - # nominal met propagation - if self.propagate_met: - # get pt and phi of all jets after correcting - jetsum = events.Jet[met_prop_mask].sum(axis=1) - jet_pt_all_levels = jetsum.pt - jet_phi_all_levels = jetsum.phi - - # propagate changes from L2 corrections and onwards (i.e. no L1) to MET - met_pt, met_phi = propagate_met( - jet_pt_only_l1, - jet_phi_only_l1, - jet_pt_all_levels, - jet_phi_all_levels, - events.RawMET.pt, - events.RawMET.phi, - ) - events = set_ak_column_f32(events, "MET.pt", met_pt) - events = set_ak_column_f32(events, "MET.phi", met_phi) - - # look up JEC uncertainties - if self.junc_names: - jec_uncertainties = junc_provider.getUncertainty( - JetEta=events.Jet.eta, - JetPt=events.Jet.pt_raw, - ) - for name, jec_unc_factors in jec_uncertainties: - # jec_unc_factors[I_EVT][I_JET][I_VAR] - events = set_ak_column_f32(events, f"Jet.pt_jec_{name}_up", events.Jet.pt * jec_unc_factors[:, :, 0]) - events = set_ak_column_f32(events, f"Jet.pt_jec_{name}_down", events.Jet.pt * jec_unc_factors[:, :, 1]) - events = set_ak_column_f32(events, f"Jet.mass_jec_{name}_up", events.Jet.mass * jec_unc_factors[:, :, 0]) - events = set_ak_column_f32(events, f"Jet.mass_jec_{name}_down", events.Jet.mass * jec_unc_factors[:, :, 1]) - - # shifted MET propagation - if self.propagate_met: - jet_pt_up = events.Jet[met_prop_mask][f"pt_jec_{name}_up"] - jet_pt_down = events.Jet[met_prop_mask][f"pt_jec_{name}_down"] - met_pt_up, met_phi_up = propagate_met( - jet_pt_all_levels, - jet_phi_all_levels, - jet_pt_up, - events.Jet[met_prop_mask].phi, - met_pt, - met_phi, - ) - met_pt_down, met_phi_down = propagate_met( - jet_pt_all_levels, - jet_phi_all_levels, - jet_pt_down, - events.Jet[met_prop_mask].phi, - met_pt, - met_phi, - ) - events = set_ak_column_f32(events, f"MET.pt_jec_{name}_up", met_pt_up) - events = set_ak_column_f32(events, f"MET.pt_jec_{name}_down", met_pt_down) - events = set_ak_column_f32(events, f"MET.phi_jec_{name}_up", met_phi_up) - events = set_ak_column_f32(events, f"MET.phi_jec_{name}_down", met_phi_down) - - return events - - -@jec_coffea.init -def jec_coffea_init(self: Calibrator) -> None: - sources = self.uncertainty_sources - if sources is None: - sources = self.config_inst.x.jec.uncertainty_sources - - # add shifted jet variables - self.produces |= { - f"Jet.{shifted_var}_jec_{junc_name}_{junc_dir}" - for shifted_var in ("pt", "mass") - for junc_name in sources - for junc_dir in ("up", "down") - } - - # add MET variables - if self.propagate_met: - self.uses |= {"RawMET.pt", "RawMET.phi"} - self.produces |= {"MET.pt", "MET.phi"} - - # add shifted MET variables - self.produces |= { - f"MET.{shifted_var}_jec_{junc_name}_{junc_dir}" - for shifted_var in ("pt", "phi") - for junc_name in sources - for junc_dir in ("up", "down") - } - - -@jec_coffea.requires -def jec_coffea_requires(self: Calibrator, reqs: dict) -> None: - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) - - -@jec_coffea.setup -def jec_coffea_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: - """ - Determine correct JEC files for task based on config/dataset and inject them into the calibrator - function call. - - :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` - instance. - :param inputs: Additional inputs, currently not used. - :param reader_targets: TODO: add docs - - :raises ValueError: If module is provided with more than one JEC uncertainty source file. - """ - # get external files bundle that contains JEC text files - bundle = reqs["external_files"] - - # make selector for JEC text files based on sample type (and era for data) - if self.dataset_inst.is_data: - jec_era = self.dataset_inst.get_aux("jec_era", None) - # if no special JEC era is specified, infer based on 'era' - if jec_era is None: - jec_era = "Run" + self.dataset_inst.get_aux("era") - - resolve_samples = lambda x: x.data[jec_era] - else: - resolve_samples = lambda x: x.mc - - # store jec files with all correction levels - self.jec_files = [ - t.path - for t in resolve_samples(bundle.files.jec).values() - ] - self.jec_names = list(zip( - get_basenames(self.jec_files), - get_basenames(resolve_samples(self.config_inst.x.external_files.jec).values()), - )) - - # store jec files with only L1* corrections for MET propagation - self.jec_files_only_l1 = [ - t.path - for level, t in resolve_samples(bundle.files.jec).items() - if level.startswith("L1") - ] - self.jec_names_only_l1 = list(zip( - get_basenames(self.jec_files_only_l1), - get_basenames([ - src - for level, src in resolve_samples(self.config_inst.x.external_files.jec).items() - if level.startswith("L1") - ]), - )) - - # store uncertainty - self.junc_files = [ - t.path - for t in resolve_samples(bundle.files.junc) - ] - self.junc_names = list(zip( - get_basenames(self.junc_files), - get_basenames(resolve_samples(self.config_inst.x.external_files.junc)), - )) - - # ensure exactly one 'UncertaintySources' file is passed - if len(self.junc_names) != 1: - raise ValueError( - f"expected exactly one 'UncertaintySources' file, got {len(self.junc_names)}", - ) - - sources = self.uncertainty_sources - if sources is None: - sources = self.config_inst.x.jec.uncertainty_sources - - # update the weight names to include the uncertainty sources specified in the config - self.junc_names = [ - (f"{basename}_{src}", f"{orig_basename}_{src}") - for basename, orig_basename in self.junc_names - for src in sources - ] - - -# custom jec calibrator that only runs nominal correction -jec_coffea_nominal = jec_coffea.derive("jec_coffea_nominal", cls_dict={"uncertainty_sources": []}) - - -# -# Jet energy resolution smearing -# - -@calibrator( - uses={ - "Jet.pt", "Jet.eta", "Jet.phi", "Jet.mass", "Jet.genJetIdx", - "Rho.fixedGridRhoFastjetAll", "fixedGridRhoFastjetAll", - "GenJet.pt", "GenJet.eta", "GenJet.phi", - "MET.pt", "MET.phi", - attach_coffea_behavior, - }, - produces={ - "Jet.pt", "Jet.mass", - "Jet.pt_unsmeared", "Jet.mass_unsmeared", - "Jet.pt_jer_up", "Jet.pt_jer_down", "Jet.mass_jer_up", "Jet.mass_jer_down", - "MET.pt", "MET.phi", - "MET.pt_jer_up", "MET.pt_jer_down", "MET.phi_jer_up", "MET.phi_jer_down", - }, - # toggle for propagation to MET - propagate_met=True, - # only run on mc - mc_only=True, -) -def jer_coffea(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: - """ - Apply jet energy resolution smearing and calculate shifts for Jet Energy Resolution (JER) scale - factor variations. - - Follows the recommendations given in https://twiki.cern.ch/twiki/bin/viewauth/CMS/JetResolution. - - The module applies the scale factors associated to the JER and performs the stochastic smearing - to make the energy resolution in simulation more realistic. - - :param events: awkward array containing events to process - """ - # save the unsmeared properties in case they are needed later - events = set_ak_column_f32(events, "Jet.pt_unsmeared", events.Jet.pt) - events = set_ak_column_f32(events, "Jet.mass_unsmeared", events.Jet.mass) - - # use event numbers in chunk to seed random number generator - # TODO: use seeds! - rand_gen = np.random.Generator(np.random.SFC64(events.event.to_list())) - - # obtain rho, which might be located at different routes, depending on the nano version - rho = ( - events.fixedGridRhoFastjetAll - if "fixedGridRhoFastjetAll" in events.fields else - events.Rho.fixedGridRhoFastjetAll - ) - - # build/retrieve lookup providers for JECs and uncertainties - # NOTE: could also be moved to `jer_setup`, but keep here in case the provider ever needs - # to change based on the event content (JER change in the middle of a run) - jer_provider = get_lookup_provider( - self.jer_files, - coffea_txt_converters.convert_jr_txt_file, - coffea_jetmet_tools.JetResolution, - names=self.jer_names, - ) - jersf_provider = get_lookup_provider( - self.jersf_files, - coffea_txt_converters.convert_jersf_txt_file, - coffea_jetmet_tools.JetResolutionScaleFactor, - names=self.jersf_names, - ) - - # look up jet energy resolutions - # jer[I_EVT][I_JET] - jer = ak.materialized(jer_provider.getResolution( - JetEta=events.Jet.eta, - JetPt=events.Jet.pt, - Rho=rho, - )) - - # look up jet energy resolution scale factors - # jersf[I_EVT][I_JET][I_VAR] - jersf = jersf_provider.getScaleFactor( - JetEta=events.Jet.eta, - JetPt=events.Jet.pt, - ) - - # -- stochastic smearing - - # normally distributed random numbers according to JER - jer_random_normal = ak_random(0, jer, rand_func=rand_gen.normal) - - # scale random numbers according to JER SF - jersf2_m1 = jersf ** 2 - 1 - add_smear = np.sqrt(ak.where(jersf2_m1 < 0, 0, jersf2_m1)) - - # broadcast over JER SF variations - jer_random_normal, jersf_z = ak.broadcast_arrays(jer_random_normal, add_smear) - - # compute smearing factors (stochastic method) - smear_factors_stochastic = 1.0 + jer_random_normal * add_smear - - # -- scaling method (using gen match) - - # mask negative gen jet indices (= no gen match) - valid_gen_jet_idxs = ak.mask(events.Jet.genJetIdx, events.Jet.genJetIdx >= 0) - - # pad list of gen jets to prevent index error on match lookup - padded_gen_jets = ak.pad_none(events.GenJet, ak.max(valid_gen_jet_idxs) + 1) - - # gen jets that match the reconstructed jets - matched_gen_jets = padded_gen_jets[valid_gen_jet_idxs] - - # compute the relative (reco - gen) pt difference - pt_relative_diff = (events.Jet.pt - matched_gen_jets.pt) / events.Jet.pt - - # test if matched gen jets are within 3 * resolution - is_matched_pt = np.abs(pt_relative_diff) < 3 * jer - is_matched_pt = ak.fill_none(is_matched_pt, False) # masked values = no gen match - - # (no check for Delta-R matching criterion; we assume this was done during - # nanoAOD production to get the `genJetIdx`) - - # broadcast over JER SF variations - pt_relative_diff, jersf = ak.broadcast_arrays(pt_relative_diff, jersf) - - # compute smearing factors (scaling method) - smear_factors_scaling = 1.0 + (jersf - 1.0) * pt_relative_diff - - # -- hybrid smearing: take smear factors from scaling if there was a match, - # otherwise take the stochastic ones - smear_factors = ak.where( - is_matched_pt[:, :, None], - smear_factors_scaling, - smear_factors_stochastic, - ) - - # ensure array is not nullable (avoid ambiguity on Arrow/Parquet conversion) - smear_factors = ak.fill_none(smear_factors, 0.0) - - # store pt and phi of the full jet system - if self.propagate_met: - jetsum = events.Jet.sum(axis=1) - jet_pt_before = jetsum.pt - jet_phi_before = jetsum.phi - - # apply the smearing factors to the pt and mass - # (note: apply variations first since they refer to the original pt) - events = set_ak_column_f32(events, "Jet.pt_jer_up", events.Jet.pt * smear_factors[:, :, 1]) - events = set_ak_column_f32(events, "Jet.mass_jer_up", events.Jet.mass * smear_factors[:, :, 1]) - events = set_ak_column_f32(events, "Jet.pt_jer_down", events.Jet.pt * smear_factors[:, :, 2]) - events = set_ak_column_f32(events, "Jet.mass_jer_down", events.Jet.mass * smear_factors[:, :, 2]) - events = set_ak_column_f32(events, "Jet.pt", events.Jet.pt * smear_factors[:, :, 0]) - events = set_ak_column_f32(events, "Jet.mass", events.Jet.mass * smear_factors[:, :, 0]) - - # recover coffea behavior - events = self[attach_coffea_behavior](events, collections=["Jet"], **kwargs) - - # met propagation - if self.propagate_met: - # save unsmeared quantities - events = set_ak_column_f32(events, "MET.pt_unsmeared", events.MET.pt) - events = set_ak_column_f32(events, "MET.phi_unsmeared", events.MET.phi) - - # get pt and phi of all jets after correcting - jetsum = events.Jet.sum(axis=1) - jet_pt_after = jetsum.pt - jet_phi_after = jetsum.phi - - # propagate changes to MET - met_pt, met_phi = propagate_met( - jet_pt_before, - jet_phi_before, - jet_pt_after, - jet_phi_after, - events.MET.pt, - events.MET.phi, - ) - met_pt_up, met_phi_up = propagate_met( - jet_pt_after, - jet_phi_after, - events.Jet.pt_jer_up, - events.Jet.phi, - met_pt, - met_phi, - ) - met_pt_down, met_phi_down = propagate_met( - jet_pt_after, - jet_phi_after, - events.Jet.pt_jer_down, - events.Jet.phi, - met_pt, - met_phi, - ) - events = set_ak_column_f32(events, "MET.pt", met_pt) - events = set_ak_column_f32(events, "MET.phi", met_phi) - events = set_ak_column_f32(events, "MET.pt_jer_up", met_pt_up) - events = set_ak_column_f32(events, "MET.pt_jer_down", met_pt_down) - events = set_ak_column_f32(events, "MET.phi_jer_up", met_phi_up) - events = set_ak_column_f32(events, "MET.phi_jer_down", met_phi_down) - - return events - - -@jer_coffea.init -def jer_coffea_init(self: Calibrator) -> None: - if not self.propagate_met: - return - - self.uses |= { - "MET.pt", "MET.phi", - } - self.produces |= { - "MET.pt", "MET.phi", "MET.pt_jer_up", "MET.pt_jer_down", "MET.phi_jer_up", - "MET.phi_jer_down", "MET.pt_unsmeared", "MET.phi_unsmeared", - } - - -@jer_coffea.requires -def jer_coffea_requires(self: Calibrator, reqs: dict) -> None: - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) - - -@jer_coffea.setup -def jer_coffea_setup( - self: Calibrator, - reqs: dict, inputs: dict, - reader_targets: InsertableDict, -) -> None: - """ - Determine correct JER files for task based on config/dataset and inject them into the calibrator - function call. - - :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` - instance. - :param inputs: Additional inputs, currently not used. - :param reader_targets: TODO: add docs. - - :raises ValueError: If module is provided with more than one JER uncertainty source file. - """ - # get external files bundle that contains JR text files - bundle = reqs["external_files"] - - resolve_sample = lambda x: x.mc - - # pass text files to calibrator method - for key in ("jer", "jersf"): - # pass the paths to the text files that contain the corrections/uncertainties - files = [ - t.path for t in resolve_sample(bundle.files[key]) - ] - setattr(self, f"{key}_files", files) - - # also pass a list of tuples encoding the correspondence between the - # file basenames on disk (as determined by `BundleExternalFiles`) and the - # original file basenames (needed by coffea to identify the weights correctly) - basenames = get_basenames(files) - orig_basenames = get_basenames(resolve_sample(self.config_inst.x.external_files[key])) - setattr(self, f"{key}_names", list(zip(basenames, orig_basenames))) - - # ensure exactly one file is passed - if len(files) != 1: - raise ValueError( - f"Expected exactly one file for key '{key}', got {len(files)}.", - ) - - -# -# General jets calibrator -# - -@calibrator( - uses={jec_coffea, jer_coffea}, - produces={jec_coffea, jer_coffea}, - # toggle for propagation to MET - propagate_met=True, -) -def jets_coffea(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: - """ - Instance of :py:class:`~columnflow.calibration.Calibrator` that does all relevant calibrations - for jets, i.e. JEC and JER. For more information, see :py:class:`~.jec_coffea` and - :py:class:`~.jer_coffea`. - - :param events: awkward array containing events to process. - """ - # apply jet energy corrections - events = self[jec_coffea](events, **kwargs) - - # apply jer smearing on MC only - if self.dataset_inst.is_mc: - events = self[jer_coffea](events, **kwargs) - - return events - - -@jets_coffea.init -def jets_coffea_init(self: Calibrator) -> None: - # forward the propagate_met argument to the producers - self.deps_kwargs[jec_coffea] = {"propagate_met": self.propagate_met} - self.deps_kwargs[jer_coffea] = {"propagate_met": self.propagate_met} diff --git a/columnflow/calibration/cms/met.py b/columnflow/calibration/cms/met.py index 01b6ea9ef..229b4c9cb 100644 --- a/columnflow/calibration/cms/met.py +++ b/columnflow/calibration/cms/met.py @@ -4,9 +4,12 @@ MET corrections. """ +import law + from columnflow.calibration import Calibrator, calibrator -from columnflow.util import maybe_import +from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column +from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -80,7 +83,7 @@ def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: @met_phi.init -def met_phi_init(self: Calibrator) -> None: +def met_phi_init(self: Calibrator, **kwargs) -> None: """ Initialize the :py:attr:`met_pt_corrector` and :py:attr:`met_phi_corrector` attributes. """ @@ -89,16 +92,28 @@ def met_phi_init(self: Calibrator) -> None: @met_phi.requires -def met_phi_requires(self: Calibrator, reqs: dict) -> None: +def met_phi_requires( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @met_phi.setup -def met_phi_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: dict) -> None: +def met_phi_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: """ Load the correct met files using the :py:func:`from_string` method of the :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` @@ -109,13 +124,10 @@ def met_phi_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: di :param inputs: Additional inputs, currently not used. :param reader_targets: Additional targets, currently not used. """ - bundle = reqs["external_files"] - # create the pt and phi correctors - import correctionlib - correction_set = correctionlib.CorrectionSet.from_string( - self.get_met_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) + met_file = self.get_met_file(reqs["external_files"].files) + correction_set = load_correction_set(met_file) + name_tmpl = self.get_met_config() self.met_pt_corrector = correction_set[name_tmpl.format( variable="pt", diff --git a/columnflow/calibration/cms/tau.py b/columnflow/calibration/cms/tau.py index 292064e35..b9ada41ef 100644 --- a/columnflow/calibration/cms/tau.py +++ b/columnflow/calibration/cms/tau.py @@ -10,9 +10,11 @@ import itertools from dataclasses import dataclass, field +import law + from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import propagate_met -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column, flat_np_view, ak_copy from columnflow.types import Any @@ -31,10 +33,7 @@ class TECConfig: corrector_kwargs: dict[str, Any] = field(default_factory=dict) @classmethod - def new( - cls, - obj: TECConfig | tuple[str] | dict[str, str], - ) -> TECConfig: + def new(cls, obj: TECConfig | tuple[str] | dict[str, str]) -> TECConfig: # purely for backwards compatibility with the old tuple format that accepted the two # working point values if isinstance(obj, tuple) and len(obj) == 2: @@ -119,8 +118,8 @@ def tec( "eta": eta[dm_mask], "dm": dm[dm_mask], "genmatch": match[dm_mask], - "id": self.tec_config.tagger, - **self.tec_config.corrector_kwargs, + "id": self.tec_cfg.tagger, + **self.tec_cfg.corrector_kwargs, } args = tuple( variable_map[inp.name] for inp in self.tec_corrector.inputs @@ -210,8 +209,8 @@ def tec( @tec.init -def tec_init(self: Calibrator) -> None: - self.tec_config: TECConfig = self.get_tec_config() +def tec_init(self: Calibrator, **kwargs) -> None: + self.tec_cfg = self.get_tec_config() # add nominal met columns of propagating nominal tec if self.propagate_met: @@ -226,33 +225,37 @@ def tec_init(self: Calibrator) -> None: src_fields += [f"{self.met_name}.{var}" for var in ["pt", "phi"]] self.produces |= { - f"{field}_tec_{match}_dm{dm}_{direction}" - for field, match, dm, direction in itertools.product( - src_fields, - ["jet", "e"], - [0, 1, 10, 11], - ["up", "down"], - ) + f"{field}_tec_{{jet,e}}_dm{{0,1,10,11}}_{{up,down}}" + for field in src_fields } @tec.requires -def tec_requires(self: Calibrator, reqs: dict) -> None: +def tec_requires( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: + if "external_files" in reqs: + return + from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @tec.setup -def tec_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: - bundle = reqs["external_files"] - +def tec_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: # create the tec corrector - import correctionlib - correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate - correction_set = correctionlib.CorrectionSet.from_string( - self.get_tau_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) - self.tec_corrector = correction_set[self.tec_config.correction_set] + tau_file = self.get_tau_file(reqs["external_files"].files) + self.tec_corrector = load_correction_set(tau_file)[self.tec_cfg.correction_set] # check versions assert self.tec_corrector.version in [0, 1] diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 99112d3e2..f9155b7be 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -23,12 +23,11 @@ import law import order as od -from law.util import InsertableDict from columnflow.types import Sequence, Callable, Any, T, Generator from columnflow.util import ( UNSET, maybe_import, classproperty, DotDict, DerivableMeta, Derivable, pattern_matcher, - get_source_code, real_path, + get_source_code, real_path, freeze, get_docs_url, ) np = maybe_import("numpy") @@ -586,6 +585,7 @@ class ColumnCollection(enum.Flag): ALL_FROM_CALIBRATOR = enum.auto() ALL_FROM_CALIBRATORS = enum.auto() ALL_FROM_SELECTOR = enum.auto() + ALL_FROM_REDUCER = enum.auto() ALL_FROM_PRODUCER = enum.auto() ALL_FROM_PRODUCERS = enum.auto() ALL_FROM_ML_EVALUATION = enum.auto() @@ -1356,13 +1356,13 @@ def ak_copy(ak_array: ak.Array) -> ak.Array: class RouteFilter(object): """ - Shallow helper class that handles removal of routes in an awkward array that do not match those - in *keep_routes*. Each route can either be a :py:class:`Route` instance, or anything that is - accepted by its constructor. Example: + Shallow helper class that handles the filtering of routes in an awkward array that do (not) match those in + *remove_routes* (*keep_routes*). Each route can either be a :py:class:`Route` instance, or anything that is accepted + by its constructor. Example: .. code-block:: python - route_filter = RouteFilter(["Jet.pt", "Jet.eta"]) + route_filter = RouteFilter(keep=["Jet.pt", "Jet.eta"]) events = route_filter(events) print(get_ak_routes(events)) @@ -1371,42 +1371,60 @@ class RouteFilter(object): # "Jet.eta", # ] - .. py:attribute:: keep_routes + .. py:attribute:: keep type: list The routes to keep. - .. py:attribute:: remove_routes + .. py:attribute:: remove - type: None, set + type: list - A set of :py:class:`Route` instances that are removed, defined after the first call to this - instance. + The routes to remove. Evaluated after routes to keep. """ - def __init__(self, keep_routes: Sequence[Route | str]): + def __init__( + self, + *, + keep: Sequence[Route | str] | None = None, + remove: Sequence[Route | str] | None = None, + cache: bool = True, + ) -> None: super().__init__() - self.keep_routes = list(keep_routes) - self.remove_routes = None + self.keep = list(keep) if keep else [] + self.remove = list(remove) if remove else [] + self.cache = cache + + self._remove_routes = None def __call__(self, ak_array: ak.Array) -> ak.Array: - # manually remove colums that should not be kept - if self.remove_routes is None: - # convert routes to keep into string columns for pattern checks - keep_columns = [Route(route).column for route in self.keep_routes] + # potentially get routes to remove from cache + remove_routes = self._remove_routes if self.cache else None + + # determine when empty + if remove_routes is None: + # convert routes to strings for pattern checks + keep_columns = [Route(route).column for route in self.keep] + remove_columns = [Route(route).column for route in self.remove] # determine routes to remove - self.remove_routes = { - route - for route in get_ak_routes(ak_array) - if not law.util.multi_match(route.column, keep_columns) + remove_routes = { + r for r in get_ak_routes(ak_array) + if ( + (keep_columns and not law.util.multi_match(r.column, keep_columns)) or + (remove_columns and law.util.multi_match(r.column, remove_columns)) + ) } # apply the filtering - for route in self.remove_routes: - ak_array = remove_ak_column(ak_array, route) + for r in remove_routes: + ak_array = remove_ak_column(ak_array, r) + + # add routes to cache for next call + if self.cache and self._remove_routes is None: + self._remove_routes = remove_routes return ak_array @@ -1464,9 +1482,11 @@ def call_func(self, events): instance-level, the full sets of :py:attr:`used_columns` and :py:attr:`produced_columns` are simply resolvable through attributes. - *call_func* defines the function being invoked when the instance is *called*. An additional - initialization function can be wrapped through a decorator (similiar to ``property`` setters) as - shown in the example below. They constitute a mechanism to update the :py:attr:`uses` and + *call_func* defines the function being invoked when the instance is *called*. Additional + initialization functions can be registered through a decorator (similiar to ``property`` + setters) as shown in the example below. ``pre_init`` is invoked prior to any dependecy + registration. One use case is the forwarding of dependency keyword arguments via + :py:attr:`deps_kwargs`. ``init`` constitutes a mechanism to update the :py:attr:`uses` and :py:attr:`produces` sets to declare dependencies in a more dynamic way. .. code-block:: python @@ -1594,6 +1614,12 @@ def my_other_func_init(self): The wrapped function to be called on arrays. + .. py:attribute: pre_init_func + + type: callable + + The registered function called before any dependency initialization, or *None*. + .. py:attribute: init_func type: callable @@ -1610,6 +1636,7 @@ def my_other_func_init(self): # class-level attributes as defaults call_func = None + pre_init_func = None init_func = None skip_func = None @@ -1725,6 +1752,17 @@ def init(cls, func: Callable[[], None]) -> None: """ cls.init_func = func + @classmethod + def pre_init(cls, func: Callable[[], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`pre_init_func` + which is invoked prior to any dependency creation. The function should not accept positional + arguments. + + The decorator does not return the wrapped function. + """ + cls.pre_init_func = func + @classmethod def skip(cls, func: Callable[[], bool]) -> None: """ @@ -1739,6 +1777,7 @@ def skip(cls, func: Callable[[], bool]) -> None: def __init__( self, call_func: Callable | None = law.no_value, + pre_init_func: Callable | None = law.no_value, init_func: Callable | None = law.no_value, skip_func: Callable | None = law.no_value, check_used_columns: bool | None = None, @@ -1753,6 +1792,8 @@ def __init__( # add class-level attributes as defaults for unset arguments (no_value) if call_func == law.no_value: call_func = self.__class__.call_func + if pre_init_func == law.no_value: + pre_init_func = self.__class__.pre_init_func if init_func == law.no_value: init_func = self.__class__.init_func if skip_func == law.no_value: @@ -1767,6 +1808,8 @@ def __init__( # when a custom funcs are passed, bind them to this instance if call_func: self.call_func = call_func.__get__(self, self.__class__) + if pre_init_func: + self.pre_init_func = pre_init_func.__get__(self, self.__class__) if init_func: self.init_func = init_func.__get__(self, self.__class__) if skip_func: @@ -1808,8 +1851,8 @@ def __init__( # dictionary of dependency class to instance, set in create_dependencies self.deps = DotDict() - # dictionary of keyword arguments mapped to dependenc classes to be forwarded to their init - self.deps_kwargs = defaultdict(dict) # TODO: avoid using `defaultdict` + # dictionary of keyword arguments mapped to dependent classes to are forwarded to their init + self.deps_kwargs = defaultdict(DotDict) # deferred part of the initialization if deferred_init: @@ -1821,6 +1864,23 @@ def __getitem__(self, dep_cls: DerivableMeta) -> ArrayFunction: """ return self.deps[dep_cls] + def _iter_dependency_set(self, objects: set[Any]) -> Generator[Any, None, None]: + q = deque(objects) + while q: + obj = q.popleft() + # evaluate deferred columns + if isinstance(obj, self.DeferredColumn): + obj = obj(self) + # when obj is falsy, skip it + if not obj: + continue + # when obj is a set of objects, include them in the iteration + if isinstance(obj, set): + q.extendleft(obj) + continue + # yield it + yield obj + def has_dep(self, dep_cls: DerivableMeta) -> bool: """ Returns whether a dependency of class *dep_cls* is present. @@ -1861,10 +1921,14 @@ def walk_deps( # add the next dependencies extend(dep.deps.values()) - def deferred_init(self, instance_cache: dict | None = None) -> dict: + def deferred_init(self, instance_cache: dict | None = None) -> None: """ Controls the deferred part of the initialization process. """ + # run this instance's pre init function which might update dep kwargs + if callable(self.pre_init_func): + self.pre_init_func() + # create dependencies once instance_cache = instance_cache or {} self.create_dependencies(instance_cache) @@ -1874,16 +1938,7 @@ def deferred_init(self, instance_cache: dict | None = None) -> dict: self.init_func() # instantiate dependencies again, but only perform updates - # self.create_dependencies(instance_cache, only_update=True) - - # NOTE: the above does not correctly propagate `deps_kwargs` to the dependencies. - # As a workaround, instantiate all dependencies fully a second time by - # invalidating the instance cache and setting `only_update` to False - instance_cache = {} - self.create_dependencies(instance_cache, only_update=False) - - # NOTE: return value currently not being used anywhere -> remove? - return instance_cache + self.create_dependencies(instance_cache, only_update=True) def create_dependencies( self, @@ -1892,7 +1947,7 @@ def create_dependencies( ) -> None: """ Walks through all dependencies configured in the :py:attr:`_dependency_sets` and fills - :py:attr:`deps` as well as separate sets, corresponding to the classes defined in + :py:attr:`deps` as well as separate sets, corresponding to the attributes defined in :py:attr:`_dependency_sets` (e.g. :py:attr:`uses` -> :py:attr:`uses_instances`). *instance_cache* is a dictionary that is serves as a cache to prevent same classes being @@ -1901,30 +1956,33 @@ def create_dependencies( def add_dep(cls_or_inst): is_cls = ArrayFunction.derived_by(cls_or_inst) cls = cls_or_inst if is_cls else cls_or_inst.__class__ - if not only_update or cls not in self.deps: - # create or get the instance - if is_cls: - # use the cache - if cls not in instance_cache: - # create the instance first without its deps, then cache it but do not - # create its own deps yet within the deferred init - inst = self.instantiate_dependency(cls, deferred_init=False) - instance_cache[cls] = inst - inst = instance_cache[cls] - else: - inst = cls_or_inst - # optionally skip the instance - if callable(inst.skip_func) and inst.skip_func(): - self.deps.pop(cls, None) - return None + # skip when the dependency is already present + if only_update and cls in self.deps: + return self.deps[cls] + + # create or get the instance + if is_cls: + # use the cache + if cls not in instance_cache: + # create the instance first without its deps, then cache it but do not + # create its own deps yet within the deferred init + inst = self.instantiate_dependency(cls, deferred_init=False) + instance_cache[cls] = inst + inst = instance_cache[cls] + else: + inst = cls_or_inst - # run the deferred init that creates its own deps - inst.deferred_init(instance_cache) + # optionally skip the instance + if callable(inst.skip_func) and inst.skip_func(): + self.deps.pop(cls, None) + return None - # store it - self.deps[cls] = inst + # run the deferred init that creates its own deps + inst.deferred_init(instance_cache) + # store and return it + self.deps[cls] = inst return self.deps[cls] # track dependent classes that are handled in the following @@ -1936,40 +1994,21 @@ def add_dep(cls_or_inst): instances.clear() # go through all dependent objects and create instances of classes, considering caching - objs = list(getattr(self, attr)) - while objs: - obj = objs.pop(0) - - # obj might be a deferred column - if isinstance(obj, self.DeferredColumn): - obj = obj(self) - - # when obj is falsy, skip it - if not obj: - continue - - # when obj is a set of objects, i.e., it cannot be understood as a Route, - # extend the loop and start over, otherwise handle obj as is - if isinstance(obj, set): - objs = list(obj) + objs - continue - - # handle other types + for obj in self._iter_dependency_set(getattr(self, attr)): + # handle array functions if ArrayFunction.derived_by(obj) or isinstance(obj, ArrayFunction): obj = add_dep(obj) if obj: added_deps.append(obj.__class__) instances.add(self.IOFlagged(obj, self.IOFlag.AUTO)) + # handle flagged array functions elif isinstance(obj, self.IOFlagged): obj = self.IOFlagged(add_dep(obj.wrapped), obj.io_flag) if obj: added_deps.append(obj.wrapped.__class__) instances.add(obj) - else: - instances.add(obj) - # synchronize dependencies # this might remove deps that were present in self.deps already before this method is called # but that were not added in the loop above @@ -1978,11 +2017,7 @@ def add_dep(cls_or_inst): if cls not in added_deps: del self.deps[cls] - def instantiate_dependency( - self, - cls: DerivableMeta, - **kwargs: Any, - ) -> ArrayFunction: + def instantiate_dependency(self, cls: DerivableMeta, **kwargs) -> ArrayFunction: """ Controls the instantiation of a dependency given by its *cls* and arbitrary *kwargs*. The latter update optional keyword arguments in :py:attr:`self.deps_kwargs` and are then @@ -2030,18 +2065,30 @@ def _get_columns( # declare _this_ call cached _cache.add(self.IOFlagged(self, io_flag)) - # add columns of all dependent objects - for obj in (self.uses_instances if io_flag == self.IOFlag.USES else self.produces_instances): - if isinstance(obj, (ArrayFunction, self.IOFlagged)): - # don't propagate to dependencies - if not dependencies: - continue + # add normal columns + objs = self.uses if io_flag == self.IOFlag.USES else self.produces + for obj in self._iter_dependency_set(objs): + # skip objects that do not refer to simple columns + if ArrayFunction.derived_by(obj) or isinstance(obj, (ArrayFunction, self.IOFlagged)): + continue + + if isinstance(obj, str): + # expand braces in strings + columns |= set(map(Route, law.util.brace_expand(obj))) + else: + # let Route handle everything else + columns.add(Route(obj)) - flagged = obj + # add columns of all dependent instances + if dependencies: + objs = (self.uses_instances if io_flag == self.IOFlag.USES else self.produces_instances) + for obj in objs: if isinstance(obj, ArrayFunction): flagged = self.IOFlagged(obj, io_flag) elif obj.io_flag == self.IOFlag.AUTO: flagged = self.IOFlagged(obj.wrapped, io_flag) + else: + flagged = obj # skip when already cached if flagged in _cache: @@ -2049,12 +2096,6 @@ def _get_columns( # add the columns columns |= flagged.wrapped._get_columns(flagged.io_flag, _cache=_cache) - elif isinstance(obj, str): - # expand braces in strings - columns |= set(map(Route, law.util.brace_expand(obj))) - else: - # let Route handle it - columns.add(Route(obj)) return columns @@ -2063,19 +2104,25 @@ def _get_used_columns(self, _cache: set | None = None) -> set[Route]: @property def used_columns(self) -> set[Route]: - return self._get_used_columns() + try: + return self._get_used_columns() + except AttributeError as e: + raise Exception(str(e)) from e def _get_produced_columns(self, _cache: set | None = None) -> set[Route]: return self._get_columns(io_flag=self.IOFlag.PRODUCES, _cache=_cache) @property def produced_columns(self) -> set[Route]: - return self._get_produced_columns() + try: + return self._get_produced_columns() + except AttributeError as e: + raise Exception(str(e)) from e def _check_columns(self, ak_array: ak.Array, io_flag: IOFlag) -> None: """ - Check if awkward array contains at least one column matching each - entry in 'uses' or 'produces' and raise Exception if none were found. + Check if awkward array contains at least one column matching each entry in 'uses' or + 'produces' and raise Exception if none were found. """ # get own columns, i.e, routes that are non-optional routes = [ @@ -2111,12 +2158,12 @@ def __call__(self, *args, **kwargs) -> Any: """ # check if the call_func is callable if not callable(self.call_func): - raise Exception(f"call_func of {self} is not callable") + raise Exception(f"call_func of '{self}' is not callable") # raise in case the call is actually being skipped if callable(self.skip_func) and self.skip_func(): raise Exception( - f"skip_func of {self} returned True, cannot invoke call_func; skip_func code: \n\n" + f"skip_func of '{self}' returned True, cannot invoke call_func; skip_func code:\n\n" f"{get_source_code(self.skip_func, indent=4)}", ) @@ -2213,7 +2260,29 @@ def skip_column( return tagged_column("skip", *routes) -class TaskArrayFunction(ArrayFunction): +class TaskArrayFunctionMeta(DerivableMeta): + + def __new__(metacls, cls_name: str, bases: tuple, cls_dict: dict) -> TaskArrayFunctionMeta: + # add an instance cache if not disabled + cls_dict.setdefault("cache_instances", True) + cls_dict["_instances"] = {} if cls_dict["cache_instances"] else None + + return super().__new__(metacls, cls_name, bases, cls_dict) + + def __call__(cls, *args, **kwargs) -> TaskArrayFunction: + # when not caching instances, return right away + if not cls.cache_instances: + return super().__call__(*args, **kwargs) + + # build the cache key from the inst_dict in kwargs + key = freeze((cls, kwargs.get("inst_dict", {}))) + if key not in cls._instances: + cls._instances[key] = super().__call__(*args, **kwargs) + + return cls._instances[key] + + +class TaskArrayFunction(ArrayFunction, metaclass=TaskArrayFunctionMeta): """ Subclass of :py:class:`ArrayFunction` providing an interface to certain task features such as declaring dependent or produced shifts, task requirements, and defining a custom setup @@ -2228,33 +2297,33 @@ class TaskArrayFunction(ArrayFunction): `py:attr:`shifts` itself. As opposed to more basic :py:class:`ArrayFunction`'s, instances of *this* class have a direct - interface to tasks and can influence their behavior - and vice-versa. For this purpose, custom - task requirements, and a setup of objects resulting from these requirements can be defined in a - similar, programmatic way. Also, they might define an optional *sandbox* that is required to run - this array function. + interface to tasks and can influence their behavior - and vice-versa. For this purpose, a + ``post-init`` hook receiving the task, custom task requirements, the setup of objects resulting + from these requirements, and a teardown method can be defined in a similar, programmatic way. + Also, they might define an optional *sandbox* that is required to run this array function. Exmple: .. code-block:: python - class my_func(ArrayFunction): + class my_func(TaskArrayFunction): uses = {"Jet.pt"} produces = {"Jet.pt_weighted"} - def call_func(self, events): + def call_func(self, events, **kwargs): # self.weights is defined below events["Jet", "pt_weighted"] = events.Jet.pt * self.weights # define requirements that (e.g.) compute the weights @my_func.requires - def requires(self, reqs): + def requires(self, task, reqs): # fill the requirements dict reqs["weights_task"] = SomeWeightsTask.req(self.task) reqs["columns_task"] = SomeColumnsTask.req(self.task) # define the setup step that loads event weights from the required task @my_func.setup - def setup(self, reqs, inputs, reader_targets): + def setup(self, task, reqs, inputs, reader_targets): # load the weights once, inputs is corresponding to what we added to reqs above weights = inputs["weights_task"].load(formatter="json") @@ -2300,6 +2369,13 @@ class the normal way, or use a decorator to wrap the main callable first and by The resolved, flat set of dependent or produced shifts. + .. py:attribute: post_init_func + + type: callable + + The registered function defining actions to be taken after the initialization and before + the actual setup, or *None*. + .. py:attribute: requires_func type: callable @@ -2312,6 +2388,12 @@ class the normal way, or use a decorator to wrap the main callable first and by The registered function performing the custom setup step, or *None*. + .. py:attribute:: teardown_func + + type: callable + + The registered function performing a custom teardown step, or *None*. + .. py:attribute:: sandbox type: str, None @@ -2336,20 +2418,16 @@ class the normal way, or use a decorator to wrap the main callable first and by """ # class-level attributes as defaults + post_init_func = None requires_func = None setup_func = None + teardown_func = None sandbox = None call_force = None max_chunk_size = None shifts = set() _dependency_sets = ArrayFunction._dependency_sets | {"shifts"} - def __str__(self) -> str: - """ - Returns a string representation of this TaskArrayFunction instance. - """ - return self.cls_name - @staticmethod def pick_cached_result(cached_result: T, *args, **kwargs) -> T: """ @@ -2377,13 +2455,34 @@ def pick_cached_result(cached_result: T, *args, **kwargs) -> T: # otherwise, also return all but the first cached return value return events, *cached_result[1:] + @classmethod + def post_init(cls, func: Callable[[dict], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`post_init_func` + which receives the task instance. The function should accept one argument: + + - *task*, the invoking task instance. + + The decorator does not return the wrapped function. + + .. note:: + + When the task invoking the requirement is workflow, be aware that both the actual + workflow instance as well as branch tasks might call the wrapped function. When the + requirements should differ between them, make sure to use the + :py:meth:`BaseWorkflow.is_workflow` and :py:meth:`BaseWorkflow.is_branch` methods to + distinguish the cases. + """ + cls.post_init_func = func + @classmethod def requires(cls, func: Callable[[dict], None]) -> None: """ Decorator to wrap a function *func* that should be registered as :py:meth:`requires_func` - which is used to define additional task requirements. The function should accept one - positional argument: + which is used to define additional task requirements. The function should accept two + arguments: + - *task*, the invoking task instance. - *reqs*, a dictionary into which requirements should be inserted. The decorator does not return the wrapped function. @@ -2402,9 +2501,10 @@ def requires(cls, func: Callable[[dict], None]) -> None: def setup(cls, func: Callable[[dict], None]) -> None: """ Decorator to wrap a function *func* that should be registered as :py:meth:`setup_func` - which is used to perform a custom setup of objects. The function should accept two - positional arguments: + which is used to perform a custom setup of objects. The function should accept four + arguments: + - *task*, the invoking task instance. - *reqs*, a dictionary containing the required tasks as defined by the custom :py:meth:`requires_func`. - *inputs*, a dictionary containing the outputs created by the tasks in *reqs*. @@ -2415,11 +2515,26 @@ def setup(cls, func: Callable[[dict], None]) -> None: """ cls.setup_func = func + @classmethod + def teardown(cls, func: Callable[[dict], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`teardown_func` + which is used to perform a custom teardown of objects at the end of processing. The function + should accept one argument: + + - *task*, the invoking task instance. + + The decorator does not return the wrapped function. + """ + cls.teardown_func = func + def __init__( self, *args, + post_init_func: Callable | law.NoValue | None = law.no_value, requires_func: Callable | law.NoValue | None = law.no_value, setup_func: Callable | law.NoValue | None = law.no_value, + teardown_func: Callable | law.NoValue | None = law.no_value, sandbox: str | law.NoValue | None = law.no_value, call_force: bool | law.NoValue | None = law.no_value, max_chunk_size: int | None = law.no_value, @@ -2433,10 +2548,14 @@ def __init__( super().__init__(*args, **kwargs) # add class-level attributes as defaults for unset arguments (no_value) + if post_init_func == law.no_value: + post_init_func = self.__class__.post_init_func if requires_func == law.no_value: requires_func = self.__class__.requires_func if setup_func == law.no_value: setup_func = self.__class__.setup_func + if teardown_func == law.no_value: + teardown_func = self.__class__.teardown_func if sandbox == law.no_value: sandbox = self.__class__.sandbox if call_force == law.no_value: @@ -2447,10 +2566,17 @@ def __init__( pick_cached_result = self.__class__.pick_cached_result # when custom funcs are passed, bind them to this instance + if post_init_func: + self.post_init_func = post_init_func.__get__(self, self.__class__) if requires_func: self.requires_func = requires_func.__get__(self, self.__class__) if setup_func: self.setup_func = setup_func.__get__(self, self.__class__) + if teardown_func: + self.teardown_func = teardown_func.__get__(self, self.__class__) + + # remember if certain custom functions were called + self._post_init_called = False # other attributes self.sandbox = sandbox @@ -2469,14 +2595,33 @@ def __getattr__(self, attr: str) -> Any: if attr in self.inst_dict: return self.inst_dict[attr] + # extra warnings for a limited period of time to ensure a smooth transition to the new + # task array function interface + if attr in {"task", "global_shift_inst", "local_shift_inst"}: + docs_url1 = get_docs_url("user_guide", "task_array_functions.html") + docs_url2 = get_docs_url("user_guide", "02_03_transition.html") + logger.warning_once( + f"taf_interface_deprected_{attr}", + f"direct access to attribute '{attr}' was removed in favor of a) using the 'task' " + "instance passed as an argument to most task array function hooks, or b) using the " + "correct task array function hook for the specific use case (e.g. pre_init() or " + f"post_init() instead of init()); see {docs_url1} and {docs_url2} for more info", + ) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") + def __str__(self) -> str: + """ + Returns a string representation of this TaskArrayFunction instance. + """ + return self.cls_name + def instantiate_dependency(self, cls: DerivableMeta, **kwargs: Any) -> TaskArrayFunction: """ Controls the instantiation of a dependency given by its *cls* and arbitrary *kwargs*, updated by *this* instances :py:attr:`inst_dict`. """ - # add inst_dict when cls is a TaskArrayFunction itself + # add a reference to the same inst_dict when cls is a TaskArrayFunction itself if TaskArrayFunction.derived_by(cls): kwargs.setdefault("inst_dict", self.inst_dict) @@ -2534,82 +2679,140 @@ def _get_all_shifts(self, _cache: set | None = None) -> set[str]: @property def all_shifts(self) -> set[str]: - return self._get_all_shifts() + try: + return self._get_all_shifts() + except AttributeError as e: + raise Exception(str(e)) from e + + def run_post_init( + self, + task: law.Task, + force: bool = False, + _cache: set | None = None, + ) -> None: + """ + Recursively runs the :py:meth:`post_init_func` of this instance and all dependencies. + """ + # create the call cache + if _cache is None: + _cache = set() + + # run the requirements of all dependent objects + for dep in self.get_dependencies(): + if isinstance(dep, TaskArrayFunction): + dep.run_post_init(task, force=force, _cache=_cache) + + # run this instance's post init function + if self not in _cache and callable(self.post_init_func): + _cache.add(self) + if not self._post_init_called or force: + self.post_init_func(task=task) + self._post_init_called = True def run_requires( self, - reqs: dict | None = None, + task: law.Task, + reqs: dict[str, DotDict[str, Any]] | None = None, _cache: set | None = None, - ) -> dict: + ) -> dict[str, DotDict[str, Any]]: """ Recursively runs the :py:meth:`requires_func` of this instance and all dependencies. *reqs* defaults to an empty dictionary which should be filled to store the requirements. """ - # default requirements + # defaults if reqs is None: - reqs = DotDict() + reqs = {} # create the call cache if _cache is None: _cache = set() + # run the requirements of all dependent objects + for dep in self.get_dependencies(): + if isinstance(dep, TaskArrayFunction): + dep.run_requires(task, reqs=reqs, _cache=_cache) + # run this instance's requires function - if callable(self.requires_func): + if self not in _cache and callable(self.requires_func): + _cache.add(self) if self.cls_name not in reqs: reqs[self.cls_name] = DotDict() - self.requires_func(reqs[self.cls_name]) - - # run the requirements of all dependent objects - for dep in self.get_dependencies(): - if isinstance(dep, TaskArrayFunction) and dep not in _cache: - _cache.add(dep) - dep.run_requires(reqs=reqs, _cache=_cache) + self.requires_func(task=task, reqs=reqs[self.cls_name]) return reqs def run_setup( self, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict[str, law.FileSystemFileTarget] | None = None, + task: law.Task, + reqs: dict[str, DotDict[str, Any]] | None = None, + inputs: dict[str, Any] | None = None, + reader_targets: law.util.InsertableDict[str, law.FileSystemFileTarget] | None = None, _cache: set | None = None, - ) -> dict[str, law.FileSystemTarget]: + ) -> law.util.InsertableDict[str, law.FileSystemFileTarget]: """ Recursively runs the :py:meth:`setup_func` of this instance and all dependencies. *reqs* corresponds to the requirements created by :py:func:`run_requires`, and *inputs* are their outputs. *reader_targets* defaults to an empty InsertableDict which should be filled to store targets of columnar data that are to be included in an event chunk loop. """ - # default column targets + # defaults + if reqs is None: + reqs = {} + if inputs is None: + inputs = {} if reader_targets is None: - reader_targets = DotDict() + reader_targets = law.util.InsertableDict() # create the call cache if _cache is None: _cache = set() + # run the setup of all dependent objects + for dep in self.get_dependencies(): + if isinstance(dep, TaskArrayFunction): + dep.run_setup( + task, + reqs=reqs, + inputs=inputs, + reader_targets=reader_targets, + _cache=_cache, + ) + # run this instance's setup function if callable(self.setup_func): + _cache.add(self) if self.cls_name not in reqs: reqs[self.cls_name] = DotDict() if self.cls_name not in inputs: inputs[self.cls_name] = DotDict() - self.setup_func(reqs[self.cls_name], inputs[self.cls_name], reader_targets) + self.setup_func( + task=task, + reqs=reqs[self.cls_name], + inputs=inputs[self.cls_name], + reader_targets=reader_targets, + ) - # run the setup of all dependent objects + return reader_targets + + def run_teardown(self, task: law.Task, _cache: set | None = None) -> None: + """ + Recursively runs the :py:meth:`teardown_func` of this instance and all dependencies. + """ + # create the call cache + if _cache is None: + _cache = set() + + # run the teardown of all dependent objects for dep in self.get_dependencies(): - if isinstance(dep, TaskArrayFunction) and dep not in _cache: - _cache.add(dep) - dep.run_setup(reqs, inputs, reader_targets, _cache=_cache) + if isinstance(dep, TaskArrayFunction): + dep.run_teardown(task, _cache=_cache) - return reader_targets + # run this instance's teardown function + if self not in _cache and callable(self.teardown_func): + _cache.add(self) + self.teardown_func(task=task) - def __call__( - self, - *args, - call_force: bool | None = None, - **kwargs, - ) -> Any: + def __call__(self, *args, call_force: bool | None = None, **kwargs) -> Any: """ Calls the wrapped :py:meth:`call_func` with all *args* and *kwargs*. The latter is updated with :py:attr:`call_kwargs` when set, but giving priority to existing *kwargs*. @@ -2675,10 +2878,11 @@ def get_sandbox(self, raise_on_collision: bool = False) -> str | None: f"multiple sandboxes found while traversing dependencies of {self.cls_name}: " f"{','.sandboxes}" ) - if not raise_on_collision: - logger.warning(f"{msg}; using the first one") - return sandboxes[0] - raise Exception(msg) + if raise_on_collision: + raise Exception(msg) + + logger.warning(f"{msg}; using the first one") + return sandboxes[0] def get_min_chunk_size(self) -> int | None: """ @@ -2702,7 +2906,7 @@ class NoThreadPool(object): class SyncResult(object): - def __init__(self, return_value: Any): + def __init__(self, return_value: Any) -> None: super().__init__() self.return_value = return_value @@ -2712,7 +2916,7 @@ def ready(self) -> bool: def get(self) -> Any: return self.return_value - def __init__(self, processes: int): + def __init__(self, processes: int) -> None: super().__init__() self._processes = processes @@ -2752,12 +2956,12 @@ class TaskQueue(object): # task object Task = namedtuple("Task", ["func", "args", "kwargs"]) - def __init__(self): + def __init__(self) -> None: super().__init__() self._tasks = {} - def __bool__(self): + def __bool__(self) -> bool: return bool(self._tasks) def add( @@ -2818,7 +3022,7 @@ def __init__( path: str, open_options: dict | None = None, materialization_strategy: MaterializationStrategy = MaterializationStrategy.PARTITIONS, - ): + ) -> None: super().__init__() open_options = open_options or {} diff --git a/columnflow/config_util.py b/columnflow/config_util.py index 86d7ff5ce..3a3da34f3 100644 --- a/columnflow/config_util.py +++ b/columnflow/config_util.py @@ -9,19 +9,24 @@ __all__ = [] import re +import dataclasses import itertools -from collections import OrderedDict +from collections import OrderedDict, defaultdict import law import order as od -from columnflow.util import maybe_import +from columnflow.util import maybe_import, get_docs_url +from columnflow.columnar_util import flat_np_view, layout_ak_array from columnflow.types import Callable, Any, Sequence ak = maybe_import("awkward") np = maybe_import("numpy") +logger = law.logger.get_logger(__name__) + + def get_events_from_categories( events: ak.Array, categories: Sequence[str | od.Category], @@ -60,6 +65,39 @@ def get_events_from_categories( return events[mask] +def get_category_name_columns( + category_ids: ak.Array, + config_inst: od.Config, +) -> ak.Array: + """ + Function that transforms column of category ids to column of category names. + + :param category_ids: Awkward array of category ids. + :param config_inst: Config instance from which to load category instances. + :raises ValueError: If any of the category ids is not defined in the *config_inst*. + :return: Awkward array of category names with the same shape as *category_ids* + """ + flat_ids = flat_np_view(category_ids) + + # map all category ids present in *category_ids* to category instances + category_map = { + _id: config_inst.get_category(_id, default=None) + for _id in set(flat_ids) + } + if any(cat is None for cat in category_map.values()): + undefined_ids = {cat_id for cat_id, cat_inst in category_map.items() if cat_inst is None} + raise ValueError(f"undefined category ids: {', '.join(map(str, undefined_ids))}") + + # Create a vectorized function for the mapping + map_to_name = np.vectorize(lambda _id: category_map[_id].name) + + # Apply the mapping and layout to the original shape + flat_names = map_to_name(flat_ids) + category_names = layout_ak_array(flat_names, category_ids) + + return category_names + + def get_root_processes_from_campaign(campaign: od.config.Campaign) -> od.unique.UniqueObjectIndex: """ Extracts all root process objects from datasets contained in an order *campaign* and returns @@ -274,6 +312,25 @@ def add_shift_aliases( shift.x.column_aliases = _aliases +def get_shift_from_configs(configs: list[od.Config], shift: str | od.Shift, silent: bool = False) -> od.Shift | None: + """ + Given a list of *configs* and a *shift* name or instance, returns the corresponding shift instance from the first + config that contains it. If *silent* is *True*, *None* is returned instead of raising an exception in case the shift + is not found. + """ + if isinstance(shift, od.Shift): + shift = shift.name + + for config in configs: + if config.has_shift(shift): + return config.get_shift(shift) + + if silent: + return None + + raise ValueError(f"shift '{shift}' not found in any of the given configs: {configs}") + + def get_shifts_from_sources(config: od.Config, *shift_sources: Sequence[str]) -> list[od.Shift]: """ Takes a *config* object and returns a list of shift instances for both directions given a @@ -288,6 +345,39 @@ def get_shifts_from_sources(config: od.Config, *shift_sources: Sequence[str]) -> ) +def group_shifts( + shifts: od.Shift | Sequence[od.Shift], +) -> tuple[od.Shift | None, dict[str, tuple[od.Shift, od.Shift]]]: + """ + Takes several :py:class:`order.Shift` instances *shifts* and groups them according to their + shift source. The nominal shift, if present, is returned separately. The remaining shifts are + grouped by their source and the corresponding up and down shifts are stored in a dictionary. + Example: + .. code-block:: python + # assuming the following shifts exist + group_shifts([nominal, x_up, y_up, y_down, x_down]) + # -> (nominal, {"x": (x_up, x_down), "y": (y_up, y_down)}) + An exception is raised in case a shift source is represented only by its up or down shift. + """ + nominal = None + grouped = defaultdict(lambda: [None, None]) + + up_sources = set() + down_sources = set() + for shift in law.util.make_list(shifts): + if shift.name == "nominal": + nominal = shift + else: + grouped[shift.source][shift.is_up] = shift + (up_sources if shift.is_up else down_sources).add(shift.source) + + # check completeness of shifts + if (diff := up_sources.symmetric_difference(down_sources)): + raise ValueError(f"shift sources {diff} are not complete and cannot be grouped") + + return nominal, dict(grouped) + + def expand_shift_sources(shifts: Sequence[str] | set[str]) -> list[str]: """ Given a sequence *shifts* containing either shift names (``_``) or shift @@ -371,43 +461,74 @@ def add_category( return parent.add_category(**kwargs) +@dataclasses.dataclass +class CategoryGroup: + """ + Container to store information about a group of categories, mostly used for creating combinations in + :py:func:`create_category_combinations`. + + :param categories: List of :py:class:`order.Category` objects or names that refer to the desired category. + :param is_complete: Should be *True* if the union of category selections covers the full phase space (no gaps). + :param has_overlap: Should be *False* if all categories are pairwise disjoint (no overlap). + :param warn: If *True*, a warning is issued when summing over the group of categories. + """ + + categories: list[od.Category | str] + is_complete: bool + has_overlap: bool + warn: bool = True + + @property + def is_partition(self) -> bool: + """ + Returns *True* if the group of categories is a full partition of the phase space (no overlap, no gaps). + """ + return self.is_complete and not self.has_overlap + + def create_category_combinations( config: od.Config, - categories: dict[str, list[od.Category]], + categories: dict[str, CategoryGroup | list[od.Category]], name_fn: Callable[[Any], str], kwargs_fn: Callable[[Any], dict] | None = None, skip_existing: bool = True, skip_fn: Callable[[dict[str, od.Category], str], bool] | None = None, ) -> int: """ - Given a *config* object and sequences of *categories* in a dict, creates all combinations of - possible leaf categories at different depths, connects them with parent - child relations - (see :py:class:`order.Category`) and returns the number of newly created categories. - - *categories* should be a dictionary that maps string names to sequences of categories that - should be combined. The names are used as keyword arguments in a callable *name_fn* that is - supposed to return the name of newly created categories (see example below). - - Each newly created category is instantiated with this name as well as arbitrary keyword - arguments as returned by *kwargs_fn*. This function is called with the categories (in a - dictionary, mapped to the sequence names as given in *categories*) that contribute to the newly - created category and should return a dictionary. If the fields ``"id"`` and ``"selection"`` are - missing, they are filled with reasonable defaults leading to a auto-generated, deterministic id - and a list of all parent selection statements. - - If the name of a new category is already known to *config* it is skipped unless *skip_existing* - is *False*. In addition, *skip_fn* can be a callable that receives a dictionary mapping group - names to categories that represents the combination of categories to be added. In case *skip_fn* - returns *True*, the combination is skipped. + Given a *config* object and sequences of *categories* in a dict, creates all combinations of possible leaf + categories at different depths, connects them with parent - child relations (see :py:class:`order.Category`) and + returns the number of newly created categories. + + *categories* should be a dictionary that maps string names to :py:class:`CategoryGroup` objects which are thin + wrappers around sequences of categories (objects or names). Group names (dictionary keys) are used as keyword + arguments in a callable *name_fn* that is supposed to return the name of newly created categories (see example + below). + + .. note:: + + The :py:attr:`CategoryGroup.is_complete` and :py:attr:`CategoryGroup.has_overlap` attributes are imperative for + columnflow to determine whether the summation over specific categories is valid or may result in under- or + over-counting when combining leaf categories. These checks may be performed by other functions and tools based + on information derived from groups and stored in auxiliary fields of the newly created categories. + + Each newly created category is instantiated with this name as well as arbitrary keyword arguments as returned by + *kwargs_fn*. This function is called with the categories (in a dictionary, mapped to the sequence names as given in + *categories*) that contribute to the newly created category and should return a dictionary. If the fields ``"id"`` + and ``"selection"`` are missing, they are filled with reasonable defaults leading to a auto-generated, deterministic + id and a list of all parent selection statements. + + If the name of a new category is already known to *config* it is skipped unless *skip_existing* is *False*. In + addition, *skip_fn* can be a callable that receives a dictionary mapping group names to categories that represents + the combination of categories to be added. In case *skip_fn* returns *True*, the combination is skipped. Example: .. code-block:: python categories = { - "lepton": [cfg.get_category("e"), cfg.get_category("mu")], - "n_jets": [cfg.get_category("1j"), cfg.get_category("2j")], - "n_tags": [cfg.get_category("0t"), cfg.get_category("1t")], + "lepton": CategoryGroup(categories=["e", "mu"], is_complete=False, has_overlap=False), + "n_jets": CategoryGroup(categories=["eq0j", "eq1j", "ge2j"], is_complete=True, has_overlap=False), + "n_tags": CategoryGroup(categories=["0t", "1t"], is_complete=False, has_overlap=False), } def name_fn(categories): @@ -423,20 +544,40 @@ def kwargs_fn(categories): create_category_combinations(cfg, categories, name_fn, kwargs_fn) :param config: :py:class:`order.Config` object for which the categories are created. - :param categories: Dictionary that maps group names to sequences of categories. - :param name_fn: Callable that receives a dictionary mapping group names to categories and - returns the name of the newly created category. - :param kwargs_fn: Callable that receives a dictionary mapping group names to categories and - returns a dictionary of keyword arguments that are forwarded to the category constructor. - :param skip_existing: If *True*, skip the creation of a category when it already exists in - *config*. - :param skip_fn: Callable that receives a dictionary mapping group names to categories and - returns *True* if the combination should be skipped. + :param categories: Dictionary that maps group names to :py:class:`CategoryGroup` containers. + :param name_fn: Callable that receives a dictionary mapping group names to categories and returns the name of the + newly created category. + :param kwargs_fn: Callable that receives a dictionary mapping group names to categories and returns a dictionary of + keyword arguments that are forwarded to the category constructor. + :param skip_existing: If *True*, skip the creation of a category when it already exists in *config*. + :param skip_fn: Callable that receives a dictionary mapping group names to categories and returns *True* if the + combination should be skipped. :raises TypeError: If *name_fn* is not a callable. :raises TypeError: If *kwargs_fn* is not a callable when set. :raises ValueError: If a non-unique category id is detected. :return: Number of newly created categories. """ + # cast categories + for name, _categories in categories.items(): + # ensure CategoryGroup is used + if not isinstance(_categories, CategoryGroup): + docs_url = get_docs_url("api", "config_util.html", anchor="columnflow.config_util.CategoryGroup") + logger.warning_once( + "deprecated_category_group_lists", + f"using a list to define a sequence of categories for create_category_combinations() is depcreated " + f"and will be removed in a future version, please use a CategoryGroup instance instead: {docs_url}", + ) + _categories = CategoryGroup( + categories=law.util.make_list(_categories), + is_complete=True, + has_overlap=False, + ) + categories[name] = _categories + # cast string category names to instances + for i, cat in enumerate(_categories.categories): + if isinstance(cat, str): + _categories.categories[i] = config.get_category(cat) + n_created_categories = 0 unique_ids_cache = {cat.id for cat, _, _ in config.walk_categories()} n_groups = len(categories) @@ -459,7 +600,7 @@ def kwargs_fn(categories): for _group_names in itertools.combinations(group_names, _n_groups): # build the product of all categories for the given groups - _categories = [categories[group_name] for group_name in _group_names] + _categories = [categories[group_name].categories for group_name in _group_names] for root_cats in itertools.product(*_categories): # build the name root_cats = dict(zip(_group_names, root_cats)) diff --git a/columnflow/hist_util.py b/columnflow/hist_util.py index 600c67a56..7f16da17a 100644 --- a/columnflow/hist_util.py +++ b/columnflow/hist_util.py @@ -8,6 +8,7 @@ __all__ = [] +import functools import law import order as od @@ -31,11 +32,10 @@ def fill_hist( fill_kwargs: dict[str, Any] | None = None, ) -> None: """ - Fills a histogram *h* with data from an awkward array, numpy array or nested dictionary *data*. - The data is assumed to be structured in the same way as the histogram axes. If - *last_edge_inclusive* is *True*, values that would land exactly on the upper-most bin edge of an - axis are shifted into the last bin. If it is *None*, the behavior is determined automatically - and depends on the variable axis type. In this case, shifting is applied to all continuous, + Fills a histogram *h* with data from an awkward array, numpy array or nested dictionary *data*. The data is assumed + to be structured in the same way as the histogram axes. If *last_edge_inclusive* is *True*, values that would land + exactly on the upper-most bin edge of an axis are shifted into the last bin. If it is *None*, the behavior is + determined automatically and depends on the variable axis type. In this case, shifting is applied to all continuous, non-circular axes. """ if fill_kwargs is None: @@ -59,7 +59,7 @@ def allows_shift(ax) -> bool: # check data if not isinstance(data, dict): if len(axis_names) != 1: - raise ValueError("got multi-dimensional hist but only one dimensional data") + raise ValueError("got multi-dimensional hist but only one-dimensional data") data = {axis_names[0]: data} else: for name in axis_names: @@ -73,15 +73,30 @@ def allows_shift(ax) -> bool: data[ax.name] = ak.copy(data[ax.name]) flat_np_view(data[ax.name])[right_egde_mask] -= ax.widths[-1] * 1e-5 + # check if conversion to records is needed + arr_types = (ak.Array, np.ndarray) + vals = list(data.values()) + convert = ( + # values is a mixture of singular and array types + (any(isinstance(v, arr_types) for v in vals) and not all(isinstance(v, arr_types) for v in vals)) or + # values contain at least one array with more than one dimension + any(isinstance(v, arr_types) and v.ndim != 1 for v in vals) + ) + + # actual conversion + if convert: + arrays = ak.flatten(ak.cartesian(data)) + data = {field: arrays[field] for field in arrays.fields} + del arrays + # fill - arrays = ak.flatten(ak.cartesian(data)) - h.fill(**fill_kwargs, **{field: arrays[field] for field in arrays.fields}) + h.fill(**fill_kwargs, **data) def add_hist_axis(histogram: hist.Hist, variable_inst: od.Variable) -> hist.Hist: """ - Add an axis to a histogram based on a variable instance. The axis_type is chosen - based on the variable instance's "axis_type" auxiliary. + Add an axis to a histogram based on a variable instance. The axis_type is chosen based on the variable instance's + "axis_type" auxiliary. :param histogram: The histogram to add the axis to. :param variable_inst: The variable instance to use for the axis. @@ -102,20 +117,20 @@ def add_hist_axis(histogram: hist.Hist, variable_inst: od.Variable) -> hist.Hist default_axis_type = "integer" if variable_inst.discrete_x else "variable" axis_type = variable_inst.x("axis_type", default_axis_type).lower() - if axis_type in ("variable", "var"): + if axis_type in {"variable", "var"}: return histogram.Var(variable_inst.bin_edges, **axis_kwargs) - if axis_type in ("integer", "int"): + if axis_type in {"integer", "int"}: return histogram.Integer( int(variable_inst.bin_edges[0]), int(variable_inst.bin_edges[-1]), **axis_kwargs, ) - if axis_type in ("boolean", "bool"): + if axis_type in {"boolean", "bool"}: return histogram.Boolean(**axis_kwargs) - if axis_type in ("intcategory", "intcat"): + if axis_type in {"intcategory", "intcat"}: binning = ( [int(b) for b in variable_inst.binning] if isinstance(variable_inst.binning, list) @@ -124,16 +139,13 @@ def add_hist_axis(histogram: hist.Hist, variable_inst: od.Variable) -> hist.Hist axis_kwargs.setdefault("growth", True) return histogram.IntCat(binning, **axis_kwargs) - if axis_type in ("strcategory", "strcat"): + if axis_type in {"strcategory", "strcat"}: axis_kwargs.setdefault("growth", True) return histogram.StrCat([], **axis_kwargs) - if axis_type in ("regular", "reg"): + if axis_type in {"regular", "reg"}: if not variable_inst.even_binning: - logger.warning( - "regular axis with uneven binning is not supported, using first and last bin edge " - "instead", - ) + logger.warning("regular axis with uneven binning is not supported, using first and last bin edge instead") return histogram.Regular( variable_inst.n_bins, variable_inst.bin_edges[0], @@ -144,24 +156,142 @@ def add_hist_axis(histogram: hist.Hist, variable_inst: od.Variable) -> hist.Hist raise ValueError(f"unknown axis type '{axis_type}'") +def get_axis_kwargs(axis: hist.axis.AxesMixin) -> dict[str, Any]: + """ + Extract information from an *axis* instance that would be needed to create a new one. + + :param axis: The axis instance to extract information from. + :return: The extracted information in a dict. + """ + axis_attrs = ["name", "label"] + traits_attrs = [] + kwargs = {} + + if isinstance(axis, hist.axis.Variable): + axis_attrs.append("edges") + traits_attrs = ["underflow", "overflow", "growth", "circular"] + elif isinstance(axis, hist.axis.Regular): + axis_attrs = ["transform"] + traits_attrs = ["underflow", "overflow", "growth", "circular"] + kwargs["bins"] = axis.size + kwargs["start"] = axis.edges[0] + kwargs["stop"] = axis.edges[-1] + elif isinstance(axis, hist.axis.Integer): + traits_attrs = ["underflow", "overflow", "growth", "circular"] + kwargs["start"] = axis.edges[0] + kwargs["stop"] = axis.edges[-1] + elif isinstance(axis, hist.axis.Boolean): + # nothing to add to common attributes + pass + elif isinstance(axis, (hist.axis.IntCategory, hist.axis.StrCategory)): + traits_attrs = ["overflow", "growth"] + kwargs["categories"] = list(axis) + else: + raise NotImplementedError(f"axis type '{type(axis).__name__}' not supported") + + return ( + {attr: getattr(axis, attr) for attr in axis_attrs} | + {attr: getattr(axis.traits, attr) for attr in traits_attrs} | + kwargs + ) + + +def copy_axis(axis: hist.axis.AxesMixin, **kwargs: dict[str, Any]) -> hist.axis.AxesMixin: + """ + Copy an axis with the option to override its attributes. + """ + # create arguments for new axis from overlay with current and requested ones + axis_kwargs = get_axis_kwargs(axis) | kwargs + + # create new instance + return type(axis)(**axis_kwargs) + + def create_hist_from_variables( *variable_insts, - int_cat_axes: tuple[str] | None = None, + categorical_axes: tuple[tuple[str, str]] | None = None, weight: bool = True, + storage: str | None = None, ) -> hist.Hist: histogram = hist.Hist.new - # integer category axes - if int_cat_axes: - for name in int_cat_axes: - histogram = histogram.IntCat([], name=name, growth=True) - - # requested axes + # additional category axes + if categorical_axes: + for name, axis_type in categorical_axes: + if axis_type in ("intcategory", "intcat"): + histogram = histogram.IntCat([], name=name, growth=True) + elif axis_type in ("strcategory", "strcat"): + histogram = histogram.StrCat([], name=name, growth=True) + else: + raise ValueError(f"unknown axis type '{axis_type}' in argument 'categorical_axes'") + + # requested axes from variables for variable_inst in variable_insts: histogram = add_hist_axis(histogram, variable_inst) - # weight storage - if weight: + # add the storage + if storage is None: + # use weight value for backwards compatibility + storage = "weight" if weight else "double" + else: + storage = storage.lower() + if storage == "weight": histogram = histogram.Weight() + elif storage == "double": + histogram = histogram.Double() + else: + raise ValueError(f"unknown storage type '{storage}'") return histogram + + +create_columnflow_hist = functools.partial(create_hist_from_variables, categorical_axes=( + # axes that are used in columnflow tasks per default + # (NOTE: "category" axis is filled as int, but transformed to str afterwards) + ("category", "intcat"), + ("process", "intcat"), + ("shift", "strcat"), +)) + + +def translate_hist_intcat_to_strcat( + h: hist.Hist, + axis_name: str, + id_map: dict[int, str], +) -> hist.Hist: + out_axes = [ + ax if ax.name != axis_name else hist.axis.StrCategory( + [id_map[v] for v in list(ax)], + name=ax.name, + label=ax.label, + growth=ax.traits.growth, + ) + for ax in h.axes + ] + return hist.Hist(*out_axes, storage=h.storage_type(), data=h.view(flow=True)) + + +def add_missing_shifts( + h: hist.Hist, + expected_shifts_bins: set[str], + str_axis: str = "shift", + nominal_bin: str = "nominal", +) -> None: + """ + Adds missing shift bins to a histogram *h*. + """ + # get the set of bins that are missing in the histogram + shift_bins = set(h.axes[str_axis]) + missing_shifts = set(expected_shifts_bins) - shift_bins + if missing_shifts: + nominal = h[{str_axis: hist.loc(nominal_bin)}] + for missing_shift in missing_shifts: + # for each missing shift, create the missing shift bin with an + # empty fill and then copy the nominal histogram into it + dummy_fill = [ + ax[0] if ax.name != str_axis else missing_shift + for ax in h.axes + ] + h.fill(*dummy_fill, weight=0) + # TODO: this might skip overflow and underflow bins + h[{str_axis: hist.loc(missing_shift)}] = nominal.view() diff --git a/columnflow/histogramming/__init__.py b/columnflow/histogramming/__init__.py new file mode 100644 index 000000000..2282f94fb --- /dev/null +++ b/columnflow/histogramming/__init__.py @@ -0,0 +1,261 @@ +# coding: utf-8 + +""" +Tools for producing histograms and event-wise weights. +""" + +from __future__ import annotations + +import inspect + +import law +import order as od + +from columnflow.types import Callable +from columnflow.util import DerivableMeta, maybe_import +from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Any + + +hist = maybe_import("hist") + + +class HistProducer(TaskArrayFunction): + """ + Base class for all histogram producers, i.e., functions that control the creation of histograms, event weights, and + optional post-processing. + + .. py:attribute:: create_hist_func + + type: callable + + The registered function performing the custom histogram creation. + + .. py:attribute:: fill_hist_func + + type: callable + + The registered function performing the custom histogram filling. + + .. py:attribute:: post_process_hist_func + + type: callable + + The registered function for performing an optional post-processing of histograms before they are saved. + + .. py:attribute:: post_process_merged_hist_func + + type: callable + + The registered function for performing an optional post-processing of histograms after they are merged. + """ + + # class-level attributes as defaults + create_hist_func = None + fill_hist_func = None + post_process_hist_func = None + post_process_merged_hist_func = None + skip_compatibility_check = False + exposed = True + + @classmethod + def hist_producer( + cls, + func: Callable | None = None, + bases: tuple = (), + mc_only: bool = False, + data_only: bool = False, + **kwargs, + ) -> DerivableMeta | Callable: + """ + Decorator for creating a new :py:class:`HistProducer` subclass with additional, optional *bases* and attaching + the decorated function to it as :py:meth:`~HistProducer.call_func`. + + When *mc_only* (*data_only*) is *True*, the hist producer is skipped and not considered by other task array + functions in case they are evaluated on a :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` + attribute) whose ``is_mc`` (``is_data``) attribute is *False*. + + All additional *kwargs* are added as class members of the new subclasses. + + :param func: Function to be wrapped and integrated into new :py:class:`HistProducer` class. + :param bases: Additional bases for the new hist producer. + :param mc_only: Boolean flag indicating that this hist producer should only run on Monte Carlo simulation and + skipped for real data. + :param data_only: Boolean flag indicating that this hist producer should only run on real data and skipped for + Monte Carlo simulation. + :return: New hist producer subclass. + """ + def decorator(func: Callable) -> DerivableMeta: + # create the class dict + cls_dict = { + **kwargs, + "call_func": func, + "mc_only": mc_only, + "data_only": data_only, + } + + # get the module name + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + + # get the producer name + cls_name = cls_dict.pop("cls_name", func.__name__) + + # hook to update the class dict during class derivation + def update_cls_dict(cls_name, cls_dict, get_attr): + mc_only = get_attr("mc_only") + data_only = get_attr("data_only") + + # optionally add skip function + if mc_only and data_only: + raise Exception(f"hist producer {cls_name} received both mc_only and data_only") + + if mc_only or data_only: + if cls_dict.get("skip_func"): + raise Exception( + f"hist producer {cls_name} received custom skip_func, but either mc_only or data_only " + "are set", + ) + + if "skip_func" not in cls_dict: + def skip_func(self, **kwargs) -> bool: + # check mc_only and data_only + if mc_only and not self.dataset_inst.is_mc: + return True + if data_only and not self.dataset_inst.is_data: + return True + + # in all other cases, do not skip + return False + + cls_dict["skip_func"] = skip_func + + return cls_dict + + cls_dict["update_cls_dict"] = update_cls_dict + + # create the subclass + subclass = cls.derive(cls_name, bases=bases, cls_dict=cls_dict, module=module) + + return subclass + + return decorator(func) if func else decorator + + @classmethod + def create_hist(cls, func: Callable[[dict], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`create_hist_func`. The function + should accept two arguments: + + - *variables*, a list of :py:class:`order.Variable` instances (usually one). + - *task*, the invoking task instance. + + The return value of the function should be a histogram object or a container with histogram objects. + The decorator does not return the wrapped function. + """ + cls.create_hist_func = func + + @classmethod + def fill_hist(cls, func: Callable[[dict], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`fill_hist_func`. The function should + accept three arguments: + + - *h*, the histogram (or a container with histograms) to fill. + - *data*, a dictionary with data to fill. + - *task*, the invoking task instance. + + The decorator does not return the wrapped function. + """ + cls.fill_hist_func = func + + @classmethod + def post_process_hist(cls, func: Callable[[dict], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`post_process_hist_func`. The function + should accept two arguments: + + - *h*, the histogram (or a container with histograms) to post process. + - *task*, the invoking task instance. + + The decorator does not return the wrapped function. + """ + cls.post_process_hist_func = func + + @classmethod + def post_process_merged_hist(cls, func: Callable[[dict], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`post_process_merged_hist_func`. The + function should accept two arguments: + + - *h*, the histogram (or a container with histograms) to post process. + - *task*, the invoking task instance. + + The return value of the function should be a histogram object. + The decorator does not return the wrapped function. + """ + cls.post_process_merged_hist_func = func + + def __init__( + self, + *args, + create_hist_func: Callable | law.NoValue | None = law.no_value, + fill_hist_func: Callable | law.NoValue | None = law.no_value, + post_process_hist_func: Callable | law.NoValue | None = law.no_value, + post_process_merged_hist_func: Callable | law.NoValue | None = law.no_value, + **kwargs, + ): + super().__init__(*args, **kwargs) + + # add class-level attributes as defaults for unset arguments (no_value) + if create_hist_func == law.no_value: + create_hist_func = self.__class__.create_hist_func + if fill_hist_func == law.no_value: + fill_hist_func = self.__class__.fill_hist_func + if post_process_hist_func == law.no_value: + post_process_hist_func = self.__class__.post_process_hist_func + if post_process_merged_hist_func == law.no_value: + post_process_merged_hist_func = self.__class__.post_process_merged_hist_func + + # when custom funcs are passed, bind them to this instance + if create_hist_func: + self.create_hist_func = create_hist_func.__get__(self, self.__class__) + if fill_hist_func: + self.fill_hist_func = fill_hist_func.__get__(self, self.__class__) + if post_process_hist_func: + self.post_process_hist_func = post_process_hist_func.__get__(self, self.__class__) + if post_process_merged_hist_func: + self.post_process_merged_hist_func = post_process_merged_hist_func.__get__(self, self.__class__) + + def run_create_hist(self, variables: list[od.Variable], task: law.Task) -> Any: + """ + Invokes the :py:meth:`create_hist_func` of this instance and returns its result, forwarding all arguments. + """ + return self.create_hist_func(variables, task=task) + + def run_fill_hist(self, h: Any, data: dict[str, Any], task: law.Task) -> None: + """ + Invokes the :py:meth:`fill_hist_func` of this instance and returns its result, forwarding all arguments. + """ + return self.fill_hist_func(h, data, task=task) + + def run_post_process_hist(self, h: Any, task: law.Task) -> Any: + """ + Invokes the :py:meth:`post_process_hist_func` of this instance and returns its result, forwarding all arguments. + """ + if not callable(self.post_process_hist_func): + return h + return self.post_process_hist_func(h, task=task) + + def run_post_process_merged_hist(self, h: Any, task: law.Task) -> hist.Histogram: + """ + Invokes the :py:meth:`post_process_merged_hist_func` of this instance and returns its result, forwarding all + arguments. + """ + if not callable(self.post_process_merged_hist_func): + return h + return self.post_process_merged_hist_func(h, task=task) + + +# shorthand +hist_producer = HistProducer.hist_producer diff --git a/columnflow/histogramming/default.py b/columnflow/histogramming/default.py new file mode 100644 index 000000000..8171031ef --- /dev/null +++ b/columnflow/histogramming/default.py @@ -0,0 +1,153 @@ +# coding: utf-8 + +""" +Default histogram producers that define columnflow's default behavior. +""" + +from __future__ import annotations + +import law +import order as od + +from columnflow.histogramming import HistProducer, hist_producer +from columnflow.util import maybe_import +from columnflow.hist_util import create_hist_from_variables, fill_hist, translate_hist_intcat_to_strcat +from columnflow.columnar_util import has_ak_column, Route +from columnflow.types import Any + +np = maybe_import("numpy") +ak = maybe_import("awkward") +hist = maybe_import("hist") + + +@hist_producer() +def cf_default(self: HistProducer, events: ak.Array, **kwargs) -> ak.Array: + """ + Default histogram producer that implements all hooks necessary to ensure columnflow's default behavior: + + - create_hist: defines the histogram structure + - __call__: receives an event chunk and updates it, and creates event weights (1's in this case) + - fill: receives the data and fills the histogram + - post_process_hist: post-processes the histogram before it is saved + """ + return events, ak.Array(np.ones(len(events), dtype=np.float32)) + + +@cf_default.create_hist +def cf_default_create_hist( + self: HistProducer, + variables: list[od.Variable], + task: law.Task, + **kwargs, +) -> hist.Histogram: + """ + Define the histogram structure for the default histogram producer. + """ + return create_hist_from_variables( + *variables, + categorical_axes=( + ("category", "intcat"), + ("process", "intcat"), + ("shift", "intcat"), + ), + weight=True, + ) + + +@cf_default.fill_hist +def cf_default_fill_hist(self: HistProducer, h: hist.Histogram, data: dict[str, Any], task: law.Task) -> None: + """ + Fill the histogram with the data. + """ + fill_hist(h, data, last_edge_inclusive=task.last_edge_inclusive) + + +@cf_default.post_process_hist +def cf_default_post_process_hist(self: HistProducer, h: hist.Histogram, task: law.Task) -> hist.Histogram: + """ + Post-process the histogram, converting integer to string axis for consistent lookup across configs where ids might + be different. + """ + axis_names = {ax.name for ax in h.axes} + + # translate axes + if "category" in axis_names: + category_map = {cat.id: cat.name for cat in self.config_inst.get_leaf_categories()} + h = translate_hist_intcat_to_strcat(h, "category", category_map) + if "process" in axis_names: + process_map = {proc_id: self.config_inst.get_process(proc_id).name for proc_id in h.axes["process"]} + h = translate_hist_intcat_to_strcat(h, "process", process_map) + if "shift" in axis_names: + shift_map = {task.global_shift_inst.id: task.global_shift_inst.name} + h = translate_hist_intcat_to_strcat(h, "shift", shift_map) + + return h + + +@cf_default.hist_producer() +def all_weights(self: HistProducer, events: ak.Array, **kwargs) -> ak.Array: + """ + HistProducer that combines all event weights from the *event_weights* aux entry from either the config or the + dataset. The weights are multiplied together to form the full event weight. + + The expected structure of the *event_weights* aux entry is a dictionary with the weight column name as key and a + list of shift sources as values. The shift sources are used to declare the shifts that the produced event weight + depends on. Example: + + .. code-block:: python + + from columnflow.config_util import get_shifts_from_sources + # add weights and their corresponding shifts for all datasets + cfg.x.event_weights = { + "normalization_weight": [], + "muon_weight": get_shifts_from_sources(config, "mu_sf"), + "btag_weight": get_shifts_from_sources(config, "btag_hf", "btag_lf"), + } + for dataset_inst in cfg.datasets: + # add dataset-specific weights and their corresponding shifts + dataset.x.event_weights = {} + if not dataset_inst.has_tag("skip_pdf"): + dataset_inst.x.event_weights["pdf_weight"] = get_shifts_from_sources(config, "pdf") + """ + weight = ak.Array(np.ones(len(events))) + + # build the full event weight + if self.dataset_inst.is_mc and len(events): + # multiply weights from global config `event_weights` aux entry + for column in self.config_inst.x.event_weights: + weight = weight * Route(column).apply(events) + + # multiply weights from dataset-specific `event_weights` aux entry + for column in self.dataset_inst.x("event_weights", []): + if has_ak_column(events, column): + weight = weight * Route(column).apply(events) + else: + self.logger.warning_once( + f"missing_dataset_weight_{column}", + f"weight '{column}' for dataset {self.dataset_inst.name} not found", + ) + + return events, weight + + +@all_weights.init +def all_weights_init(self: HistProducer) -> None: + weight_columns = set() + + if self.dataset_inst.is_data: + return + + # add used weight columns and declare shifts that the produced event weight depends on + if self.config_inst.has_aux("event_weights"): + weight_columns |= {Route(column) for column in self.config_inst.x.event_weights} + for shift_insts in self.config_inst.x.event_weights.values(): + self.shifts |= {shift_inst.name for shift_inst in shift_insts} + + # optionally also for weights defined by a dataset + if self.dataset_inst.has_aux("event_weights"): + weight_columns |= {Route(column) for column in self.dataset_inst.x("event_weights", [])} + for shift_insts in self.dataset_inst.x.event_weights.values(): + self.shifts |= {shift_inst.name for shift_inst in shift_insts} + + # add weight columns to uses + self.uses |= weight_columns diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 224e7818e..d5c3ab01e 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -14,9 +14,13 @@ import yaml from columnflow.types import Generator, Callable, TextIO, Sequence, Any -from columnflow.util import ( - DerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, -) +from columnflow.util import DerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url + + +logger = law.logger.get_logger(__name__) + + +default_dataset = law.config.get_expanded("analysis", "default_dataset") class ParameterType(enum.Enum): @@ -49,11 +53,11 @@ def is_rate(self: ParameterType) -> bool: :returns: *True* if the parameter type is a rate type, *False* otherwise. """ - return self in ( + return self in { self.rate_gauss, self.rate_uniform, self.rate_unconstrained, - ) + } @property def is_shape(self: ParameterType) -> bool: @@ -62,9 +66,9 @@ def is_shape(self: ParameterType) -> bool: :returns: *True* if the parameter type is a shape type, *False* otherwise. """ - return self in ( + return self in { self.shape, - ) + } class ParameterTransformation(enum.Enum): @@ -105,9 +109,9 @@ def from_shape(self: ParameterTransformation) -> bool: :returns: *True* if the transformation is derived from shape, *False* otherwise. """ - return self in ( + return self in { self.effect_from_shape, - ) + } @property def from_rate(self: ParameterTransformation) -> bool: @@ -116,15 +120,14 @@ def from_rate(self: ParameterTransformation) -> bool: :returns: *True* if the transformation is derived from rate, *False* otherwise. """ - return self in ( + return self in { self.effect_from_rate, - ) + } class ParameterTransformations(tuple): """ - Container around a sequence of :py:class:`ParameterTransformation`'s with a few convenience - methods. + Container around a sequence of :py:class:`ParameterTransformation`'s with a few convenience methods. :param transformations: A sequence of :py:class:`ParameterTransformation` or their string names. """ @@ -139,7 +142,7 @@ def __new__( :param transformations: A sequence of :py:class:`ParameterTransformation` or their string names. :returns: A new instance of :py:class:`ParameterTransformations`. """ - # TODO: at this point one could object / complain in case incompatible transfos are used + # TODO: at this point one could object / complain in case incompatible trafos are used transformations = [ (t if isinstance(t, ParameterTransformation) else ParameterTransformation[t]) for t in transformations @@ -149,7 +152,7 @@ def __new__( return super().__new__(cls, transformations) @property - def any_from_shape(self: ParameterTransformations) -> bool: + def any_from_shape(self) -> bool: """ Checks if any transformation is derived from shape. @@ -158,7 +161,7 @@ def any_from_shape(self: ParameterTransformations) -> bool: return any(t.from_shape for t in self) @property - def any_from_rate(self: ParameterTransformations) -> bool: + def any_from_rate(self) -> bool: """ Checks if any transformation is derived from rate. @@ -170,6 +173,11 @@ def any_from_rate(self: ParameterTransformations) -> bool: class FlowStrategy(enum.Enum): """ Strategy to handle over- and underflow bin contents. + + :cvar ignore: Ignore over- and underflow bins. + :cvar warn: Issue a warning when over- and underflow bins are encountered. + :cvar remove: Remove over- and underflow bins. + :cvar move: Move over- and underflow bins to the first and last bin, respectively. """ ignore = "ignore" @@ -183,53 +191,72 @@ def __str__(self) -> str: class InferenceModel(Derivable): """ - Interface to statistical inference models with connections to config objects (such as - py:class:`order.Config` or :py:class:`order.Dataset`). + Interface to statistical inference models with connections to config objects (such as py:class:`order.Config` or + :py:class:`order.Dataset`). - The internal structure to describe a model looks as follows (in yaml style) and is accessible - through :py:attr:`model` as well as property access to its top-level objects. + The internal structure to describe a model looks as follows (in yaml style) and is accessible through + :py:attr:`model` as well as property access to its top-level objects. .. code-block:: yaml categories: - name: cat1 - config_category: 1e - config_variable: ht - config_data_datasets: [data_mu_a] + postfix: null + config_data: + 22pre_v14: + category: 1e + variable: ht + data_datasets: [data_mu_a] data_from_processes: [] - flow_strategy: warn mc_stats: 10 + flow_strategy: warn + rate_precision: 5 + empty_bin_value: 1e-05 processes: - name: HH - config_process: hh is_signal: True - config_mc_datasets: [hh_ggf] + config_data: + 22pre_v14: + process: hh + mc_datasets: [hh_ggf] scale: 1.0 is_dynamic: False parameters: - name: lumi type: rate_gauss effect: 1.02 - config_shift_source: null + effect_precision: 4 + config_data: {} + transformations: [] - name: pu type: rate_gauss effect: [0.97, 1.02] - config_shift_source: null + effect_precision: 4 + config_data: {} + transformations: [] - name: pileup type: shape effect: 1.0 - config_shift_source: minbias_xs + effect_precision: 4 + config_data: + 22pre_v14: + shift_source: minbias_xs + transformations: [] - name: tt is_signal: False - config_process: ttbar - config_mc_datasets: [tt_sl, tt_dl, tt_fh] + config_data: + 22pre_v14: + process: tt + mc_datasets: [tt_sl, tt_dl, tt_fh] scale: 1.0 is_dynamic: False parameters: - name: lumi type: rate_gauss effect: 1.02 - config_shift_source: null + effect_precision: 4 + config_data: {} + transformations: [] - name: cat2 ... @@ -245,17 +272,11 @@ class InferenceModel(Derivable): The unique name of this model. - .. py:attribute:: config_inst - - type: order.Config, None - - Reference to the :py:class:`order.Config` object. + .. py:attribute:: config_insts - .. py:attribute:: config_callbacks + type: list[order.Config] - type: list - - A list of callables that are invoked after :py:meth:`set_config` was called. + Reference to :py:class:`order.Config` objects. .. py:attribute:: model @@ -275,15 +296,15 @@ class YamlDumper(yaml.SafeDumper): @classmethod def _map_repr(cls, dumper: yaml.Dumper, data: dict) -> str: - return dumper.represent_dict(dict(data)) + return dumper.represent_dict(data if isinstance(data, dict) else dict(data)) @classmethod def _list_repr(cls, dumper: yaml.Dumper, data: list) -> str: - return dumper.represent_list(list(data)) + return dumper.represent_list(data if isinstance(data, list) else list(data)) @classmethod def _str_repr(cls, dumper: yaml.Dumper, data: str) -> str: - return dumper.represent_str(str(data)) + return dumper.represent_str(data if isinstance(data, str) else str(data)) def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -329,6 +350,14 @@ def decorator(func: Callable) -> DerivableMeta: return decorator(func) if func else decorator + @classmethod + def used_datasets(cls, config_inst: od.Config) -> list[str]: + """ + Used datasets for which the `upstream_task_cls.resolve_instances` will be called. + Defaults to the default dataset. + """ + return [default_dataset] + @classmethod def model_spec(cls) -> DotDict: """ @@ -346,38 +375,39 @@ def model_spec(cls) -> DotDict: def category_spec( cls, name: str, - config_category: str | None = None, - config_variable: str | None = None, - config_data_datasets: Sequence[str] | None = None, + config_data: dict[str, DotDict] | None = None, data_from_processes: Sequence[str] | None = None, flow_strategy: FlowStrategy | str = FlowStrategy.warn, mc_stats: int | float | tuple | None = None, + postfix: str | None = None, empty_bin_value: float = 1e-5, + rate_precision: int = 5, ) -> DotDict: """ Returns a dictionary representing a category (interchangeably called bin or channel in other tools), forwarding all arguments. :param name: The name of the category in the model. - :param config_category: The name of the source category in the config to use. - :param config_variable: The name of the variable in the config to use. - :param config_data_datasets: List of names or patterns of datasets in the config to use for - real data. - :param data_from_processes: Optional list of names of :py:meth:`process_spec` objects that, - when *config_data_datasets* is not defined, make up a fake data contribution. - :param flow_strategy: A :py:class:`FlowStrategy` instance describing the strategy to handle - over- and underflow bin contents. - :param mc_stats: Either *None* to disable MC stat uncertainties, or an integer, a float or - a tuple thereof to control the options of MC stat options. + :param config_data: Dictionary mapping names of :py:class:`order.Config` objects to dictionaries following the + :py:meth:`category_config_spec` that wrap settings like category, variable and real datasets in that config. + :param data_from_processes: Optional list of names of :py:meth:`process_spec` objects that, when + *config_data_datasets* is not defined, make up a fake data contribution. + :param flow_strategy: A :py:class:`FlowStrategy` instance describing the strategy to handle over- and underflow + bin contents. + :param mc_stats: Either *None* to disable MC stat uncertainties, or an integer, a float or a tuple thereof to + control the options of MC stat options. + :param postfix: A postfix that is appended to (e.g.) file names created for this model. :param empty_bin_value: When bins have no content, they are filled with this value. + :param rate_precision: The precision of reported rates. :returns: A dictionary representing the category. - """ return DotDict([ ("name", str(name)), - ("config_category", str(config_category) if config_category else None), - ("config_variable", str(config_variable) if config_variable else None), - ("config_data_datasets", list(map(str, config_data_datasets or []))), + ("config_data", ( + {k: cls.category_config_spec(**v) for k, v in config_data.items()} + if config_data + else {} + )), ("data_from_processes", list(map(str, data_from_processes or []))), ("flow_strategy", ( flow_strategy @@ -385,7 +415,9 @@ def category_spec( else FlowStrategy[flow_strategy] )), ("mc_stats", mc_stats), + ("postfix", str(postfix) if postfix else None), ("empty_bin_value", empty_bin_value), + ("rate_precision", rate_precision), ("processes", []), ]) @@ -393,9 +425,8 @@ def category_spec( def process_spec( cls, name: str, - config_process: str | None = None, is_signal: bool = False, - config_mc_datasets: Sequence[str] | None = None, + config_data: dict[str, DotDict] | None = None, scale: float | int = 1.0, is_dynamic: bool = False, ) -> DotDict: @@ -404,18 +435,21 @@ def process_spec( :param name: The name of the process in the model. :param is_signal: A boolean flag deciding whether this process describes signal. - :param config_process: The name of the source process in the config to use. - :param config_mc_datasets: List of names or patterns of MC datasets in the config to use. + :param config_data: Dictionary mapping names of :py:class:`order.Config` objects to dictionaries following the + :py:meth:`process_config_spec` that wrap settings like process and mc datasets in that config. :param scale: A float value to scale the process, defaulting to 1.0. - :param is_dynamic: A boolean flag deciding whether this process is dynamic, i.e., whether it - is created on-the-fly. + :param is_dynamic: A boolean flag deciding whether this process is dynamic, i.e., whether it is created + on-the-fly. :returns: A dictionary representing the process. """ return DotDict([ ("name", str(name)), ("is_signal", bool(is_signal)), - ("config_process", str(config_process) if config_process else None), - ("config_mc_datasets", list(map(str, config_mc_datasets or []))), + ("config_data", ( + {k: cls.process_config_spec(**v) for k, v in config_data.items()} + if config_data + else {} + )), ("scale", float(scale)), ("is_dynamic", bool(is_dynamic)), ("parameters", []), @@ -426,29 +460,36 @@ def parameter_spec( cls, name: str, type: ParameterType | str, - transformations: Sequence[ParameterTransformation | str] = (ParameterTransformation.none,), - config_shift_source: str | None = None, + transformations: Sequence[ParameterTransformation | str] = (), + config_data: dict[str, DotDict] | None = None, effect: Any | None = 1.0, + effect_precision: int = 4, ) -> DotDict: """ Returns a dictionary representing a (nuisance) parameter, forwarding all arguments. :param name: The name of the parameter in the model. :param type: A :py:class:`ParameterType` instance describing the type of this parameter. - :param transformations: A sequence of :py:class:`ParameterTransformation` instances - describing transformations to be applied to the effect of this parameter. - :param config_shift_source: The name of a systematic shift source in the config that this - parameter corresponds to. - :param effect: An arbitrary object describing the effect of the parameter (e.g. float for - symmetric rate effects, 2-tuple for down/up variation, etc). + :param transformations: A sequence of :py:class:`ParameterTransformation` instances describing transformations + to be applied to the effect of this parameter. + :param config_data: Dictionary mapping names of :py:class:`order.Config` objects to dictionaries following the + :py:meth:`parameter_config_spec` that wrap settings like corresponding shift source in that config. + :param effect: An arbitrary object describing the effect of the parameter (e.g. float for symmetric rate + effects, 2-tuple for down/up variation, etc). + :param effect_precision: The precision of reported effects. :returns: A dictionary representing the parameter. """ return DotDict([ ("name", str(name)), ("type", type if isinstance(type, ParameterType) else ParameterType[type]), ("transformations", ParameterTransformations(transformations)), - ("config_shift_source", str(config_shift_source) if config_shift_source else None), + ("config_data", ( + {k: cls.parameter_config_spec(**v) for k, v in config_data.items()} + if config_data + else {} + )), ("effect", effect), + ("effect_precision", effect_precision), ]) @classmethod @@ -469,6 +510,60 @@ def parameter_group_spec( ("parameter_names", list(map(str, parameter_names or []))), ]) + @classmethod + def category_config_spec( + cls, + category: str | None = None, + variable: str | None = None, + data_datasets: Sequence[str] | None = None, + ) -> DotDict: + """ + Returns a dictionary representing configuration specific data, forwarding all arguments. + + :param category: The name of the source category in the config to use. + :param variable: The name of the variable in the config to use. + :param data_datasets: List of names or patterns of datasets in the config to use for real data. + :returns: A dictionary representing category specific settings. + """ + return DotDict([ + ("category", str(category) if category else None), + ("variable", str(variable) if variable else None), + ("data_datasets", list(map(str, data_datasets or []))), + ]) + + @classmethod + def process_config_spec( + cls, + process: str | None = None, + mc_datasets: Sequence[str] | None = None, + ) -> DotDict: + """ + Returns a dictionary representing configuration specific data, forwarding all arguments. + + :param process: The name of the process in the config to use. + :param mc_datasets: List of names or patterns of datasets in the config to use for mc. + :returns: A dictionary representing process specific settings. + """ + return DotDict([ + ("process", str(process) if process else None), + ("mc_datasets", list(map(str, mc_datasets or []))), + ]) + + @classmethod + def parameter_config_spec( + cls, + shift_source: str | None = None, + ) -> DotDict: + """ + Returns a dictionary representing configuration specific data, forwarding all arguments. + + :param shift_source: The name of a systematic shift source in the config. + :returns: A dictionary representing parameter specific settings. + """ + return DotDict([ + ("shift_source", str(shift_source) if shift_source else None), + ]) + @classmethod def require_shapes_for_parameter(self, param_obj: dict) -> bool: """ @@ -498,11 +593,14 @@ def require_shapes_for_parameter(self, param_obj: dict) -> bool: f"'{param_obj.type}' and transformations {param_obj.transformations}", ) - def __init__(self, config_inst: od.Config) -> None: + def __init__(self, config_insts: list[od.Config]) -> None: super().__init__() # store attributes - self.config_inst = config_inst + self.config_insts = config_insts + + # temporary attributes for as long as we issue deprecation warnings + self.__config_inst = None # model info self.model = self.model_spec() @@ -531,6 +629,28 @@ def pprint(self) -> None: # property access to top-level objects # + # !! to be removed in a future release + @property + def config_inst(self) -> od.Config: + if self.__config_inst: + return self.__config_inst + + # trigger a verbose warning in case the deprecated attribute is accessed + docs_url = get_docs_url("user_guide", "02_03_transition.html") + api_url = get_docs_url("api", "inference", "index.html", anchor="columnflow.inference.InferenceModel") + logger.warning_once( + "inference_model_deprected_config_inst", + "access to attribute 'config_inst' in inference model was removed; use 'config_insts' instead; also, make " + "sure to use the new 'config_data' attribute in 'add_category()' for a more fine-grained control over the " + f"composition of your inference model categories; see {docs_url} and {api_url} for more info", + ) + + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute 'config_inst'") + + @config_inst.setter + def config_inst(self, config_inst: od.Config) -> None: + self.__config_inst = config_inst + @property def categories(self) -> DotDict: return self.model.categories @@ -547,6 +667,7 @@ def get_categories( self, category: str | Sequence[str] | None = None, only_names: bool = False, + match_mode: Callable = any, ) -> list[DotDict | str]: """ Returns a list of categories whose name match *category*. *category* can be a string, a @@ -555,12 +676,14 @@ def get_categories( :param category: A string, pattern, or sequence of them to match category names. :param only_names: A boolean flag to return only names of categories if set to *True*. + :param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :returns: A list of matching categories or their names. """ # rename arguments to make their meaning explicit category_pattern = category - match = pattern_matcher(category_pattern or "*") + match = pattern_matcher(category_pattern or "*", mode=match_mode) return [ (category.name if only_names else category) for category in self.categories @@ -571,6 +694,7 @@ def get_category( self, category: str | Sequence[str], only_name: bool = False, + match_mode: Callable = any, silent: bool = False, ) -> DotDict | str: """ @@ -580,15 +704,17 @@ def get_category( *True*, only the name of the category is returned rather than a structured dictionary. :param category: A string, pattern, or sequence of them to match category names. - :param silent: A boolean flag to return *None* instead of raising an exception if no or - more than one category is found. :param only_name: A boolean flag to return only the name of the category if set to *True*. + :param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param silent: A boolean flag to return *None* instead of raising an exception if no or more than one category + is found. :returns: A single matching category or its name. """ # rename arguments to make their meaning explicit category_name = category - categories = self.get_categories(category_name, only_names=only_name) + categories = self.get_categories(category_name, only_names=only_name, match_mode=match_mode) # length checks if not categories: @@ -605,19 +731,22 @@ def get_category( def has_category( self, category: str | Sequence[str], + match_mode: Callable = any, ) -> bool: """ Returns *True* if a category whose name matches *category* is existing, and *False* otherwise. *category* can be a string, a pattern, or sequence of them. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if a matching category exists, *False* otherwise. """ # rename arguments to make their meaning explicit category_pattern = category # simple length check - return len(self.get_categories(category_pattern)) > 0 + return len(self.get_categories(category_pattern, only_names=True, match_mode=match_mode)) > 0 def add_category(self, *args, **kwargs) -> None: """ @@ -640,17 +769,20 @@ def add_category(self, *args, **kwargs) -> None: def remove_category( self, category: str | Sequence[str], + match_mode: Callable = any, ) -> bool: """ Removes one or more categories whose names match *category*. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if at least one category was removed, *False* otherwise. """ # rename arguments to make their meaning explicit category_pattern = category - match = pattern_matcher(category_pattern) + match = pattern_matcher(category_pattern, mode=match_mode) n_before = len(self.categories) self.categories[:] = [ category @@ -675,6 +807,8 @@ def get_processes( category: str | Sequence[str] | None = None, only_names: bool = False, flat: bool = False, + match_mode: Callable = any, + category_match_mode: Callable = any, ) -> dict[str, DotDict | str] | list[str]: """ Returns a dictionary of processes whose names match *process*, mapped to the name of the @@ -687,6 +821,10 @@ def get_processes( :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to filter categories. :param only_names: A boolean flag to return only names of processes if set to *True*. + :param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :param flat: A boolean flag to return a flat, unique list of process names if set to *True*. :returns: A dictionary of processes mapped to the category name, or a list of process names. """ @@ -699,10 +837,10 @@ def get_processes( only_names = True # get matching categories first - categories = self.get_categories(category_pattern) + categories = self.get_categories(category_pattern, match_mode=category_match_mode) # look for the process pattern in each of them - match = pattern_matcher(process_pattern or "*") + match = pattern_matcher(process_pattern or "*", mode=match_mode) pairs = ( (category.name, [ (process.name if only_names else process) @@ -726,6 +864,8 @@ def get_process( process: str | Sequence[str], category: str | Sequence[str] | None = None, only_name: bool = False, + match_mode: Callable = any, + category_match_mode: Callable = any, silent: bool = False, ) -> DotDict | str: """ @@ -739,9 +879,13 @@ def get_process( :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. - :param silent: A boolean flag to return *None* instead of raising an exception if no or - more than one process is found. :param only_name: A boolean flag to return only the name of the process if set to *True*. + :param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param silent: A boolean flag to return *None* instead of raising an exception if no or more than one process is + found. :returns: A single matching process or its name. :raises ValueError: If no process or more than one process is found and *silent* is *False*. """ @@ -753,6 +897,8 @@ def get_process( process_name, category=category_pattern, only_names=only_name, + match_mode=match_mode, + category_match_mode=category_match_mode, ) # checks @@ -786,6 +932,8 @@ def has_process( self, process: str | Sequence[str], category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, ) -> bool: """ Returns *True* if a process whose name matches *process*, and optionally whose category's @@ -794,6 +942,10 @@ def has_process( :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if a matching process exists, *False* otherwise. """ # rename arguments to make their meaning explicit @@ -801,12 +953,19 @@ def has_process( category_pattern = category # simple length check - return len(self.get_processes(process_pattern, category=category_pattern)) > 0 + return len(self.get_processes( + process_pattern, + category=category_pattern, + only_names=True, + match_mode=match_mode, + category_match_mode=category_match_mode, + )) > 0 def add_process( self, *args, category: str | Sequence[str] | None = None, + category_match_mode: Callable = any, silent: bool = False, **kwargs, ) -> None: @@ -820,11 +979,12 @@ def add_process( :param args: Positional arguments used to create the process. :param category: A string, pattern, or sequence of them to match category names. - :param silent: A boolean flag to suppress exceptions if a process with the same name - already exists. + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param silent: A boolean flag to suppress exceptions if a process with the same name already exists. :param kwargs: Keyword arguments used to create the process. - :raises ValueError: If a process with the same name already exists in one of the - categories and *silent* is *False*. + :raises ValueError: If a process with the same name already exists in one of the categories and *silent* is + *False*. """ # rename arguments to make their meaning explicit category_pattern = category @@ -832,7 +992,7 @@ def add_process( process = self.process_spec(*args, **kwargs) # get categories the process should be added to - categories = self.get_categories(category_pattern) + categories = self.get_categories(category_pattern, match_mode=category_match_mode) # check for duplicates target_categories = [] @@ -854,6 +1014,8 @@ def remove_process( self, process: str | Sequence[str], category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, ) -> bool: """ Removes one or more processes whose names match *process*, and optionally whose category's @@ -862,6 +1024,10 @@ def remove_process( :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if at least one process was removed, *False* otherwise. """ # rename arguments to make their meaning explicit @@ -869,9 +1035,9 @@ def remove_process( category_pattern = category # get categories the process should be removed from - categories = self.get_categories(category_pattern) + categories = self.get_categories(category_pattern, match_mode=category_match_mode) - match = pattern_matcher(process_pattern) + match = pattern_matcher(process_pattern, mode=match_mode) removed_any = False for category in categories: n_before = len(category.processes) @@ -897,6 +1063,9 @@ def get_parameters( parameter: str | Sequence[str] | None = None, process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, + process_match_mode: Callable = any, only_names: bool = False, flat: bool = False, ) -> dict[str, dict[str, DotDict | str]] | list[str]: @@ -912,10 +1081,15 @@ def get_parameters( :param parameter: A string, pattern, or sequence of them to match parameter names. :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :param only_names: A boolean flag to return only names of parameters if set to *True*. :param flat: A boolean flag to return a flat, unique list of parameter names if set to *True*. - :returns: A dictionary of parameters mapped to category and process names, or a list of - parameter names. + :returns: A dictionary of parameters mapped to category and process names, or a list of parameter names. """ # rename arguments to make their meaning explicit parameter_pattern = parameter @@ -927,10 +1101,15 @@ def get_parameters( only_names = True # get matching processes (mapped to matching categories) - processes = self.get_processes(process=process_pattern, category=category_pattern) + processes = self.get_processes( + process=process_pattern, + category=category_pattern, + match_mode=process_match_mode, + category_match_mode=category_match_mode, + ) # look for the process pattern in each pair - match = pattern_matcher(parameter_pattern or "*") + match = pattern_matcher(parameter_pattern or "*", mode=match_mode) parameters = DotDict() for category_name, _processes in processes.items(): pairs = ( @@ -961,6 +1140,9 @@ def get_parameter( parameter: str | Sequence[str], process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, + process_match_mode: Callable = any, only_name: bool = False, silent: bool = False, ) -> DotDict | str: @@ -976,9 +1158,15 @@ def get_parameter( :param parameter: A string, pattern, or sequence of them to match parameter names. :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :param only_name: A boolean flag to return only the name of the parameter if set to *True*. - :param silent: A boolean flag to return *None* instead of raising an exception if no or more - than one parameter is found. + :param silent: A boolean flag to return *None* instead of raising an exception if no or more than one parameter + is found. :returns: A single matching parameter or its name. :raises ValueError: If no parameter or more than one parameter is found and *silent* is *False*. """ @@ -991,6 +1179,9 @@ def get_parameter( parameter_name, process=process_pattern, category=category_pattern, + match_mode=match_mode, + category_match_mode=category_match_mode, + process_match_mode=process_match_mode, only_names=only_name, ) @@ -1038,6 +1229,9 @@ def has_parameter( parameter: str | Sequence[str], process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, + process_match_mode: Callable = any, ) -> bool: """ Returns *True* if a parameter whose name matches *parameter*, and optionally whose @@ -1048,6 +1242,12 @@ def has_parameter( :param parameter: A string, pattern, or sequence of them to match parameter names. :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if a matching parameter exists, *False* otherwise. """ # rename arguments to make their meaning explicit @@ -1060,6 +1260,9 @@ def has_parameter( parameter_pattern, process=process_pattern, category=category_pattern, + match_mode=match_mode, + category_match_mode=category_match_mode, + process_match_mode=process_match_mode, )) > 0 def add_parameter( @@ -1067,29 +1270,34 @@ def add_parameter( *args, process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + category_match_mode: Callable = any, + process_match_mode: Callable = any, group: str | Sequence[str] | None = None, **kwargs, ) -> DotDict: """ - Adds a new parameter to all categories and processes whose names match *category* and - *process*, with all *args* and *kwargs* used to create the structured parameter dictionary - via :py:meth:`parameter_spec`. Both *process* and *category* can be a string, a pattern, or - sequence of them. + Adds a parameter to all categories and processes whose names match *category* and *process*, with all *args* and + *kwargs* used to create the structured parameter dictionary via :py:meth:`parameter_spec`. Both *process* and + *category* can be a string, a pattern, or sequence of them. - When *group* is given, the parameter is added to one or more parameter groups via - :py:meth:`add_parameter_to_group`. + If a parameter with the same name already exists in one of the processes throughout the categories, an exception + is raised. - If a parameter with the same name already exists in one of the processes throughout the - categories, an exception is raised. + When *group* is given and the parameter is new, it is added to one or more parameter groups via + :py:meth:`add_parameter_to_group`. :param args: Positional arguments used to create the parameter. :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :param group: A string, pattern, or sequence of them to specify parameter groups. :param kwargs: Keyword arguments used to create the parameter. :returns: The created parameter. - :raises ValueError: If a parameter with the same name already exists in one of the processes - throughout the categories. + :raises ValueError: If a parameter with the same name already exists in one of the processes throughout the + categories. """ # rename arguments to make their meaning explicit process_pattern = process @@ -1098,7 +1306,12 @@ def add_parameter( parameter = self.parameter_spec(*args, **kwargs) # get processes (mapped to categories) the parameter should be added to - processes = self.get_processes(process=process_pattern, category=category_pattern) + processes = self.get_processes( + process=process_pattern, + category=category_pattern, + match_mode=process_match_mode, + category_match_mode=category_match_mode, + ) # check for duplicates for category_name, _processes in processes.items(): @@ -1125,6 +1338,9 @@ def remove_parameter( parameter: str | Sequence[str], process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, + process_match_mode: Callable = any, ) -> bool: """ Removes one or more parameters whose names match *parameter*, and optionally whose @@ -1134,6 +1350,12 @@ def remove_parameter( :param parameter: A string, pattern, or sequence of them to match parameter names. :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if at least one parameter was removed, *False* otherwise. """ # rename arguments to make their meaning explicit @@ -1142,9 +1364,14 @@ def remove_parameter( category_pattern = category # get processes (mapped to categories) the parameter should be removed from - processes = self.get_processes(process=process_pattern, category=category_pattern) + processes = self.get_processes( + process=process_pattern, + category=category_pattern, + match_mode=process_match_mode, + category_match_mode=category_match_mode, + ) - match = pattern_matcher(parameter_pattern) + match = pattern_matcher(parameter_pattern, mode=match_mode) removed_any = False for _processes in processes.values(): for process in _processes: @@ -1169,6 +1396,7 @@ def remove_parameter( def get_parameter_groups( self, group: str | Sequence[str] | None = None, + match_mode: Callable = any, only_names: bool = False, ) -> list[DotDict | str]: """ @@ -1179,13 +1407,15 @@ def get_parameter_groups( structured dictionaries. :param group: A string, pattern, or sequence of them to match group names. + :param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see + :py:func:`pattern_matcher`). :param only_names: A boolean flag to return only names of parameter groups if set to *True*. :returns: A list of parameter groups or their names. """ # rename arguments to make their meaning explicit group_pattern = group - match = pattern_matcher(group_pattern or "*") + match = pattern_matcher(group_pattern or "*", mode=match_mode) return [ (group.name if only_names else group) for group in self.parameter_groups @@ -1195,6 +1425,7 @@ def get_parameter_groups( def get_parameter_group( self, group: str | Sequence[str], + match_mode: Callable = any, only_name: bool = False, ) -> DotDict | str: """ @@ -1206,6 +1437,8 @@ def get_parameter_group( structured dictionary. :param group: A string, pattern, or sequence of them to match group names. + :param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see + :py:func:`pattern_matcher`). :param only_name: A boolean flag to return only the name of the parameter group if set to *True*. :returns: A single matching parameter group or its name. :raises ValueError: If no parameter group or more than one parameter group is found. @@ -1213,7 +1446,7 @@ def get_parameter_group( # rename arguments to make their meaning explicit group_name = group - groups = self.get_parameter_groups(group_name, only_names=only_name) + groups = self.get_parameter_groups(group_name, match_mode=match_mode, only_names=only_name) # checks if not groups: @@ -1226,19 +1459,22 @@ def get_parameter_group( def has_parameter_group( self, group: str | Sequence[str], + match_mode: Callable = any, ) -> bool: """ Returns *True* if a parameter group whose name matches *group* exists, and *False* otherwise. *group* can be a string, a pattern, or sequence of them. :param group: A string, pattern, or sequence of them to match group names. + :param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if a matching parameter group exists, *False* otherwise. """ # rename arguments to make their meaning explicit group_pattern = group # simeple length check - return len(self.get_parameter_groups(group_pattern)) > 0 + return len(self.get_parameter_groups(group_pattern, match_mode=match_mode)) > 0 def add_parameter_group(self, *args, **kwargs) -> None: """ @@ -1262,6 +1498,7 @@ def add_parameter_group(self, *args, **kwargs) -> None: def remove_parameter_group( self, group: str | Sequence[str], + match_mode: Callable = any, ) -> bool: """ Removes one or more parameter groups whose names match *group*. *group* can be a string, a @@ -1269,12 +1506,14 @@ def remove_parameter_group( otherwise. :param group: A string, pattern, or sequence of them to match group names. + :param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if at least one group was removed, *False* otherwise. """ # rename arguments to make their meaning explicit group_pattern = group - match = pattern_matcher(group_pattern) + match = pattern_matcher(group_pattern, mode=match_mode) n_before = len(self.parameter_groups) self.parameter_groups[:] = [ group @@ -1289,6 +1528,8 @@ def add_parameter_to_group( self, parameter: str | Sequence[str], group: str | Sequence[str], + match_mode: Callable = any, + parameter_match_mode: Callable = any, ) -> bool: """ Adds a parameter named *parameter* to one or multiple parameter groups whose names match @@ -1299,6 +1540,10 @@ def add_parameter_to_group( :param parameter: A string, pattern, or sequence of them to match parameter names. :param group: A string, pattern, or sequence of them to match group names. + :param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see + :py:func:`pattern_matcher`). + :param parameter_match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if at least one parameter was added to a group, *False* otherwise. """ # rename arguments to make their meaning explicit @@ -1306,7 +1551,7 @@ def add_parameter_to_group( group_pattern = group # stop when there are no matching groups - groups = self.get_parameter_groups(group_pattern) + groups = self.get_parameter_groups(group_pattern, match_mode=match_mode) if not groups: return False @@ -1314,7 +1559,7 @@ def add_parameter_to_group( _is_pattern = lambda s: is_pattern(s) or is_regex(s) parameter_pattern = law.util.make_list(parameter_pattern) if any(map(_is_pattern, parameter_pattern)): - parameter_names = self.get_parameters(parameter_pattern, flat=True) + parameter_names = self.get_parameters(parameter_pattern, flat=True, match_mode=parameter_match_mode) else: parameter_names = parameter_pattern @@ -1332,6 +1577,8 @@ def remove_parameter_from_groups( self, parameter: str | Sequence[str], group: str | Sequence[str] | None = None, + match_mode: Callable = any, + parameter_match_mode: Callable = any, ) -> bool: """ Removes all parameters matching *parameter* from parameter groups whose names match *group*. @@ -1340,6 +1587,10 @@ def remove_parameter_from_groups( :param parameter: A string, pattern, or sequence of them to match parameter names. :param group: A string, pattern, or sequence of them to match group names. + :param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see + :py:func:`pattern_matcher`). + :param parameter_match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if at least one parameter was removed, *False* otherwise. """ # rename arguments to make their meaning explicit @@ -1347,11 +1598,11 @@ def remove_parameter_from_groups( group_pattern = group # stop when there are no matching groups - groups = self.get_parameter_groups(group_pattern) + groups = self.get_parameter_groups(group_pattern, match_mode=match_mode) if not groups: return False - match = pattern_matcher(parameter_pattern) + match = pattern_matcher(parameter_pattern, mode=parameter_match_mode) removed_any = False for group in groups: n_before = len(group.parameter_names) @@ -1371,24 +1622,29 @@ def remove_parameter_from_groups( def get_categories_with_process( self, process: str | Sequence[str], + match_mode: Callable = any, ) -> list[str]: """ Returns a flat list of category names that contain processes matching *process*. *process* can be a string, a pattern, or sequence of them. :param process: A string, pattern, or sequence of them to match process names. + :param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :returns: A list of category names containing matching processes. """ # rename arguments to make their meaning explicit process_pattern = process # plain name lookup - return list(self.get_processes(process=process_pattern, only_names=True).keys()) + return list(self.get_processes(process=process_pattern, match_mode=match_mode, only_names=True).keys()) def get_processes_with_parameter( self, parameter: str | Sequence[str], category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, flat: bool = True, ) -> list[str] | dict[str, list[str]]: """ @@ -1400,6 +1656,10 @@ def get_processes_with_parameter( :param parameter: A string, pattern, or sequence of them to match parameter names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :param flat: A boolean flag to return a flat, unique list of process names if set to *True*. :returns: A dictionary of process names mapped to category names, or a flat list of process names. """ @@ -1410,6 +1670,8 @@ def get_processes_with_parameter( processes = self.get_parameters( parameter=parameter_pattern, category=category_pattern, + match_mode=match_mode, + category_match_mode=category_match_mode, only_names=True, ) @@ -1429,6 +1691,8 @@ def get_categories_with_parameter( self, parameter: str | Sequence[str], process: str | Sequence[str] | None = None, + match_mode: Callable = any, + process_match_mode: Callable = any, flat: bool = True, ) -> list[str] | dict[str, list[str]]: """ @@ -1440,6 +1704,10 @@ def get_categories_with_parameter( :param parameter: A string, pattern, or sequence of them to match parameter names. :param process: A string, pattern, or sequence of them to match process names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). + :param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :param flat: A boolean flag to return a flat, unique list of category names if set to *True*. :returns: A dictionary of category names mapped to process names, or a flat list of category names. """ @@ -1450,6 +1718,8 @@ def get_categories_with_parameter( categories = self.get_parameters( parameter=parameter_pattern, process=process_pattern, + match_mode=match_mode, + process_match_mode=process_match_mode, only_names=True, ) @@ -1468,18 +1738,21 @@ def get_categories_with_parameter( def get_groups_with_parameter( self, parameter: str | Sequence[str], + match_mode: Callable = any, ) -> list[str]: """ Returns a list of names of parameter groups that contain a parameter whose name matches *parameter*. *parameter* can be a string, a pattern, or sequence of them. :param parameter: A string, pattern, or sequence of them to match parameter names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). :returns: A list of names of parameter groups containing the matching parameter. """ # rename arguments to make their meaning explicit parameter_pattern = parameter - match = pattern_matcher(parameter_pattern) + match = pattern_matcher(parameter_pattern, mode=match_mode) return [ group.name for group in self.parameter_groups @@ -1518,19 +1791,23 @@ def remove_empty_categories(self) -> None: def remove_dangling_parameters_from_groups( self, keep_parameters: str | Sequence[str] | None = None, + match_mode: Callable = any, ) -> None: """ Removes names of parameters from parameter groups that are not assigned to any process in any category. :param keep_parameters: A string, pattern, or sequence of them to specify parameters to keep. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). """ # get a list of all parameters parameter_names = self.get_parameters("*", flat=True) - # get list of parameters to keep + # get set of parameters to keep + _keep_parameters = set() if keep_parameters: - keep_parameters = self.get_parameters(keep_parameters, flat=True) + _keep_parameters = set(self.get_parameters(keep_parameters, match_mode=match_mode, flat=True)) # go through groups and remove dangling parameters for group in self.parameter_groups: @@ -1539,7 +1816,7 @@ def remove_dangling_parameters_from_groups( for parameter_name in group.parameter_names if ( parameter_name in parameter_names or - (keep_parameters and parameter_name in keep_parameters) + (_keep_parameters and parameter_name in _keep_parameters) ) ] @@ -1561,6 +1838,8 @@ def iter_processes( self, process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, ) -> Generator[tuple[DotDict, DotDict], None, None]: """ Generator that iteratively yields all processes whose names match *process*, optionally @@ -1569,9 +1848,18 @@ def iter_processes( :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :returns: A generator yielding 2-tuples of category name and process object. """ - processes = self.get_processes(process=process, category=category) + processes = self.get_processes( + process=process, + category=category, + match_mode=match_mode, + category_match_mode=category_match_mode, + ) for category_name, processes in processes.items(): for process in processes: yield (category_name, process) @@ -1581,6 +1869,9 @@ def iter_parameters( parameter: str | Sequence[str] | None = None, process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, + process_match_mode: Callable = any, ) -> Generator[tuple[DotDict, DotDict, DotDict], None, None]: """ Generator that iteratively yields all parameters whose names match *parameter*, optionally @@ -1590,9 +1881,22 @@ def iter_parameters( :param parameter: A string, pattern, or sequence of them to match parameter names. :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). + :param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). :returns: A generator yielding 3-tuples of category name, process name, and parameter object. """ - parameters = self.get_parameters(parameter=parameter, process=process, category=category) + parameters = self.get_parameters( + parameter=parameter, + process=process, + category=category, + match_mode=match_mode, + category_match_mode=category_match_mode, + process_match_mode=process_match_mode, + ) for category_name, parameters in parameters.items(): for process_name, parameters in parameters.items(): for parameter in parameters: @@ -1607,6 +1911,8 @@ def scale_process( scale: int | float, process: str | Sequence[str] | None = None, category: str | Sequence[str] | None = None, + match_mode: Callable = any, + category_match_mode: Callable = any, ) -> bool: """ Sets the scale attribute of all processes whose names match *process*, optionally in all @@ -1615,10 +1921,19 @@ def scale_process( :param scale: The scale value to set for the matching processes. :param process: A string, pattern, or sequence of them to match process names. :param category: A string, pattern, or sequence of them to match category names. + :param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see + :py:func:`pattern_matcher`). + :param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see + :py:func:`pattern_matcher`). :returns: *True* if at least one process was found and scaled, *False* otherwise. """ found = False - for _, process in self.iter_processes(process=process, category=category): + for _, process in self.iter_processes( + process=process, + category=category, + match_mode=match_mode, + category_match_mode=category_match_mode, + ): process.scale = float(scale) found = True return found diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index aa41180e5..032a988ce 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -12,32 +12,35 @@ import law from columnflow import __version__ as cf_version -from columnflow.types import Sequence, Any -from columnflow.inference import ( - InferenceModel, ParameterType, ParameterTransformation, FlowStrategy, -) +from columnflow.inference import InferenceModel, ParameterType, ParameterTransformation, FlowStrategy from columnflow.util import DotDict, maybe_import, real_path, ensure_dir, safe_div, maybe_int +from columnflow.types import Sequence, Any, Union, Hashable -np = maybe_import("np") hist = maybe_import("hist") -uproot = maybe_import("uproot") logger = law.logger.get_logger(__name__) +# type aliases for nested histogram structs +ShiftHists = dict[Union[str, tuple[str, str]], hist.Hist] # "nominal" or (param_name, "up|down") -> hists +ConfigHists = dict[str, ShiftHists] # config name -> hists +ProcHists = dict[str, ConfigHists] # process name -> hists +DatacardHists = dict[str, ProcHists] # category name -> hists + class DatacardWriter(object): """ - Generic writer for combine datacards using a instance of an :py:class:`InferenceModel` - *inference_model_inst* and a threefold nested dictionary "category -> process -> shift -> hist". + Generic writer for combine datacards using a instance of an :py:class:`InferenceModel` *inference_model_inst* and a + four-fold nested dictionary "category -> process -> config -> shift -> hist". - *rate_precision* and *parameter_precision* control the number of digits of values for measured - rates and parameter effects. + *rate_precision* and *effect_precision* control the number of digits of values for measured rates and parameter + effects. They are used in case the category and parameter objects of the inference model are configured with + non-postive values for *rate_precision* and *effect_precision*, respectively. .. note:: - At the moment, all shapes are written into the same root file and a shape line with - wildcards for both bin and process resolution is created. + At the moment, all shapes are written into the same root file and a shape line with wildcards for both bin and + process resolution is created. """ # minimum separator between columns @@ -46,17 +49,17 @@ class DatacardWriter(object): def __init__( self, inference_model_inst: InferenceModel, - histograms: dict[str, dict[str, dict[str, hist.Hist]]], + histograms: DatacardHists, rate_precision: int = 4, - parameter_precision: int = 4, - ): + effect_precision: int = 4, + ) -> None: super().__init__() # store attributes self.inference_model_inst = inference_model_inst self.histograms = histograms self.rate_precision = rate_precision - self.parameter_precision = parameter_precision + self.effect_precision = effect_precision def write( self, @@ -98,14 +101,20 @@ def write( blocks.shapes = [("shapes", "*", "*", shapes_path_ref, nom_pattern, syst_pattern)] separators.add("shapes") + # store rate precisions per category + rate_precisions = { + cat_obj.name: self.rate_precision if cat_obj.rate_precision <= 0 else cat_obj.rate_precision + for cat_obj in map(self.inference_model_inst.get_category, rates.keys()) + } + # observations blocks.observations = [] if all("data" in _rates for _rates in rates.values()): blocks.observations = [ ("bin", list(rates)), ("observation", [ - maybe_int(round(_rates["data"], self.rate_precision)) - for _rates in rates.values() + maybe_int(round(_rates["data"], rate_precisions[cat_name])) + for cat_name, _rates in rates.items() ]), ] separators.add("observations") @@ -133,13 +142,15 @@ def write( (-s_names.index(proc_name) if proc_name in s_names else b_names.index(proc_name) + 1) for _, proc_name in flat_rates ]), - ("rate", [round(rate, self.rate_precision) for rate in flat_rates.values()]), + ("rate", [ + round(rate, rate_precisions[cat_name]) + for (cat_name, _), rate in flat_rates.items() + ]), ] separators.add("rates") # tabular-style parameters blocks.tabular_parameters = [] - rnd = lambda f: round(f, self.parameter_precision) for param_name in self.inference_model_inst.get_parameters(flat=True): param_obj = None effects = [] @@ -165,13 +176,21 @@ def write( param_obj = _param_obj elif _param_obj.type != param_obj.type: raise ValueError( - f"misconfigured parameter '{param_name}' with type '{_param_obj.type}' " - f"that was previously seen with incompatible type '{param_obj.type}'", + f"misconfigured parameter '{param_name}' with type '{_param_obj.type}' that was previously " + f"seen with incompatible type '{param_obj.type}'", ) # get the effect effect = _param_obj.effect + # rounding helper depending on the effect precision + effect_precision = ( + self.effect_precision + if _param_obj.effect_precision <= 0 + else _param_obj.effect_precision + ) + rnd = lambda f: round(f, effect_precision) + # update and transform effects if _param_obj.type.is_rate: # obtain from shape effects when requested @@ -228,8 +247,8 @@ def write( effects.append(f"{rnd(effect[0])}/{rnd(effect[1])}") else: raise ValueError( - f"effect '{effect}' of parameter '{param_name}' with type {param_obj.type} " - f"on process '{proc_name}' in category '{cat_name}' cannot be encoded", + f"effect '{effect}' of parameter '{param_name}' with type {param_obj.type} on process " + f"'{proc_name}' in category '{cat_name}' cannot be encoded", ) # add the tabular line @@ -343,6 +362,8 @@ def write_shapes( - the datacard pattern for extracting nominal shapes, and - the datacard pattern for extracting systematic shapes. """ + import uproot + # create the directory shapes_path = real_path(shapes_path) shapes_dir = os.path.dirname(shapes_path) @@ -419,39 +440,54 @@ def fill_empty(cat_obj, h): h.view().variance[mask] = cat_obj.empty_bin_value # iterate through shapes - for cat_name, hists in self.histograms.items(): + for cat_name, proc_hists in self.histograms.items(): cat_obj = self.inference_model_inst.get_category(cat_name) _rates = rates[cat_name] = OrderedDict() _effects = effects[cat_name] = OrderedDict() - for proc_name, _hists in hists.items(): - __effects = _effects[proc_name] = OrderedDict() + for proc_name, config_hists in proc_hists.items(): + # skip if process is not known to category + if not self.inference_model_inst.has_process(process=proc_name, category=cat_name): + continue # defer the handling of data to the end if proc_name == "data": continue + # flat list of hists for configs that contribute to this category + hists: list[dict[Hashable, hist.Hist]] = [ + hd for config_name, hd in config_hists.items() + if config_name in cat_obj.config_data + ] + if not hists: + continue + + # helper to sum over them for a given shift key and an optional fallback + def sum_hists(key: Hashable, fallback_key: Hashable | None = None) -> hist.Hist: + def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: + if key in hd: + return hd[key] + if fallback_key and fallback_key in hd: + return hd[fallback_key] + raise Exception( + f"'{key}' shape for process '{proc_name}' in category '{cat_name}' misconfigured: {hd}", + ) + return sum(map(get, hists[1:]), get(hists[0]).copy()) + # get the process scale (usually 1) proc_obj = self.inference_model_inst.get_process(proc_name, category=cat_name) scale = proc_obj.scale # nominal shape - h_nom = _hists["nominal"].copy() * scale - fill_empty(cat_obj, h_nom) + h_nom = sum_hists("nominal") * scale nom_name = nom_pattern.format(category=cat_name, process=proc_name) + fill_empty(cat_obj, h_nom) handle_flow(cat_obj, h_nom, nom_name) out_file[nom_name] = h_nom _rates[proc_name] = h_nom.sum().value - # helper to return the two variations - def get_shapes(param_name): - __hists = _hists[param_name] - if "up" not in __hists or "down" not in __hists: - raise Exception( - f"shapes of parameter '{param_name}' for process '{proc_name}' " - f"in category '{cat_name}' misconfigured: {__hists}", - ) - return __hists["down"] * scale, __hists["up"] * scale + # prepare effects + __effects = _effects[proc_name] = OrderedDict() # go through all parameters and check if varied shapes need to be processed for _, _, param_obj in self.inference_model_inst.iter_parameters(category=cat_name, process=proc_name): @@ -465,19 +501,21 @@ def get_shapes(param_name): f_down, f_up = param_obj.effect else: raise ValueError( - f"cannot interpret effect of parameter '{param_obj.name}' to " - f"create shape: {param_obj.effect}", + f"cannot interpret effect of parameter '{param_obj.name}' to create shape: " + f"{param_obj.effect}", ) h_down = h_nom.copy() * f_down h_up = h_nom.copy() * f_up else: # just extract the shapes - h_down, h_up = get_shapes(param_obj.name) + h_down = sum_hists((param_obj.name, "down"), "nominal") * scale + h_up = sum_hists((param_obj.name, "up"), "nominal") * scale elif param_obj.type.is_rate: if param_obj.transformations.any_from_shape: # just extract the shapes - h_down, h_up = get_shapes(param_obj.name) + h_down = sum_hists((param_obj.name, "down"), "nominal") * scale + h_up = sum_hists((param_obj.name, "up"), "nominal") * scale else: # skip the parameter continue @@ -491,9 +529,8 @@ def get_shapes(param_name): if not (min(d, n) <= n <= max(d, n)): # skip one sided effects logger.info( - f"skipping shape centralization of parameter '{param_obj.name}' " - f"for process '{proc_name}' in category '{cat_name}' as effect " - "is one-sided", + f"skipping shape centralization of parameter '{param_obj.name}' for process " + f"'{proc_name}' in category '{cat_name}' as effect is one-sided", ) continue # find the central point, compute the diff w.r.t. nominal, and shift @@ -544,31 +581,33 @@ def get_shapes(param_name): # fake data from processes h_data = [] for proc_name in cat_obj.data_from_processes: - if proc_name not in hists: - logger.warning( - f"process '{proc_name}' not found in histograms for created fake data, " - "skipping", - ) - continue - h_data.append(hists[proc_name]["nominal"]) + if proc_name in proc_hists: + h_data.extend([hd["nominal"] for hd in proc_hists[proc_name].values()]) + else: + logger.warning(f"process '{proc_name}' not found in histograms for created fake data, skipping") if not h_data: proc_str = ",".join(map(str, cat_obj.data_from_processes)) - raise Exception(f"no requested process '{proc_str}' found to create fake data") + raise Exception(f"none of requested processes '{proc_str}' found to create fake data") h_data = sum(h_data[1:], h_data[0].copy()) data_name = data_pattern.format(category=cat_name) + fill_empty(cat_obj, h_data) + handle_flow(cat_obj, h_data, data_name) out_file[data_name] = h_data _rates["data"] = float(h_data.sum().value) - elif cat_obj.config_data_datasets: - if "data" not in hists: - raise Exception( - f"the inference model '{self.inference_model_inst.name}' is configured to " - f"use real data in category '{cat_name}' but no histogram named 'data' " - "exists", - ) + elif any(cd.data_datasets for cd in cat_obj.config_data.values()): + h_data = [] + for config_name, config_data in cat_obj.config_data.items(): + if "data" not in proc_hists or config_name not in proc_hists["data"]: + raise Exception( + f"the inference model '{self.inference_model_inst.cls_name}' is configured to use real " + f"data for config '{config_name}' in category '{cat_name}' but no histogram received at " + f"entry ['data']['{config_name}']: {proc_hists}", + ) + h_data.append(proc_hists["data"][config_name]["nominal"]) - # simply save the data histogram - h_data = hists["data"]["nominal"].copy() + # simply save the data histogram that was already built from the requested datasets + h_data = sum(h_data[1:], h_data[0].copy()) data_name = data_pattern.format(category=cat_name) handle_flow(cat_obj, h_data, data_name) out_file[data_name] = h_data @@ -589,9 +628,7 @@ def align_lines( lengths = {min(len(line), 1e9 if end < 0 else end) for line in lines} if len(lengths) > 1: - raise Exception( - f"line alignment cannot be performed with lines of varying lengths: {lengths}", - ) + raise Exception(f"line alignment cannot be performed with lines of varying lengths: {lengths}") # convert to columns and get the maximum width per column n_cols = lengths.pop() diff --git a/columnflow/ml/__init__.py b/columnflow/ml/__init__.py index fb6cdaf77..43419ea86 100644 --- a/columnflow/ml/__init__.py +++ b/columnflow/ml/__init__.py @@ -41,9 +41,10 @@ class MLModel(Derivable): assigned (:py:meth:`setup`), a fine-grained configuration of additional training requirements (:py:meth:`requires`), diverging training and evaluation phase spaces (:py:meth:`training_configs`, :py:meth:`training_calibrators`, :py:meth:`training_selector`, - :py:meth:`training_producers`), or how hyper-paramaters are string encoded for output - declarations (:py:meth:`parameter_pairs`). The optional py:meth:`preparation_producer` allows - setting a producer that is run during the initial preparation of ML columns. + :py:meth:`training_producers`, :py:meth:`evaluation_producers`), or how hyper-paramaters are + string encoded for output declarations (:py:meth:`parameter_pairs`). The optional + py:meth:`preparation_producer` allows setting a producer that is run during the initial + preparation of ML columns. .. py:classattribute:: single_config @@ -347,11 +348,11 @@ def training_configs( def training_calibrators( self: MLModel, - config_inst: od.Config, + analysis_inst: od.Analysis, requested_calibrators: Sequence[str], ) -> list[str]: """ - Given a sequence of *requested_calibrators* for a *config_inst*, this method can alter + Given a sequence of *requested_calibrators* for a *analysis_inst*, this method can alter and/or replace them to define a different set of calibrators for the preprocessing and training pipeline. This can be helpful in cases where training and evaluation phase spaces, as well as the required input columns are intended to diverge. @@ -362,18 +363,18 @@ def training_calibrators( :language: python :pyobject: TestModel.training_calibrators - :param config_inst: Config instance to extract the *requested_calibrators* from + :param analysis_inst: Analysis instance to extract the *requested_calibrators* from :returns: Set with str of the *requested_calibrators* """ return list(requested_calibrators) def training_selector( self: MLModel, - config_inst: od.Config, + analysis_inst: od.Analysis, requested_selector: str, ) -> str: """ - Given a *requested_selector* for a *config_inst*, this method can change it to define a + Given a *requested_selector* for a *analysis_inst*, this method can change it to define a different selector for the preprocessing and training pipeline. This can be helpful in cases where training and evaluation phase spaces, as well as the required input columns are intended to diverge. @@ -384,18 +385,18 @@ def training_selector( :language: python :pyobject: TestModel.training_selector - :param config_inst: Config instance to extract the *requested_selector* from + :param analysis_inst: Analysis instance to extract the *requested_selector* from :returns: Set with str of the *requested_selector* """ return requested_selector def training_producers( self: MLModel, - config_inst: od.Config, + analysis_inst: od.Analysis, requested_producers: Sequence[str], ) -> list[str]: """ - Given a sequence of *requested_producers* for a *config_inst*, this method can alter and/or + Given a sequence of *requested_producers* for a *analysis_inst*, this method can alter and/or replace them to define a different set of producers for the preprocessing and training pipeline. This can be helpful in cases where training and evaluation phase spaces, as well as the required input columns are intended to diverge. @@ -406,20 +407,42 @@ def training_producers( :language: python :pyobject: TestModel.training_producers - :param config_inst: Config instance to extract the *requested_producers* from + :param analysis_inst: Analysis instance to extract the *requested_producers* from + :returns: Set with str of the *requested_producers* + """ + return list(requested_producers) + + def evaluation_producers( + self: MLModel, + analysis_inst: od.Analysis, + requested_producers: Sequence[str], + ) -> list[str]: + """ + Given a sequence of *requested_producers* for a *analysis_inst*, this method can alter and/or + replace them to define a different set of producers for the evaluation phase of the ML + pipeline. This can be helpful in cases where the producers in the evaluation phase + and subsequent tasks are intended to diverge. + + Example usage: + + .. literalinclude:: ../../user_guide/examples/ml_code.py + :language: python + :pyobject: TestModel.evaluation_producers + + :param analysis_inst: Analysis instance to extract the *requested_producers* from :returns: Set with str of the *requested_producers* """ return list(requested_producers) def preparation_producer( self: MLModel, - config_inst: od.Config, + analysis_inst: od.Analysis, ) -> str | None: """ This method allows setting a producer that can be called as part of the preparation - of the ML input columns given a *config_inst*. + of the ML input columns given a *analysis_inst*. - :param config_inst: :py:class:`~order.Config` object for which the producer should run. + :param analysis_inst: :py:class:`~order.Analysis` object for which the producer should run. :return: Name of a :py:class:`Producer` class or *None*. """ return None diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index e3bb42773..522995977 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -8,12 +8,18 @@ __all__ = [] +import order as od + from columnflow.types import Sequence from columnflow.util import maybe_import, try_float +from columnflow.config_util import group_shifts from columnflow.plotting.plot_util import ( get_position, + apply_ax_kwargs, get_cms_label, remove_label_placeholders, + apply_label_placeholders, + calculate_stat_error, ) hist = maybe_import("hist") @@ -21,18 +27,19 @@ mpl = maybe_import("matplotlib") plt = maybe_import("matplotlib.pyplot") mplhep = maybe_import("mplhep") -od = maybe_import("order") -def draw_error_bands( +def draw_stat_error_bands( ax: plt.Axes, h: hist.Hist, norm: float | Sequence | np.ndarray = 1.0, **kwargs, ) -> None: - # compute relative errors - rel_error = h.variances()**0.5 / h.values() - rel_error[np.isnan(rel_error)] = 0.0 + assert len(h.axes) == 1 + + # compute relative statistical errors + rel_stat_error = h.variances()**0.5 / h.values() + rel_stat_error[np.isnan(rel_stat_error)] = 0.0 # compute the baseline # fill 1 in places where both numerator and denominator are 0, and 0 for remaining nan's @@ -40,19 +47,119 @@ def draw_error_bands( baseline[(h.values() == 0) & (norm == 0)] = 1.0 baseline[np.isnan(baseline)] = 0.0 - defaults = { + bar_kwargs = { "x": h.axes[0].centers, + "bottom": baseline * (1 - rel_stat_error), + "height": baseline * 2 * rel_stat_error, "width": h.axes[0].edges[1:] - h.axes[0].edges[:-1], - "height": baseline * 2 * rel_error, - "bottom": baseline * (1 - rel_error), "hatch": "///", - "facecolor": "none", "linewidth": 0, - "color": "black", + "color": "none", + "edgecolor": "black", "alpha": 1.0, + **kwargs, } - defaults.update(kwargs) - ax.bar(**defaults) + ax.bar(**bar_kwargs) + + +def draw_syst_error_bands( + ax: plt.Axes, + h: hist.Hist, + syst_hists: Sequence[hist.Hist], + shift_insts: Sequence[od.Shift], + norm: float | Sequence | np.ndarray = 1.0, + method: str = "quadratic_sum", + **kwargs, +) -> None: + assert len(h.axes) == 1 + assert method in ("quadratic_sum", "envelope") + + nominal_shift, shift_groups = group_shifts(shift_insts) + if nominal_shift is None: + raise ValueError("no nominal shift found in the list of shift instances") + + # create pairs of shifts mapping from up -> down and vice versa + shift_pairs = {} + for up_shift, down_shift in shift_groups.values(): + shift_pairs[up_shift] = down_shift + shift_pairs[down_shift] = up_shift + + # stack histograms separately per shift, falling back to the nominal one when missing + shift_stacks: dict[od.Shift, hist.Hist] = {} + for shift_inst in sum(shift_groups.values(), []): + for _h in syst_hists: + # when the shift is present, the flipped shift must exist as well + shift_ax = _h.axes["shift"] + if shift_inst.name in shift_ax: + if shift_pairs[shift_inst].name not in shift_ax: + raise RuntimeError( + f"shift {shift_inst} found in histogram but {shift_pairs[shift_inst]} is missing; " + f"existing shifts: {','.join(map(str, list(shift_ax)))}", + ) + shift_name = shift_inst.name + else: + shift_name = nominal_shift.name + # store the slice + _h = _h[{"shift": hist.loc(shift_name)}] + if shift_inst not in shift_stacks: + shift_stacks[shift_inst] = _h + else: + shift_stacks[shift_inst] += _h + + # loop over bins, subtract nominal yields from stacked yields and merge differences into + # a systematic error per bin using the given method (quadratic sum vs. evelope) + # note 1: if the up/down variations of the same shift source point in the same direction, a + # statistical combination is pointless and their minimum/maximum is selected instead + # note 2: relative signs are consumed into the meaning of "up" and "down" here as they already + # are combinations evaluated for a specific direction + syst_error_up = [] + syst_error_down = [] + for b in range(h.axes[0].size): + up_diffs = [] + down_diffs = [] + for source, (up_shift, down_shift) in shift_groups.items(): + # get actual differences resulting from this shift + shift_up_diff = shift_stacks[up_shift].values()[b] - h.values()[b] + shift_down_diff = shift_stacks[down_shift].values()[b] - h.values()[b] + # store them depending on whether they really increase or decrease the yield + up_diffs.append(max(shift_up_diff, shift_down_diff, 0)) + down_diffs.append(min(shift_up_diff, shift_down_diff, 0)) + # combination based on the method + if method == "quadratic_sum": + up_diff = sum(d**2 for d in up_diffs)**0.5 + down_diff = sum(d**2 for d in down_diffs)**0.5 + else: # envelope + up_diff = max(up_diffs) + down_diff = min(down_diffs) + # save values + syst_error_up.append(up_diff) + syst_error_down.append(down_diff) + + # compute relative systematic errors + rel_syst_error_up = np.array(syst_error_up) / h.values() + rel_syst_error_up[np.isnan(rel_syst_error_up)] = 0.0 + rel_syst_error_down = np.array(syst_error_down) / h.values() + rel_syst_error_down[np.isnan(rel_syst_error_down)] = 0.0 + + # compute the baseline + # fill 1 in places where both numerator and denominator are 0, and 0 for remaining nan's + baseline = h.values() / norm + baseline[(h.values() == 0) & (norm == 0)] = 1.0 + baseline[np.isnan(baseline)] = 0.0 + + bar_kwargs = { + "x": h.axes[0].centers, + "bottom": baseline * (1 - rel_syst_error_down), + "height": baseline * (rel_syst_error_up + rel_syst_error_down), + "width": h.axes[0].edges[1:] - h.axes[0].edges[:-1], + "hatch": "\\\\\\", + "linewidth": 0, + "color": "none", + "edgecolor": "#30c300", + "alpha": 1.0, + **kwargs, + } + ax.bar(**bar_kwargs) def draw_stack( @@ -77,6 +184,7 @@ def draw_stack( # solution: transform norm -> [norm]*len(h) h = hist.Stack(*[i / norm for i in h]) + # draw only the stack, no error bars/bands with stack = True defaults = { "ax": ax, "stack": True, @@ -90,19 +198,37 @@ def draw_hist( ax: plt.Axes, h: hist.Hist, norm: float | Sequence | np.ndarray = 1.0, + error_type: str = "variance", **kwargs, ) -> None: + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} + if kwargs.get("color", "") is None: # when color is set to None, remove it such that matplotlib automatically chooses a color kwargs.pop("color") - h = h / norm defaults = { "ax": ax, "stack": False, "histtype": "step", } defaults.update(kwargs) + if "yerr" not in defaults: + if h.storage_type.accumulator is not hist.accumulators.WeightedSum: + raise TypeError( + "Error bars calculation only implemented for histograms with storage type WeightedSum " + "either change the Histogram storage_type or set yerr manually", + ) + yerr = calculate_stat_error(h, error_type) + # normalize yerr to the histogram = error propagation on standard deviation + yerr = abs(yerr / norm) + # replace inf with nan for any bin where norm = 0 and calculate_stat_error returns a non zero value + if np.any(np.isinf(yerr)): + yerr[np.isinf(yerr)] = np.nan + defaults["yerr"] = yerr + + h = h / norm + h.plot1d(**defaults) @@ -110,11 +236,14 @@ def draw_profile( ax: plt.Axes, h: hist.Hist, norm: float | Sequence | np.ndarray = 1.0, + error_type: str = "variance", **kwargs, ) -> None: """ Profiled histograms contains the storage type "Mean" and can therefore not be normalized """ + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} + if kwargs.get("color", "") is None: # when color is set to None, remove it such that matplotlib automatically chooses a color kwargs.pop("color") @@ -125,6 +254,13 @@ def draw_profile( "histtype": "step", } defaults.update(kwargs) + if "yerr" not in defaults: + if h.storage_type.accumulator is not hist.accumulators.WeightedSum: + raise TypeError( + "Error bars calculation only implemented for histograms with storage type WeightedSum " + "either change the Histogram storage_type or set yerr manually", + ) + defaults["yerr"] = calculate_stat_error(h, error_type) h.plot1d(**defaults) @@ -132,31 +268,37 @@ def draw_errorbars( ax: plt.Axes, h: hist.Hist, norm: float | Sequence | np.ndarray = 1.0, + error_type: str = "poisson_unweighted", **kwargs, ) -> None: + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} + values = h.values() / norm - variances = h.variances() / norm**2 - # compute asymmetric poisson errors for data - # TODO: passing the output of poisson_interval as yerr to mpl.plothist leads to - # buggy error bars and the documentation is clearly wrong (mplhep 0.3.12, - # hist 2.4.0), so adjust the output to make up for that, but maybe update or - # remove the next lines if this is fixed to not correct it "twice" - from hist.intervals import poisson_interval - yerr = poisson_interval(values, variances) - yerr[np.isnan(yerr)] = 0 - yerr[0] = values - yerr[0] - yerr[1] -= values - yerr[yerr < 0] = 0 + defaults = { "x": h.axes[0].centers, "y": values, - "yerr": yerr, "color": "k", "linestyle": "none", "marker": "o", "elinewidth": 1, } defaults.update(kwargs) + + if "yerr" not in defaults: + if h.storage_type.accumulator is not hist.accumulators.WeightedSum: + raise TypeError( + "Error bars calculation only implemented for histograms with storage type WeightedSum " + "either change the Histogram storage_type or set yerr manually", + ) + yerr = calculate_stat_error(h, error_type) + # normalize yerr to the histogram = error propagation on standard deviation + yerr = abs(yerr / norm) + # replace inf with nan for any bin where norm = 0 and calculate_stat_error returns a non zero value + if np.any(np.isinf(yerr)): + yerr[np.isinf(yerr)] = np.nan + defaults["yerr"] = yerr + ax.errorbar(**defaults) @@ -169,34 +311,34 @@ def plot_all( whitespace_fraction: float = 0.3, magnitudes: float = 4, **kwargs, -) -> tuple(plt.Figure, tuple(plt.Axes)): +) -> tuple[plt.Figure, tuple[plt.Axes, ...]]: """ - Function that calls multiple plotting methods based on two configuration dictionaries, - *plot_config* and *style_config*. + Function that calls multiple plotting methods based on two configuration dictionaries, *plot_config* and + *style_config*. The *plot_config* expects dictionaries with fields: - "method": str, identical to the name of a function defined above, - "hist": hist.Hist or hist.Stack, - "kwargs": dict (optional), - "ratio_kwargs": dict (optional), + + - "method": str, identical to the name of a function defined above + - "hist": hist.Hist or hist.Stack + - "kwargs": dict (optional) + - "ratio_kwargs": dict (optional) The *style_config* expects fields (all optional): - "gridspec_cfg": dict, - "ax_cfg": dict, - "rax_cfg": dict, - "legend_cfg": dict, - "cms_label_cfg": dict, - - :param plot_config: Dictionary that defines which plot methods will be called with which - key word arguments. + + - "gridspec_cfg": dict + - "ax_cfg": dict + - "rax_cfg": dict + - "legend_cfg": dict + - "cms_label_cfg": dict + + :param plot_config: Dictionary that defines which plot methods will be called with which key word arguments. :param style_config: Dictionary that defines arguments on how to style the overall plot. :param skip_ratio: Optional bool parameter to not display the ratio plot. :param skip_legend: Optional bool parameter to not display the legend. :param cms_label: Optional string parameter to set the CMS label text. - :param whitespace_fraction: Optional float parameter that defines the ratio of which - the plot will consist of whitespace for the legend and labels - :param magnitudes: Optional float parameter that defines the displayed ymin when plotting - with a logarithmic scale. + :param whitespace_fraction: Optional float parameter that defines the ratio of which the plot will consist of + whitespace for the legend and labels + :param magnitudes: Optional float parameter that defines the displayed ymin when plotting with a logarithmic scale. :return: tuple of plot figure and axes """ # general mplhep style @@ -207,17 +349,21 @@ def plot_all( grid_spec = {"left": 0.15, "right": 0.95, "top": 0.95, "bottom": 0.1} grid_spec |= style_config.get("gridspec_cfg", {}) if not skip_ratio: - grid_spec |= {"height_ratios": [3, 1], "hspace": 0} + grid_spec = {"height_ratios": [3, 1], "hspace": 0, **grid_spec} fig, axs = plt.subplots(2, 1, gridspec_kw=grid_spec, sharex=True) (ax, rax) = axs else: + grid_spec.pop("height_ratios", None) fig, ax = plt.subplots(gridspec_kw=grid_spec) axs = (ax,) # invoke all plots methods plot_methods = { func.__name__: func - for func in [draw_error_bands, draw_stack, draw_hist, draw_profile, draw_errorbars] + for func in [ + draw_stat_error_bands, draw_syst_error_bands, draw_stack, draw_hist, draw_profile, + draw_errorbars, + ] } for key, cfg in plot_config.items(): # check if required fields are present @@ -254,24 +400,8 @@ def plot_all( # prioritize style_config ax settings ax_kwargs.update(style_config.get("ax_cfg", {})) - # some settings cannot be handled by ax.set - xminorticks = ax_kwargs.pop("xminorticks", ax_kwargs.pop("minorxticks", None)) - yminorticks = ax_kwargs.pop("yminorticks", ax_kwargs.pop("minoryticks", None)) - xloc = ax_kwargs.pop("xloc", None) - yloc = ax_kwargs.pop("yloc", None) - - # set all values - ax.set(**ax_kwargs) - - # set manual configs - if xminorticks is not None: - ax.set_xticks(xminorticks, minor=True) - if yminorticks is not None: - ax.set_xticks(yminorticks, minor=True) - if xloc is not None: - ax.set_xlabel(ax.get_xlabel(), loc=xloc) - if yloc is not None: - ax.set_ylabel(ax.get_ylabel(), loc=yloc) + # apply axis kwargs + apply_ax_kwargs(ax, ax_kwargs) # ratio plot if not skip_ratio: @@ -285,18 +415,8 @@ def plot_all( } rax_kwargs.update(style_config.get("rax_cfg", {})) - # some settings cannot be handled by ax.set - xloc = rax_kwargs.pop("xloc", None) - yloc = rax_kwargs.pop("yloc", None) - - # set all values - rax.set(**rax_kwargs) - - # set manual configs - if xloc is not None: - rax.set_xlabel(rax.get_xlabel(), loc=xloc) - if yloc is not None: - rax.set_ylabel(rax.get_ylabel(), loc=yloc) + # apply axis kwargs + apply_ax_kwargs(rax, rax_kwargs) # remove x-label from main axis if "xlabel" in rax_kwargs: @@ -322,7 +442,7 @@ def plot_all( # custom argument: entries_per_column n_cols = legend_kwargs.get("ncols", 1) - entries_per_col = legend_kwargs.pop("entries_per_column", None) + entries_per_col = legend_kwargs.pop("cf_entries_per_column", None) if callable(entries_per_col): entries_per_col = entries_per_col(ax, handles, labels, n_cols) if entries_per_col and n_cols > 1: @@ -339,10 +459,21 @@ def plot_all( labels.insert(i * max_entries + n, "") # custom hook to adjust handles and labels - update_handles_labels = legend_kwargs.pop("update_handles_labels", None) + update_handles_labels = legend_kwargs.pop("cf_update_handles_labels", None) if callable(update_handles_labels): update_handles_labels(ax, handles, labels, n_cols) + # interpret placeholders + apply = [] + if legend_kwargs.pop("cf_short_labels", False): + apply.append("SHORT") + if legend_kwargs.pop("cf_line_breaks", False): + apply.append("BREAK") + labels = [apply_label_placeholders(label, apply=apply) for label in labels] + + # drop remaining placeholders + labels = list(map(remove_label_placeholders, labels)) + # make legend using ordered handles/labels ax.legend(handles, labels, **legend_kwargs) diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 61b091b7a..0608075fa 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -21,13 +21,16 @@ remove_residual_axis, apply_variable_settings, apply_process_settings, - apply_density_to_hists, + apply_process_scaling, + apply_density, hists_merge_cutflow_steps, get_position, get_profile_variations, blind_sensitive_bins, join_labels, ) +from columnflow.hist_util import add_missing_shifts + hist = maybe_import("hist") np = maybe_import("numpy") @@ -37,11 +40,12 @@ od = maybe_import("order") -def plot_variable_per_process( +def plot_variable_stack( hists: OrderedDict, config_inst: od.Config, category_inst: od.Category, variable_insts: list[od.Variable], + shift_insts: list[od.Shift] | None, style_config: dict | None = None, density: bool | None = False, shape_norm: bool | None = False, @@ -50,29 +54,62 @@ def plot_variable_per_process( variable_settings: dict | None = None, **kwargs, ) -> plt.Figure: - """ - TODO: misleading function name, it should somehow contain "stack" and not "per_proceess" - """ - remove_residual_axis(hists, "shift") - variable_inst = variable_insts[0] - blinding_threshold = kwargs.get("blinding_threshold", None) + # process-based settings (styles and attributes) + hists, process_style_config = apply_process_settings(hists, process_settings) + # variable-based settings (rebinning, slicing, flow handling) + hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + # process scaling + hists = apply_process_scaling(hists) + # remove data in bins where sensitivity exceeds some threshold + blinding_threshold = kwargs.get("blinding_threshold", None) if blinding_threshold: hists = blind_sensitive_bins(hists, config_inst, blinding_threshold) - hists = apply_variable_settings(hists, variable_insts, variable_settings) - hists = apply_process_settings(hists, process_settings) - hists = apply_density_to_hists(hists, density) - - plot_config = prepare_stack_plot_config(hists, shape_norm=shape_norm, **kwargs) + # density scaling per bin + if density: + hists = apply_density(hists, density) + + if len(shift_insts) == 1: + # when there is exactly one shift bin, we can remove the shift axis + hists = remove_residual_axis(hists, "shift", select_value=shift_insts[0].name) + else: + # remove shift axis of histograms that are not to be stacked + unstacked_hists = { + proc_inst: h + for proc_inst, h in hists.items() + if proc_inst.is_mc and getattr(proc_inst, "unstack", False) + } + hists |= remove_residual_axis(unstacked_hists, "shift", select_value="nominal") + + # prepare the plot config + plot_config = prepare_stack_plot_config( + hists, + shape_norm=shape_norm, + shift_insts=shift_insts, + **kwargs, + ) + # prepare and update the style config default_style_config = prepare_style_config( - config_inst, category_inst, variable_inst, density, shape_norm, yscale, + config_inst, + category_inst, + variable_inst, + density, + shape_norm, + yscale, + ) + style_config = law.util.merge_dicts( + default_style_config, + process_style_config, + variable_style_config[variable_inst], + style_config, + deep=True, ) - style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + # additional, plot function specific changes if shape_norm: - style_config["ax_cfg"]["ylabel"] = r"$\Delta N/N$" + style_config["ax_cfg"]["ylabel"] = "Normalized entries" return plot_all(plot_config, style_config, **kwargs) @@ -86,18 +123,19 @@ def plot_variable_variants( density: bool | None = False, shape_norm: bool = False, yscale: str | None = None, - hide_errors: bool | None = None, + hide_stat_errors: bool | None = None, variable_settings: dict | None = None, **kwargs, ) -> plt.Figure: """ TODO. """ - remove_residual_axis(hists, "shift") + hists = remove_residual_axis(hists, "shift") variable_inst = variable_insts[0] hists = apply_variable_settings(hists, variable_insts, variable_settings) - hists = apply_density_to_hists(hists, density) + if density: + hists = apply_density(hists, density) plot_config = OrderedDict() @@ -118,14 +156,19 @@ def plot_variable_variants( "norm": hists["Initial"].values(), }, } - if hide_errors: + if hide_stat_errors: for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: plot_cfg[key]["yerr"] = None # setup style config default_style_config = prepare_style_config( - config_inst, category_inst, variable_inst, density, shape_norm, yscale, + config_inst, + category_inst, + variable_inst, + density, + shape_norm, + yscale, ) # plot-function specific changes default_style_config["rax_cfg"]["ylim"] = (0., 1.1) @@ -133,7 +176,7 @@ def plot_variable_variants( style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) if shape_norm: - style_config["ax_cfg"]["ylabel"] = r"$\Delta N/N$" + style_config["ax_cfg"]["ylabel"] = "Normalized entries" return plot_all(plot_config, style_config, **kwargs) @@ -143,11 +186,12 @@ def plot_shifted_variable( config_inst: od.Config, category_inst: od.Category, variable_insts: list[od.Variable], + shift_insts: list[od.Shift] | None, style_config: dict | None = None, density: bool | None = False, shape_norm: bool = False, yscale: str | None = None, - hide_errors: bool | None = None, + hide_stat_errors: bool | None = None, legend_title: str | None = None, process_settings: dict | None = None, variable_settings: dict | None = None, @@ -157,9 +201,17 @@ def plot_shifted_variable( TODO. """ variable_inst = variable_insts[0] - hists = apply_variable_settings(hists, variable_insts, variable_settings) - hists = apply_process_settings(hists, process_settings) - hists = apply_density_to_hists(hists, density) + + hists, process_style_config = apply_process_settings(hists, process_settings) + hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + hists = apply_process_scaling(hists) + if density: + hists = apply_density(hists, density) + + # add missing shifts to all histograms + all_shifts = set.union(*[set(h.axes["shift"]) for h in hists.values()]) + for h in hists.values(): + add_missing_shifts(h, all_shifts, str_axis="shift", nominal_bin="nominal") # create the sum of histograms over all processes h_sum = sum(list(hists.values())[1:], list(hists.values())[0].copy()) @@ -171,12 +223,12 @@ def plot_shifted_variable( "up": "red", "down": "blue", } - for i, shift_id in enumerate(h_sum.axes["shift"]): - shift_inst = config_inst.get_shift(shift_id) + for i, shift_name in enumerate(h_sum.axes["shift"]): + shift_inst = config_inst.get_shift(shift_name) - h = h_sum[{"shift": hist.loc(shift_id)}] + h = h_sum[{"shift": hist.loc(shift_name)}] # assuming `nominal` always has shift id 0 - ratio_norm = h_sum[{"shift": hist.loc(0)}].values() + ratio_norm = h_sum[{"shift": hist.loc("nominal")}].values() diff = sum(h.values()) / sum(ratio_norm) - 1 label = shift_inst.label @@ -196,7 +248,7 @@ def plot_shifted_variable( "color": colors[shift_inst.direction], }, } - if hide_errors: + if hide_stat_errors: for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: plot_cfg[key]["yerr"] = None @@ -211,16 +263,26 @@ def plot_shifted_variable( yscale = "log" if variable_inst.log_y else "linear" default_style_config = prepare_style_config( - config_inst, category_inst, variable_inst, density, shape_norm, yscale, + config_inst, + category_inst, + variable_inst, + density, + shape_norm, + yscale, ) default_style_config["rax_cfg"]["ylim"] = (0.25, 1.75) default_style_config["rax_cfg"]["ylabel"] = "Ratio" if legend_title: default_style_config["legend_cfg"]["title"] = legend_title - - style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + style_config = law.util.merge_dicts( + default_style_config, + process_style_config, + variable_style_config[variable_inst], + style_config, + deep=True, + ) if shape_norm: - style_config["ax_cfg"]["ylabel"] = r"$\Delta N/N$" + style_config["ax_cfg"]["ylabel"] = "Normalized entries" return plot_all(plot_config, style_config, **kwargs) @@ -239,10 +301,12 @@ def plot_cutflow( """ TODO. """ - remove_residual_axis(hists, "shift") + hists = remove_residual_axis(hists, "shift") - hists = apply_process_settings(hists, process_settings) - hists = apply_density_to_hists(hists, density) + hists, process_style_config = apply_process_settings(hists, process_settings) + hists = apply_process_scaling(hists) + if density: + hists = apply_density(hists, density) hists = hists_merge_cutflow_steps(hists) # setup plotting config @@ -290,7 +354,7 @@ def plot_cutflow( "com": config_inst.campaign.ecm, }, } - style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + style_config = law.util.merge_dicts(default_style_config, process_style_config, style_config, deep=True) # ratio plot not used here; set `skip_ratio` to True kwargs["skip_ratio"] = True @@ -310,7 +374,7 @@ def plot_profile( style_config: dict | None = None, density: bool | None = False, yscale: str | None = "", - hide_errors: bool | None = None, + hide_stat_errors: bool | None = None, process_settings: dict | None = None, variable_settings: dict | None = None, skip_base_distribution: bool = False, @@ -340,11 +404,13 @@ def plot_profile( raise Exception("The plot_profile function can only be used for 2-dimensional input histograms.") # remove shift axis from histograms - remove_residual_axis(hists, "shift") + hists = remove_residual_axis(hists, "shift") - hists = apply_variable_settings(hists, variable_insts, variable_settings) - hists = apply_process_settings(hists, process_settings) - hists = apply_density_to_hists(hists, density) + hists, process_style_config = apply_process_settings(hists, process_settings) + hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + hists = apply_process_scaling(hists) + if density: + hists = apply_density(hists, density) # process histograms to profiled and reduced histograms profiled_hists, reduced_hists = OrderedDict(), OrderedDict() @@ -388,17 +454,28 @@ def plot_profile( }, } - if hide_errors: + if hide_stat_errors: for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: plot_cfg[key]["yerr"] = None default_style_config = prepare_style_config( - config_inst, category_inst, variable_insts[0], density=density, yscale=yscale, + config_inst, + category_inst, + variable_insts[0], + density=density, + yscale=yscale, + xtick_rotation=kwargs.get("rotate_xticks", None), ) default_style_config["ax_cfg"]["ylabel"] = f"profiled {variable_insts[1].x_title}" - style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + style_config = law.util.merge_dicts( + default_style_config, + process_style_config, + variable_style_config[variable_insts[0]], + style_config, + deep=True, + ) # ratio plot not used here; set `skip_ratio` to True kwargs["skip_ratio"] = True @@ -433,7 +510,7 @@ def plot_profile( ) ax1.set( ylim=(ax1_ymin, ax1_ymax), - ylabel=r"$\Delta N/N$", + ylabel="Normalized entries", yscale=base_distribution_yscale, ) diff --git a/columnflow/plotting/plot_functions_2d.py b/columnflow/plotting/plot_functions_2d.py index f2a583b86..c611f13f4 100644 --- a/columnflow/plotting/plot_functions_2d.py +++ b/columnflow/plotting/plot_functions_2d.py @@ -17,7 +17,8 @@ remove_residual_axis, apply_variable_settings, apply_process_settings, - apply_density_to_hists, + apply_process_scaling, + apply_density, get_position, reduce_with, ) @@ -36,6 +37,7 @@ def plot_2d( config_inst: od.Config, category_inst: od.Category, variable_insts: list[od.Variable], + shift_insts: list[od.Shift], style_config: dict | None = None, density: bool | None = False, shape_norm: bool | None = False, @@ -54,13 +56,13 @@ def plot_2d( **kwargs, ) -> plt.Figure: # remove shift axis from histograms - remove_residual_axis(hists, "shift") + hists = remove_residual_axis(hists, "shift") - hists = apply_variable_settings(hists, variable_insts, variable_settings) - - hists = apply_process_settings(hists, process_settings) - - hists = apply_density_to_hists(hists, density) + hists, process_style_config = apply_process_settings(hists, process_settings) + hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + hists = apply_process_scaling(hists) + if density: + hists = apply_density(hists, density) # use CMS plotting style plt.style.use(mplhep.style.CMS) @@ -179,7 +181,14 @@ def plot_2d( "text": category_inst.label, }, } - style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + style_config = law.util.merge_dicts( + default_style_config, + process_style_config, + variable_style_config[variable_insts[0]], + variable_style_config[variable_insts[1]], + style_config, + deep=True, + ) # apply style_config ax.set(**style_config["ax_cfg"]) diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index bb9dccb70..8aa5a0302 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -17,8 +17,9 @@ import order as od import scinum as sn -from columnflow.util import maybe_import, try_int, try_complex -from columnflow.types import Iterable, Any, Callable +from columnflow.util import maybe_import, try_int, try_complex, UNSET +from columnflow.hist_util import copy_axis +from columnflow.types import Iterable, Any, Callable, Sequence, Hashable math = maybe_import("math") hist = maybe_import("hist") @@ -68,6 +69,14 @@ def get_cms_label(ax: plt.Axes, llabel: str) -> dict: return cms_label_kwargs +def get_attr_or_aux(proc: od.AuxDataMixin, attr: str, default: Any) -> Any: + if (value := getattr(proc, attr, UNSET)) != UNSET: + return value + if proc.has_aux(attr): + return proc.get_aux(attr) + return default + + def round_dynamic(value: int | float) -> int | float: """ Rounds a *value* at various scales to a subjective, sensible precision. Rounding rules: @@ -93,49 +102,6 @@ def round_dynamic(value: int | float) -> int | float: return int(value) if value >= 1 else value -def inject_label( - label: str, - inject: str | int | float, - *, - placeholder: str | None = None, - before_parentheses: bool = False, -) -> str: - """ - Injects a string *inject* into a *label* at a specific position, determined by different - strategies in the following order: - - - If *placeholder* is defined, *label* should contain a substring ``"__PLACEHOLDER__"`` - which is replaced. - - Otherwise, if *before_parentheses* is set to True, the string is inserted before the last - pair of parentheses. - - Otherwise, the string is appended to the label. - - :param label: The label to inject the string *inject* into. - :param inject: The string to inject. - :param placeholder: The placeholder to replace in the label. - :param before_parentheses: Whether to insert the string before the parentheses in the label. - :return: The updated label. - """ - # replace the placeholder - if placeholder and f"__{placeholder}__" in label: - return label.replace(f"__{placeholder}__", inject) - - # when the label contains trailing parentheses, insert the string before them - if before_parentheses and label.endswith(")"): - in_parentheses = 1 - for i in range(len(label) - 2, -1, -1): - c = label[i] - if c == ")": - in_parentheses += 1 - elif c == "(": - in_parentheses -= 1 - if not in_parentheses: - return f"{label[:i]} {inject} {label[i:]}" - - # otherwise, just append - return f"{label} {inject}" - - def apply_settings( instances: Iterable[od.AuxDataMixin], settings: dict[str, Any] | None, @@ -219,13 +185,15 @@ def hists_merge_cutflow_steps( def apply_process_settings( - hists: dict, + hists: dict[Hashable, hist.Hist], process_settings: dict | None = None, -) -> dict: +) -> tuple[dict[Hashable, hist.Hist], dict[str, Any]]: """ - applies settings from `process_settings` dictionary to the `process_insts`; - the `scale` setting is directly applied to the histograms + applies settings from `process_settings` dictionary to the `process_insts` """ + # store info gathered along application of process settings that can be inserted to the style config + process_style_config = {} + # apply all settings on process insts apply_settings( hists.keys(), @@ -233,6 +201,10 @@ def apply_process_settings( parent_check=(lambda proc, parent_name: proc.has_parent_process(parent_name)), ) + return hists, process_style_config + + +def apply_process_scaling(hists: dict[Hashable, hist.Hist]) -> dict[Hashable, hist.Hist]: # helper to compute the stack integral stack_integral = None @@ -240,18 +212,19 @@ def get_stack_integral() -> float: nonlocal stack_integral if stack_integral is None: stack_integral = sum( - proc_h.sum().value + remove_residual_axis_single(proc_h, "shift", select_value="nominal").sum().value for proc, proc_h in hists.items() - if not hasattr(proc, "unstack") and not proc.is_data + if proc.is_mc and not get_attr_or_aux(proc, "unstack", False) ) return stack_integral for proc_inst, h in hists.items(): # apply "scale" setting directly to the hists - scale_factor = getattr(proc_inst, "scale", None) or proc_inst.x("scale", None) + scale_factor = get_attr_or_aux(proc_inst, "scale", None) if scale_factor == "stack": # compute the scale factor and round - scale_factor = round_dynamic(get_stack_integral() / h.sum().value) + h_no_shift = remove_residual_axis_single(h, "shift", select_value="nominal") + scale_factor = round_dynamic(get_stack_integral() / h_no_shift.sum().value) or 1 if try_int(scale_factor): scale_factor = int(scale_factor) hists[proc_inst] = h * scale_factor @@ -260,39 +233,40 @@ def get_stack_integral() -> float: if scale_factor < 1e5 else re.sub(r"e(\+?)(-?)(0*)", r"e\2", f"{scale_factor:.1e}") ) - proc_inst.label = inject_label( - proc_inst.label, - rf"$\times${scale_factor_str}", - placeholder="SCALE", - before_parentheses=True, - ) + if scale_factor != 1: + proc_inst.label = apply_label_placeholders( + proc_inst.label, + apply="SCALE", + scale=scale_factor_str, + ) - # remove remaining placeholders - proc_inst.label = remove_label_placeholders(proc_inst.label) + # remove remaining scale placeholders + proc_inst.label = remove_label_placeholders(proc_inst.label, drop="SCALE") return hists -def remove_label_placeholders(label: str) -> str: - return re.sub("__[A-Z0-9]+__", "", label) - - def apply_variable_settings( - hists: dict, + hists: dict[Hashable, hist.Hist], variable_insts: list[od.Variable], variable_settings: dict | None = None, -) -> dict: +) -> tuple[dict[Hashable, hist.Hist], dict[od.Variable, dict[str, Any]]]: """ applies settings from *variable_settings* dictionary to the *variable_insts*; the *rebin*, *overflow*, *underflow*, and *slice* settings are directly applied to the histograms """ + # store info gathered along application of variable settings that can be inserted to the style config + variable_style_config = {} + # apply all settings on variable insts apply_settings(variable_insts, variable_settings) # apply certain setting directly to histograms for var_inst in variable_insts: + variable_style_config[var_inst] = {} + # rebinning - rebin_factor = getattr(var_inst, "rebin", None) or var_inst.x("rebin", None) + rebin_factor = get_attr_or_aux(var_inst, "rebin", None) if try_int(rebin_factor): for proc_inst, h in list(hists.items()): rebin_factor = int(rebin_factor) @@ -300,20 +274,15 @@ def apply_variable_settings( hists[proc_inst] = h # overflow and underflow bins - overflow = getattr(var_inst, "overflow", None) - if overflow is None: - overflow = var_inst.x("overflow", False) - underflow = getattr(var_inst, "underflow", None) - if underflow is None: - underflow = var_inst.x("underflow", False) - + overflow = get_attr_or_aux(var_inst, "overflow", False) + underflow = get_attr_or_aux(var_inst, "underflow", False) if overflow or underflow: for proc_inst, h in list(hists.items()): h = use_flow_bins(h, var_inst.name, underflow=underflow, overflow=overflow) hists[proc_inst] = h # slicing - slices = getattr(var_inst, "slice", None) or var_inst.x("slice", None) + slices = get_attr_or_aux(var_inst, "slice", None) if ( slices and isinstance(slices, Iterable) and len(slices) >= 2 and try_complex(slices[0]) and try_complex(slices[1]) @@ -324,7 +293,25 @@ def apply_variable_settings( h = h[{var_inst.name: slice(slice_0, slice_1)}] hists[proc_inst] = h - return hists + # additional x axis transformations + for trafo in law.util.make_list(get_attr_or_aux(var_inst, "x_transformations", None) or []): + # forced representation into equal bins + if trafo in {"equal_distance_with_edges", "equal_distance_with_indices"}: + hists, orig_edges = rebin_equal_width(hists, var_inst.name) + new_edges = list(hists.values())[0].axes[-1].edges + # store edge values as well as ticks if needed + ax_cfg = {"xlim": (new_edges[0], new_edges[-1])} + if trafo == "equal_distance_with_edges": + # optionally round edges + rnd = get_attr_or_aux(var_inst, "x_edge_rounding", (lambda e: e)) + edge_labels = [rnd(e) for e in orig_edges] + ax_cfg |= {"xmajorticks": new_edges, "xmajorticklabels": edge_labels, "xminorticks": []} + variable_style_config[var_inst].setdefault("ax_cfg", {}).update(ax_cfg) + variable_style_config[var_inst].setdefault("rax_cfg", {}).update(ax_cfg) + else: + raise ValueError(f"unknown x transformation '{trafo}'") + + return hists, variable_style_config def use_flow_bins( @@ -379,7 +366,7 @@ def use_flow_bins( return h_out -def apply_density_to_hists(hists: dict, density: bool | None = False) -> dict: +def apply_density(hists: dict, density: bool = True) -> dict: """ Scales number of histogram entries to bin widths. """ @@ -396,22 +383,49 @@ def apply_density_to_hists(hists: dict, density: bool | None = False) -> dict: return hists -def remove_residual_axis(hists: dict, ax_name: str, max_bins: int = 1) -> dict: +def remove_residual_axis_single( + h: hist.Hist, + ax_name: str, + max_bins: int = 1, + select_value: Any = None, +) -> hist.Hist: + # force always returning a copy + h = h.copy() + + # nothing to do if the axis is not present + if ax_name not in h.axes.name: + return h + + # when a selection is given, select the corresponding value + if select_value is not None: + h = h[{ax_name: [hist.loc(select_value)]}] + + # check remaining axis + n_bins = len(h.axes[ax_name]) + if n_bins > max_bins: + raise Exception( + f"axis '{ax_name}' of histogram has {n_bins} bins whereas at most {max_bins} bins are " + f"accepted for removal of residual axis", + ) + + # accumulate remaining axis + return h[{ax_name: sum}] + + +def remove_residual_axis( + hists: dict, + ax_name: str, + max_bins: int = 1, + select_value: Any = None, +) -> dict: """ - removes axis named 'ax_name' if existing and there is only a single bin in the axis; + Removes axis named 'ax_name' if existing and there is only a single bin in the axis; raises Exception otherwise """ - for key, hist in list(hists.items()): - if ax_name in hist.axes.name: - n_bins = len(hist.axes[ax_name]) - if n_bins > max_bins: - raise Exception( - f"{ax_name} axis of histogram for key {key} has {n_bins} values whereas at most " - f"{max_bins} is expected", - ) - hists[key] = hist[{ax_name: sum}] - - return hists + return { + key: remove_residual_axis_single(h, ax_name, max_bins=max_bins, select_value=select_value) + for key, h in hists.items() + } def prepare_style_config( @@ -421,6 +435,7 @@ def prepare_style_config( density: bool | None = False, shape_norm: bool | None = False, yscale: str | None = "", + **kwargs: Any, ) -> dict: """ small helper function that sets up a default style config based on the instances @@ -443,15 +458,16 @@ def prepare_style_config( style_config = { "ax_cfg": { "xlim": xlim, - # TODO: need to make bin width and unit configurable in future "ylabel": variable_inst.get_full_y_title(bin_width=False, unit=False, unit_format=unit_format), "xlabel": variable_inst.get_full_x_title(unit_format=unit_format), "yscale": yscale, "xscale": "log" if variable_inst.log_x else "linear", + "xrotation": variable_inst.x("x_label_rotation", None), }, "rax_cfg": { "ylabel": "Data / MC", "xlabel": variable_inst.get_full_x_title(unit_format=unit_format), + "xrotation": variable_inst.x("x_label_rotation", None), }, "legend_cfg": {}, "annotate_cfg": {"text": cat_label or ""}, @@ -466,9 +482,9 @@ def prepare_style_config( if variable_inst.discrete_x or "int" in axis_type: # remove the "xscale" attribute since it messes up the bin edges style_config["ax_cfg"].pop("xscale") - style_config["ax_cfg"]["minorxticks"] = [] + style_config["ax_cfg"]["xminorticks"] = [] if variable_inst.discrete_y: - style_config["ax_cfg"]["minoryticks"] = [] + style_config["ax_cfg"]["yminorticks"] = [] return style_config @@ -476,7 +492,8 @@ def prepare_style_config( def prepare_stack_plot_config( hists: OrderedDict, shape_norm: bool | None = False, - hide_errors: bool | None = None, + hide_stat_errors: bool | None = None, + shift_insts: Sequence[od.Shift] | None = None, **kwargs, ) -> OrderedDict: """ @@ -486,31 +503,33 @@ def prepare_stack_plot_config( """ # separate histograms into stack, lines and data hists mc_hists, mc_colors, mc_edgecolors, mc_labels = [], [], [], [] - line_hists, line_colors, line_labels, line_hide_errors = [], [], [], [] - data_hists, data_hide_errors = [], [] + mc_syst_hists = [] + line_hists, line_colors, line_labels, line_hide_stat_errors = [], [], [], [] + data_hists, data_hide_stat_errors = [], [] data_label = None + default_shift = shift_insts[0].name if len(shift_insts) == 1 else "nominal" + for process_inst, h in hists.items(): # if given, per-process setting overrides task parameter - proc_hide_errors = hide_errors - if getattr(process_inst, "hide_errors", None) is not None: - proc_hide_errors = process_inst.hide_errors + proc_hide_stat_errors = get_attr_or_aux(process_inst, "hide_stat_errors", hide_stat_errors) if process_inst.is_data: - data_hists.append(h) - data_hide_errors.append(proc_hide_errors) + data_hists.append(remove_residual_axis_single(h, "shift", select_value=default_shift)) + data_hide_stat_errors.append(proc_hide_stat_errors) if data_label is None: data_label = process_inst.label - elif process_inst.is_mc: - if getattr(process_inst, "unstack", False): - line_hists.append(h) - line_colors.append(process_inst.color1) - line_labels.append(process_inst.label) - line_hide_errors.append(proc_hide_errors) - else: - mc_hists.append(h) - mc_colors.append(process_inst.color1) - mc_edgecolors.append(process_inst.color2) - mc_labels.append(process_inst.label) + elif get_attr_or_aux(process_inst, "unstack", False): + line_hists.append(remove_residual_axis_single(h, "shift", select_value=default_shift)) + line_colors.append(process_inst.color1) + line_labels.append(process_inst.label) + line_hide_stat_errors.append(proc_hide_stat_errors) + else: + mc_hists.append(remove_residual_axis_single(h, "shift", select_value=default_shift)) + mc_colors.append(process_inst.color1) + mc_edgecolors.append(process_inst.color2) + mc_labels.append(process_inst.label) + if "shift" in h.axes.name and h.axes["shift"].size > 1: + mc_syst_hists.append(h) h_data, h_mc, h_mc_stack = None, None, None if data_hists: @@ -547,6 +566,7 @@ def prepare_stack_plot_config( "norm": line_norm, "label": line_labels[i], "color": line_colors[i], + "error_type": "variance", }, # "ratio_kwargs": { # "norm": h.values(), @@ -555,21 +575,40 @@ def prepare_stack_plot_config( } # suppress error bars by overriding `yerr` - if line_hide_errors[i]: + if line_hide_stat_errors[i]: for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: plot_cfg[key]["yerr"] = False - # draw stack error - if h_mc_stack is not None and not hide_errors: + # draw statistical error for stack + if h_mc_stack is not None and not hide_stat_errors: mc_norm = sum(h_mc.values()) if shape_norm else 1 - plot_config["mc_uncert"] = { - "method": "draw_error_bands", + plot_config["mc_stat_unc"] = { + "method": "draw_stat_error_bands", "hist": h_mc, "kwargs": {"norm": mc_norm, "label": "MC stat. unc."}, "ratio_kwargs": {"norm": h_mc.values()}, } + # draw systematic error for stack + if h_mc_stack is not None and mc_syst_hists: + mc_norm = sum(h_mc.values()) if shape_norm else 1 + plot_config["mc_syst_unc"] = { + "method": "draw_syst_error_bands", + "hist": h_mc, + "kwargs": { + "syst_hists": mc_syst_hists, + "shift_insts": shift_insts, + "norm": mc_norm, + "label": "MC syst. unc.", + }, + "ratio_kwargs": { + "syst_hists": mc_syst_hists, + "shift_insts": shift_insts, + "norm": h_mc.values(), + }, + } + # draw data if data_hists: data_norm = sum(h_data.values()) if shape_norm else 1 @@ -579,16 +618,18 @@ def prepare_stack_plot_config( "kwargs": { "norm": data_norm, "label": data_label or "Data", + "error_type": "poisson_unweighted", }, } if h_mc is not None: plot_config["data"]["ratio_kwargs"] = { "norm": h_mc.values() * data_norm / mc_norm, + "error_type": "poisson_unweighted", } # suppress error bars by overriding `yerr` - if any(data_hide_errors): + if any(data_hide_stat_errors): for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: plot_cfg[key]["yerr"] = False @@ -596,6 +637,55 @@ def prepare_stack_plot_config( return plot_config +def split_ax_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Split the given dictionary into two dictionaries based on the keys that are valid for matplotlib's ``ax.set()`` + function, and all others, potentially accepted by :py:func:`apply_ax_kwargs`. + """ + set_kwargs, other_kwargs = {}, {} + other_keys = { + "xmajorticks", "xminorticks", "xmajorticklabels", "xminorticklabels", "xloc", "xrotation", + "ymajorticks", "yminorticks", "yloc", "yrotation", + } + for key, value in kwargs.items(): + (other_kwargs if key in other_keys else set_kwargs)[key] = value + return set_kwargs, other_kwargs + + +def apply_ax_kwargs(ax: plt.Axes, kwargs: dict[str, Any]) -> None: + """ + Apply the given keyword arguments to the given axis, splitting them into those that are valid for ``ax.set()`` and + those that are not, and applying them separately. + """ + # split + set_kwargs, other_kwargs = split_ax_kwargs(kwargs) + + # apply standard ones + ax.set(**set_kwargs) + + # apply others + if other_kwargs.get("xmajorticks") is not None: + ax.set_xticks(other_kwargs.get("xmajorticks"), minor=False) + if other_kwargs.get("ymajorticks") is not None: + ax.set_yticks(other_kwargs.get("ymajorticks"), minor=False) + if other_kwargs.get("xminorticks") is not None: + ax.set_xticks(other_kwargs.get("xminorticks"), minor=True) + if other_kwargs.get("yminorticks") is not None: + ax.set_yticks(other_kwargs.get("yminorticks"), minor=True) + if other_kwargs.get("xmajorticklabels") is not None: + ax.set_xticklabels(other_kwargs.get("xmajorticklabels"), minor=False) + if other_kwargs.get("xminorticklabels") is not None: + ax.set_xticklabels(other_kwargs.get("xminorticklabels"), minor=True) + if other_kwargs.get("xloc") is not None: + ax.set_xlabel(ax.get_xlabel(), loc=other_kwargs.get("xloc")) + if other_kwargs.get("yloc") is not None: + ax.set_ylabel(ax.get_ylabel(), loc=other_kwargs.get("yloc")) + if other_kwargs.get("xrotation") is not None: + ax.tick_params(axis="x", labelrotation=other_kwargs.get("xrotation")) + if other_kwargs.get("yrotation") is not None: + ax.tick_params(axis="y", labelrotation=other_kwargs.get("yrotation")) + + def get_position(minimum: float, maximum: float, factor: float = 1.4, logscale: bool = False) -> float: """ get a relative position between a min and max value based on the scale """ if logscale: @@ -791,9 +881,9 @@ def blind_sensitive_bins( check_if_signal = lambda proc: any(signal == proc or signal.has_process(proc) for signal in signal_procs) # separate histograms into signals, backgrounds and data hists and calculate sums - signals = {proc: hist for proc, hist in hists.items() if proc.is_mc and check_if_signal(proc)} - data = {proc: hist.copy() for proc, hist in hists.items() if proc.is_data} - backgrounds = {proc: hist for proc, hist in hists.items() if proc.is_mc and proc not in signals} + signals = {proc: h for proc, h in hists.items() if proc.is_mc and check_if_signal(proc)} + data = {proc: h.copy() for proc, h in hists.items() if proc.is_data} + backgrounds = {proc: h for proc, h in hists.items() if proc.is_mc and proc not in signals} # Return hists unchanged in case any of the three dicts is empty. if not signals or not backgrounds or not data: @@ -803,8 +893,9 @@ def blind_sensitive_bins( ) return hists - signals_sum = sum(signals.values()) - backgrounds_sum = sum(backgrounds.values()) + # get nominal signal and background yield sums per bin + signals_sum = sum(remove_residual_axis(signals, "shift", select_value="nominal").values()) + backgrounds_sum = sum(remove_residual_axis(backgrounds, "shift", select_value="nominal").values()) # calculate sensitivity by S / sqrt(S + B) sensitivity = signals_sum.values() / np.sqrt(signals_sum.values() + backgrounds_sum.values()) @@ -816,11 +907,168 @@ def blind_sensitive_bins( mask[first_ind:last_ind] = True # set data points in masked region to zero - for proc, hist in data.items(): - hist.values()[mask] = 0 - hist.variances()[mask] = 0 + for proc, h in data.items(): + h.values()[..., mask] = 0 + h.variances()[..., mask] = 0 # merge all histograms hists = law.util.merge_dicts(signals, backgrounds, data) return hists + + +def rebin_equal_width( + hists: dict[Hashable, hist.Hist], + axis_name: str, +) -> tuple[dict[Hashable, hist.Hist], np.ndarray]: + """ + In a dictionary, rebins an axis named *axis_name* of all histograms to have the same amount of bins but with equal + width. This is achieved by using integer edge values starting at 0. The original edge values are returned as well. + Bin contents are not changed but copied to the rebinned histograms. + + :param hists: Dictionary of histograms to rebin. + :param axis_name: Name of the axis to rebin. + :return: Tuple of the rebinned histograms and the new bin edges. + """ + # get the variable axis from the first histogram + assert hists + for var_index, var_axis in enumerate(list(hists.values())[0].axes): + if var_axis.name == axis_name: + break + else: + raise ValueError(f"axis '{axis_name}' not found in histograms") + assert isinstance(var_axis, (hist.axis.Variable, hist.axis.Regular)) + orig_edges = var_axis.edges + + # prepare arguments for the axis copy + if isinstance(var_axis, hist.axis.Variable): + axis_kwargs = {"edges": list(range(len(orig_edges)))} + else: # hist.axis.Regular + axis_kwargs = {"start": orig_edges[0], "stop": orig_edges[-1]} + + # rebin all histograms + new_hists = type(hists)() + for key, h in hists.items(): + # create a new histogram + new_axes = h.axes[:var_index] + (copy_axis(var_axis, **axis_kwargs),) + h.axes[var_index + 1:] + new_h = hist.Hist(*new_axes, storage=h.storage_type()) + + # copy contents and save + new_h.view()[...] = h.view() + new_hists[key] = new_h + + return new_hists, orig_edges + + +def apply_label_placeholders( + label: str, + apply: str | Sequence[str] | None = None, + skip: str | Sequence[str] | None = None, + **kwargs: Any, +) -> str: + """ + Interprets placeholders in the format "__NAME__" in a label and returns an updated label. + Currently supported placeholders are: + - SHORT: removes everything (and including) the placeholder + - BREAK: inserts a line break + - SCALE: inserts a scale factor, passed as "scale" in kwargs; when "scale_format" is given + as well, the scale factor is formatted accordingly + *apply* and *skip* can be used to de/select certain placeholders. + """ + # handle apply/skip decisions + if apply: + _apply = set(p.upper() for p in law.util.make_list(apply)) + do_apply = lambda p: p in _apply + elif skip: + _skip = set(p.upper() for p in law.util.make_list(skip)) + do_apply = lambda p: p not in _skip + else: + do_apply = lambda p: True + + # shortening + if do_apply("SHORT"): + label = re.sub(r"__SHORT__.*", "", label) + + # lines breaks + if do_apply("BREAK"): + label = label.replace("__BREAK__", "\n") + + # scale factor + if do_apply("SCALE") and "scale" in kwargs: + scale_str = kwargs.get("scale_format", "$\\times${}").format(kwargs["scale"]) + if "__SCALE__" in label: + label = label.replace("__SCALE__", scale_str) + else: + label += scale_str + + return label + + +def remove_label_placeholders( + label: str, + keep: str | Sequence[str] | None = None, + drop: str | Sequence[str] | None = None, +) -> str: + # when placeholders should be kept, determine all existing ones and identify remaining to drop + if keep: + keep = law.util.make_list(keep) + placeholders = re.findall("__([^_]+)__", label) + drop = list(set(placeholders) - set(keep)) + + # drop specific placeholders or all + if drop: + drop = law.util.make_list(drop) + sel = f"({'|'.join(d.upper() for d in drop)})" + else: + sel = "[A-Z0-9]+" + + return re.sub(f"__{sel}__", "", label) + + +def calculate_stat_error( + hist: hist.Hist, + error_type: str, +) -> dict: + """ + Calculate the error to be plotted for the given histogram *hist*. + Supported error types are: + - 'variance': the plotted error is the square root of the variance for each bin + - 'poisson_unweighted': the plotted error is the poisson error for each bin + - 'poisson_weighted': the plotted error is the poisson error for each bin, weighted by the variance + """ + + # determine the error type + if error_type == "variance": + yerr = hist.view().variance ** 0.5 + elif error_type in {"poisson_unweighted", "poisson_weighted"}: + # compute asymmetric poisson confidence interval + from hist.intervals import poisson_interval + + variances = hist.view().variance if error_type == "poisson_weighted" else None + values = hist.view().value + confidence_interval = poisson_interval(values, variances) + + if error_type == "poisson_weighted": + # might happen if some bins are empty, see https://github.com/scikit-hep/hist/blob/5edbc25503f2cb8193cc5ff1eb71e1d8fa877e3e/src/hist/intervals.py#L74 # noqa: E501 + confidence_interval[np.isnan(confidence_interval)] = 0 + elif np.any(np.isnan(confidence_interval)): + raise ValueError("Unweighted Poisson interval calculation returned NaN values, check Hist package") + + # calculate the error + # yerr_lower is the lower error + yerr_lower = values - confidence_interval[0] + # yerr_upper is the upper error + yerr_upper = confidence_interval[1] - values + # yerr is the size of the errorbars to be plotted + yerr = np.array([yerr_lower, yerr_upper]) + + if np.any(yerr < 0): + logger.warning( + "yerr < 0, setting to 0. " + "This should not happen, please check your histogram.", + ) + yerr[yerr < 0] = 0 + else: + raise ValueError(f"unknown error type '{error_type}'") + + return yerr diff --git a/columnflow/production/__init__.py b/columnflow/production/__init__.py index 00190e05b..529191cf3 100644 --- a/columnflow/production/__init__.py +++ b/columnflow/production/__init__.py @@ -8,10 +8,9 @@ import inspect -from columnflow.types import Callable, Sequence +from columnflow.types import Callable from columnflow.util import DerivableMeta from columnflow.columnar_util import TaskArrayFunction -from columnflow.config_util import expand_shift_sources class Producer(TaskArrayFunction): @@ -28,8 +27,6 @@ def producer( bases: tuple = (), mc_only: bool = False, data_only: bool = False, - nominal_only: bool = False, - shifts_only: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -41,11 +38,6 @@ def producer( :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. - When *nominal_only* is *True* or *shifts_only* is set, the producer is skipped and not - considered by other calibrators, selectors and producers in case they are evaluated on a - :py:class:`order.Shift` (using the :py:attr:`global_shift_inst` attribute) whose name does - not match. - All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Producer` class. @@ -54,10 +46,6 @@ def producer( Monte Carlo simulation and skipped for real data. :param data_only: Boolean flag indicating that this :py:class:`Producer` should only run on real data and skipped for Monte Carlo simulation. - :param nominal_only: Boolean flag indicating that this :py:class:`Producer` should only run - on the nominal shift and skipped on any other shifts. - :param shifts_only: Shift names that this :py:class:`Producer` should only run on, skipping - all other shifts. :return: New :py:class:`Producer` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -67,8 +55,6 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, - "nominal_only": nominal_only, - "shifts_only": shifts_only, } # get the module name @@ -82,45 +68,23 @@ def decorator(func: Callable) -> DerivableMeta: def update_cls_dict(cls_name, cls_dict, get_attr): mc_only = get_attr("mc_only") data_only = get_attr("data_only") - nominal_only = get_attr("nominal_only") - shifts_only = get_attr("shifts_only") - - # prepare shifts_only - if shifts_only: - shifts_only_expanded = set(expand_shift_sources(shifts_only)) - if shifts_only_expanded != shifts_only: - shifts_only = shifts_only_expanded - cls_dict["shifts_only"] = shifts_only # optionally add skip function if mc_only and data_only: raise Exception(f"producer {cls_name} received both mc_only and data_only") - if nominal_only and shifts_only: + if (mc_only or data_only) and cls_dict.get("skip_func"): raise Exception( - f"producer {cls_name} received both nominal_only and shifts_only", + f"producer {cls_name} received custom skip_func, but either mc_only or " + "data_only are set", ) - if mc_only or data_only or nominal_only or shifts_only: - if cls_dict.get("skip_func"): - raise Exception( - f"producer {cls_name} received custom skip_func, but either mc_only, " - "data_only, nominal_only or shifts_only are set", - ) if "skip_func" not in cls_dict: - def skip_func(self): + def skip_func(self, **kwargs) -> bool: # check mc_only and data_only - if getattr(self, "dataset_inst", None): - if mc_only and not self.dataset_inst.is_mc: - return True - if data_only and not self.dataset_inst.is_data: - return True - - # check nominal_only and shifts_only - if getattr(self, "global_shift_inst", None): - if nominal_only and not self.global_shift_inst.is_nominal: - return True - if shifts_only and self.global_shift_inst.name not in shifts_only: - return True + if mc_only and not self.dataset_inst.is_mc: + return True + if data_only and not self.dataset_inst.is_data: + return True # in all other cases, do not skip return False diff --git a/columnflow/production/categories.py b/columnflow/production/categories.py index 2d4da610b..862b33273 100644 --- a/columnflow/production/categories.py +++ b/columnflow/production/categories.py @@ -6,8 +6,6 @@ from __future__ import annotations -from collections import defaultdict - import law from columnflow.categorization import Categorizer @@ -25,7 +23,7 @@ @producer( produces={"category_ids"}, # custom function to skip categorizers - skip_category=(lambda self, task, category_inst: False), + skip_category=(lambda self, category_inst: False), ) def category_ids( self: Producer, @@ -63,17 +61,14 @@ def category_ids( @category_ids.init -def category_ids_init(self: Producer) -> None: - if not self.inst_dict.get("task"): - return - +def category_ids_init(self: Producer, **kwargs) -> None: # store a mapping from leaf category to categorizer classes for faster lookup - self.categorizer_map = defaultdict(list) + self.categorizer_map = {} # add all categorizers obtained from leaf category selection expressions to the used columns for cat_inst in self.config_inst.get_leaf_categories(): # check if skipped - if self.skip_category(self.inst_dict["task"], cat_inst): + if self.skip_category(cat_inst): continue # treat all selections as lists of categorizers @@ -99,7 +94,4 @@ def category_ids_init(self: Producer) -> None: self.uses.add(categorizer) self.produces.add(categorizer) - self.categorizer_map[cat_inst].append(categorizer) - - # cast to normal dict to prevent silent failures on KeyError - self.categorizer_map = dict(self.categorizer_map) + self.categorizer_map.setdefault(cat_inst, []).append(categorizer) diff --git a/columnflow/production/cms/btag.py b/columnflow/production/cms/btag.py index 3538483b5..65fd9c290 100644 --- a/columnflow/production/cms/btag.py +++ b/columnflow/production/cms/btag.py @@ -11,8 +11,8 @@ import law from columnflow.production import Producer, producer -from columnflow.util import maybe_import, InsertableDict -from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array +from columnflow.util import maybe_import, load_correction_set +from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array, DotDict from columnflow.types import Any np = maybe_import("numpy") @@ -67,7 +67,7 @@ def new( @producer( - uses={"Jet.{pt,eta,hadronFlavour}"}, + uses={"Jet.{pt,eta,phi,mass,hadronFlavour}"}, # only run on mc mc_only=True, # configurable weight name @@ -80,6 +80,7 @@ def new( def btag_weights( self: Producer, events: ak.Array, + task: law.Task, jet_mask: ak.Array | type(Ellipsis) = Ellipsis, negative_b_score_action: str = "ignore", negative_b_score_log_mode: str = "warning", @@ -241,7 +242,7 @@ def add_weight(syst_name, syst_direction, column_name): # when the requested uncertainty is a known jec shift, obtain the propagated effect and # do not produce additional systematics - shift_inst = self.global_shift_inst + shift_inst = task.global_shift_inst if shift_inst.is_nominal: # nominal weight and those of all method intrinsic uncertainties events = add_weight("central", None, self.weight_name) @@ -275,20 +276,25 @@ def add_weight(syst_name, syst_direction, column_name): return events -@btag_weights.init -def btag_weights_init(self: Producer) -> None: +@btag_weights.post_init +def btag_weights_post_init(self: Producer, task: law.Task, **kwargs) -> None: # depending on the requested shift_inst, there are three cases to handle: # 1. when a JEC uncertainty is requested whose propagation to btag weights is known, the # producer should only produce that specific weight column # 2. when the nominal shift is requested, the central weight and all variations related to the # method-intrinsic shifts are produced # 3. when any other shift is requested, only create the central weight column - self.btag_config: BTagSFConfig = self.get_btag_config() - self.uses.add(f"Jet.{self.btag_config.discriminator}") - shift_inst = getattr(self, "global_shift_inst", None) - if not shift_inst: - return + # NOTE: we currently setup the produced columns only during the post_init. This means + # that the `produces` of this Producer will be empty during task initialization, meaning + # that this Producer would be skipped if one would directly request it on command line + + # gather info + self.btag_config = self.get_btag_config() + shift_inst = task.global_shift_inst + + # use the btag discriminator + self.uses.add(f"Jet.{self.btag_config.discriminator}") # to handle this efficiently in one spot, store jec information self.jec_source = shift_inst.x.jec_source if shift_inst.has_tag("jec") else None @@ -327,27 +333,28 @@ def btag_weights_init(self: Producer) -> None: @btag_weights.requires -def btag_weights_requires(self: Producer, reqs: dict) -> None: +def btag_weights_requires( + self: Producer, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @btag_weights.setup def btag_weights_setup( self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: - bundle = reqs["external_files"] - - # create the btag sf corrector - import correctionlib - correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate - correction_set = correctionlib.CorrectionSet.from_string( - self.get_btag_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) - self.btag_sf_corrector = correction_set[self.btag_config.correction_set] + # load the btag sf corrector + btag_file = self.get_btag_file(reqs["external_files"].files) + self.btag_sf_corrector = load_correction_set(btag_file)[self.btag_config.correction_set] diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py new file mode 100644 index 000000000..52718f801 --- /dev/null +++ b/columnflow/production/cms/dy.py @@ -0,0 +1,446 @@ +# coding: utf-8 + +""" +Column production methods related to Drell-Yan reweighting. +""" + +from __future__ import annotations + +import law + +from dataclasses import dataclass + +from columnflow.production import Producer, producer +from columnflow.util import maybe_import, load_correction_set +from columnflow.columnar_util import set_ak_column + +np = maybe_import("numpy") +ak = maybe_import("awkward") +vector = maybe_import("vector") + +logger = law.logger.get_logger(__name__) + + +@dataclass +class DrellYanConfig: + era: str + order: str + correction: str + unc_correction: str + + def __post_init__(self) -> None: + if ( + not self.era or + not self.order or + not self.correction or + not self.unc_correction + ): + raise ValueError("incomplete dy_weight_config: missing era, order, correction or unc_correction") + + +@producer( + uses={"GenPart.*"}, + produces={"gen_dilepton_{pdgid,pt}", "gen_dilepton_{vis,all}.{pt,eta,phi,mass}"}, +) +def gen_dilepton(self, events: ak.Array, **kwargs) -> ak.Array: + """ + Reconstruct the di-lepton pair from generator level info. This considers only visible final-state particles. + In addition it provides the four-momenta of all leptons (including neutrinos) from the hard process. + """ + # get the absolute pdg id (to account for anti-particles) and status of the particles + pdg_id = abs(events.GenPart.pdgId) + status = events.GenPart.status + + # lepton masks for DY ptll reweighting corrections + # -> https://indico.cern.ch/event/1495537/contributions/6359516/attachments/3014424/5315938/HLepRare_25.02.14.pdf + ele_mu_mask = ( + ((pdg_id == 11) | (pdg_id == 13)) & + (status == 1) & + events.GenPart.hasFlags("fromHardProcess") + ) + # taus need to have status == 2, + tau_mask = ( + (pdg_id == 15) & (status == 2) & events.GenPart.hasFlags("fromHardProcess") + ) + + # lepton masks for recoil corrections + # -> https://indico.cern.ch/event/1495537/contributions/6359516/attachments/3014424/5315938/HLepRare_25.02.14.pdf + lepton_all_mask = ( + # e, mu, taus, neutrinos + ( + (pdg_id >= 11) & + (pdg_id <= 16) & + (status == 1) & + events.GenPart.hasFlags("fromHardProcess") + ) | + # tau decay products + events.GenPart.hasFlags("isDirectHardProcessTauDecayProduct") + ) + lepton_vis_mask = lepton_all_mask & ( + # no e neutrinos, mu neutrinos, or taus neutrinos + (pdg_id != 12) & (pdg_id != 14) | (pdg_id != 16) + ) + + # combine the masks + lepton_mask = ele_mu_mask | tau_mask + lepton_pairs = events.GenPart[lepton_mask] + lepton_pairs_vis = events.GenPart[lepton_vis_mask] + lepton_pairs_all = events.GenPart[lepton_all_mask] + + # some up the four momenta of the leptons + lepton_pair_momenta = lepton_pairs.sum(axis=-1) + lepton_pair_momenta_vis = lepton_pairs_vis.sum(axis=-1) + lepton_pair_momenta_all = lepton_pairs_all.sum(axis=-1) + + # absolute pdg id of one parrticle of the lepton pair + lepton_pair_pdgid = ak.without_parameters(abs(events.GenPart[lepton_mask].pdgId[:, 0])) + + # finally, save generator-level lepton pair variables + events = set_ak_column(events, "gen_dilepton_pdgid", lepton_pair_pdgid) + events = set_ak_column(events, "gen_dilepton_pt", lepton_pair_momenta.pt) + events = set_ak_column(events, "gen_dilepton_vis.pt", lepton_pair_momenta_vis.pt) + events = set_ak_column(events, "gen_dilepton_vis.eta", lepton_pair_momenta_vis.eta) + events = set_ak_column(events, "gen_dilepton_vis.phi", lepton_pair_momenta_vis.phi) + events = set_ak_column(events, "gen_dilepton_vis.mass", lepton_pair_momenta_vis.mass) + events = set_ak_column(events, "gen_dilepton_all.pt", lepton_pair_momenta_all.pt) + events = set_ak_column(events, "gen_dilepton_all.eta", lepton_pair_momenta_all.eta) + events = set_ak_column(events, "gen_dilepton_all.phi", lepton_pair_momenta_all.phi) + events = set_ak_column(events, "gen_dilepton_all.mass", lepton_pair_momenta_all.mass) + + return events + + +@producer( + uses={"gen_dilepton_pt"}, + # weight variations are defined in init + produces={"dy_weight"}, + # only run on mc + mc_only=True, + # function to determine the correction file + get_dy_weight_file=(lambda self, external_files: external_files.dy_weight_sf), + # function to load the config + get_dy_weight_config=(lambda self: self.config_inst.x.dy_weight_config), +) +def dy_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: + """ + Creates Drell-Yan weights using the correctionlib. + https://cms-higgs-leprare.docs.cern.ch/htt-common/DY_reweight/#correctionlib-file + + Requires an external file in the config under ``dy_weight_sf``: + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "dy_weight_sf": "/afs/cern.ch/work/m/mrieger/public/mirrors/external_files/DY_pTll_weights_v2.json.gz", # noqa + }) + + *get_dy_weight_file* can be adapted in a subclass in case it is stored differently in the external files. + + The campaign era and name of the correction set (see link above) should be given as an auxiliary entry in the config: + + .. code-block:: python + + cfg.x.dy_weight_config = DrellYanConfig( + era="2022preEE", + order="NLO", + correction="DY_pTll_reweighting", + unc_correction="DY_pTll_reweighting_N_uncertainty", + ) + + *get_dy_weight_config* can be adapted in a subclass in case it is stored differently in the config. + """ + + # map the input variable names from the corrector to our columns + variable_map = { + "era": self.dy_config.era, + "order": self.dy_config.order, + "ptll": events.gen_dilepton_pt, + } + + # initializing the list of weight variations + weights_list = [("dy_weight", "nom")] + + # appending the respective number of uncertainties to the weight list + for i in range(self.n_unc): + for shift in ("up", "down"): + tmp_tuple = (f"dy_weight{i + 1}_{shift}", f"{shift}{i + 1}") + weights_list.append(tmp_tuple) + + # preparing the input variables for the corrector + for column_name, syst in weights_list: + variable_map_syst = {**variable_map, "syst": syst} + + # evaluating dy weights given a certain era, ptll array and sytematic shift + inputs = [variable_map_syst[inp.name] for inp in self.dy_corrector.inputs] + dy_weight = self.dy_corrector.evaluate(*inputs) + + # save the weights in a new column + events = set_ak_column(events, column_name, dy_weight, value_type=np.float32) + + return events + + +@dy_weights.init +def dy_weights_init(self: Producer) -> None: + # the number of weights in partial run 3 is always 10 + if self.config_inst.campaign.x.year not in {2022, 2023}: + raise NotImplementedError( + f"campaign year {self.config_inst.campaign.x.year} is not yet supported by {self.cls_name}", + ) + self.n_unc = 10 + + # register dynamically produced weight columns + for i in range(self.n_unc): + self.produces.add(f"dy_weight{i + 1}_{{up,down}}") + + +@dy_weights.requires +def dy_weights_requires(self: Producer, task: law.Task, reqs: dict) -> None: + """ + Adds the requirements needed the underlying task to derive the Drell-Yan weights into *reqs*. + """ + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +@dy_weights.setup +def dy_weights_setup( + self: Producer, + task: law.Task, + reqs: dict, + inputs: dict, + reader_targets: law.util.InsertableDict, +) -> None: + """ + Loads the Drell-Yan weight calculator from the external files bundle and saves them in the + py:attr:`dy_corrector` attribute for simpler access in the actual callable. The number of uncertainties + is calculated, per era, by another correcter in the external file and is saved in the + py:attr:`dy_unc_corrector` attribute. + """ + bundle = reqs["external_files"] + + # import all correctors from the external file + correction_set = load_correction_set(self.get_dy_weight_file(bundle.files)) + + # check number of fetched correctors + if len(correction_set.keys()) != 2: + raise Exception("Expected exactly two types of Drell-Yan correction") + + # create the weight and uncertainty correctors + self.dy_config: DrellYanConfig = self.get_dy_weight_config() + self.dy_corrector = correction_set[self.dy_config.correction] + self.dy_unc_corrector = correction_set[self.dy_config.unc_correction] + + dy_n_unc = int(self.dy_unc_corrector.evaluate(self.dy_config.order)) + + if dy_n_unc != self.n_unc: + raise ValueError( + f"Expected {self.n_unc} uncertainties, got {dy_n_unc}", + ) + + +@producer( + uses={ + # MET information + # -> only Run 3 (PuppiMET) is supported + "PuppiMET.{pt,phi}", + # Number of jets (as a per-event scalar) + "Jet.{pt,phi,eta,mass}", + # Gen-level boson information (full boson momentum) + # -> gen_dilepton_vis.pt, gen_dilepton_vis.phi, gen_dilepton_all.pt, gen_dilepton_all.phi + gen_dilepton.PRODUCES, + }, + produces={ + "RecoilCorrMET.{pt,phi}", + "RecoilCorrMET.{pt,phi}_{recoilresp,recoilres}_{up,down}", + }, + mc_only=True, + # function to determine the recoil correction file from external files + get_dy_recoil_file=(lambda self, external_files: external_files.dy_recoil_sf), + # function to load the config + get_dy_recoil_config=(lambda self: self.config_inst.x.dy_recoil_config), +) +def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array: + """ + Producer for bosonic recoil corrections which are applied to PuppiMET to create a new ``RecoilCorrMET`` collection. + See https://cms-higgs-leprare.docs.cern.ch/htt-common/V_recoil for more info. + + Requires an external file in the config under ``dy_recoil_sf``: + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "dy_recoil_sf": "/afs/cern.ch/work/m/mrieger/public/mirrors/external_files/Recoil_corrections_v2.json.gz", + }) + + *get_dy_recoil_file* can be adapted in a subclass in case it is stored differently in the external files. + + The campaign era and name of the correction set (see link above) should be given as an auxiliary entry in the config: + + .. code-block:: python + + cfg.x.dy_recoil_config = DrellYanConfig( + era="2022preEE", + order="NLO", + correction="Recoil_correction_Rescaling", + unc_correction="Recoil_correction_Uncertainty", + ) + + *get_dy_recoil_config* can be adapted in a subclass in case it is stored differently in the config. + """ + # steps: + # 1) Build transverse vectors for MET and the generator-level boson (full and visible). + # 2) Compute the recoil vector U = MET + vis - full in the transverse plane. + # 3) Project U along and perpendicular to the full boson direction. + # 4) Apply the nominal recoil correction and reassemble the corrected MET. + # 5) For each systematic variation, apply the uncertainty correction on H components and reconstruct MET. + # Build MET vector (using dummy eta and mass, since only x and y matter) + met = vector.array({ + "pt": events.PuppiMET.pt, + "phi": events.PuppiMET.phi, + "eta": np.zeros_like(events.PuppiMET.pt), + "mass": np.zeros_like(events.PuppiMET.pt), + }) + + # Build full and visible boson vectors from generator-level information + full = vector.array({ + "pt": events.gen_dilepton_all.pt, + "phi": events.gen_dilepton_all.phi, + "eta": np.zeros_like(events.gen_dilepton_all.pt), + "mass": np.zeros_like(events.gen_dilepton_all.pt), + }) + vis = vector.array({ + "pt": events.gen_dilepton_vis.pt, + "phi": events.gen_dilepton_vis.phi, + "eta": np.zeros_like(events.gen_dilepton_vis.pt), + "mass": np.zeros_like(events.gen_dilepton_vis.pt), + }) + + # Compute the recoil vector U = MET + vis - full + u_x = met.x + vis.x - full.x + u_y = met.y + vis.y - full.y + + # Project U onto the full boson direction + full_pt = full.pt + full_unit_x = full.x / full_pt + full_unit_y = full.y / full_pt + upara = u_x * full_unit_x + u_y * full_unit_y + uperp = -u_x * full_unit_y + u_y * full_unit_x + + # Determine jet multiplicity for the event (jet selection as in original) + jet_selection = ( + ((events.Jet.pt > 30) & (np.abs(events.Jet.eta) < 2.5)) | + ((events.Jet.pt > 50) & (np.abs(events.Jet.eta) >= 2.5)) + ) + selected_jets = events.Jet[jet_selection] + njet = np.asarray(ak.num(selected_jets, axis=1), dtype=np.float32) + + # Apply nominal recoil correction on U components + # (see here: https://cms-higgs-leprare.docs.cern.ch/htt-common/V_recoil/#example-snippet) + upara_corr = self.recoil_corrector.evaluate( + self.dy_recoil_config.era, + self.dy_recoil_config.order, + njet, + events.gen_dilepton_all.pt, + "Upara", + upara, + ) + uperp_corr = self.recoil_corrector.evaluate( + self.dy_recoil_config.era, + self.dy_recoil_config.order, + njet, + events.gen_dilepton_all.pt, + "Uperp", + uperp, + ) + + # Reassemble the corrected U vector + ucorr_x = upara_corr * full_unit_x - uperp_corr * full_unit_y + ucorr_y = upara_corr * full_unit_y + uperp_corr * full_unit_x + + # Recompute corrected MET: MET_corr = U_corr - vis + full + met_corr_x = ucorr_x - vis.x + full.x + met_corr_y = ucorr_y - vis.y + full.y + met_corr_pt = np.sqrt(met_corr_x**2 + met_corr_y**2) + met_corr_phi = np.arctan2(met_corr_y, met_corr_x) + + events = set_ak_column(events, "RecoilCorrMET.pt", met_corr_pt, value_type=np.float32) + events = set_ak_column(events, "RecoilCorrMET.phi", met_corr_phi, value_type=np.float32) + + # --- Systematic variations --- + # Derive H from the nominal corrected MET: H = - (MET_corr + vis) + h_x = -met_corr_x - vis.x + h_y = -met_corr_y - vis.y + h_pt = np.sqrt(h_x**2 + h_y**2) + h_phi = np.arctan2(h_y, h_x) + # Project H into the full boson coordinate system + hpara = h_pt * np.cos(h_phi - full.phi) + hperp = h_pt * np.sin(h_phi - full.phi) + + for syst, postfix in [ + ("RespUp", "recoilresp_up"), + ("RespDown", "recoilresp_down"), + ("ResolUp", "recoilres_up"), + ("ResolDown", "recoilres_down"), + ]: + hpara_var = self.recoil_unc_corrector.evaluate( + self.dy_recoil_config.era, + self.dy_recoil_config.order, + njet, + events.gen_dilepton_all.pt, + "Hpara", + hpara, + syst, + ) + hperp_var = self.recoil_unc_corrector.evaluate( + self.dy_recoil_config.era, + self.dy_recoil_config.order, + njet, + events.gen_dilepton_all.pt, + "Hperp", + hperp, + syst, + ) + # Reconstruct the corrected H vector in the full boson frame + hcorr_x = hpara_var * np.cos(full.phi) - hperp_var * np.sin(full.phi) + hcorr_y = hpara_var * np.sin(full.phi) + hperp_var * np.cos(full.phi) + # Reconstruct the MET variation: MET_var = -H_corr - vis + met_var_x = -hcorr_x - vis.x + met_var_y = -hcorr_y - vis.y + met_var_pt = np.sqrt(met_var_x**2 + met_var_y**2) + met_var_phi = np.arctan2(met_var_y, met_var_x) + events = set_ak_column(events, f"RecoilCorrMET.pt_{postfix}", met_var_pt, value_type=np.float32) + events = set_ak_column(events, f"RecoilCorrMET.phi_{postfix}", met_var_phi, value_type=np.float32) + + return events + + +@recoil_corrected_met.requires +def recoil_corrected_met_requires(self: Producer, task: law.Task, reqs: dict) -> None: + # Ensure that external files are bundled. + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +@recoil_corrected_met.setup +def recoil_corrected_met_setup( + self: Producer, + task: law.Task, + reqs: dict, + inputs: dict, + reader_targets: law.util.InsertableDict, +) -> None: + # load the correction set + bundle = reqs["external_files"] + correction_set = load_correction_set(self.get_dy_recoil_file(bundle.files)) + + # Retrieve the corrections used for the nominal correction and for uncertainties. + self.dy_recoil_config: DrellYanConfig = self.get_dy_recoil_config() + self.recoil_corrector = correction_set[self.dy_recoil_config.correction] + self.recoil_unc_corrector = correction_set[self.dy_recoil_config.unc_correction] diff --git a/columnflow/production/cms/electron.py b/columnflow/production/cms/electron.py index d8d611b39..e88d115e8 100644 --- a/columnflow/production/cms/electron.py +++ b/columnflow/production/cms/electron.py @@ -8,9 +8,12 @@ from dataclasses import dataclass +import law + from columnflow.production import Producer, producer -from columnflow.util import maybe_import, InsertableDict, load_correction_set +from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array +from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -30,10 +33,7 @@ def __post_init__(self) -> None: raise ValueError("only one of working_point or hlt_path must be set") @classmethod - def new( - cls, - obj: ElectronSFConfig | tuple[str, str, str], - ) -> ElectronSFConfig: + def new(cls, obj: ElectronSFConfig | tuple[str, str, str]) -> ElectronSFConfig: # purely for backwards compatibility with the old tuple format if isinstance(obj, cls): return obj @@ -143,28 +143,33 @@ def electron_weights_init(self: Producer, **kwargs) -> None: @electron_weights.requires -def electron_weights_requires(self: Producer, reqs: dict) -> None: +def electron_weights_requires( + self: Producer, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @electron_weights.setup def electron_weights_setup( self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: - bundle = reqs["external_files"] + self.electron_config = self.get_electron_config() # load the corrector - correction_set = load_correction_set(self.get_electron_file(bundle.files)) - - self.electron_config: ElectronSFConfig = self.get_electron_config() - self.electron_sf_corrector = correction_set[self.electron_config.correction] + e_file = self.get_electron_file(reqs["external_files"].files) + self.electron_sf_corrector = load_correction_set(e_file)[self.electron_config.correction] # the ValType key accepts different arguments for efficiencies and scale factors if self.electron_config.correction.endswith("Eff"): @@ -187,3 +192,16 @@ def electron_weights_setup( "weight_name": "electron_trigger_weight", }, ) + + +@producer( + uses={"Electron.{pt,phi,eta,deltaEtaSC}"}, + produces={"Electron.superclusterEta"}, +) +def electron_sceta(self, events: ak.Array, **kwargs) -> ak.Array: + """ + Returns the electron super cluster eta. + """ + sc_eta = events.Electron.eta + events.Electron.deltaEtaSC + events = set_ak_column(events, "Electron.superclusterEta", sc_eta, value_type=np.float32) + return events diff --git a/columnflow/production/cms/gen_top_decay.py b/columnflow/production/cms/gen_top_decay.py index 62f9e03c0..8e925aaa0 100644 --- a/columnflow/production/cms/gen_top_decay.py +++ b/columnflow/production/cms/gen_top_decay.py @@ -4,6 +4,8 @@ Producers that determine the generator-level particles related to a top quark decay. """ +from __future__ import annotations + from columnflow.production import Producer, producer from columnflow.util import maybe_import from columnflow.columnar_util import set_ak_column @@ -80,13 +82,9 @@ def gen_top_decay_products(self: Producer, events: ak.Array, **kwargs) -> ak.Arr @gen_top_decay_products.skip -def gen_top_decay_products_skip(self: Producer) -> bool: +def gen_top_decay_products_skip(self: Producer, **kwargs) -> bool: """ Custom skip function that checks whether the dataset is a MC simulation containing top quarks in the first place. """ - # never skip when there is not dataset - if not getattr(self, "dataset_inst", None): - return False - return self.dataset_inst.is_data or not self.dataset_inst.has_tag("has_top") diff --git a/columnflow/production/cms/jet.py b/columnflow/production/cms/jet.py index a99f7709b..b37736df8 100644 --- a/columnflow/production/cms/jet.py +++ b/columnflow/production/cms/jet.py @@ -1,18 +1,174 @@ # coding: utf-8 """ -Jet-related quantities +Jet-related producers. """ from __future__ import annotations +from dataclasses import dataclass + +import law + from columnflow.production import Producer, producer -from columnflow.util import maybe_import -from columnflow.columnar_util import set_ak_column +from columnflow.util import maybe_import, load_correction_set +from columnflow.columnar_util import set_ak_column, layout_ak_array, flat_np_view np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") + + +@dataclass +class JetIdConfig: + """ + Container object to describe a CMS jet id configuration, consisting of names of correction sets mapped to bit + positions, similar to how the ``jetId`` column is defined in nanoAOD. Example: + + .. code-block:: python + + # configurtion for AK4 puppi jets + # second bit for "tight" id, third bit for "tight + lepton veto" id + JetIdConfig(corrections={ + "AK4PUPPI_Tight": 2, + "AK4PUPPI_TightLeptonVeto": 3, + }) + """ + + corrections: dict[str, int] + + def __post_init__(self) -> None: + # for each correction, check if the bit is set and fits into a uint8 + for cor_name, bit in self.corrections.items(): + if not (1 <= bit <= 8): + raise ValueError(f"jet id bit must be between 1 and 8, got {bit} for {cor_name}") + + +@producer( + # names of used and produced columns are added dynamically in init depending on jet_name + # name of the jet collection + jet_name="Jet", + # function to determine the correction file + get_jet_id_file=(lambda self, external_files: external_files.jet_id), + # function to determine the jet id config + get_jet_id_config=(lambda self: self.config_inst.x.jet_id), +) +def jet_id(self: Producer, events: ak.Array, **kwargs) -> ak.Array: + """ + Recomputes the jet id flag. Requires an external file in the config under ``jet_id``. Example: + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "jet_id": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-120c4271/POG/JME/2022_Summer22/jetid.json.gz", # noqa + }) + + *get_jet_id_file* can be adapted in a subclass in case it is stored differently in the external files. + + The pairs of correction set names and jet id bits (!) should be configured using the :py:class:`JetIdConfig` as an + auxiliary entry in the config: + + .. code-block:: python + + from columnflow.production.cms.jet import JetIdConfig + cfg.x.jet_id = JetIdConfig( + corrections={ + "AK4PUPPI_Tight": 2, + "AK4PUPPI_TightLeptonVeto": 3, + }, + ) + + *get_jet_id_config* can be adapted in a subclass in case it is stored differently in the config. + + Resources: + + - https://twiki.cern.ch/twiki/bin/view/CMS/JetID13p6TeV?rev=22#nanoAOD_Flags + - https://cms-talk.web.cern.ch/t/bug-in-the-jetid-flag-in-nanov12-and-nanov13/108135 + - https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/120c4271917f30d67fb64c789eb91f7b52be4845/examples/jetidExample.py + """ # noqa + # compute flat inputs + variable_map = { + col: flat_np_view(events[self.jet_name][col], axis=1) + for col in self.jet_columns + } + # sum of multiplicies might be required + if "multiplicity" not in variable_map and "chMultiplicity" in variable_map and "neMultiplicity" in variable_map: + variable_map["multiplicity"] = variable_map["chMultiplicity"] + variable_map["neMultiplicity"] + + # identify jets for which the evaluation can succeed + valid_mask = self.get_valid_mask(variable_map) + variable_map = {col: value[valid_mask] for col, value in variable_map.items()} + + # prepare the flat jet id array into which evaluated values will be inserted + jet_id_flat = np.zeros(len(valid_mask), dtype=np.uint8) + + # iterate over all correctors + for cor_name, pass_bit in self.cfg.corrections.items(): + inputs = [variable_map[inp.name] for inp in self.jet_id_correctors[cor_name].inputs] + id_flag = self.jet_id_correctors[cor_name].evaluate(*inputs).astype(np.uint8) + # the flag is either 0 or 1, so shift the bit to the correct position + jet_id_flat[valid_mask] |= id_flag << (pass_bit - 1) + + # apply correct layout + jet_id = layout_ak_array(jet_id_flat, events[self.jet_name].eta) + + # store them + events = set_ak_column(events, f"{self.jet_name}.jetId", jet_id, value_type=np.uint8) + + return events + + +@jet_id.init +def jet_id_init(self: Producer, **kwargs) -> None: + """ + Dynamically add the names of the used and produced columns depending on the jet name. + """ + self.jet_columns = ["eta", "chHEF", "neHEF", "chEmEF", "neEmEF", "muEF", "chMultiplicity", "neMultiplicity"] + self.uses.update(f"{self.jet_name}.{col}" for col in self.jet_columns) + self.produces.add(f"{self.jet_name}.jetId") + + +@jet_id.requires +def jet_id_requires(self: Producer, task: law.Task, reqs: dict, **kwargs) -> None: + """ + Adds the requirements needed the underlying task to recompute the jet id into *reqs*. + """ + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +@jet_id.setup +def jet_id_setup( + self: Producer, + task: law.Task, + reqs: dict, + inputs: dict, + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: + """ + Sets up the correction sets needed for the jet id using the external files. + """ + bundle = reqs["external_files"] + + # get the jet id config + self.cfg: JetIdConfig = self.get_jet_id_config() + + # create the correctors + correction_set = load_correction_set(self.get_jet_id_file(bundle.files)) + self.jet_id_correctors = {cor_name: correction_set[cor_name] for cor_name in self.cfg.corrections} + + # store a lambda to identify good jets (a value of zero will be stored for others) + self.get_valid_mask = lambda variable_map: variable_map["chMultiplicity"] >= 0 + + +# derive with defaults for fatjets +fatjet_id = jet_id.derive("fatjet_id", cls_dict={ + "jet_name": "FatJet", + "get_jet_id_config": (lambda self: self.config_inst.x.fatjet_id), +}) @producer( @@ -21,11 +177,7 @@ subjet_name="SubJet", output_column="msoftdrop", ) -def msoftdrop( - self: Producer, - events: ak.Array, - **kwargs, -) -> ak.Array: +def msoftdrop(self: Producer, events: ak.Array, **kwargs) -> ak.Array: """ Recalculates the softdrop mass for a given jet collection by computing the four-vector sum of the corresponding subjets. @@ -46,7 +198,11 @@ def msoftdrop( valid_subjet_idxs = ak.mask(subjet_idx, subjet_idx >= 0) # pad list of subjets to prevent index error on lookup - padded_subjet = ak.pad_none(subjet, ak.max(valid_subjet_idxs) + 1) + max_valid_subjets = ak.max(valid_subjet_idxs) + padded_subjet = ak.pad_none( + subjet, + 0 if max_valid_subjets is None else (max_valid_subjets + 1), + ) # retrieve subjets for each jet valid_subjet = padded_subjet[valid_subjet_idxs] @@ -59,7 +215,7 @@ def msoftdrop( valid_subjets = ak.with_name( valid_subjets, "PtEtaPhiMLorentzVector", - behavior=coffea.nanoevents.NanoAODSchema.behavior(), + behavior=self.nano_behavior, ) # recompute softdrop mass from LV sum @@ -85,7 +241,7 @@ def msoftdrop( @msoftdrop.init def msoftdrop_init(self: Producer, **kwargs) -> None: """ - Dynamically add `uses` and `produces` + Dynamically add `uses` and `produces`. """ # input columns self.uses |= { @@ -102,3 +258,10 @@ def msoftdrop_init(self: Producer, **kwargs) -> None: # outputs self.produces = {f"{self.jet_name}.{self.output_column}"} + + +@msoftdrop.setup +def msoftdrop_setup(self: Producer, task: law.Task, reqs: dict, **kwargs) -> None: + import coffea + + self.nano_behavior = coffea.nanoevents.NanoAODSchema.behavior() diff --git a/columnflow/production/cms/muon.py b/columnflow/production/cms/muon.py index 612e61df6..071b3122f 100644 --- a/columnflow/production/cms/muon.py +++ b/columnflow/production/cms/muon.py @@ -6,11 +6,14 @@ from __future__ import annotations +import law + from dataclasses import dataclass from columnflow.production import Producer, producer -from columnflow.util import maybe_import, InsertableDict, load_correction_set +from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array +from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -22,10 +25,7 @@ class MuonSFConfig: campaign: str = "" @classmethod - def new( - cls, - obj: MuonSFConfig | tuple[str, str], - ) -> MuonSFConfig: + def new(cls, obj: MuonSFConfig | tuple[str, str]) -> MuonSFConfig: # purely for backwards compatibility with the old tuple format if isinstance(obj, cls): return obj @@ -129,27 +129,34 @@ def muon_weights_init(self: Producer, **kwargs) -> None: @muon_weights.requires -def muon_weights_requires(self: Producer, reqs: dict) -> None: +def muon_weights_requires( + self: Producer, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @muon_weights.setup def muon_weights_setup( self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: bundle = reqs["external_files"] # load the corrector correction_set = load_correction_set(self.get_muon_file(bundle.files)) - self.muon_config: MuonSFConfig = self.get_muon_config() + self.muon_config = self.get_muon_config() self.muon_sf_corrector = correction_set[self.muon_config.correction] # check versions diff --git a/columnflow/production/cms/parton_shower.py b/columnflow/production/cms/parton_shower.py new file mode 100644 index 000000000..6ad5e6699 --- /dev/null +++ b/columnflow/production/cms/parton_shower.py @@ -0,0 +1,91 @@ +# coding: utf-8 + +""" +Producers for storing parton shower weights. +""" + +from __future__ import annotations + +from columnflow.production import producer, Producer +from columnflow.columnar_util import set_ak_column, full_like +from columnflow.util import maybe_import, DotDict + +ak = maybe_import("awkward") +np = maybe_import("numpy") + + +@producer( + uses={"PSWeight"}, + produces={"{isr,fsr}_weight{,_up,_down}"}, + # only run on mc + mc_only=True, + # indices where to find weight variations in PSWeight + indices=DotDict( + isr_weight_up=0, + fsr_weight_up=1, + isr_weight_down=2, + fsr_weight_down=3, + ), +) +def ps_weights( + self: Producer, + events: ak.Array, + invalid_weights_action: str = "raise", + **kwargs, +) -> ak.Array: + """ + Producer that reads out parton shower uncertainties on an event-by-event basis. + + The *invalid_weights_action* defines the procedure of how to handle events with missing or an unexpected number of + weights. Supported modes are: + + - ``"raise"``: An exception is raised. + - ``"ignore_one"``: Ignores cases where only a single weight is present and a weight of one is stored for all + variations. + - ``"ignore"``: Stores a weight of one for all missing weight variataions. + + Resources: + - https://cms-nanoaod-integration.web.cern.ch/integration/master/mc94X_doc.html + """ + known_actions = {"raise", "ignore_one", "ignore"} + if invalid_weights_action not in known_actions: + raise ValueError( + f"unknown invalid_weights_action '{invalid_weights_action}', known values are {','.join(known_actions)}", + ) + + # setup nominal weights + ones = np.ones(len(events), dtype=np.float32) + events = set_ak_column(events, "fsr_weight", ones) + events = set_ak_column(events, "isr_weight", ones) + + # check if weight variations are missing and if needed, pad them + indices = self.indices + ps_weights = events.PSWeight + num_weights = ak.num(ps_weights, axis=1) + max_index = max(indices.values()) + if ak.any(bad_mask := num_weights <= max_index): + msg = "" + if invalid_weights_action == "ignore": + # pad weights + ps_weights = ak.fill_none(ak.pad_none(ps_weights, max_index + 1, axis=1), 1.0, axis=1) + elif invalid_weights_action == "ignore_one": + # special treatment if there is only one weight + if ak.all(num_weights == 1): + ps_weights = full_like(ps_weights, 1.0) + indices = {column: 0 for column in indices} + else: + msg = f"at least {max_index + 1} or exactly one" + else: # raise + msg = f"at least {max_index + 1}" + if msg: + bad_values = ",".join(map(str, set(num_weights[bad_mask]))) + raise ValueError( + f"the number of PSWeight values is expected to be {msg}, but also found numbers of '{bad_values}' in " + f"{ak.mean(bad_mask) * 100:.1f}% of events in dataset {self.dataset_inst.name}", + ) + + # now loop through the names and save the respective normalized PSWeights + for column, index in indices.items(): + events = set_ak_column(events, column, ps_weights[:, index]) + + return events diff --git a/columnflow/production/cms/pdf.py b/columnflow/production/cms/pdf.py index 5de474d04..28009c1c3 100644 --- a/columnflow/production/cms/pdf.py +++ b/columnflow/production/cms/pdf.py @@ -200,7 +200,7 @@ def pdf_weights( @pdf_weights.init -def pdf_weight_init(self: Producer) -> None: +def pdf_weight_init(self: Producer, **kwargs) -> None: # add produced columns: nominal+all, or nominal+up+down self.produces.add("pdf_weight{,s}" if self.store_all_weights else "pdf_weight{,_up,_down}") diff --git a/columnflow/production/cms/pileup.py b/columnflow/production/cms/pileup.py index d8f6c2274..5262ba803 100644 --- a/columnflow/production/cms/pileup.py +++ b/columnflow/production/cms/pileup.py @@ -9,8 +9,9 @@ import law from columnflow.production import Producer, producer -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, DotDict from columnflow.columnar_util import set_ak_column +from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -56,7 +57,12 @@ def pu_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: @pu_weight.requires -def pu_weight_requires(self: Producer, reqs: dict) -> None: +def pu_weight_requires( + self: Producer, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: """ Adds the requirements needed the underlying task to derive the pileup weights into *reqs*. """ @@ -64,15 +70,17 @@ def pu_weight_requires(self: Producer, reqs: dict) -> None: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @pu_weight.setup def pu_weight_setup( self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: """ Loads the pileup calculator from the external files bundle and saves them in the @@ -121,7 +129,12 @@ def pu_weights_from_columnflow(self: Producer, events: ak.Array, **kwargs) -> ak @pu_weights_from_columnflow.requires -def pu_weights_from_columnflow_requires(self: Producer, reqs: dict) -> None: +def pu_weights_from_columnflow_requires( + self: Producer, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: """ Adds the requirements needed the underlying task to derive the pileup weights into *reqs*. """ @@ -129,15 +142,17 @@ def pu_weights_from_columnflow_requires(self: Producer, reqs: dict) -> None: return from columnflow.tasks.cms.external import CreatePileupWeights - reqs["pu_weights"] = CreatePileupWeights.req(self.task) + reqs["pu_weights"] = CreatePileupWeights.req(task) @pu_weights_from_columnflow.setup def pu_weights_from_columnflow_setup( self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: """ Loads the pileup weights added through the requirements and saves them in the diff --git a/columnflow/production/cms/scale.py b/columnflow/production/cms/scale.py index 443359cef..caa683566 100644 --- a/columnflow/production/cms/scale.py +++ b/columnflow/production/cms/scale.py @@ -9,9 +9,9 @@ import law from columnflow.production import Producer -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, DotDict from columnflow.columnar_util import set_ak_column -from columnflow.columnar_util import DotDict +from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -28,7 +28,14 @@ class _ScaleWeightBase(Producer): Common base class for the scale weight producers below that join a setup function. """ - def setup_func(self, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: + def setup_func( + self, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, + ) -> None: # named weight indices self.indices_9 = DotDict( mur_down_muf_down=0, @@ -224,12 +231,19 @@ def murmuf_envelope_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Ar @murmuf_envelope_weights.setup def murmuf_envelope_weights_setup( self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: - # call the super func - super(murmuf_envelope_weights, self).setup_func(reqs, inputs, reader_targets) + super(murmuf_envelope_weights, self).setup_func( + task=task, + reqs=reqs, + inputs=inputs, + reader_targets=reader_targets, + **kwargs, + ) # create a flat list if indices, skipping those for crossed variations self.envelope_indices_9 = [ diff --git a/columnflow/production/cms/seeds.py b/columnflow/production/cms/seeds.py index 75c684a56..09c84d8a9 100644 --- a/columnflow/production/cms/seeds.py +++ b/columnflow/production/cms/seeds.py @@ -12,8 +12,9 @@ import law from columnflow.production import Producer, producer -from columnflow.util import maybe_import, primes, InsertableDict +from columnflow.util import maybe_import, primes, DotDict from columnflow.columnar_util import Route, set_ak_column, optional_column as optional +from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -137,7 +138,7 @@ def deterministic_event_seeds(self, events: ak.Array, **kwargs) -> ak.Array: @deterministic_event_seeds.init -def deterministic_event_seeds_init(self) -> None: +def deterministic_event_seeds_init(self, **kwargs) -> None: """ Producer initialization that adds columns to the set of *used* columns based on the *event_columns*, *object_count_columns*, and *object_columns* lists. @@ -150,9 +151,11 @@ def deterministic_event_seeds_init(self) -> None: @deterministic_event_seeds.setup def deterministic_event_seeds_setup( self, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: """ Setup function that defines conventions methods needed during the producer function. @@ -227,15 +230,17 @@ def call_func(self, events: ak.Array, **kwargs) -> ak.Array: return events - def init_func(self) -> None: + def init_func(self, **kwargs) -> None: self.uses |= {f"{self.object_field}.pt"} self.produces |= {f"{self.object_field}.deterministic_seed"} def setup_func( self, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: """Setup before entering the event chunk loop. diff --git a/columnflow/production/cms/supercluster_eta.py b/columnflow/production/cms/supercluster_eta.py deleted file mode 100644 index b1285b724..000000000 --- a/columnflow/production/cms/supercluster_eta.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Module to calculate Photon super cluster eta. -Source: https://twiki.cern.ch/twiki/bin/view/CMS/EgammaNanoAOD#How_to_get_photon_supercluster_e -""" - -import law -import functools - -from columnflow.production import producer -from columnflow.util import maybe_import -from columnflow.columnar_util import set_ak_column - -np = maybe_import("numpy") -ak = maybe_import("awkward") - -logger = law.logger.get_logger(__name__) - -set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) - - -@producer( - uses={"Electron.{pt,phi,eta,deltaEtaSC}"}, - produces={"Electron.superclusterEta"}, -) -def electron_sceta(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Returns the electron super cluster eta. - """ - - events = set_ak_column_f32( - events, "Electron.superclusterEta", - events.Electron.eta + events.Electron.deltaEtaSC, - ) - return events diff --git a/columnflow/production/cms/top_pt_weight.py b/columnflow/production/cms/top_pt_weight.py index 46a30b1f2..bb1fb4c4e 100644 --- a/columnflow/production/cms/top_pt_weight.py +++ b/columnflow/production/cms/top_pt_weight.py @@ -4,6 +4,10 @@ Column producers related to top quark pt reweighting. """ +from __future__ import annotations + +from dataclasses import dataclass + import law from columnflow.production import Producer, producer @@ -12,17 +16,31 @@ ak = maybe_import("awkward") np = maybe_import("numpy") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents.methods.nanoaod") + logger = law.logger.get_logger(__name__) +@dataclass +class TopPtWeightConfig: + params: dict[str, float] + pt_max: float = 500.0 + + @classmethod + def new(cls, obj: TopPtWeightConfig | dict[str, float]) -> TopPtWeightConfig: + # backward compatibility only + if isinstance(obj, cls): + return obj + return cls(params=obj) + + @producer( uses={"GenPart.{pdgId,statusFlags}"}, # requested GenPartonTop columns, passed to the *uses* and *produces* produced_top_columns={"pt"}, mc_only=True, + # skip the producer unless the datasets has this specified tag (no skip check performed when none) + require_dataset_tag="has_top", ) def gen_parton_top(self: Producer, events: ak.Array, **kwargs) -> ak.Array: """ @@ -49,44 +67,55 @@ def gen_parton_top(self: Producer, events: ak.Array, **kwargs) -> ak.Array: @gen_parton_top.init -def gen_parton_top_init(self: Producer) -> bool: +def gen_parton_top_init(self: Producer, **kwargs) -> bool: for col in self.produced_top_columns: self.uses.add(f"GenPart.{col}") self.produces.add(f"GenPartonTop.{col}") @gen_parton_top.skip -def gen_parton_top_skip(self: Producer) -> bool: +def gen_parton_top_skip(self: Producer, **kwargs) -> bool: """ - Custom skip function that checks whether the dataset is a MC simulation containing top - quarks in the first place. + Custom skip function that checks whether the dataset is a MC simulation containing top quarks in the first place + using the :py:attr:`require_dataset_tag` attribute. """ - # never skip when there is not dataset - if not getattr(self, "dataset_inst", None): + # never skip if the tag is not set + if self.require_dataset_tag is None: return False - return self.dataset_inst.is_data or not self.dataset_inst.has_tag("has_top") + return self.dataset_inst.is_data or not self.dataset_inst.has_tag(self.require_dataset_tag) + + +def get_top_pt_weight_config(self: Producer) -> TopPtWeightConfig: + if self.config_inst.has_aux("top_pt_reweighting_params"): + logger.info_once( + "deprecated_top_pt_weight_config", + "the config aux field 'top_pt_reweighting_params' is deprecated and will be removed in " + "a future release, please use 'top_pt_weight' instead", + ) + params = self.config_inst.x.top_pt_reweighting_params + else: + params = self.config_inst.x.top_pt_weight + + return TopPtWeightConfig.new(params) @producer( - uses={ - "GenPartonTop.pt", - }, - produces={ - "top_pt_weight", "top_pt_weight_up", "top_pt_weight_down", - }, - get_top_pt_config=(lambda self: self.config_inst.x.top_pt_reweighting_params), + uses={"GenPartonTop.pt"}, + produces={"top_pt_weight{,_up,_down}"}, + get_top_pt_weight_config=get_top_pt_weight_config, + # skip the producer unless the datasets has this specified tag (no skip check performed when none) + require_dataset_tag="is_ttbar", ) def top_pt_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: """ Compute SF to be used for top pt reweighting. - The *GenPartonTop.pt* column can be produced with the :py:class:`gen_parton_top` Producer. - - The SF should *only be applied in ttbar MC* as an event weight and is computed - based on the gen-level top quark transverse momenta. + See https://twiki.cern.ch/twiki/bin/view/CMS/TopPtReweighting?rev=31 for more information. - The function is skipped when the dataset is data or when it does not have the tag *is_ttbar*. + The *GenPartonTop.pt* column can be produced with the :py:class:`gen_parton_top` Producer. The + SF should *only be applied in ttbar MC* as an event weight and is computed based on the + gen-level top quark transverse momenta. The top pt reweighting parameters should be given as an auxiliary entry in the config: @@ -105,19 +134,21 @@ def top_pt_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: :param events: awkward array containing events to process """ - - # get SF function parameters from config - params = self.get_top_pt_config() - # check the number of gen tops - if ak.any(ak.num(events.GenPartonTop, axis=1) != 2): - logger.warning("There are events with != 2 GenPartonTops. This producer should only run for ttbar") + if ak.any((n_tops := ak.num(events.GenPartonTop, axis=1)) != 2): + raise Exception( + f"{self.cls_name} can only run on events with two generator top quarks, but found " + f"counts of {','.join(map(str, sorted(set(n_tops))))}", + ) + + # clamp top pt + top_pt = events.GenPartonTop.pt + if self.cfg.pt_max >= 0.0: + top_pt = ak.where(top_pt > self.cfg.pt_max, self.cfg.pt_max, top_pt) - # clamp top pT < 500 GeV - pt_clamped = ak.where(events.GenPartonTop.pt > 500.0, 500.0, events.GenPartonTop.pt) for variation in ("", "_up", "_down"): # evaluate SF function - sf = np.exp(params[f"a{variation}"] + params[f"b{variation}"] * pt_clamped) + sf = np.exp(self.cfg.params[f"a{variation}"] + self.cfg.params[f"b{variation}"] * top_pt) # compute weight from SF product for top and anti-top weight = np.sqrt(np.prod(sf, axis=1)) @@ -128,13 +159,18 @@ def top_pt_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: return events +@top_pt_weight.init +def top_pt_weight_init(self: Producer) -> None: + # store the top pt weight config + self.cfg = self.get_top_pt_weight_config() + + @top_pt_weight.skip -def top_pt_weight_skip(self: Producer) -> bool: +def top_pt_weight_skip(self: Producer, **kwargs) -> bool: """ - Skip if running on anything except ttbar MC simulation. + Skip if running on anything except ttbar MC simulation, evaluated via the :py:attr:`require_dataset_tag` attribute. """ - # never skip when there is no dataset - if not getattr(self, "dataset_inst", None): - return False + if self.require_dataset_tag is None: + return self.dataset_inst.is_data return self.dataset_inst.is_data or not self.dataset_inst.has_tag("is_ttbar") diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 9c2dd296f..56e4c0c82 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -13,8 +13,9 @@ import scinum as sn from columnflow.production import Producer, producer -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, DotDict from columnflow.columnar_util import set_ak_column +from columnflow.types import Any np = maybe_import("numpy") sp = maybe_import("scipy") @@ -27,8 +28,8 @@ def get_inclusive_dataset(self: Producer) -> od.Dataset: """ - Helper function to obtain the inclusive dataset from a list of datasets that are required to - stitch this *dataset_inst*. + Helper function to obtain the inclusive dataset from a list of datasets that are required to stitch this + *dataset_inst*. """ process_map = {d.processes.get_first(): d for d in self.stitching_datasets} @@ -69,9 +70,8 @@ def get_br_from_inclusive_dataset( stats: dict, ) -> dict[int, float]: """ - Helper function to compute the branching ratios from the inclusive sample. - This is done with ratios of event weights isolated per dataset and thus independent of the - overall mc weight normalization. + Helper function to compute the branching ratios from the inclusive sample. This is done with ratios of event weights + isolated per dataset and thus independent of the overall mc weight normalization. """ # define helper variables and mapping between process ids and dataset names proc_ds_map = { @@ -152,9 +152,8 @@ def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: rel_unc = proc_br(sn.UP, unc=True, factor=True) if rel_unc > 0.05: logger.warning( - "large error on the branching ratio for process " - f"{inclusive_proc.get_process(proc_id).name} with process id {proc_id} " - f"({rel_unc * 100:.2f}%)", + f"large error on the branching ratio for process {inclusive_proc.get_process(proc_id).name} with " + f"process id {proc_id} ({rel_unc * 100:.2f}%)", ) # just store the nominal value @@ -171,6 +170,8 @@ def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: uses={"process_id", "mc_weight"}, # name of the output column weight_name="normalization_weight", + # which luminosity to apply, uses the value stored in the config when None + luminosity=None, # whether to allow stitching datasets allow_stitching=False, get_xsecs_from_inclusive_dataset=False, @@ -182,19 +183,24 @@ def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: ) def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: """ - Uses luminosity information of internal py:attr:`config_inst`, the cross section of a process - obtained through :py:class:`category_ids` and the sum of event weights from the - py:attr:`selection_stats` attribute to assign each event a normalization weight. - The normalization weight is stored in a new column named after the py:attr:`weight_name` - attribute. When py:attr`allow_stitching` is set to True, the sum of event weights is computed - for all datasets with a leaf process contained in the leaf processes of the - py:attr:`dataset_inst`. For stitching, the process_id needs to be reconstructed for each leaf - process on a per event basis. Moreover, when stitching is enabled, an additional normalization - weight is computed for the inclusive dataset only and stored in a column named - `_inclusive_only`. This weight resembles the normalization weight for the - inclusive dataset, as if it were unstitched and should therefore only be applied, when using the - inclusive dataset as a standalone dataset. - + Uses luminosity information of internal py:attr:`config_inst`, the cross section of a process obtained through + :py:class:`category_ids` and the sum of event weights from the py:attr:`selection_stats` attribute to assign each + event a normalization weight. The normalization weight is stored in a new column named after the + py:attr:`weight_name` attribute. + + The computation of all weights requires that the selection statistics ("stats" output of :py:class:`SelectEvents`) + contains a field ``"sum_mc_weight_per_process"`` which itself is a dictionary mapping process ids to the sum of + event weights for that process. + + *luminosity* is used to scale the yield of the simulation. When *None*, the ``luminosity`` auxiliary field of the + config is used. + + When py:attr`allow_stitching` is set to True, the sum of event weights is computed for all datasets with a leaf + process contained in the leaf processes of the py:attr:`dataset_inst`. For stitching, the process_id needs to be + reconstructed for each leaf process on a per event basis. Moreover, when stitching is enabled, an additional + normalization weight is computed for the inclusive dataset only and stored in a column named + `_inclusive_only`. This weight resembles the normalization weight for the inclusive dataset, as if it + were unstitched and should therefore only be applied, when using the inclusive dataset as a standalone dataset. """ # read the process id column process_id = np.asarray(events.process_id) @@ -204,8 +210,8 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra invalid_ids = unique_process_ids - self.xs_process_ids if invalid_ids: raise Exception( - f"process_id field contains id(s) {invalid_ids} for which no cross sections were " - f"found; process ids with cross sections: {self.xs_process_ids}", + f"process_id field contains id(s) {invalid_ids} for which no cross sections were found; process ids with " + f"cross sections: {self.xs_process_ids}", ) # read the weight per process (defined as lumi * xsec / sum_weights) from the lookup table @@ -228,24 +234,26 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra @normalization_weights.requires -def normalization_weights_requires(self: Producer, reqs: dict) -> None: +def normalization_weights_requires( + self: Producer, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: """ - Adds the requirements needed by the underlying py:attr:`task` to access selection stats into - *reqs*. + Adds the requirements needed by the underlying py:attr:`task` to access selection stats into *reqs*. """ # check that all datasets are known for dataset in self.stitching_datasets: if not self.config_inst.has_dataset(dataset): - raise Exception( - f"unknown dataset '{dataset}' required for normalization weights computation", - ) + raise Exception(f"unknown dataset '{dataset}' required for normalization weights computation") from columnflow.tasks.selection import MergeSelectionStats reqs["selection_stats"] = { dataset.name: MergeSelectionStats.req_different_branching( - self.task, + task, dataset=dataset.name, - branch=-1 if self.task.is_workflow() else 0, + branch=-1 if task.is_workflow() else 0, ) for dataset in self.stitching_datasets } @@ -256,21 +264,22 @@ def normalization_weights_requires(self: Producer, reqs: dict) -> None: @normalization_weights.setup def normalization_weights_setup( self: Producer, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: """ - Sets up objects required by the computation of normalization weights and stores them as instance - attributes: + Sets up objects required by the computation of normalization weights and stores them as instance attributes: - - py: attr: `process_weight_table`: A sparse array serving as a lookup table for the - calculated process weights. This weight is defined as the product of the luminosity, the - cross section, divided by the sum of event weights per process. + - py: attr: `process_weight_table`: A sparse array serving as a lookup table for the calculated process weights. + This weight is defined as the product of the luminosity, the cross section, divided by the sum of event + weights per process. """ # load the selection stats selection_stats = { - dataset: self.task.cached_value( + dataset: task.cached_value( key=f"selection_stats_{dataset}", func=lambda: inp["stats"].load(formatter="json"), ) @@ -299,22 +308,23 @@ def normalization_weights_setup( unknown_process_ids = allowed_ids - {p.id for p in process_insts} if unknown_process_ids: raise Exception( - f"selection stats contain ids of processes that were not previously registered to the " - f"config '{self.config_inst.name}': {', '.join(map(str, unknown_process_ids))}", + f"selection stats contain ids of processes that were not previously registered to the config " + f"'{self.config_inst.name}': {', '.join(map(str, unknown_process_ids))}", ) # likewise, drop processes that were not seen during selection process_insts = {p for p in process_insts if p.id in allowed_ids} max_id = max(process_inst.id for process_inst in process_insts) - # get the luminosity from the config - lumi = self.config_inst.x.luminosity.nominal + # get the luminosity + lumi = self.config_inst.x.luminosity if self.luminosity is None else self.luminosity + lumi = lumi.nominal if isinstance(lumi, sn.Number) else float(lumi) # create a event weight lookup table process_weight_table = sp.sparse.lil_matrix((1, max_id + 1), dtype=np.float32) if self.allow_stitching and self.get_xsecs_from_inclusive_dataset: inclusive_dataset = self.inclusive_dataset - logger.info(f"using inclusive dataset {inclusive_dataset.name} for cross section lookup") + logger.debug(f"using inclusive dataset {inclusive_dataset.name} for cross section lookup") # extract branching ratios from inclusive dataset(s) inclusive_proc = inclusive_dataset.processes.get_first() @@ -327,8 +337,7 @@ def normalization_weights_setup( ) if not branching_ratios: raise Exception( - "no branching ratios could be computed based on the inclusive dataset " - f"{inclusive_dataset}", + f"no branching ratios could be computed based on the inclusive dataset {inclusive_dataset}", ) # compute the weight the inclusive dataset would have on its own without stitching @@ -348,8 +357,8 @@ def normalization_weights_setup( for process_inst in process_insts: if self.config_inst.campaign.ecm not in process_inst.xsecs.keys(): raise KeyError( - f"no cross section registered for process {process_inst} for center-of-mass " - f"energy of {self.config_inst.campaign.ecm}", + f"no cross section registered for process {process_inst} for center-of-mass energy of " + f"{self.config_inst.campaign.ecm}", ) sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(process_inst.id)] xsec = process_inst.get_xsec(self.config_inst.campaign.ecm).nominal @@ -360,13 +369,10 @@ def normalization_weights_setup( @normalization_weights.init -def normalization_weights_init(self: Producer) -> None: +def normalization_weights_init(self: Producer, **kwargs) -> None: """ Initializes the normalization weights producer by setting up the normalization weight column. """ - if getattr(self, "dataset_inst", None) is None: - return - self.produces.add(self.weight_name) if self.allow_stitching: self.stitching_datasets = self.get_stitching_datasets() diff --git a/columnflow/reduction/__init__.py b/columnflow/reduction/__init__.py new file mode 100644 index 000000000..c35975c2e --- /dev/null +++ b/columnflow/reduction/__init__.py @@ -0,0 +1,105 @@ +# coding: utf-8 + +""" +Event and collection reduction objects. +""" + +from __future__ import annotations + +import inspect + +from columnflow.types import Callable +from columnflow.util import DerivableMeta +from columnflow.columnar_util import TaskArrayFunction + + +class Reducer(TaskArrayFunction): + """ + Base class for all reducers. + """ + + exposed = True + + @classmethod + def reducer( + cls, + func: Callable | None = None, + bases: tuple = (), + mc_only: bool = False, + data_only: bool = False, + **kwargs, + ) -> DerivableMeta | Callable: + """ + Decorator for creating a new :py:class:`~.Reducer` subclass with additional, optional *bases* and attaching the + decorated function to it as ``call_func``. + + When *mc_only* (*data_only*) is *True*, the reducer is skipped and not considered by other task array functions + in case they are evalauted on a :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose + ``is_mc`` (``is_data``) attribute is *False*. + + All additional *kwargs* are added as class members of the new subclasses. + + :param func: Function to be wrapped and integrated into new :py:class:`Reducer` class. + :param bases: Additional bases for the new reducer. + :param mc_only: Boolean flag indicating that this reducer should only run on Monte Carlo simulation and skipped + for real data. + :param data_only: Boolean flag indicating that this reducer should only run on real data and skipped for Monte + Carlo simulation. + :return: New reducer subclass. + """ + def decorator(func: Callable) -> DerivableMeta: + # create the class dict + cls_dict = { + **kwargs, + "call_func": func, + "mc_only": mc_only, + "data_only": data_only, + } + + # get the module name + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + + # get the reducer name + cls_name = cls_dict.pop("cls_name", func.__name__) + + # hook to update the class dict during class derivation + def update_cls_dict(cls_name, cls_dict, get_attr): + mc_only = get_attr("mc_only") + data_only = get_attr("data_only") + + # optionally add skip function + if mc_only and data_only: + raise Exception(f"reducer {cls_name} received both mc_only and data_only") + if (mc_only or data_only) and cls_dict.get("skip_func"): + raise Exception( + f"reducer {cls_name} received custom skip_func, but either mc_only or data_only are set", + ) + + if "skip_func" not in cls_dict: + def skip_func(self, **kwargs) -> bool: + # check mc_only and data_only + if mc_only and not self.dataset_inst.is_mc: + return True + if data_only and not self.dataset_inst.is_data: + return True + + # in all other cases, do not skip + return False + + cls_dict["skip_func"] = skip_func + + return cls_dict + + cls_dict["update_cls_dict"] = update_cls_dict + + # create the subclass + subclass = cls.derive(cls_name, bases=bases, cls_dict=cls_dict, module=module) + + return subclass + + return decorator(func) if func else decorator + + +# shorthand +reducer = Reducer.reducer diff --git a/columnflow/reduction/default.py b/columnflow/reduction/default.py new file mode 100644 index 000000000..21c9da51c --- /dev/null +++ b/columnflow/reduction/default.py @@ -0,0 +1,92 @@ +# coding: utf-8 + +""" +Reducer definition for achieving columnsflow's default reduction behavior in three steps: + - remove unwanted events (using "event" mask of selection results) + - create new collections (using "objects" mapping of selection results) + - only keep certain columns after the reduction +""" + +from __future__ import annotations + +from collections import defaultdict + +import law + +from columnflow.reduction import Reducer, reducer +from columnflow.reduction.util import create_event_mask, create_collections_from_masks +from columnflow.util import maybe_import + +ak = maybe_import("awkward") + + +@reducer() +def cf_default_keep_columns(self: Reducer, events: ak.Array, selection: ak.Array, **kwargs) -> ak.Array: + """ + Reducer that does nothing but to define the columns to keep after the reduction in a backwards-compatible way using + the "keep_columns" auxiliary config field as was the default in previous columnflow versions. + """ + return events + + +@cf_default_keep_columns.post_init +def cf_default_keep_columns_post_init(self: Reducer, task: law.Task, **kwargs) -> None: + for c in self.config_inst.x.keep_columns.get(task.task_family, ["*"]): + self.produces.update(task._expand_keep_column(c)) + + +@reducer( + # disable the check for used columns + check_used_columns=False, + # whether to add cf_default_keep_columns as a dependency to achieve backwards compatibility + add_keep_columns=True, +) +def cf_default(self: Reducer, events: ak.Array, selection: ak.Array, task: law.Task, **kwargs) -> ak.Array: + # build the event mask + event_mask = create_event_mask(selection, task.selector_steps) + + # apply it + events = events[event_mask] + + # add collections + if "objects" in selection.fields: + events = create_collections_from_masks(events, selection.objects[event_mask]) + + return events + + +@cf_default.init +def cf_default_init(self: Reducer, **kwargs) -> None: + if self.add_keep_columns: + self.uses.add(cf_default_keep_columns.PRODUCES) + self.produces.add(cf_default_keep_columns.PRODUCES) + + +@cf_default.post_init +def cf_default_post_init(self: Reducer, task: law.Task, **kwargs) -> None: + # the updates to used columns are only necessary if the task invokes the reducer + if not task.invokes_reducer: + return + + # add used columns pointing to the selection steps + # (all starting with "steps." for ReduceEvents to decide to load them from selection result data) + for step in task.selector_steps: + self.uses.add(f"steps.{step}") + + # based on the columns to write, determine which collections need to be read to produce new collections + # (masks must start with "objects." for ReduceEvents to decide to load them from selection result data) + output_collection_fields = defaultdict(set) + for route in self.produced_columns: + if len(route) > 1: + output_collection_fields[route[0]].add(route) + + # iterate through collections and update used colums + for src_col, dst_cols in task.collection_map.items(): + for dst_col in dst_cols: + # skip if the collection does not need to be loaded at all + if not law.util.multi_match(dst_col, output_collection_fields.keys()): + continue + # read the object mask + self.uses.add(f"objects.{src_col}.{dst_col}") + # make sure that the corresponding columns of the source collection are loaded + self.uses.update(src_col + route[1:] for route in output_collection_fields[dst_col]) diff --git a/columnflow/selection/util.py b/columnflow/reduction/util.py similarity index 55% rename from columnflow/selection/util.py rename to columnflow/reduction/util.py index ea96831a1..e9c4ab826 100644 --- a/columnflow/selection/util.py +++ b/columnflow/reduction/util.py @@ -1,32 +1,55 @@ # coding: utf-8 """ -Helpful utilities often used in selections. +Helpful reduction utilities. """ from __future__ import annotations __all__ = [] +import functools + import law from columnflow.util import maybe_import -from columnflow.columnar_util import set_ak_column, sorted_indices_from_mask as _sorted_indices_from_mask +from columnflow.columnar_util import set_ak_column ak = maybe_import("awkward") logger = law.logger.get_logger(__name__) +full_slice = slice(None, None) + -def sorted_indices_from_mask(*args, **kwargs) -> ak.Array: - # deprecated - logger.warning_once( - "sorted_indices_from_mask_deprecated", - "columnflow.selection.util.sorted_indices_from_mask() is deprecated and will be removed in " - "April 2025; use columnflow.columnar_util.sorted_indices_from_mask() instead", - ) - return _sorted_indices_from_mask(*args, **kwargs) +def create_event_mask(selection: ak.Array, requested_steps: tuple[str]) -> ak.Array | slice: + """ + Creates and returns an event mask based on a *selection* results array and a tuple of *requested_steps* according to + the following checks (in that order): + + - When not empty, *requested_steps* are considered fields of the *selection.steps* array and subsequently + concatenated with a logical AND operation. + - Otherwise, if the *event* field is present in the *selection* array, it is used instead. + - Otherwise, a empty slice object is returned. + """ + # build the event mask from requested steps + if requested_steps: + # check if all steps are present + missing_steps = set(requested_steps) - set(selection.steps.fields) + if missing_steps: + raise Exception(f"selector steps {','.join(missing_steps)} requested but missing in {selection.steps}") + return functools.reduce( + (lambda a, b: a & b), + (selection["steps", step] for step in requested_steps), + ) + + # use the event field if present + if "event" in selection.fields: + return selection.event + + # fallback to an empty slice + return full_slice def create_collections_from_masks( diff --git a/columnflow/selection/__init__.py b/columnflow/selection/__init__.py index 214b17729..ba17a5a7d 100644 --- a/columnflow/selection/__init__.py +++ b/columnflow/selection/__init__.py @@ -12,10 +12,9 @@ import law import order as od -from columnflow.types import Callable, Sequence, T +from columnflow.types import Callable, T from columnflow.util import maybe_import, DotDict, DerivableMeta from columnflow.columnar_util import TaskArrayFunction -from columnflow.config_util import expand_shift_sources ak = maybe_import("awkward") @@ -42,8 +41,6 @@ def selector( bases=(), mc_only: bool = False, data_only: bool = False, - nominal_only: bool = False, - shifts_only: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -55,11 +52,6 @@ def selector( :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. - When *nominal_only* is *True* or *shifts_only* is set, the selector is skipped and not - considered by other calibrators, selectors and producers in case they are evaluated on a - :py:class:`order.Shift` (using the :py:attr:`global_shift_inst` attribute) whose name does - not match. - All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Selector` class. @@ -68,10 +60,6 @@ def selector( Monte Carlo simulation and skipped for real data. :param data_only: Boolean flag indicating that this :py:class:`Selector` should only run on real data and skipped for Monte Carlo simulation. - :param nominal_only: Boolean flag indicating that this :py:class:`Selector` should only run - on the nominal shift and skipped on any other shifts. - :param shifts_only: Shift names that this :py:class:`Selector` should only run on, - skipping all other shifts. :return: New :py:class:`Selector` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -81,8 +69,6 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, - "nominal_only": nominal_only, - "shifts_only": shifts_only, } # get the module name @@ -96,45 +82,23 @@ def decorator(func: Callable) -> DerivableMeta: def update_cls_dict(cls_name, cls_dict, get_attr): mc_only = get_attr("mc_only") data_only = get_attr("data_only") - nominal_only = get_attr("nominal_only") - shifts_only = get_attr("shifts_only") - - # prepare shifts_only - if shifts_only: - shifts_only_expanded = set(expand_shift_sources(shifts_only)) - if shifts_only_expanded != shifts_only: - shifts_only = shifts_only_expanded - cls_dict["shifts_only"] = shifts_only # optionally add skip function if mc_only and data_only: raise Exception(f"selector {cls_name} received both mc_only and data_only") - if nominal_only and shifts_only: + if (mc_only or data_only) and cls_dict.get("skip_func"): raise Exception( - f"selector {cls_name} received both nominal_only and shifts_only", + f"selector {cls_name} received custom skip_func, but either mc_only or " + "data_only are set", ) - if mc_only or data_only or nominal_only or shifts_only: - if cls_dict.get("skip_func"): - raise Exception( - f"selector {cls_name} received custom skip_func, but either mc_only, " - "data_only, nominal_only or shifts_only are set", - ) if "skip_func" not in cls_dict: - def skip_func(self): + def skip_func(self, **kwargs) -> bool: # check mc_only and data_only - if getattr(self, "dataset_inst", None): - if mc_only and not self.dataset_inst.is_mc: - return True - if data_only and not self.dataset_inst.is_data: - return True - - # check nominal_only and shifts_only - if getattr(self, "global_shift_inst", None): - if nominal_only and not self.global_shift_inst.is_nominal: - return True - if shifts_only and self.global_shift_inst.name not in shifts_only: - return True + if mc_only and not self.dataset_inst.is_mc: + return True + if data_only and not self.dataset_inst.is_data: + return True # in all other cases, do not skip return False diff --git a/columnflow/selection/cms/jets.py b/columnflow/selection/cms/jets.py index 945ed1b1e..20c008778 100644 --- a/columnflow/selection/cms/jets.py +++ b/columnflow/selection/cms/jets.py @@ -10,8 +10,9 @@ import math from columnflow.selection import Selector, SelectionResult, selector -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column, flat_np_view, optional_column as optional +from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -137,29 +138,31 @@ def jet_veto_map( @jet_veto_map.requires -def jet_veto_map_requires(self: Selector, reqs: dict) -> None: +def jet_veto_map_requires( + self: Selector, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @jet_veto_map.setup def jet_veto_map_setup( self: Selector, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: - bundle = reqs["external_files"] - # create the corrector - import correctionlib - correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate - correction_set = correctionlib.CorrectionSet.from_string( - self.get_veto_map_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) + map_file = self.get_veto_map_file(reqs["external_files"].files) + correction_set = load_correction_set(map_file) keys = list(correction_set.keys()) if len(keys) != 1: raise ValueError(f"Expected exactly one correction in the file, got {len(keys)}") diff --git a/columnflow/selection/cms/json_filter.py b/columnflow/selection/cms/json_filter.py index 0806ce2a7..2b750a563 100644 --- a/columnflow/selection/cms/json_filter.py +++ b/columnflow/selection/cms/json_filter.py @@ -6,8 +6,11 @@ from __future__ import annotations +import law + from columnflow.selection import Selector, selector, SelectionResult -from columnflow.util import maybe_import, InsertableDict, DotDict +from columnflow.util import maybe_import, DotDict +from columnflow.types import Any ak = maybe_import("awkward") np = maybe_import("numpy") @@ -91,20 +94,27 @@ def json_filter( @json_filter.requires -def json_filter_requires(self: Selector, reqs: dict) -> None: +def json_filter_requires( + self: Selector, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: if "external_files" in reqs: return from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) + reqs["external_files"] = BundleExternalFiles.req(task) @json_filter.setup def json_filter_setup( self: Selector, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: """ Setup function for :py:class:`json_filter`. Load golden JSON and set up run/luminosity block diff --git a/columnflow/selection/cms/met_filters.py b/columnflow/selection/cms/met_filters.py index 5a2ac54a8..3033b8641 100644 --- a/columnflow/selection/cms/met_filters.py +++ b/columnflow/selection/cms/met_filters.py @@ -87,12 +87,9 @@ def met_filters( @met_filters.init -def met_filters_init(self: Selector) -> None: +def met_filters_init(self: Selector, **kwargs) -> None: met_filters = self.get_met_filters() if isinstance(met_filters, dict): - # do nothing when no dataset_inst is known - if not getattr(self, "dataset_inst", None): - return met_filters = met_filters[self.dataset_inst.data_source] # store filters as an attribute for faster lookup diff --git a/columnflow/selection/empty.py b/columnflow/selection/empty.py index 2fed7be13..0be227402 100644 --- a/columnflow/selection/empty.py +++ b/columnflow/selection/empty.py @@ -91,7 +91,7 @@ def empty( @empty.init -def empty_init(self: Selector) -> None: +def empty_init(self: Selector, **kwargs) -> None: """ Initializes the selector by finding the id of the inclusive category if no hard-coded category ids are given on class-level. diff --git a/columnflow/selection/matching.py b/columnflow/selection/matching.py deleted file mode 100644 index 7debfd5cc..000000000 --- a/columnflow/selection/matching.py +++ /dev/null @@ -1,169 +0,0 @@ -# coding: utf-8 - -""" -Distance-based methods. -""" - -from __future__ import annotations - -from columnflow.types import Callable, Union -from columnflow.selection import Selector, SelectionResult, selector -from columnflow.util import maybe_import - -np = maybe_import("numpy") -ak = maybe_import("awkward") - - -def cleaning_factory( - selector_name: str, - to_clean: str, - clean_against: list[str], - metric: Union[Callable, None] = None, -) -> Selector: - """ - Factory to generate a function with name *selector_name* that cleans the field *to_clean* in an - array following the :external+coffea:py:class:`~coffea.nanoevents.NanoAODSchema` against the - field(s) *clean_against*. First, the necessary column names to construct four-momenta for the - different object fields are constructed, i.e. ``pt``, ``eta``, ``phi`` and ``e`` for the - different objects. Finally, the actual selector function is generated, which uses these columns. - - :param selector_name: Name of the :py:class:`~columnflow.selection.Selector` class to be - initialized. - :param to_clean: Name of the field to be cleaned (e.g. ``"Jet"``). - :param clean_against: Names of the fields of object to clean field *to_clean* against - (e.g. ``["Muon"]``). - :param metric: Function to use for the cleaning. If None, use - :external+coffea:py:meth:`~coffea.nanoevents.methods.vector.LorentzVector.delta_r`. - :return: Instance of :py:class:`~columnflow.selection.Selector`. - """ - # default of the metric function is the delta_r function - # of the coffea LorentzVectors - if metric is None: - metric = lambda a, b: a.delta_r(b) - - # compile the list of variables that are necessary for the four momenta - # this list is always the same - variables_for_lorentzvec = ["pt", "eta", "phi", "e"] - - # sum up all fields aht are to be considered, i.e. the field *to_clean* - # and all fields in *clean_against* - all_fields = clean_against + [to_clean] - - # construct the set of columns that is necessary for the four momenta in - # the different fields (and thus also for the current implementation of - # the cleaning itself) by looping through the fields and variables. - - uses = { - f"{x}.{var}" for x in all_fields for var in variables_for_lorentzvec - } - - # additionally, also load the lengths of the different fields - uses |= {f"n{x}" for x in all_fields} - - # finally, construct selector function itself - @selector(uses=uses, name=selector_name) - def func( - self: Selector, - events: ak.Array, - to_clean: str, - clean_against: list[str], - metric: Union[Callable, None] = metric, - threshold: float = 0.4, - ) -> ak.Array: - """ - Abstract function to perform a cleaning of field *to_clean* against a (list of) field(s) - *clean_against* based on an abitrary metric *metric* (e.g. - :external+coffea:py:meth:`~coffea.nanoevents.methods.vector.LorentzVector.delta_r`). First - concatenate all fields in *clean_against*, which thus includes all fields that are to be - used for the comparison of the metric. Then construct the metric for all permutations of the - different objects using the :external+coffea:doc:`index` - :external+coffea:py:meth:`~coffea.nanoevents.methods.vector.LorentzVector.nearest` - implementation. All objects in field *to_clean* are removed if the metric is below the - *threshold*. - - :param self: :py:class:`columnflow.selection.Selector` instance into which this function is - embedded. - :param events: array containing events in the NanoAOD format - param to_clean: Name of the field to be cleaned (e.g. ``"Jet"``) - :param clean_against: Names of the fields of object to clean field *to_clean* against (e.g. - ``["Muon"]``) - :param metric: Function to use for the cleaning. If None, the - :external+coffea:py:meth:`~coffea.nanoevents.methods.vector.LorentzVector.delta_r`, - defaults to None. - :param threshold: Threshold value for decision which objects to keep and which to reject, - defaults to ``0.4``. - :return: array of indices of cleaned objects, ordered according to the ``pt`` of the - objects. - """ - # concatenate the fields that are to be used in the construction - # of the metric table - summed_clean_against = ak.concatenate( - [events[x] for x in clean_against], - axis=1, - ) - - # load actual NanoEventArray that is to be cleaned - to_clean_field = events[to_clean] - - # construct metric table for these objects. The metric table contains the minimal value of - # the metric *metric* for each object in field *to_clean* w.r.t. all objects in - # *summed_clean_against*. Thus, it has the dimensions nevents x nto_clean, where *nevents* - # is the number of events in the current chunk of data and *nto_clean* is the length of the - # field *to_clean*. Note that the argument *threshold* in the *nearest* function must be set - # to None since the function will perform a selection itself to extract the nearest objects - # (i.e. applies the selection we want here in reverse) - _, metric = to_clean_field.nearest( - summed_clean_against, - metric=metric, - return_metric=True, - threshold=None, - ) - # build a binary mask based on the selection threshold provided by the - # user - mask = metric > threshold - - # construct final result. Currently, this is the list of indices for - # clean jets, sorted for pt - # WARNING: this still contains the bug with the application of the mask - # which will be adressed in a PR in the very near future - # TODO: return the mask itself instead of the list of indices - sorted_list = ak.argsort(to_clean_field.pt, axis=-1, ascending=False)[mask] - return sorted_list - - return func - - -delta_r_jet_lepton = cleaning_factory( - selector_name="delta_r_jet_lepton", - to_clean="Jet", - clean_against=["Muon", "Electron"], - metric=lambda a, b: a.delta_r(b), -) - - -@selector(uses={delta_r_jet_lepton}) -def jet_lepton_delta_r_cleaning( - self: Selector, - events: ak.Array, - stats: dict[str, Union[int, float]], - threshold: float = 0.4, - **kwargs, -) -> tuple[ak.Array, SelectionResult]: - """ - Function to apply the selection requirements necessary for a cleaning of jets against leptons. - - The function calls the requirements to clean the field *Jet* against the concatination of the - fields *[Muon, Electron]*, i.e. all leptons and passes the desired threshold for the selection - - :param events: Array containing events in the NanoAOD format - :param stats: :py:class:`dictionary ` containing selection stats (not used here). - :param threshold: Threshold value for decision which objects to keep and which to reject. - - :return: Tuple containing the events array and a - :py:class:`~columnflow.selection.SelectionResult` with indices of cleaned jets in the - "Jet" object field. - """ - clean_jet_indices = self[delta_r_jet_lepton](events, "Jet", ["Muon", "Electron"], threshold=threshold) - - # TODO: should not return a new object collection but an array with masks - return events, SelectionResult(objects={"Jet": clean_jet_indices}) diff --git a/columnflow/selection/stats.py b/columnflow/selection/stats.py index 5038a6a03..fdd84a5e1 100644 --- a/columnflow/selection/stats.py +++ b/columnflow/selection/stats.py @@ -12,9 +12,11 @@ from collections import defaultdict from operator import and_, getitem as getitem_ -from columnflow.types import Sequence, Callable +import law + from columnflow.selection import Selector, SelectionResult, selector -from columnflow.util import maybe_import, InsertableDict +from columnflow.util import maybe_import, DotDict +from columnflow.types import Sequence, Callable, Any np = maybe_import("numpy") ak = maybe_import("awkward") @@ -115,7 +117,13 @@ def increment_stats( feature, a *skip_func* can be defined that receives the weight name and the names of the groups of an entry. If the function returns *True*, the entry will be skipped. """ - # default skip func + # defaults + if weight_map is None: + weight_map = {} + if group_map is None: + group_map = {} + if group_combinations is None: + group_combinations = [] if skip_func is None: skip_func = lambda weight_name, group_names: False @@ -126,7 +134,6 @@ def increment_stats( } # treat groups as combinations of a single group - group_combinations = list(group_combinations or []) for group_name, group_data in list(group_map.items())[::-1]: if group_data.get("combinations_only", False) or (group_name,) in group_combinations: continue @@ -209,9 +216,11 @@ def increment_stats( @increment_stats.setup def increment_stats_setup( self: Selector, - reqs: dict, - inputs: dict, - reader_targets: InsertableDict, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, ) -> None: # flags to descibe "number" and "sum" fields self.NUM, self.SUM = range(2) diff --git a/columnflow/tasks/calibration.py b/columnflow/tasks/calibration.py index 36db18b27..cc097c6b2 100644 --- a/columnflow/tasks/calibration.py +++ b/columnflow/tasks/calibration.py @@ -7,22 +7,28 @@ import luigi import law -from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory +from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory from columnflow.tasks.framework.mixins import CalibratorMixin, ChunkedIOMixin from columnflow.tasks.framework.remote import RemoteWorkflow +from columnflow.tasks.framework.decorators import on_failure from columnflow.tasks.external import GetDatasetLFNs from columnflow.util import maybe_import, ensure_proxy, dev_sandbox ak = maybe_import("awkward") -class CalibrateEvents( +class _CalibrateEvents( CalibratorMixin, ChunkedIOMixin, - DatasetTask, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`CalibrateEvents`. + """ + + +class CalibrateEvents(_CalibrateEvents): """ Task to apply calibrations to objects, e.g. leptons and jets. @@ -40,9 +46,7 @@ class CalibrateEvents( GetDatasetLFNs=GetDatasetLFNs, ) - # register sandbox and shifts found in the chosen calibrator to this task - register_calibrator_sandbox = True - register_calibrator_shifts = True + invokes_calibrator = True def workflow_requires(self) -> dict: """ @@ -56,7 +60,9 @@ def workflow_requires(self) -> dict: reqs["lfns"] = self.reqs.GetDatasetLFNs.req(self) # add calibrator dependent requirements - reqs["calibrator"] = law.util.make_unique(law.util.flatten(self.calibrator_inst.run_requires())) + reqs["calibrator"] = law.util.make_unique(law.util.flatten( + self.calibrator_inst.run_requires(task=self), + )) return reqs @@ -67,7 +73,9 @@ def requires(self) -> dict: reqs = {"lfns": self.reqs.GetDatasetLFNs.req(self)} # add calibrator dependent requirements - reqs["calibrator"] = law.util.make_unique(law.util.flatten(self.calibrator_inst.run_requires())) + reqs["calibrator"] = law.util.make_unique(law.util.flatten( + self.calibrator_inst.run_requires(task=self), + )) return reqs @@ -88,6 +96,7 @@ def output(self): @ensure_proxy @law.decorator.localize(input=False) @law.decorator.safe_output + @on_failure(callback=lambda task: task.teardown_calibrator_inst()) def run(self): """ Run method of this task. @@ -102,8 +111,13 @@ def run(self): output_chunks = {} # run the calibrator setup - calibrator_reqs = self.calibrator_inst.run_requires() - reader_targets = self.calibrator_inst.run_setup(calibrator_reqs, luigi.task.getpaths(calibrator_reqs)) + self._array_function_post_init() + calibrator_reqs = self.calibrator_inst.run_requires(task=self) + reader_targets = self.calibrator_inst.run_setup( + task=self, + reqs=calibrator_reqs, + inputs=luigi.task.getpaths(calibrator_reqs), + ) # create a temp dir for saving intermediate files tmp_dir = law.LocalDirectoryTarget(is_tmp=True) @@ -115,7 +129,7 @@ def run(self): # define columns that will be written write_columns = self.calibrator_inst.produced_columns - route_filter = RouteFilter(write_columns) + route_filter = RouteFilter(keep=write_columns) # let the lfn_task prepare the nano file (basically determine a good pfn) [(lfn_index, input_file)] = lfn_task.iter_nano_files(self) @@ -140,7 +154,7 @@ def run(self): events = update_ak_array(events, *cols) # just invoke the calibration function - events = self.calibrator_inst(events) + events = self.calibrator_inst(events, task=self) # remove columns events = route_filter(events) @@ -154,6 +168,9 @@ def run(self): output_chunks[(lfn_index, pos.index)] = chunk self.chunked_io.queue(sorted_ak_to_parquet, (events, chunk.abspath)) + # teardown the calibrator + self.teardown_calibrator_inst() + # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( diff --git a/columnflow/tasks/cms/external.py b/columnflow/tasks/cms/external.py index 7f360988e..03eb98220 100644 --- a/columnflow/tasks/cms/external.py +++ b/columnflow/tasks/cms/external.py @@ -20,7 +20,7 @@ class CreatePileupWeights(ConfigTask): - sandbox = "bash::$CF_BASE/sandboxes/cmssw_default.sh" + single_config = True data_mode = luigi.ChoiceParameter( default="hist", @@ -30,6 +30,8 @@ class CreatePileupWeights(ConfigTask): ) version = None + sandbox = "bash::$CF_BASE/sandboxes/cmssw_default.sh" + # upstream requirements reqs = Requirements( BundleExternalFiles=BundleExternalFiles, diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index 41e436a2c..f0bfae242 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -4,180 +4,30 @@ Tasks related to the creation of datacards for inference purposes. """ -from collections import OrderedDict, defaultdict +from __future__ import annotations import law import order as od -from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory -from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, InferenceModelMixin, - HistHookMixin, WeightProducerMixin, -) -from columnflow.tasks.framework.remote import RemoteWorkflow -from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms -from columnflow.util import dev_sandbox, DotDict -from columnflow.config_util import get_datasets_from_process - - -class CreateDatacards( - HistHookMixin, - InferenceModelMixin, - WeightProducerMixin, - MLModelsMixin, - ProducersMixin, - SelectorStepsMixin, - CalibratorsMixin, - law.LocalWorkflow, - RemoteWorkflow, -): - sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) - - # upstream requirements - reqs = Requirements( - RemoteWorkflow.reqs, - MergeHistograms=MergeHistograms, - MergeShiftedHistograms=MergeShiftedHistograms, - ) - - def create_branch_map(self): - return list(self.inference_model_inst.categories) - - def get_mc_datasets(self, proc_obj: dict) -> list[str]: - """ - Helper to find mc datasets. - - :param proc_obj: process object from an InferenceModel - :return: List of dataset names corresponding to the process *proc_obj*. - """ - # when datasets are defined on the process object itself, interpret them as patterns - if proc_obj.config_mc_datasets: - return [ - dataset.name - for dataset in self.config_inst.datasets - if ( - dataset.is_mc and - law.util.multi_match(dataset.name, proc_obj.config_mc_datasets, mode=any) - ) - ] - - # if the proc object is dynamic, it is calculated and the fly (e.g. via a hist hook) - # and doesn't have any additional requirements - if proc_obj.is_dynamic: - return [] - - # otherwise, check the config - return [ - dataset_inst.name - for dataset_inst in get_datasets_from_process(self.config_inst, proc_obj.config_process) - ] - - def get_data_datasets(self, cat_obj: dict) -> list[str]: - """ - Helper to find data datasets. +from columnflow.tasks.framework.base import AnalysisTask, wrapper_factory +from columnflow.tasks.framework.inference import SerializeInferenceModelBase +from columnflow.tasks.histograms import MergeHistograms - :param cat_obj: category object from an InferenceModel - :return: List of dataset names corresponding to the category *cat_obj*. - """ - if not cat_obj.config_data_datasets: - return [] - return [ - dataset.name - for dataset in self.config_inst.datasets - if ( - dataset.is_data and - law.util.multi_match(dataset.name, cat_obj.config_data_datasets, mode=any) - ) - ] +class CreateDatacards(SerializeInferenceModelBase): - def workflow_requires(self): - reqs = super().workflow_requires() - - # initialize defaultdict, mapping datasets to variables + shift_sources - mc_dataset_params = defaultdict(lambda: {"variables": set(), "shift_sources": set()}) - data_dataset_params = defaultdict(lambda: {"variables": set()}) - - for cat_obj in self.branch_map.values(): - for proc_obj in cat_obj.processes: - for dataset in self.get_mc_datasets(proc_obj): - # add all required variables and shifts per dataset - mc_dataset_params[dataset]["variables"].add(cat_obj.config_variable) - mc_dataset_params[dataset]["shift_sources"].update( - param_obj.config_shift_source - for param_obj in proc_obj.parameters - if self.inference_model_inst.require_shapes_for_parameter(param_obj) - ) - - for dataset in self.get_data_datasets(cat_obj): - data_dataset_params[dataset]["variables"].add(cat_obj.config_variable) - - # set workflow requirements per mc dataset - reqs["merged_hists"] = set( - self.reqs.MergeShiftedHistograms.req_different_branching( - self, - dataset=dataset, - shift_sources=tuple(params["shift_sources"]), - variables=tuple(params["variables"]), - ) - for dataset, params in mc_dataset_params.items() - ) - - # add workflow requirements per data dataset - for dataset, params in data_dataset_params.items(): - reqs["merged_hists"].add( - self.reqs.MergeHistograms.req_different_branching( - self, - dataset=dataset, - variables=tuple(params["variables"]), - ), - ) - - return reqs - - def requires(self): - cat_obj = self.branch_data - reqs = { - proc_obj.name: { - dataset: self.reqs.MergeShiftedHistograms.req_different_branching( - self, - dataset=dataset, - shift_sources=tuple( - param_obj.config_shift_source - for param_obj in proc_obj.parameters - if self.inference_model_inst.require_shapes_for_parameter(param_obj) - ), - variables=(cat_obj.config_variable,), - branch=-1, - workflow="local", - ) - for dataset in self.get_mc_datasets(proc_obj) - } - for proc_obj in cat_obj.processes - if not proc_obj.is_dynamic - } - if cat_obj.config_data_datasets: - reqs["data"] = { - dataset: self.reqs.MergeHistograms.req_different_branching( - self, - dataset=dataset, - variables=(cat_obj.config_variable,), - branch=-1, - workflow="local", - ) - for dataset in self.get_data_datasets(cat_obj) - } - - return reqs + resolution_task_cls = MergeHistograms def output(self): hooks_repr = self.hist_hooks_repr cat_obj = self.branch_data def basename(name: str, ext: str) -> str: - parts = [name, cat_obj.name, cat_obj.config_variable] + parts = [name, cat_obj.name] if hooks_repr: parts.append(f"hooks_{hooks_repr}") + if cat_obj.postfix is not None: + parts.append(cat_obj.postfix) return f"{'__'.join(map(str, parts))}.{ext}" return { @@ -189,113 +39,75 @@ def basename(name: str, ext: str) -> str: @law.decorator.safe_output def run(self): import hist - from columnflow.inference.cms.datacard import DatacardWriter + from columnflow.inference.cms.datacard import DatacardHists, ShiftHists, DatacardWriter # prepare inputs inputs = self.input() - # prepare config objects + # loop over all configs required by the datacard category and gather histograms cat_obj = self.branch_data - category_inst = self.config_inst.get_category(cat_obj.config_category) - variable_inst = self.config_inst.get_variable(cat_obj.config_variable) - leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] - - # histogram data per process - hists: dict[od.Process, hist.Hist] = dict() - - with self.publish_step(f"extracting {variable_inst.name} in {category_inst.name} ..."): - # loop over processes and forward them to any possible hist hooks - for proc_obj_name, inp in inputs.items(): - if proc_obj_name == "data": - # there is not process object for data - proc_obj = None - process_inst = self.config_inst.get_process("data") + datacard_hists: DatacardHists = {cat_obj.name: {}} + + # step 1: gather histograms per process for each config + input_hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} + for config_inst in self.config_insts: + # skip configs that are not required + if not cat_obj.config_data.get(config_inst.name): + continue + # load them + input_hists[config_inst] = self.load_process_hists(inputs, cat_obj, config_inst) + + # step 2: apply hist hooks + input_hists = self.invoke_hist_hooks(input_hists) + + # step 3: transform to nested histogram as expected by the datacard writer + for config_inst in input_hists.keys(): + config_data = cat_obj.config_data.get(config_inst.name) + + # determine leaf categories to gather + category_inst = config_inst.get_category(config_data.category) + leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] + + # start the transformation + proc_objs = list(cat_obj.processes) + if config_data.data_datasets and not cat_obj.data_from_processes: + proc_objs.append(self.inference_model_inst.process_spec(name="data")) + for proc_obj in proc_objs: + # get the corresponding process instance + if proc_obj.name == "data": + process_inst = config_inst.get_process("data") + elif config_inst.name in proc_obj.config_data: + process_inst = config_inst.get_process(proc_obj.config_data[config_inst.name].process) else: - proc_obj = self.inference_model_inst.get_process(proc_obj_name, category=cat_obj.name) - process_inst = self.config_inst.get_process(proc_obj.config_process) - sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] - - h_proc = None - for dataset, _inp in inp.items(): - dataset_inst = self.config_inst.get_dataset(dataset) - - # skip when the dataset is already known to not contain any sub process - if not any(map(dataset_inst.has_process, sub_process_insts)): - self.logger.warning( - f"dataset '{dataset}' does not contain process '{process_inst.name}' " - "or any of its subprocesses which indicates a misconfiguration in the " - f"inference model '{self.inference_model}'", - ) - continue - - # open the histogram and work on a copy - h = _inp["collection"][0]["hists"][variable_inst.name].load(formatter="pickle").copy() - - # axis selections - h = h[{ - "process": [ - hist.loc(p.id) - for p in sub_process_insts - if p.id in h.axes["process"] - ], - }] - - # axis reductions - h = h[{"process": sum}] - - # add the histogram for this dataset - if h_proc is None: - h_proc = h - else: - h_proc += h - - # there must be a histogram - if h_proc is None: - raise Exception(f"no histograms found for process '{process_inst.name}'") - - # save histograms in hist_hook format - hists[process_inst] = h_proc - - # apply hist hooks - hists = self.invoke_hist_hooks(hists) - - # define datacard processes to loop over - cat_processes = list(cat_obj.processes) - if cat_obj.config_data_datasets and not cat_obj.data_from_processes: - cat_processes.append(DotDict({"name": "data"})) - - # after application of hist hooks, we can proceed with the datacard creation - datacard_hists: OrderedDict[str, OrderedDict[str, hist.Hist]] = OrderedDict() - for proc_obj in cat_processes: - # obtain process information from inference model and config again - proc_name = "data" if proc_obj.name == "data" else proc_obj.config_process - process_inst = self.config_inst.get_process(proc_name) + # skip process objects that rely on data from a different config + continue - h_proc = hists.get(process_inst, None) - if h_proc is None: + # extract the histogram for the process + if not (h_proc := input_hists[config_inst].get(process_inst, None)): self.logger.warning( - f"found no histogram for process '{proc_obj.name}', please check your " + f"found no histogram to model datacard process '{proc_obj.name}', please check your " f"inference model '{self.inference_model}'", ) continue - # select relevant category + # select relevant categories h_proc = h_proc[{ "category": [ - hist.loc(c.id) + hist.loc(c.name) for c in leaf_category_insts - if c.id in h_proc.axes["category"] + if c.name in h_proc.axes["category"] ], }][{"category": sum}] # create the nominal hist - datacard_hists[proc_obj.name] = OrderedDict() - nominal_shift_inst = self.config_inst.get_shift("nominal") - datacard_hists[proc_obj.name]["nominal"] = h_proc[ - {"shift": hist.loc(nominal_shift_inst.id)} - ] + datacard_hists[cat_obj.name].setdefault(proc_obj.name, {}).setdefault(config_inst.name, {}) + shift_hists: ShiftHists = datacard_hists[cat_obj.name][proc_obj.name][config_inst.name] + shift_hists["nominal"] = h_proc[{ - # stop here for data + "shift": hist.loc(config_inst.get_shift("nominal").name), + }] + + # no additional shifts need to be created for data if proc_obj.name == "data": continue @@ -305,18 +117,17 @@ def run(self): if not self.inference_model_inst.require_shapes_for_parameter(param_obj): continue # store the varied hists - datacard_hists[proc_obj.name][param_obj.name] = {} + shift_source = param_obj.config_data[config_inst.name].shift_source for d in ["up", "down"]: - shift_inst = self.config_inst.get_shift(f"{param_obj.config_shift_source}_{d}") - datacard_hists[proc_obj.name][param_obj.name][d] = h_proc[ - {"shift": hist.loc(shift_inst.id)} - ] - - # forward objects to the datacard writer - outputs = self.output() - writer = DatacardWriter(self.inference_model_inst, {cat_obj.name: datacard_hists}) - with outputs["card"].localize("w") as tmp_card, outputs["shapes"].localize("w") as tmp_shapes: - writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outputs["shapes"].basename) + shift_hists[(param_obj.name, d)] = h_proc[{ + "shift": hist.loc(config_inst.get_shift(f"{shift_source}_{d}").name), + }] + + # forward objects to the datacard writer + outputs = self.output() + writer = DatacardWriter(self.inference_model_inst, datacard_hists) + with outputs["card"].localize("w") as tmp_card, outputs["shapes"].localize("w") as tmp_shapes: + writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outputs["shapes"].basename) CreateDatacardsWrapper = wrapper_factory( diff --git a/columnflow/tasks/cutflow.py b/columnflow/tasks/cutflow.py index 89c522493..2aaf0aa53 100644 --- a/columnflow/tasks/cutflow.py +++ b/columnflow/tasks/cutflow.py @@ -12,10 +12,12 @@ import order as od from columnflow.tasks.framework.base import ( - Requirements, AnalysisTask, DatasetTask, ShiftTask, wrapper_factory, RESOLVE_DEFAULT, + Requirements, AnalysisTask, ShiftTask, wrapper_factory, RESOLVE_DEFAULT, ) from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, VariablesMixin, CategoriesMixin, ChunkedIOMixin, + CalibratorsMixin, SelectorMixin, VariablesMixin, CategoriesMixin, ChunkedIOMixin, + DatasetsProcessesMixin, + CalibratorClassesMixin, SelectorClassMixin, ) from columnflow.tasks.framework.plotting import ( PlotBase, PlotBase1D, PlotBase2D, ProcessPlotSettingMixin, VariablePlotSettingMixin, @@ -25,18 +27,24 @@ from columnflow.tasks.framework.parameters import last_edge_inclusive_inst from columnflow.tasks.selection import MergeSelectionMasks from columnflow.util import DotDict, dev_sandbox -from columnflow.hist_util import create_hist_from_variables +from columnflow.hist_util import create_columnflow_hist, translate_hist_intcat_to_strcat -class CreateCutflowHistograms( - VariablesMixin, - SelectorStepsMixin, +class _CreateCutflowHistograms( CalibratorsMixin, + SelectorMixin, ChunkedIOMixin, - DatasetTask, + VariablesMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`CreateCutflowHistograms`. + """ + + +class CreateCutflowHistograms(_CreateCutflowHistograms): + # overwrite selector steps to use default resolution selector_steps = law.CSVParameter( default=(RESOLVE_DEFAULT,), @@ -45,21 +53,25 @@ class CreateCutflowHistograms( brace_expand=True, parse_empty=True, ) - - steps_variable = od.Variable( - name="step", - aux={"axis_type": "strcategory"}, + missing_selector_step_strategy = luigi.ChoiceParameter( + significant=False, + default=law.config.get_default("analysis", "missing_selector_step_strategy", "raise"), + choices=("raise", "ignore", "dummy"), + description="how to handle selector steps that are not defined by the selector; if " + "'raise', an exception will be thrown; if 'ignore', the selector step will be ignored; if " + "'dummy' the output histogram will contain an entry for the step identical to the previous " + "one; the default can be configured via the law config entry " + "*missing_selector_step_strategy* in the *analysis* section; if no default is specified " + "there, 'raise' is assumed", ) + steps_variable = od.Variable(name="step", aux={"axis_type": "strcategory"}) last_edge_inclusive = last_edge_inclusive_inst - sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) - selector_steps_order_sensitive = True - initial_step = "Initial" - default_variables = ("event", "cf_*") + missing_column_alias_strategy = "original" # upstream requirements reqs = Requirements( @@ -67,31 +79,13 @@ class CreateCutflowHistograms( MergeSelectionMasks=MergeSelectionMasks, ) - # strategy for handling missing source columns when adding aliases on event chunks - missing_column_alias_strategy = "original" - - # strategy for handling selector steps not defined by selectors - missing_selector_step_strategy = luigi.ChoiceParameter( - significant=False, - default=law.config.get_default("analysis", "missing_selector_step_strategy", "raise"), - choices=("raise", "ignore", "dummy"), - description="how to handle selector steps that are not defined by the selector; if " - "'raise', an exception will be thrown; if 'ignore', the selector step will be ignored; if " - "'dummy' the output histogram will contain an entry for the step identical to the previous " - "one; the default can be configured via the law config entry " - "*missing_selector_step_strategy* in the *analysis* section; if no default is specified " - "there, 'raise' is assumed", - ) - def create_branch_map(self): # dummy branch map return [None] def workflow_requires(self): reqs = super().workflow_requires() - reqs["selection"] = self.reqs.MergeSelectionMasks.req(self, tree_index=0, _exclude={"branches"}) - return reqs def requires(self): @@ -118,6 +112,12 @@ def run(self): # prepare inputs and outputs inputs = self.input() + # get IDs and names of all leaf categories + leaf_category_map = { + cat.id: cat.name + for cat in self.config_inst.get_leaf_categories() + } + # create a temp dir for saving intermediate files tmp_dir = law.LocalDirectoryTarget(is_tmp=True) tmp_dir.touch() @@ -154,7 +154,7 @@ def expr(events): read_columns.add(route) else: # for variable_inst with custom expressions, read columns declared via aux key - read_columns |= set(variable_inst.x("inputs", [])) + read_columns |= {Route(inp) for inp in variable_inst.x("inputs", [])} expressions[variable_inst.name] = expr # prepare columns to load @@ -168,10 +168,9 @@ def prepare_hists(steps): # create histogram of not already existing if var_key not in histograms: - histograms[var_key] = create_hist_from_variables( + histograms[var_key] = create_columnflow_hist( self.steps_variable, *variable_insts, - int_cat_axes=("category", "process", "shift"), ) for arr, pos in self.iter_chunked_io( @@ -199,6 +198,14 @@ def prepare_hists(steps): # pad the category_ids when the event is not categorized at all category_ids = ak.fill_none(ak.pad_none(events.category_ids, 1, axis=-1), -1) + unique_category_ids = np.unique(ak.flatten(category_ids)) + if any(cat_id not in leaf_category_map for cat_id in unique_category_ids): + undefined_category_ids = list(map(str, set(unique_category_ids) - set(leaf_category_map))) + raise ValueError( + f"category_ids column contains ids {','.join(undefined_category_ids)} that are either not known to " + "the config at all, or not as leaf categories (i.e., they have child categories); please ensure " + "that category_ids only contains ids of known leaf categories", + ) for var_key, var_names in self.variable_tuples.items(): # helper to build the point for filling, except for the step which does @@ -210,7 +217,6 @@ def get_point(mask=Ellipsis): point = { "process": events.process_id[mask], "category": category_ids[mask], - "shift": np.ones(n_events, dtype=np.int32) * self.global_shift_inst.id, "weight": ( events.normalization_weight[mask] if self.dataset_inst.is_mc @@ -226,7 +232,10 @@ def get_point(mask=Ellipsis): fill_hist( histograms[var_key], fill_data, - fill_kwargs={"step": self.initial_step}, + fill_kwargs={ + "shift": self.global_shift_inst.name, + "step": self.initial_step, + }, last_edge_inclusive=self.last_edge_inclusive, ) @@ -249,10 +258,24 @@ def get_point(mask=Ellipsis): fill_hist( histograms[var_key], fill_data, - fill_kwargs={"step": step}, + fill_kwargs={ + "shift": self.global_shift_inst.name, + "step": step, + }, last_edge_inclusive=self.last_edge_inclusive, ) + # change some axes from int to str + for var_key in self.variable_tuples.keys(): + # category + histograms[var_key] = translate_hist_intcat_to_strcat(histograms[var_key], "category", leaf_category_map) + # process + process_map = { + proc_id: self.config_inst.get_process(proc_id).name + for proc_id in histograms[var_key].axes["process"] + } + histograms[var_key] = translate_hist_intcat_to_strcat(histograms[var_key], "process", process_map) + # dump the histograms for var_key in histograms.keys(): self.output()[var_key].dump(histograms[var_key], formatter="pickle") @@ -265,21 +288,27 @@ def get_point(mask=Ellipsis): ) -class PlotCutflowBase( - SelectorStepsMixin, +class _PlotCutflowBase( + ShiftTask, + CalibratorClassesMixin, + SelectorClassMixin, CategoriesMixin, - CalibratorsMixin, PlotBase, - ShiftTask, law.LocalWorkflow, RemoteWorkflow, +): + resolution_task_cls = CreateCutflowHistograms + single_config = True + + +class PlotCutflowBase( + _PlotCutflowBase, ): selector_steps = CreateCutflowHistograms.selector_steps sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) exclude_index = True - selector_steps_order_sensitive = True # upstream requirements @@ -288,19 +317,29 @@ class PlotCutflowBase( CreateCutflowHistograms=CreateCutflowHistograms, ) - def store_parts(self): + def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() - parts.insert_before("version", "plot", f"datasets_{self.datasets_repr}") + parts.insert_after(self.config_store_anchor, "plot", f"datasets_{self.datasets_repr}") return parts + def get_plot_shifts(self): + return [self.config_inst.get_shift(self.shift)] + -class PlotCutflow( +class _PlotCutflow( PlotCutflowBase, PlotBase1D, ProcessPlotSettingMixin, + DatasetsProcessesMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`PlotCutflow`. + """ + + +class PlotCutflow(_PlotCutflow): plot_function = PlotBase.plot_function.copy( default="columnflow.plotting.plot_functions_1d.plot_cutflow", add_default_to_description=True, @@ -400,7 +439,7 @@ def run(self): # select shift if (n_shifts := len(h_in.axes["shift"])) != 1: raise Exception(f"shift axis is supposed to only contain 1 bin, found {n_shifts}") - h_in = h_in[{"shift": hist.loc(self.global_shift_inst.id)}] + h_in = h_in[{"shift": hist.loc(self.global_shift_inst.name)}] # loop and extract one histogram per process for process_inst in process_insts: @@ -415,9 +454,9 @@ def run(self): h = h_in.copy() h = h[{ "process": [ - hist.loc(p.id) + hist.loc(p.name) for p in sub_process_insts[process_inst] - if p.id in h.axes["process"] + if p.name in h.axes["process"] ], }] h = h[{"process": sum}] @@ -439,9 +478,9 @@ def run(self): # selections h = h[{ "category": [ - hist.loc(c.id) + hist.loc(c.name) for c in leaf_category_insts - if c.id in h.axes["category"] + if c.name in h.axes["category"] ], }] # reductions @@ -456,6 +495,7 @@ def run(self): hists=hists, config_inst=self.config_inst, category_inst=category_inst.copy_shallow(), + shift_insts=self.get_plot_shifts(), **self.get_plot_parameters(), ) @@ -472,9 +512,10 @@ def run(self): class PlotCutflowVariablesBase( + PlotCutflowBase, VariablePlotSettingMixin, ProcessPlotSettingMixin, - PlotCutflowBase, + DatasetsProcessesMixin, law.LocalWorkflow, RemoteWorkflow, ): @@ -584,7 +625,7 @@ def run(self): # select shift if (n_shifts := len(h_in.axes["shift"])) != 1: raise Exception(f"shift axis is supposed to only contain 1 bin, found {n_shifts}") - h_in = h_in[{"shift": hist.loc(self.global_shift_inst.id)}] + h_in = h_in[{"shift": hist.loc(self.global_shift_inst.name)}] # loop and extract one histogram per process for process_inst in process_insts: @@ -599,9 +640,9 @@ def run(self): h = h_in.copy() h = h[{ "process": [ - hist.loc(p.id) + hist.loc(p.name) for p in sub_process_insts[process_inst] - if p.id in h.axes["process"] + if p.name in h.axes["process"] ], }] h = h[{"process": sum}] @@ -623,9 +664,9 @@ def run(self): # selections h = h[{ "category": [ - hist.loc(c.id) + hist.loc(c.name) for c in leaf_category_insts - if c.id in h.axes["category"] + if c.name in h.axes["category"] ], }] # reductions @@ -642,10 +683,17 @@ def run(self): ) -class PlotCutflowVariables1D( +class _PlotCutflowVariables1D( PlotCutflowVariablesBase, PlotBase1D, ): + """ + Base classes for :py:class:`PlotCutflowVariables1D`. + """ + + +class PlotCutflowVariables1D(_PlotCutflowVariables1D): + plot_function = PlotBase.plot_function.copy( default=law.NO_STR, description=PlotBase.plot_function.description + "; the default is resolved based on the " @@ -721,6 +769,7 @@ def run_postprocess(self, hists, category_inst, variable_insts): config_inst=self.config_inst, category_inst=category_inst.copy_shallow(), variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts], + shift_insts=self.get_plot_shifts(), style_config={"legend_cfg": {"title": f"Step '{step}'"}}, **self.get_plot_parameters(), ) @@ -745,6 +794,7 @@ def run_postprocess(self, hists, category_inst, variable_insts): config_inst=self.config_inst, category_inst=category_inst.copy_shallow(), variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts], + shift_insts=self.get_plot_shifts(), style_config={"legend_cfg": {"title": process_inst.label}}, **self.get_plot_parameters(), ) @@ -754,10 +804,17 @@ def run_postprocess(self, hists, category_inst, variable_insts): outp.dump(fig, formatter="mpl") -class PlotCutflowVariables2D( +class _PlotCutflowVariables2D( PlotCutflowVariablesBase, PlotBase2D, ): + """ + Base classes for :py:class:`PlotCutflowVariables2D`. + """ + + +class PlotCutflowVariables2D(_PlotCutflowVariables2D): + plot_function = PlotBase.plot_function.copy( default="columnflow.plotting.plot_functions_2d.plot_2d", add_default_to_description=True, @@ -796,6 +853,7 @@ def run_postprocess(self, hists, category_inst, variable_insts): config_inst=self.config_inst, category_inst=category_inst.copy_shallow(), variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts], + shift_insts=self.get_plot_shifts(), style_config={"legend_cfg": {"title": f"Step '{step}'"}}, **self.get_plot_parameters(), ) @@ -806,8 +864,8 @@ def run_postprocess(self, hists, category_inst, variable_insts): class PlotCutflowVariablesPerProcess2D( - law.WrapperTask, PlotCutflowVariables2D, + law.WrapperTask, ): # force this one to be a local workflow workflow = "local" diff --git a/columnflow/tasks/external.py b/columnflow/tasks/external.py index ee2183c1c..a0ba1948c 100644 --- a/columnflow/tasks/external.py +++ b/columnflow/tasks/external.py @@ -10,15 +10,17 @@ import time import shutil import subprocess +from dataclasses import dataclass, field import luigi import law import order as od -from columnflow.types import Sequence from columnflow.tasks.framework.base import AnalysisTask, ConfigTask, DatasetTask, wrapper_factory from columnflow.tasks.framework.parameters import user_parameter_inst +from columnflow.tasks.framework.decorators import only_local_env from columnflow.util import wget, DotDict +from columnflow.types import Sequence logger = law.logger.get_logger(__name__) @@ -86,6 +88,7 @@ def single_output(self) -> law.target.file.FileSystemFileTarget: h = law.util.create_hash(list(sorted(self.dataset_info_inst.keys))) return self.target(f"lfns_{h}.json") + @only_local_env @law.decorator.notify @law.decorator.log def run(self): @@ -365,38 +368,75 @@ def _fetch_lfn_fallback( ) +@dataclass +class ExternalFile: + """ + Container object to define an external file resource that is understood by (e.g.) + :py:class:`tasks.external.BundleExternalFiles`. Example: + + .. code-block:: python + + # refer to a simple file location + ExternalFile(location="path/to/file", version="v1") + + # refer to a directory or archive that contains multiple files + ExternalFile(location="some/archive.tgz", subpaths={"file_name": "file/in/archive"}, version="v1") + """ + + location: str + subpaths: dict[str, str] = field(default_factory=str) + version: str = "v1" + + def __str__(self) -> str: + sub = (" / " + ",".join(f"{n}={p}" for n, p in self.subpaths.items())) if self.subpaths else "" + return f"{self.location}{sub} ({self.version})" + + @classmethod + def new(cls, resource: ExternalFile | str | tuple[str] | tuple[str, str]) -> ExternalFile: + """ + Factory method to create a new instance of :py:class:`ExternalFile` with backwards-compatible parsing of simple + strings and tuples. + """ + if isinstance(resource, cls): + return resource + if isinstance(resource, str): + return cls(location=resource) + if isinstance(resource, (list, tuple)): + if len(resource) == 1: + return cls(location=resource[0]) + if len(resource) == 2: + return cls(location=resource[0], version=resource[1]) + raise ValueError(f"invalid resource type and format: {resource}") + + class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): """ Task to collect external files. - This task is intended to download source files for other tasks, such as files containing - corrections for objects, the "golden" json files, source files for the calculation of pileup - weights, and others. + This task is intended to download source files for other tasks, such as files containing corrections for objects, + the "golden" json files, source files for the calculation of pileup weights, and others. - All information about the relevant external files is extracted from the given ``config_inst``, - which must contain the keyword ``external_files`` in the auxiliary information. This can look - like this: + All information about the relevant external files is extracted from the given ``config_inst``, which must contain + the keyword ``external_files`` in the auxiliary information. This can look like this: .. code-block:: python # cfg is the current config instance cfg.x.external_files = DotDict.wrap({ - # The following assumes that the zip files are reachable under the - # url ``SOURCE_URL`` - # jet energy correction - "jet_jerc": (f"{SOURCE_URL}/POG/JME/{year}{corr_postfix}_UL/jet_jerc.json.gz", "v1"), + "jet_jerc": ExternalFile(f"{SOURCE_URL}/POG/JME/{year}{corr_postfix}_UL/jet_jerc.json.gz", version="v1"), - # tau energy correction and scale factors - "tau_sf": (f"{SOURCE_URL}/POG/TAU/{year}{corr_postfix}_UL/tau.json.gz", "v1"), + # tau energy correction and scale factors + "tau_sf": ExternalFile(f"{SOURCE_URL}/POG/TAU/{year}{corr_postfix}_UL/tau.json.gz", version="v1"), - # electron scale factors - "electron_sf": (f"{SOURCE_URL}/POG/EGM/{year}{corr_postfix}_UL/electron.json.gz", "v1"), + # electron scale factors + "electron_sf": ExternalFile(f"{SOURCE_URL}/POG/EGM/{year}{corr_postfix}_UL/electron.json.gz", version="v1"), + }) - The entries in this DotDict can either be simply the path to the source files or can be a tuple - of the format ``(path/or/url/to/source/file, VERSION)`` to introduce a versioning mechanism for - external files. + The entries in this DotDict should be :py:class:`ExternalFile` instances. """ + single_config = True + replicas = luigi.IntParameter( default=5, description="number of replicas to generate; default: 5", @@ -407,6 +447,9 @@ class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # external files, casted to ExternalFile instances once + self.ext_files = law.util.map_struct(ExternalFile.new, self.config_inst.x.external_files) + # cached hash self._files_hash = None @@ -418,16 +461,25 @@ def __init__(self, *args, **kwargs): self._files = None @classmethod - def create_unique_basename(cls, path: tuple[str] | str) -> str: + def create_unique_basename(cls, path: str | ExternalFile) -> str | dict[str, str]: """ - Create a unique basename. + Create a unique basename for a given path. When *path* is an :py:class:`ExternalFile` with one or more subpaths + defined, a dictionary mapping subpaths to unique basenames is returned. - :param path: path to create a unique basename for - :return: Unique basename + :param path: path or external file object. + :return: Unique basename(s). """ - h = law.util.create_hash(path) - basename = os.path.basename(path[0] if isinstance(path, tuple) else path) - return f"{h}_{basename}" + if isinstance(path, str): + return f"{law.util.create_hash(path)}_{os.path.basename(path)}" + + # path must be an ExternalFile + if path.subpaths: + return type(path.subpaths)( + (name, cls.create_unique_basename(os.path.join(path.location, subpath))) + for name, subpath in path.subpaths.items() + ) + + return cls.create_unique_basename(path.location) @property def files_hash(self) -> str: @@ -443,7 +495,7 @@ def deterministic_flatten(d): (key, (deterministic_flatten(d[key]) if isinstance(d[key], dict) else d[key])) for key in sorted(d) ] - flat_files = deterministic_flatten(self.config_inst.x.external_files) + flat_files = deterministic_flatten(self.ext_files) self._files_hash = law.util.create_hash(flat_files) return self._files_hash @@ -456,10 +508,7 @@ def file_names(self) -> DotDict: :return: DotDict of same shape as ``external_files`` DotDict with unique basenames. """ if self._file_names is None: - self._file_names = law.util.map_struct( - self.create_unique_basename, - self.config_inst.x.external_files, - ) + self._file_names = law.util.map_struct(self.create_unique_basename, self.ext_files) return self._file_names @@ -478,9 +527,9 @@ def get_files(self, output=None): self.files_dir = law.LocalDirectoryTarget(is_tmp=True) output.load(self.files_dir, formatter="tar") - # resolve basenames in the bundle directory and map to file targets + # resolve basenames in the bundle directory and map to local targets def resolve_basename(unique_basename): - return self.files_dir.child(unique_basename, type="f") + return self.files_dir.child(unique_basename) self._files = law.util.map_struct(resolve_basename, self.file_names) @@ -494,6 +543,7 @@ def single_output(self): # required by law.tasks.TransferLocalFile return self.target(f"externals_{self.files_hash}.tgz") + @only_local_env @law.decorator.notify @law.decorator.log @law.decorator.safe_output @@ -502,27 +552,59 @@ def run(self): tmp_dir = law.LocalDirectoryTarget(is_tmp=True) tmp_dir.touch() + # create a scratch directory for temporary downloads that will not be bundled + scratch_dir = tmp_dir.child("scratch", type="d") + scratch_dir.touch() + # progress callback - n_files = len(law.util.flatten(self.config_inst.x.external_files)) - progress = self.create_progress_callback(n_files) + progress = self.create_progress_callback(len(law.util.flatten(self.ext_files))) - # helper function to fetch generic files - def fetch_file(src, counter=[0]): - dst = os.path.join(tmp_dir.abspath, self.create_unique_basename(src)) - src = src[0] if isinstance(src, tuple) else src + # helper to fetch a single src to dst + def fetch(src, dst): if src.startswith(("http://", "https://")): # download via wget wget(src, dst) - else: - # must be a local file + elif os.path.isfile(src): + # copy local file shutil.copy2(src, dst) + elif os.path.isdir(src): + # copy local dir + shutil.copytree(src, dst) + else: + raise NotImplementedError(f"fetching {src} is not supported") + + # helper function to fetch generic files + def fetch_file(ext_file, counter=[0]): + if ext_file.subpaths: + # copy to scratch dir first in case a subpath is requested + basename = self.create_unique_basename(ext_file.location) + scratch_dst = os.path.join(scratch_dir.abspath, basename) + fetch(ext_file.location, scratch_dst) + # when not a directory, assume the file is an archive and unpack it + if not os.path.isdir(scratch_dst): + arc_dir = scratch_dir.child(basename.split(".")[0] + "_unpacked", type="d") + self.publish_message(f"unpacking {scratch_dst}") + law.LocalFileTarget(scratch_dst).load(arc_dir) + scratch_src = arc_dir.abspath + else: + scratch_src = scratch_dst + # copy all subpaths + basenames = self.create_unique_basename(ext_file) + for name, subpath in ext_file.subpaths.items(): + fetch(os.path.join(scratch_src, subpath), os.path.join(tmp_dir.abspath, basenames[name])) + else: + # copy directly to the bundle dir + src = ext_file.location + dst = os.path.join(tmp_dir.abspath, self.create_unique_basename(ext_file.location)) + fetch(src, dst) # log - self.publish_message(f"fetched {src}") + self.publish_message(f"fetched {ext_file}") progress(counter[0]) counter[0] += 1 - # fetch all files - law.util.map_struct(fetch_file, self.config_inst.x.external_files) + # fetch all files and cleanup scratch dir + law.util.map_struct(fetch_file, self.ext_files) + scratch_dir.remove() # create the bundle tmp = law.LocalFileTarget(is_tmp="tgz") diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index f5d352ea7..b6128196f 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -7,6 +7,7 @@ from __future__ import annotations import os +import abc import enum import importlib import itertools @@ -14,36 +15,42 @@ import functools import collections import copy +import subprocess +from dataclasses import dataclass, field import luigi import law import order as od from columnflow.columnar_util import mandatory_coffea_columns, Route, ColumnCollection -from columnflow.util import is_regex, DotDict +from columnflow.util import is_regex, prettify, DotDict from columnflow.types import Sequence, Callable, Any, T logger = law.logger.get_logger(__name__) +logger_dev = law.logger.get_logger(f"{__name__}-dev") # default analysis and config related objects default_analysis = law.config.get_expanded("analysis", "default_analysis") default_config = law.config.get_expanded("analysis", "default_config") default_dataset = law.config.get_expanded("analysis", "default_dataset") +default_repr_max_len = law.config.get_expanded_int("analysis", "repr_max_len") +default_repr_max_count = law.config.get_expanded_int("analysis", "repr_max_count") +default_repr_hash_len = law.config.get_expanded_int("analysis", "repr_hash_len") # placeholder to denote a default value that is resolved dynamically RESOLVE_DEFAULT = "DEFAULT" class Requirements(DotDict): - """General class for requirements of different tasks. + """ + Container for task-level requirements of different tasks. - Can be initialized with other :py:class:`~columnflow.util.DotDict` - instances and additional keyword arguments ``kwargs``, which are - added. - """ - def __init__(self, *others, **kwargs): + Can be initialized with other :py:class:`Requirement` instances and additional keyword arguments ``kwargs``, + which are added. + """ + def __init__(self, *others, **kwargs) -> None: super().__init__() # add others and kwargs @@ -51,14 +58,6 @@ def __init__(self, *others, **kwargs): self.update(reqs) -class BaseTask(law.Task): - - task_namespace = law.config.get_expanded("analysis", "cf_task_namespace", "cf") - - # container for upstream requirements for convenience - reqs = Requirements() - - class OutputLocation(enum.Enum): """ Output location flag. @@ -70,6 +69,33 @@ class OutputLocation(enum.Enum): wlcg_mirrored = "wlcg_mirrored" +@dataclass +class TaskShifts: + """ + Container for *local* and *upstream* shifts at a point in the task graph. + """ + # NOTE: maybe these should be a dict of sets (one set per config) to allow for different shifts + # per config + + local: set[str] = field(default_factory=set) + upstream: set[str] = field(default_factory=set) + + +class BaseTask(law.Task): + + task_namespace = law.config.get_expanded("analysis", "cf_task_namespace", "cf") + + # container for upstream requirements for convenience + reqs = Requirements() + + def get_params_dict(self) -> dict[str, Any]: + return { + attr: getattr(self, attr) + for attr, param in self.get_params() + if isinstance(param, luigi.Parameter) + } + + class AnalysisTask(BaseTask, law.SandboxTask): analysis = luigi.Parameter( @@ -107,13 +133,13 @@ class AnalysisTask(BaseTask, law.SandboxTask): _cfg_resources_dict = None @classmethod - def modify_param_values(cls, params: dict) -> dict: + def modify_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().modify_param_values(params) params = cls.resolve_param_values(params) return params @classmethod - def resolve_param_values(cls, params: dict) -> dict: + def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: # store a reference to the analysis inst if "analysis_inst" not in params and "analysis" in params: params["analysis_inst"] = cls.get_analysis_inst(params["analysis"]) @@ -141,7 +167,7 @@ def get_analysis_inst(cls, analysis: str) -> od.Analysis: return analysis_inst @classmethod - def req_params(cls, inst: AnalysisTask, **kwargs) -> dict: + def req_params(cls, inst: AnalysisTask, **kwargs) -> dict[str, Any]: """ Returns parameters that are jointly defined in this class and another task instance of some other class. The parameters are used when calling ``Task.req(self)``. @@ -163,7 +189,7 @@ def req_params(cls, inst: AnalysisTask, **kwargs) -> dict: isinstance(getattr(cls, "version", None), luigi.Parameter) and "version" not in kwargs and not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") and - cls.task_family != law.parser.root_task().task_family + cls.task_family != law.parser.root_task_cls().task_family ): default_version = cls.get_default_version(inst, params) if default_version and default_version != law.NO_STR: @@ -172,7 +198,7 @@ def req_params(cls, inst: AnalysisTask, **kwargs) -> dict: return params @classmethod - def _structure_cfg_items(cls, items: list[tuple[str, Any]]) -> dict: + def _structure_cfg_items(cls, items: list[tuple[str, Any]]) -> dict[str, Any]: if not items: return {} @@ -207,7 +233,7 @@ def _structure_cfg_items(cls, items: list[tuple[str, Any]]) -> dict: return items_dict @classmethod - def _get_cfg_outputs_dict(cls): + def _get_cfg_outputs_dict(cls) -> dict[str, Any]: if cls._cfg_outputs_dict is None and law.config.has_section("outputs"): # collect config item pairs skip_keys = {"wlcg_file_systems", "lfn_sources"} @@ -221,7 +247,7 @@ def _get_cfg_outputs_dict(cls): return cls._cfg_outputs_dict @classmethod - def _get_cfg_versions_dict(cls): + def _get_cfg_versions_dict(cls) -> dict[str, Any]: if cls._cfg_versions_dict is None and law.config.has_section("versions"): # collect config item pairs items = [ @@ -234,7 +260,7 @@ def _get_cfg_versions_dict(cls): return cls._cfg_versions_dict @classmethod - def _get_cfg_resources_dict(cls): + def _get_cfg_resources_dict(cls) -> dict[str, Any]: if cls._cfg_resources_dict is None and law.config.has_section("resources"): # helper to split resource values into key-value pairs themselves def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: @@ -257,7 +283,7 @@ def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: items = [ parse(key, value) for key, value in law.config.items("resources") - if value + if value and not key.startswith("_") ] cls._cfg_resources_dict = cls._structure_cfg_items(items) @@ -266,8 +292,8 @@ def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: @classmethod def get_default_version(cls, inst: AnalysisTask, params: dict[str, Any]) -> str | None: """ - Determines the default version for instances of *this* task class when created through - :py:meth:`req` from another task *inst* given parameters *params*. + Determines the default version for instances of *this* task class when created through :py:meth:`req` from + another task *inst* given parameters *params*. :param inst: The task instance from which *this* task should be created via :py:meth:`req`. :param params: The parameters that are passed to the task instance. @@ -290,7 +316,7 @@ def _get_default_version( keys: law.util.InsertableDict, ) -> str | None: # try to lookup the version in the analysis's auxiliary data - analysis_inst = params.get("analysis_inst") or getattr(inst, "analysis_inst", None) + analysis_inst = getattr(inst, "analysis_inst", None) if analysis_inst: version = cls._dfs_key_lookup(keys, analysis_inst.x("versions", {})) if version: @@ -313,22 +339,20 @@ def get_config_lookup_keys( inst_or_params: AnalysisTask | dict[str, Any], ) -> law.util.InsertiableDict: """ - Returns a dictionary with keys that can be used to lookup state specific values in a config - or dictionary, such as default task versions or output locations. + Returns a dictionary with keys that can be used to lookup state specific values in a config or dictionary, such + as default task versions or output locations. :param inst_or_params: The tasks instance or its parameters. :return: A dictionary with keys that can be used for nested lookup. """ keys = law.util.InsertableDict() - get = ( - inst_or_params.get + # add the analysis name + analysis = ( + inst_or_params.get("analysis") if isinstance(inst_or_params, dict) - else lambda attr: (getattr(inst_or_params, attr, None)) + else getattr(inst_or_params, "analysis", None) ) - - # add the analysis name - analysis = get("analysis") if analysis not in {law.NO_STR, None, ""}: keys["analysis"] = analysis @@ -340,7 +364,7 @@ def get_config_lookup_keys( @classmethod def _dfs_key_lookup( cls, - keys: law.util.InsertableDict[str, str] | Sequence[str], + keys: law.util.InsertableDict[str, str | Sequence[str]] | Sequence[str | Sequence[str]], nested_dict: dict[str, Any], empty_value: Any = None, ) -> str | Callable | None: @@ -350,12 +374,12 @@ def _dfs_key_lookup( if not nested_dict: return empty_value - # the keys to use for the lookup are the values of the keys dict - keys = collections.deque(keys.values() if isinstance(keys, dict) else keys) + # the keys to use for the lookup are the flattened values of the keys dict + flat_keys = collections.deque(law.util.flatten(keys.values() if isinstance(keys, dict) else keys)) # start tree traversal using a queue lookup consisting of names and values of tree nodes, # as well as the remaining keys (as a deferred function) to compare for that particular path - lookup = collections.deque([tpl + ((lambda: keys.copy()),) for tpl in nested_dict.items()]) + lookup = collections.deque([tpl + ((lambda: flat_keys.copy()),) for tpl in nested_dict.items()]) while lookup: pattern, obj, keys_func = lookup.popleft() @@ -380,105 +404,107 @@ def _dfs_key_lookup( return empty_value @classmethod - def get_known_shifts(cls, config_inst: od.Config, params: dict) -> tuple[set[str], set[str]]: - """ - Returns two sets of shifts in a tuple: shifts implemented by _this_ task, and dependent - shifts that are implemented by upstream tasks. - """ - # get shifts from upstream dependencies, consider both their own and upstream shifts as one - upstream_shifts = set() - for req in cls.reqs.values(): - upstream_shifts |= set.union(*(req.get_known_shifts(config_inst, params) or (set(),))) - - return set(), upstream_shifts - - @classmethod - def get_array_function_kwargs( - cls, - task: AnalysisTask | None = None, - **params, - ) -> dict[str, Any]: - if task: - analysis_inst = task.analysis_inst - elif "analysis_inst" in params: + def get_array_function_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + if "analysis_inst" in params: analysis_inst = params["analysis_inst"] else: analysis_inst = cls.get_analysis_inst(params["analysis"]) - return { - "task": task, - "analysis_inst": analysis_inst, - } - - @classmethod - def get_calibrator_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - # implemented here only for simplified mro control - return cls.get_array_function_kwargs(*args, **kwargs) - - @classmethod - def get_selector_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - # implemented here only for simplified mro control - return cls.get_array_function_kwargs(*args, **kwargs) - - @classmethod - def get_producer_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - # implemented here only for simplified mro control - return cls.get_array_function_kwargs(*args, **kwargs) - - @classmethod - def get_weight_producer_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - # implemented here only for simplified mro control - return cls.get_array_function_kwargs(*args, **kwargs) + return {"analysis_inst": analysis_inst} @classmethod def find_config_objects( cls, names: str | Sequence[str] | set[str], - container: od.UniqueObject, + container: od.UniqueObject | Sequence[od.UniqueObject], object_cls: od.UniqueObjectMeta, - object_groups: dict[str, list] | None = None, + groups_str: str | None = None, accept_patterns: bool = True, deep: bool = False, strict: bool = False, - ) -> list[str]: + multi_strategy: str = "first", + ) -> list[str] | dict[od.UniqueObject, list[str]]: """ - Returns all names of objects of type *object_cls* known to a *container* (e.g. - :py:class:`od.Analysis` or :py:class:`od.Config`) that match *names*. A name can also be a - pattern to match if *accept_patterns* is *True*, or, when given, the key of a mapping - *object_group* that matches group names to object names. When *deep* is *True* the lookup of - objects in the *container* is recursive. When *strict* is *True*, an error is raised if no - matches are found for any of the *names*. Example: + Returns all names of objects of type *object_cls* known to a *container* (e.g. :py:class:`od.Analysis` or + :py:class:`od.Config`) that match *names*. A name can also be a pattern to match if *accept_patterns* is *True*, + or, when given, the key of a mapping named *group_str* in the container auxiliary data that matches group names + to object names. + + When *deep* is *True* the lookup of objects in the *container* is recursive. When *strict* is *True*, an error + is raised if no matches are found for any of the *names*. + + *container* can also refer to a sequence of container objects. If this is the case, the default object retrieval + is performed for all of them and the resulting values can be handled with five different strategies, controlled + via *multi_strategy*: + + - ``"first"``: The first resolved name is returned. + - ``"same"``: The resolved names are forced to be identical and an exception is raised if they differ. The + first resolved value is returned. + - ``"union"``: The set union of all resolved names is returned in a list. + - ``"intersection"``: The set intersection of all resolved names is returned in a list. + - ``"all"``: The resolved values are returned in a dictionary mapped to their respective container. + + Example: .. code-block:: python - find_config_objects(["st_tchannel_*"], config_inst, od.Dataset) + find_config_objects(names=["st_tchannel_*"], container=config_inst, object_cls=od.Dataset) # -> ["st_tchannel_t", "st_tchannel_tbar"] """ + # when the container is a sequence, find objects per container and apply the multi_strategy + if isinstance(container, (list, tuple)): + if multi_strategy not in (strategies := {"first", "same", "union", "intersection", "all"}): + raise ValueError(f"invalid multi_strategy: {multi_strategy}, must be one of {','.join(strategies)}") + + all_object_names = { + _container: cls.find_config_objects( + names=names, + container=_container, + object_cls=object_cls, + groups_str=groups_str, + accept_patterns=accept_patterns, + deep=deep, + strict=strict, + ) + for _container in container + } + + if multi_strategy == "all": + return all_object_names + if multi_strategy == "first": + return all_object_names[container[0]] + if multi_strategy == "union": + return list(set.union(*map(set, all_object_names.values()))) + if multi_strategy == "intersection": + return list(set.intersection(*map(set, all_object_names.values()))) + # "same", so check that values are identical + first = all_object_names[container[0]] + if not all(all_object_names[c] == first for c in container[1:]): + raise ValueError( + f"different objects found across containers looking for '{object_cls}' objects '{names}':\n" + f"{prettify(all_object_names)}", + ) + return first + + # prepare value caching singular = object_cls.cls_name_singular plural = object_cls.cls_name_plural - _cache = {} + _cache: dict[str, set[str]] = {} - def get_all_object_names(): + def get_all_object_names() -> set[str]: if "all_object_names" not in _cache: if deep: - _cache["all_object_names"] = { - obj.name - for obj, _, _ in - getattr(container, f"walk_{plural}")() - } + _cache["all_object_names"] = {obj.name for obj, _, _ in getattr(container, f"walk_{plural}")()} else: _cache["all_object_names"] = set(getattr(container, plural).names()) return _cache["all_object_names"] - def has_obj(name): + def has_obj(name: str) -> bool: if "has_obj_func" not in _cache: kwargs = {} if object_cls in container._deep_child_classes: kwargs["deep"] = deep - _cache["has_obj_func"] = functools.partial( - getattr(container, f"has_{singular}"), - **kwargs, - ) + _cache["has_obj_func"] = functools.partial(getattr(container, f"has_{singular}"), **kwargs) return _cache["has_obj_func"](name) object_names = [] @@ -489,7 +515,7 @@ def has_obj(name): if has_obj(name): # known object object_names.append(name) - elif object_groups and name in object_groups: + elif groups_str and name in (object_groups := container.x(groups_str, {})): # a key in the object group dict lookup.extend(list(object_groups[name])) elif accept_patterns: @@ -511,223 +537,326 @@ def has_obj(name): @classmethod def resolve_config_default( cls, + *, + param: Any, task_params: dict[str, Any], - param: str | tuple[str] | None, - container: str | od.AuxDataMixin = "config_inst", + container: str | od.AuxDataMixin | Sequence[od.AuxDataMixin], default_str: str | None = None, - multiple: bool = False, - ) -> str | tuple | Any | None: + multi_strategy: str = "first", + ) -> Any | list[Any] | dict[od.AuxDataMixin, Any]: """ - Resolves a given parameter value *param*, checks if it should be placed with a default value - when empty, and in this case, does the actual default value resolution. - - This resolution is triggered only in case *param* refers to :py:attr:`RESOLVE_DEFAULT`, a - 1-tuple containing this attribute, or *None*, If so, the default is identified via the - *default_str* from an :py:class:`order.AuxDataMixin` *container* and points to an auxiliary - that can be either a string or a function. In the latter case, it is called with the task - class, the container instance, and all task parameters. Note that when no *container* is - given, *param* is returned unchanged. - - When *multiple* is *True*, a tuple is returned. If *multiple* is *False* and the resolved - parameter is an iterable, the first entry is returned. + Resolves a given parameter value *param*, checks if it should be placed with a default value when empty, and in + this case, does the actual default value resolution. + + This resolution is triggered only in case *param* refers to :py:attr:`RESOLVE_DEFAULT`, a 1-tuple containing + this attribute, or *None*. If so, the default is identified via the *default_str* from an + :py:class:`order.AuxDataMixin` *container* and points to an auxiliary that can be either a string or a function. + In the latter case, it is called with the task class, the container instance, and all task parameters. Note that + when no *container* is given, *param* is returned unchanged. + + *container* can also refer to a sequence of :py:class:`order.AuxDataMixin` objects. If this is the case, the + default resolution is performed for all of them and the resulting values can be handled with five different + strategies, controlled via *multi_strategy*: + + - ``"first"``: The first resolved value is returned. + - ``"same"``: The resolved values are forced to be identical and an exception is raised if they differ. The + first resolved value is returned. + - ``"union"``: The set union of all resolved values is returned in a list. + - ``"intersection"``: The set intersection of all resolved values is returned in a list. + - ``"all"``: The resolved values are returned in a dictionary mapped to their respective container. Example: .. code-block:: python - def resolve_param_values(params): - params["producer"] = AnalysisTask.resolve_config_default( - params, - params.get("producer"), - container=params["config_inst"] - default_str="default_producer", - multiple=True, - ) - + # assuming this is your config config_inst = od.Config( - id=0, + id=1, name="my_config", - aux={"default_producer": ["my_producer_1", "my_producer_2"]}, + aux={ + "default_selector": "my_selector", + }, ) + # and these are the task parameters params = { "config_inst": config_inst, - "producer": RESOLVE_DEFAULT, } - resolve_param_values(params) # sets params["producer"] to ("my_producer_1", "my_producer_2") - params = { - "config_inst": config_inst, - "producer": "some_other_producer", - } - resolve_param_values(params) # sets params["producer"] to "some_other_producer" + AnalysisTask.resolve_config_default( + param=RESOLVE_DEFAULT, + task_params=params, + container=config_inst, # <-- same as passing the "config_inst" key of params + default_str="default_selector", + ) + # -> "my_selector" Example where the default points to a function: .. code-block:: python - def resolve_param_values(params): - params["ml_model"] = AnalysisTask.resolve_config_default( - params, - params.get("ml_model"), - container=params["config_inst"] - default_str="default_ml_model", - multiple=True, - ) - - # a function that chooses the ml_model based on an attibute that is set in an inference_model - def default_ml_model(task_cls, container, task_params): - default_ml_model = None - - # check if task is using an inference model - if "inference_model" in task_params.keys(): - inference_model = task_params.get("inference_model", None) - - # if inference model is not set, assume it's the container default - if inference_model in (None, "NO_STR"): - inference_model = container.x.default_inference_model - - # get the default_ml_model from the inference_model_inst - inference_model_inst = columnflow.inference.InferenceModel._subclasses[inference_model] - default_ml_model = getattr(inference_model_inst, "ml_model_name", default_ml_model) - - return default_ml_model - - return default_ml_model + def default_selector(task_cls, config_inst, task_params) -> str: + # determine the selector based on dynamic conditions + return "my_other_selector config_inst = od.Config( - id=0, + id=1, name="my_config", - aux={"default_ml_model": default_ml_model}, + aux={ + "default_selector": default_selector, # <-- function + }, ) - @inference_model(ml_model_name="default_ml_model") - def my_inference_model(self): - # some inference model implementation - ... - - params = {"config_inst": config_inst, "ml_model": None, "inference_model": "my_inference_model"} - resolve_param_values(params) # sets params["ml_model"] to "my_ml_model" - - params = {"config_inst": config_inst, "ml_model": "some_ml_model", "inference_model": "my_inference_model"} - resolve_param_values(params) # sets params["ml_model"] to "some_ml_model" + AnalysisTask.resolve_config_default( + param=RESOLVE_DEFAULT, + task_params=params, + container=config_inst, + default_str="default_selector", + ) + # -> "my_other_selector" """ + if multi_strategy not in (strategies := {"first", "same", "union", "intersection", "all"}): + raise ValueError(f"invalid multi_strategy: {multi_strategy}, must be one of {','.join(strategies)}") + # check if the parameter value is to be resolved resolve_default = param in (None, RESOLVE_DEFAULT, (RESOLVE_DEFAULT,)) + return_single_value = True if param is None or isinstance(param, str) else False + # interpret missing parameters (e.g. NO_STR) as None # (special case: an empty string is usually an active decision, but counts as missing too) - if law.is_no_param(param) or resolve_default or param == "" or param == (): + if law.is_no_param(param) or resolve_default or param == "": param = None + # get the container inst (typically a config_inst or analysis_inst) + if isinstance(container, str): + container = task_params.get(container) + if not container: + return param + # actual resolution + params: dict[od.AuxDataMixin, Any] if resolve_default: - # get the container inst (mostly a config_inst or analysis_inst) - if isinstance(container, str): - container = task_params.get(container) - - # expand default when container is set - if container and default_str: - param = container.x(default_str, None) if default_str else None - - # allow default to be a function, taking task parameters as input - if isinstance(param, Callable): - param = param(cls, container, task_params) - - # when still empty, return an empty value - if param is None: - return () if multiple else None + params = {} + for _container in law.util.make_list(container): + _param = param + # expand default when container is set + if _container and default_str: + _param = _container.x(default_str, None) + # allow default to be a function, taking task parameters as input + if isinstance(_param, Callable): + _param = _param(cls, _container, task_params) + # handle empty values and return type + if not return_single_value: + _param = () if _param is None else law.util.make_tuple(_param) + elif isinstance(_param, (list, tuple)): + _param = _param[0] if _param else None + + params[_container] = _param + else: + params = {_container: param for _container in law.util.make_list(container)} - # return either a tuple or the first param, based on the *multiple* - param = law.util.make_tuple(param) - return param if multiple else (param[0] if param else None) + # handle values + if not isinstance(container, (list, tuple)): + return params[container] + if multi_strategy == "all": + return params + if multi_strategy == "first": + return params[container[0]] + # NOTE: in there two strategies, we loose all order information + if multi_strategy == "union": + return list(set.union(*map(set, params.values()))) + if multi_strategy == "intersection": + return list(set.intersection(*map(set, params.values()))) + # "same", so check that values are identical + first = params[container[0]] + if not all(params[c] == first for c in container[1:]): + default_str_repr = f" for '{default_str}'" if default_str else "" + raise ValueError(f"multiple default values found{default_str_repr} in {container}: {params}") + return first @classmethod def resolve_config_default_and_groups( cls, + *, + param: Any, task_params: dict[str, Any], - param: str | tuple[str] | None, - container: str | od.AuxDataMixin = "config_inst", + container: str | od.AuxDataMixin | Sequence[od.AuxDataMixin], + groups_str: str, default_str: str | None = None, - groups_str: str | None = None, - ) -> tuple[str]: + multi_strategy: str = "first", + debug=False, + ) -> Any | list[Any] | dict[od.AuxDataMixin, Any]: """ - This method is similar to :py:meth:`~.resolve_config_default` in that it checks if a - parameter value *param* is empty and should be replaced with a default value. See the - referenced method for documentation on *task_params*, *param*, *container* and - *default_str*. + This method is similar to :py:meth:`~.resolve_config_default` in that it checks if a parameter value *param* is + empty and should be replaced with a default value. All arguments except for *groups_str* are forwarded to this + method. - What this method does in addition is that it checks if the values contained in *param* - (after default value resolution) refers to a group of values identified via the *groups_str* - from the :py:class:`order.AuxDataMixin` *container* that maps a string to a tuple of - strings. If it does, each value in *param* that refers to a group is expanded by the actual - group values. + What this method does in addition is that it checks if the values contained in *param* (after default value + resolution) refers to a group of values identified via the *groups_str* from the :py:class:`order.AuxDataMixin` + *container* that maps a string to a tuple of strings. If it does, each value in *param* that refers to a group + is expanded by the actual group values. Example: .. code-block:: python + # assuming this is your config config_inst = od.Config( - id=0, + id=1, name="my_config", aux={ - "default_producer": ["features_1", "my_producer_group"], - "producer_groups": {"my_producer_group": ["features_2", "features_3"]}, + "default_producer": "my_producers", + "producer_groups": { + "my_producers": ["producer_1", "producer_2"], + "my_other_producers": ["my_producers", "producer_3", "producer_4"], + }, }, ) - params = {"producer": RESOLVE_DEFAULT} + # and these are the task parameters + params = { + "config_inst": config_inst, + } AnalysisTask.resolve_config_default_and_groups( - params, - params.get("producer"), - container=config_inst, + param=RESOLVE_DEFAULT, + task_params=params, + container=config_inst, # <-- same as passing the "config_inst" key of params default_str="default_producer", groups_str="producer_groups", ) - # -> ("features_1", "features_2", "features_3") + # -> ["producer_1", "producer_2"] + + Example showing recursive group expansion: + + .. code-block:: python + + # assuming config_inst and params are the same as above + + AnalysisTask.resolve_config_default_and_groups( + param="my_other_producers", # <-- points to a group that contains another group + task_params=params, + container=config_inst, + default_str="default_producer", # <-- not used as param is set explicitly + groups_str="producer_groups", + ) + # -> ["producer_1", "producer_2", "producer_3", "producer_4"] """ - # resolve the parameter - param = cls.resolve_config_default( - task_params=task_params, - param=param, - container=container, - default_str=default_str, - multiple=True, - ) - if not param: - return param + if multi_strategy not in (strategies := {"first", "same", "union", "intersection", "all"}): + raise ValueError(f"invalid multi_strategy: {multi_strategy}, must be one of {','.join(strategies)}") - # get the container inst and return if it's not set + # get the container if isinstance(container, str): container = task_params.get(container, None) - if not container: return param + containers = law.util.make_list(container) + + # resolve the parameter + params: dict[od.AuxDataMixin, Any] = cls.resolve_config_default( + param=param, + task_params=task_params, + container=containers, + default_str=default_str, + multi_strategy="all", + ) + if not params: + return param # expand groups recursively - if groups_str and container.x(groups_str, {}): - param_groups = container.x(groups_str) - values = [] - lookup = law.util.make_list(param) + values = {} + for _container, _param in params.items(): + if not (param_groups := _container.x(groups_str, {})): + values[_container] = law.util.make_tuple(_param) + continue + lookup = collections.deque(law.util.make_list(_param)) handled_groups = set() + _values = [] while lookup: - value = lookup.pop(0) + value = lookup.popleft() if value in param_groups: if value in handled_groups: raise Exception( - f"definition of '{groups_str}' contains circular references involving " - f"group '{value}'", + f"definition of '{groups_str}' contains circular references involving group '{value}'", ) - lookup = law.util.make_list(param_groups[value]) + lookup + lookup.extendleft(law.util.make_list(param_groups[value])) handled_groups.add(value) else: - values.append(value) - param = values + _values.append(value) + values[_container] = tuple(_values) + + # handle values + if not isinstance(container, (list, tuple)): + return values[container] + if multi_strategy == "all": + return values + if multi_strategy == "first": + return values[container[0]] + if multi_strategy == "union": + return list(set.union(*map(set, values.values()))) + if multi_strategy == "intersection": + return list(set.intersection(*map(set, values.values()))) + # "same", so check that values are identical + first = values[container[0]] + if not all(values[c] == first for c in container[1:]): + default_str_repr = f" for '{default_str}'" if default_str else "" + raise ValueError( + f"multiple default values found{default_str_repr} after expanding groups '{groups_str}' in " + f"{containers}: {values}", + ) + return first - return law.util.make_tuple(param) + @classmethod + def build_repr( + cls, + objects: Any | Sequence[Any], + *, + sep: str = "__", + prepend_count: bool = False, + max_len: int = default_repr_max_len, + max_count: int = default_repr_max_count, + hash_len: int = default_repr_hash_len, + ) -> str: + """ + Generic method to construct a string representation given a single or a sequece of *objects*. + + :param objects: The object or objects to be represented. + :param sep: The separator used to join the objects. + :param prepend_count: When *True*, the number of objects is prepended to the string, followed by *sep*. + :param max_len: The maximum length of the string. If exceeded, the string is truncated and hashed. + :param max_count: The maximum number of objects to include in the string. Additional objects are hashed, but + only if the resulting representation length does not exceed *max_len*. If so, the overall truncation and + hashing is applied instead. + :param hash_len: The length of the hash that is appended to the string when it is truncated. + :return: The string representation. + """ + if 0 < max_len < hash_len: + raise ValueError(f"max_len must be greater than hash_len: {max_len} <= {hash_len}") + + # join objects when a sequence is given + if isinstance(objects, (list, tuple)): + r = f"{len(objects)}{sep}" if prepend_count else "" + # truncate when requested and the expected length will not exceed max_len, in which case the overall + # truncation applies the hashing + if ( + 0 < max_count < len(objects) and + not (0 < max_len < (len(r) + sum(map(len, objects[:max_count])) + len(sep) * max_count + hash_len)) + ): + r += sep.join(objects[:max_count]) + r += f"{sep}{law.util.create_hash(objects[max_count:], l=hash_len)}" + else: + r += sep.join(objects) + else: + r = str(objects) - def __init__(self, *args, **kwargs): + # handle overall truncation + if max_len > 0 and len(r) > max_len: + r = f"{r[:max_len - hash_len - len(sep)]}{sep}{law.util.create_hash(r, l=hash_len)}" + + return r + + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store the analysis instance @@ -738,9 +867,8 @@ def __init__(self, *args, **kwargs): def cached_value(self, key: str, func: Callable[[], T]) -> T: """ - Upon first invocation, the function *func* is called and its return value is stored under - *key* in :py:attr:`_cached_values`. Subsequent calls with the same *key* return the cached - value. + Upon first invocation, the function *func* is called and its return value is stored under *key* in + :py:attr:`_cached_values`. Subsequent calls with the same *key* return the cached value. :param key: The key under which the value is stored. :param func: The function that is called to generate the value. @@ -750,13 +878,26 @@ def cached_value(self, key: str, func: Callable[[], T]) -> T: self._cached_values[key] = func() return self._cached_values[key] + def reset_sandbox(self, sandbox: str) -> None: + """ + Resets the sandbox to a new *sandbox* value. + """ + # do nothing if the value actualy does not change + if self.sandbox == sandbox: + return + + # change it and rebuild the sandbox inst when already initialized + self.sandbox = sandbox + if self._sandbox_initialized: + self._initialize_sandbox(force=True) + def store_parts(self) -> law.util.InsertableDict: """ - Returns a :py:class:`law.util.InsertableDict` whose values are used to create a store path. - For instance, the parts ``{"keyA": "a", "keyB": "b", 2: "c"}`` lead to the path "a/b/c". The - keys can be used by subclassing tasks to overwrite values. + Returns a :py:class:`law.util.InsertableDict` whose values are used to create a store path. For instance, the + parts ``{"keyA": "a", "keyB": "b", 2: "c"}`` lead to the path "a/b/c". The keys can be used by subclassing tasks + to overwrite values. - :return: Dictionary with parts to create a path to store intermediary results. + :return: Dictionary with parts that will be translated into an output directory path. """ parts = law.util.InsertableDict() @@ -789,11 +930,10 @@ def local_path( **kwargs, ) -> str: """ local_path(*path, store=None, fs=None, store_parts_modifier=None) - Joins path fragments from *store* (defaulting to :py:attr:`default_store`), - :py:meth:`store_parts` and *path* and returns the joined path. In case a *fs* is defined, - it should refer to the config section of a local file system, and consequently, *store* is - not prepended to the returned path as the resolution of absolute paths is handled by that - file system. + Joins path fragments from *store* (defaulting to :py:attr:`default_store`), :py:meth:`store_parts` and *path* + and returns the joined path. In case a *fs* is defined, it should refer to the config section of a local file + system, and consequently, *store* is not prepended to the returned path as the resolution of absolute paths is + handled by that file system. """ # if no fs is set, determine the main store directory parts = () @@ -818,11 +958,10 @@ def local_target( *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, - ): + ) -> law.LocalTarget: """ local_target(*path, dir=False, store=None, fs=None, store_parts_modifier=None, **kwargs) - Creates either a local file or directory target, depending on *dir*, forwarding all *path* - fragments, *store* and *fs* to :py:meth:`local_path` and all *kwargs* the respective target - class. + Creates either a local file or directory target, depending on *dir*, forwarding all *path* fragments, *store* + and *fs* to :py:meth:`local_path` and all *kwargs* the respective target class. """ _dir = kwargs.pop("dir", False) store = kwargs.pop("store", None) @@ -845,8 +984,8 @@ def wlcg_path( """ Joins path fragments from *store_parts()* and *path* and returns the joined path. - The full URI to the target is not considered as it is usually defined in ``[wlcg_fs]`` - sections in the law config and hence subject to :py:func:`wlcg_target`. + The full URI to the target is not considered as it is usually defined in ``[wlcg_fs]`` sections in the law + config and hence subject to :py:func:`wlcg_target`. """ # get and optional modify the store parts store_parts = self.store_parts() @@ -865,11 +1004,11 @@ def wlcg_target( *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, - ): + ) -> law.wclg.WLCGTarget: """ wlcg_target(*path, dir=False, fs=default_wlcg_fs, store_parts_modifier=None, **kwargs) - Creates either a remote WLCG file or directory target, depending on *dir*, forwarding all - *path* fragments to :py:meth:`wlcg_path` and all *kwargs* the respective target class. When - *None*, *fs* defaults to the *default_wlcg_fs* class level attribute. + Creates either a remote WLCG file or directory target, depending on *dir*, forwarding all *path* fragments to + :py:meth:`wlcg_path` and all *kwargs* the respective target class. When *None*, *fs* defaults to the + *default_wlcg_fs* class level attribute. """ _dir = kwargs.pop("dir", False) if not kwargs.get("fs"): @@ -884,7 +1023,7 @@ def wlcg_target( # create the target instance and return it return cls(path, **kwargs) - def target(self, *path, **kwargs): + def target(self, *path, **kwargs) -> law.LocalTarget | law.wlcg.WLCGTarget | law.MirroredTarget: """ target(*path, location=None, **kwargs) """ # get the default location @@ -956,15 +1095,14 @@ def target(self, *path, **kwargs): def get_parquet_writer_opts(self, repeating_values: bool = False) -> dict[str, Any]: """ - Returns an option dictionary that can be passed as *writer_opts* to - :py:meth:`~law.pyarrow.merge_parquet_task`, for instance, at the end of chunked processing - steps that produce a single parquet file. See :py:class:`~pyarrow.parquet.ParquetWriter` for - valid options. + Returns an option dictionary that can be passed as *writer_opts* to :py:meth:`~law.pyarrow.merge_parquet_task`, + for instance, at the end of chunked processing steps that produce a single parquet file. See + :py:class:`~pyarrow.parquet.ParquetWriter` for valid options. This method can be overwritten in subclasses to customize the exact behavior. - :param repeating_values: Whether the values to be written have predominantly repeating - values, in which case differnt compression and encoding strategies are followed. + :param repeating_values: Whether the values to be written have predominantly repeating values, in which case + differnt compression and encoding strategies are followed. :return: A dictionary with options that can be passed to parquet writer objects. """ # use dict encoding if values are repeating @@ -989,17 +1127,299 @@ class ConfigTask(AnalysisTask): default=default_config, description=f"name of the analysis config to use; default: '{default_config}'", ) + configs = law.CSVParameter( + default=(default_config,), + description=f"comma-separated names of analysis configs to use; default: '{default_config}'", + brace_expand=True, + ) + known_shifts = luigi.Parameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) + + exclude_params_req = {"known_shifts"} + exclude_params_sandbox = {"known_shifts"} + exclude_params_remote_workflow = {"known_shifts"} + exclude_params_index = {"known_shifts"} + exclude_params_repr = {"known_shifts"} + + # the field in the store parts behind which the new part is inserted + # added here for subclasses that typically refer to the store part added by _this_ class + config_store_anchor = "config" @classmethod - def resolve_param_values(cls, params: dict) -> dict: + def modify_task_attributes(cls) -> None: + """ + Hook that is called by law's task register meta class right after subclass creation to update class-level + attributes. + """ + super().modify_task_attributes() + + # single/multi config adjustments in case the switch has been specified + if isinstance(cls.single_config, bool): + remove_attr = "configs" if cls.has_single_config() else "config" + if getattr(cls, remove_attr, law.no_value) != law.no_value: + setattr(cls, remove_attr, None) + + @abc.abstractproperty + def single_config(cls) -> bool: + # flag that should be set to True or False by classes that should be instantiated + # (this is wrapped into an abstract instance property as a safe-guard against instantiation of a misconfigured + # subclass, but when actually specified, this is to be realized as a boolean class attribute or property) + ... + + @classmethod + def has_single_config(cls) -> bool: + """ + Returns whether the class is configured to use a single config. + + :raises AttributeError: When the class does not specify the *single_config* attribute. + :return: *True* if the class uses a single config, *False* otherwise. + """ + single_config = cls.single_config + if not isinstance(single_config, bool): + raise AttributeError(f"unspecified 'single_config' attribute in {cls}: {single_config}") + return single_config + + @classmethod + def ensure_single_config(cls, value: bool, *, attr: str | None = None) -> None: + """ + Ensures that the :py:attr:`single_config` flag of this task is set to *value* by raising an exception if it is + not. This method is typically used to guard the access to attributes. If so, *attr* is used in the exception + message to reflect this. + + :param value: The value to compare the flag with. + :param attr: The attribute that triggered the check. + """ + single_config = cls.has_single_config() + if single_config != value: + if attr: + s = "multiple configs" if single_config else "a single config" + msg = f"cannot access attribute '{attr}' when task '{cls}' has {s}" + else: + s = "multiple configs" if value else "a single config" + msg = f"task '{cls}' expected to use {s}" + raise Exception(msg) + + @classmethod + def config_mode(cls) -> str: + """ + Returns a string representation of this task's config mode. + + :return: "single" if the task has a single config, "multi" otherwise. + """ + return "single" if cls.has_single_config() else "multi" + + @classmethod + def _get_config_container(cls, params: dict[str, Any]) -> od.Config | list[od.Config] | None: + """ + Extracts the single or multiple config instances from task parameters *params*, or *None* if neither is found. + + :param params: Dictionary of task parameters. + :return: The config instance(s) or *None*. + """ + if cls.has_single_config(): + if (config_inst := params.get("config_inst")): + return config_inst + elif (config_insts := params.get("config_insts")): + return config_insts + return None + + @classmethod + def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().resolve_param_values(params) - # store a reference to the config inst - if "config_inst" not in params and "analysis_inst" in params and "config" in params: - params["config_inst"] = params["analysis_inst"].get_config(params["config"]) + if (analysis_inst := params.get("analysis_inst")): + # store a reference to the config inst(s) + if cls.has_single_config(): + if "config_inst" not in params and "config" in params: + params["config_inst"] = analysis_inst.get_config(params["config"]) + params["config_insts"] = [params["config_inst"]] + else: + if "config_insts" not in params and "configs" in params: + params["config_insts"] = list(map(analysis_inst.get_config, params["configs"])) + + # resolving of parameters that is required before ArrayFunctions etc. can be initialized + params = cls.resolve_param_values_pre_init(params) + + # check if shifts are already known + if params.get("known_shifts", None) is None: + logger_dev.debug(f"{cls.task_family}: shifts unknown") + + # initialize ArrayFunctions etc. and collect known shifts + shifts = params["known_shifts"] = TaskShifts() + params = cls.resolve_instances(params, shifts) + params["known_shifts"] = shifts + + # resolving of parameters that can only be performed after ArrayFunction initialization + params = cls.resolve_param_values_post_init(params) + + # resolving of shifts + params = cls.resolve_shifts(params) + + return params + + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + """ + Build the array function instances. + For single-config/dataset tasks, resolve_instances is implemented by mixin classes such as the ProducersMixin. + For multi-config tasks, resolve_instances from the upstream task is called for each config instance. If the + resolve_instances function needs to be called for other combinations of parameters (e.g. per dataset), it can be + overwritten by the task class. + + :param params: Dictionary of task parameters. + :param shifts: Collection of local and global shifts. + :return: Updated dictionary of task parameters. + """ + cls.get_known_shifts(params, shifts) + + if not cls.resolution_task_cls: + params["known_shifts"] = shifts + return params + + logger_dev.debug( + f"{cls.task_family}: uses ConfigTask.resolve_instances base implementation; " + f"upsteam_task_cls was defined as {cls.resolution_task_cls}; ", + ) + # base implementation for ConfigTasks that do not define any datasets. + # Needed for e.g. MergeShiftedHistograms + if cls.has_single_config(): + _params = params.copy() + _params = cls.resolution_task_cls.resolve_instances(params, shifts) + cls.resolution_task_cls.get_known_shifts(_params, shifts) + else: + for config_inst in params["config_insts"]: + _params = { + **params, + "config_inst": config_inst, + "config": config_inst.name, + } + _params = cls.resolution_task_cls.resolve_instances(_params, shifts) + cls.resolution_task_cls.get_known_shifts(_params, shifts) + + params["known_shifts"] = shifts + + return params + + @classmethod + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + """ + Resolve parameters before the array function instances have been initialized. + + :param params: Dictionary of task parameters. + :return: Updated dictionary of task parameters. + """ + return params + + @classmethod + def resolve_param_values_post_init(cls, params: dict[str, Any]) -> dict[str, Any]: + """ + Resolve parameters after the array function instances have been initialized. + + :param params: Dictionary of task parameters. + :return: Updated dictionary of task parameters. + """ + return params + + @classmethod + def resolve_shifts(cls, params: dict[str, Any]) -> dict[str, Any]: + """ + Resolve shifts + + :param params: Dictionary of task parameters. + :return: Updated dictionary of task parameters. + """ + # called within modify_param_values to resolve shifts after all other parameters have been resolved + return params + + @classmethod + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: + """ + Adjusts the local and upstream fields of the *shifts* object to include shifts implemented + by _this_ task, and dependent shifts that are implemented by upstream tasks. + + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. + """ + return params + + resolution_task_cls = None + + @classmethod + def req_params(cls, inst: law.Task, *args, **kwargs) -> dict[str, Any]: + params = super().req_params(inst, *args, **kwargs) + + # manually add known shifts between workflows and branches + if isinstance(inst, law.BaseWorkflow) and inst.__class__ == cls and getattr(inst, "known_shifts", None): + params["known_shifts"] = inst.known_shifts return params + @classmethod + def _multi_sequence_repr( + cls, + values: Sequence[str] | Sequence[Sequence[str]], + sort: bool = False, + ) -> str: + """ + Returns a string representation of a singly (for single config) or doubly (for multi config) nested sequence of + string *values*. In the former case, the values are sorted if *sort* is *True* and formed into a representation. + The behavior of the latter case depends on whether values are identical between configs. If they are, handle + them as a single sequence. Otherwise, the representation consists of the number of values per config and a hash + of the combined, flat values. + + :param values: Nested values. + :param sort: Whether to sort the values. + :return: A string representation. + """ + # empty case + if not values: + return "none" + + # optional sorting helper + maybe_sort = (lambda vals: sorted(vals)) if sort else (lambda vals: vals) + + # helper to perform the single representation, assuming already sorted values + def single_repr(values: Sequence[str]) -> str: + if not values: + return None + if len(values) == 1: + return values[0] + return f"{len(values)}_{law.util.create_hash(values)}" + + # single case + if not isinstance(values[0], (list, tuple)): + return single_repr(maybe_sort(values)) + # multi case with a single sequence + if len(values) == 1: + return single_repr(maybe_sort(values[0])) + # multi case with identical sequences + values = [maybe_sort(_values) for _values in values] + if all(_values == values[0] for _values in values[1:]): + return single_repr(values[0]) + # build full representation + _repr = "_".join(map(str, map(len, values))) + all_values = sum(values, []) + return _repr + f"_{law.util.create_hash(all_values)}" + + @classmethod + def broadcast_to_configs(cls, value: Any, name: str, n_config_insts: int) -> tuple[Any]: + if not isinstance(value, tuple) or not value: + value = (value,) + if len(value) == 1: + value *= n_config_insts + elif len(value) != n_config_insts: + raise ValueError( + f"number of {name} sequences ({len(value)}) does not match number of configs " + f"({n_config_insts})", + ) + return value + @classmethod def _get_default_version( cls, @@ -1008,9 +1428,8 @@ def _get_default_version( keys: law.util.InsertableDict, ) -> str | None: # try to lookup the version in the config's auxiliary data - config_inst = params.get("config_inst") or getattr(inst, "config_inst", None) - if config_inst: - version = cls._dfs_key_lookup(keys, config_inst.x("versions", {})) + if isinstance(inst, ConfigTask) and inst.has_single_config(): + version = cls._dfs_key_lookup(keys, inst.config_inst.x("versions", {})) if version: return version @@ -1023,47 +1442,54 @@ def get_config_lookup_keys( ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) - get = ( - inst_or_params.get + # add the config name in front of the task family + config = ( + inst_or_params.get("config") if isinstance(inst_or_params, dict) - else lambda attr: (getattr(inst_or_params, attr, None)) + else getattr(inst_or_params, "config", None) ) - - # add the config name in front of the task family - config = get("config") if config not in {law.NO_STR, None, ""}: keys.insert_before("task_family", "config", config) return keys @classmethod - def get_array_function_kwargs(cls, task=None, **params): - kwargs = super().get_array_function_kwargs(task=task, **params) + def get_array_function_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + cls.ensure_single_config(True, attr="get_array_function_dict") + + kwargs = super().get_array_function_dict(params) - if task: - kwargs["config_inst"] = task.config_inst - elif "config_inst" in params: + if "config_inst" in params: kwargs["config_inst"] = params["config_inst"] elif "config" in params and "analysis_inst" in kwargs: kwargs["config_inst"] = kwargs["analysis_inst"].get_config(params["config"]) return kwargs - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - # store a reference to the config instance - self.config_inst = self.analysis_inst.get_config(self.config) + # store a reference to the config instances + self.config_insts = [ + self.analysis_inst.get_config(config) + for config in ([self.config] if self.has_single_config() else self.configs) + ] + if self.has_single_config(): + self.config_inst = self.config_insts[0] - def store_parts(self): + @property + def config_repr(self) -> str: + return "__".join(config_inst.name for config_inst in self.config_insts) + + def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() # add the config name - parts.insert_after("task_family", "config", self.config_inst.name) + parts.insert_after("task_family", "config", self.config_repr) return parts - def find_keep_columns(self: ConfigTask, collection: ColumnCollection) -> set[Route]: + def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: """ Returns a set of :py:class:`Route` objects describing columns that should be kept given a type of column *collection*. @@ -1079,15 +1505,15 @@ def find_keep_columns(self: ConfigTask, collection: ColumnCollection) -> set[Rou return columns def _expand_keep_column( - self: ConfigTask, + self, column: ColumnCollection | Route | str | Sequence[str | int | slice | type(Ellipsis) | list | tuple], ) -> set[Route]: """ - Expands a *column* into a set of :py:class:`Route` objects. *column* can be a - :py:class:`ColumnCollection`, a string, or any type that is accepted by :py:class:`Route`. - Collections are expanded through :py:meth:`find_keep_columns`. + Expands a *column* into a set of :py:class:`Route` objects. *column* can be a :py:class:`ColumnCollection`, a + string, or any type that is accepted by :py:class:`Route`. Collections are expanded through + :py:meth:`find_keep_columns`. :param column: The column to expand. :return: A set of :py:class:`Route` objects. @@ -1122,83 +1548,72 @@ class ShiftTask(ConfigTask): allow_empty_shift = False @classmethod - def modify_param_values(cls, params): - """ - When "config" and "shift" are set, this method evaluates them to set the global shift. - For that, it takes the shifts stored in the config instance and compares it with those - defined by this class. - """ - params = super().modify_param_values(params) + def resolve_shifts(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_shifts(params) - # get params - config_inst = params.get("config_inst") - requested_shift = params.get("shift") - requested_local_shift = params.get("local_shift") + if "known_shifts" not in params: + raise Exception(f"{cls.task_family}: known shifts should be resolved before calling 'resolve_shifts'") + known_shifts = params["known_shifts"] - # require that the config is set - if config_inst in (None, law.NO_STR): - return params + # get configs + config_insts = params.get("config_insts") # require that the shift is set and known - if requested_shift in (None, law.NO_STR): - if cls.allow_empty_shift: - params["shift"] = law.NO_STR - params["local_shift"] = law.NO_STR - return params - raise Exception(f"no shift found in params: {params}") - if requested_shift not in config_inst.shifts: - raise ValueError(f"shift {requested_shift} unknown to {config_inst}") - - # determine the known shifts for this class - shifts, upstream_shifts = cls.get_known_shifts(config_inst, params) - - # actual shift resolution: compare the requested shift to known ones - # local_shift -> the requested shift if implemented by the task itself, else nominal - # shift -> the requested shift if implemented by this task - # or an upsteam task (== global shift), else nominal - if requested_local_shift in (None, law.NO_STR): - if requested_shift in shifts: - params["shift"] = requested_shift - params["local_shift"] = requested_shift - elif requested_shift in upstream_shifts: - params["shift"] = requested_shift - params["local_shift"] = "nominal" - else: - params["shift"] = "nominal" - params["local_shift"] = "nominal" - - # store references - params["global_shift_inst"] = config_inst.get_shift(params["shift"]) - params["local_shift_inst"] = config_inst.get_shift(params["local_shift"]) + if (requested_shift := params.get("shift")) in (None, law.NO_STR): + if not cls.allow_empty_shift: + raise Exception(f"no shift found in params: {params}") + global_shift = local_shift = law.NO_STR + else: + # check if the shift is known to one of the configs + shift_defined_in_config = False + for config_inst in config_insts: + if requested_shift not in config_inst.shifts: + logger_dev.debug(f"shift {requested_shift} unknown to config {config_inst}") + else: + shift_defined_in_config = True + if not shift_defined_in_config: + raise ValueError(f"shift {requested_shift} unknown to all configs") + + # actual shift resolution: compare the requested shift to known ones + # local_shift -> the requested shift if implemented by the task itself, else nominal + # shift -> the requested shift if implemented by this task + # or an upsteam task (== global shift), else nominal + global_shift = requested_shift + if (local_shift := params.get("local_shift")) in {None, law.NO_STR}: + # check cases + if requested_shift in known_shifts.local: + local_shift = requested_shift + elif requested_shift in known_shifts.upstream: + local_shift = "nominal" + else: + global_shift = "nominal" + local_shift = "nominal" - return params + # store parameters + params["shift"] = global_shift + params["local_shift"] = local_shift - @classmethod - def resolve_param_values(cls, params: dict) -> dict: - params = super().resolve_param_values(params) + # store references to shift instances + if ( + params["shift"] != law.NO_STR and + params["local_shift"] != law.NO_STR and + (not params.get("global_shift_insts") or not params.get("local_shift_insts")) + ): + params["global_shift_insts"] = {} + params["local_shift_insts"] = {} - # set default shift - if params.get("shift") in (None, law.NO_STR): - params["shift"] = "nominal" + get_shift_or_nominal = lambda config, shift: config.get_shift(shift, default=config.get_shift("nominal")) - return params + for config_inst in config_insts: + params["global_shift_insts"][config_inst] = get_shift_or_nominal(config_inst, params["shift"]) + params["local_shift_insts"][config_inst] = get_shift_or_nominal(config_inst, params["local_shift"]) - @classmethod - def get_array_function_kwargs(cls, task=None, **params): - kwargs = super().get_array_function_kwargs(task=task, **params) - - if task: - if task.local_shift_inst: - kwargs["local_shift_inst"] = task.local_shift_inst - if task.global_shift_inst: - kwargs["global_shift_inst"] = task.global_shift_inst - else: - if "local_shift_inst" in params: - kwargs["local_shift_inst"] = params["local_shift_inst"] - if "global_shift_inst" in params: - kwargs["global_shift_inst"] = params["global_shift_inst"] + if cls.has_single_config(): + config_inst = params["config_inst"] + params["global_shift_inst"] = params["global_shift_insts"][config_inst] + params["local_shift_inst"] = params["local_shift_insts"][config_inst] - return kwargs + return params @classmethod def get_config_lookup_keys( @@ -1207,41 +1622,55 @@ def get_config_lookup_keys( ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) - get = ( - inst_or_params.get + # add the (global) shift name + shift = ( + inst_or_params.get("shift") if isinstance(inst_or_params, dict) - else lambda attr: (getattr(inst_or_params, attr, None)) + else getattr(inst_or_params, "shift", None) ) - - # add the (global) shift name - shift = get("shift") - if shift not in {law.NO_STR, None, ""}: + if shift not in (law.NO_STR, None, ""): keys["shift"] = shift return keys - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store references to the shift instances - self.local_shift_inst = None - self.global_shift_inst = None + self.global_shift_insts = None + self.local_shift_insts = None if self.shift not in (None, law.NO_STR) and self.local_shift not in (None, law.NO_STR): - self.global_shift_inst = self.config_inst.get_shift(self.shift) - self.local_shift_inst = self.config_inst.get_shift(self.local_shift) + get = lambda c, s: c.get_shift(s if s in c.shifts else "nominal") + self.global_shift_insts = { + config_inst: get(config_inst, self.shift) + for config_inst in self.config_insts + } + self.local_shift_insts = { + config_inst: get(config_inst, self.local_shift) + for config_inst in self.config_insts + } + if self.has_single_config(): + self.global_shift_inst = None + self.local_shift_inst = None + if self.global_shift_insts: + self.global_shift_inst = self.global_shift_insts[self.config_inst] + self.local_shift_inst = self.local_shift_insts[self.config_inst] - def store_parts(self): + def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() # add the shift name - if self.global_shift_inst: - parts.insert_after("config", "shift", self.global_shift_inst.name) + if self.global_shift_insts: + parts.insert_after(self.config_store_anchor, "shift", self.shift) return parts class DatasetTask(ShiftTask): + # all dataset tasks are meant to work for a single config + single_config = True + dataset = luigi.Parameter( default=default_dataset, description=f"name of the dataset to process; default: '{default_dataset}'", @@ -1250,8 +1679,8 @@ class DatasetTask(ShiftTask): file_merging = None @classmethod - def resolve_param_values(cls, params): - params = super().resolve_param_values(params) + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) # store a reference to the dataset inst if "dataset_inst" not in params and "config_inst" in params and "dataset" in params: @@ -1260,21 +1689,22 @@ def resolve_param_values(cls, params): return params @classmethod - def get_known_shifts(cls, config_inst: od.Config, params: dict) -> tuple[set[str], set[str]]: + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: # dataset can have shifts, that are considered as upstream shifts - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) + super().get_known_shifts(params, shifts) - dataset_inst = params.get("dataset_inst") - if dataset_inst: + if (dataset_inst := params.get("dataset_inst")): if dataset_inst.is_data: # clear all shifts for data - shifts.clear() - upstream_shifts.clear() + shifts.local.clear() + shifts.upstream.clear() else: # extend with dataset variations for mc - upstream_shifts |= set(dataset_inst.info.keys()) - - return shifts, upstream_shifts + shifts.upstream |= set(dataset_inst.info.keys()) @classmethod def get_config_lookup_keys( @@ -1283,33 +1713,29 @@ def get_config_lookup_keys( ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) - get = ( - inst_or_params.get + # add the dataset name before the shift name + dataset = ( + inst_or_params.get("dataset") if isinstance(inst_or_params, dict) - else lambda attr: (getattr(inst_or_params, attr, None)) + else getattr(inst_or_params, "dataset", None) ) - - # add the dataset name before the shift name - dataset = get("dataset") if dataset not in {law.NO_STR, None, ""}: keys.insert_before("shift", "dataset", dataset) return keys @classmethod - def get_array_function_kwargs(cls, task=None, **params): - kwargs = super().get_array_function_kwargs(task=task, **params) + def get_array_function_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + kwargs = super().get_array_function_dict(params) - if task: - kwargs["dataset_inst"] = task.dataset_inst - elif "dataset_inst" in params: + if "dataset_inst" in params: kwargs["dataset_inst"] = params["dataset_inst"] elif "dataset" in params and "config_inst" in kwargs: kwargs["dataset_inst"] = kwargs["config_inst"].get_dataset(params["dataset"]) return kwargs - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store references to the dataset instance @@ -1323,11 +1749,11 @@ def __init__(self, *args, **kwargs): ) self.dataset_info_inst = self.dataset_inst.get_info(key) - def store_parts(self): + def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() # insert the dataset - parts.insert_after("config", "dataset", self.dataset_inst.name) + parts.insert_after(self.config_store_anchor, "dataset", self.dataset_inst.name) return parts @@ -1396,7 +1822,7 @@ class CommandTask(AnalysisTask): run_command_in_tmp = False - def _print_command(self, args): + def _print_command(self, args) -> None: max_depth = int(args[0]) print(f"print task commands with max_depth {max_depth}") @@ -1426,11 +1852,11 @@ def _print_command(self, args): else: print(offset + law.util.colored("not a CommandTask", "yellow")) - def build_command(self): + def build_command(self) -> str | list[str]: # this method should build and return the command to run raise NotImplementedError - def touch_output_dirs(self): + def touch_output_dirs(self) -> None: # keep track of created uris so we can avoid creating them twice handled_parent_uris = set() @@ -1447,7 +1873,7 @@ def touch_output_dirs(self): parent.touch() handled_parent_uris.add(parent.uri()) - def run_command(self, cmd, optional=False, **kwargs): + def run_command(self, cmd: str | list[str], optional: bool = False, **kwargs) -> subprocess.Popen: # proper command encoding cmd = (law.util.quote_cmd(cmd) if isinstance(cmd, (list, tuple)) else cmd).strip() @@ -1488,10 +1914,10 @@ def run(self, **kwargs): self.post_run_command() - def pre_run_command(self): + def pre_run_command(self) -> None: return - def post_run_command(self): + def post_run_command(self) -> None: return @@ -1503,22 +1929,23 @@ def wrapper_factory( attributes: dict | None = None, docs: str | None = None, ) -> law.task.base.Register: - """Factory function creating wrapper task classes, inheriting from *base_cls* and - :py:class:`~law.task.base.WrapperTask`, that do nothing but require multiple instances of *require_cls*. - Unless *cls_name* is defined, the name of the created class defaults to the name of - *require_cls* plus "Wrapper". Additional *attributes* are added as class-level members when - given. + """ + Factory function creating wrapper task classes, inheriting from *base_cls* and + :py:class:`~law.task.base.WrapperTask`, that do nothing but require multiple instances of *require_cls*. Unless + *cls_name* is defined, the name of the created class defaults to the name of *require_cls* plus "Wrapper". + Additional *attributes* are added as class-level members when given. - The instances of *require_cls* to be required in the - :py:meth:`~.wrapper_factory.Wrapper.requires()` method can be controlled by task parameters. - These parameters can be enabled through the string sequence *enable*, which currently accepts: + The instances of *require_cls* to be required in the :py:meth:`~.wrapper_factory.Wrapper.requires()` method can be + controlled by task parameters. These parameters can be enabled through the string sequence *enable*, which currently + accepts: - ``configs``, ``skip_configs`` - ``shifts``, ``skip_shifts`` - ``datasets``, ``skip_datasets`` - This allows to easily build wrapper tasks that loop over (combinations of) parameters that are - either defined in the analysis or config, which would otherwise lead to mostly redundant code. + This allows to easily build wrapper tasks that loop over (combinations of) parameters that are either defined in the + analysis or config, which would otherwise lead to mostly redundant code. + Example: .. code-block:: python @@ -1535,38 +1962,30 @@ class MyTask(DatasetTask): # this allows to run (e.g.) # law run MyTaskWrapper --datasets st_* --skip-datasets *_tbar - When building the requirements, the full combinatorics of parameters is considered. However, - certain conditions apply depending on enabled features. For instance, in order to use the - "configs" feature (adding a parameter "--configs" to the created class, allowing to loop over a - list of config instances known to an analysis), *require_cls* must be at least a - :py:class:`ConfigTask` accepting "--config" (mind the singular form), whereas *base_cls* must - explicitly not. + When building the requirements, the full combinatorics of parameters is considered. However, certain conditions + apply depending on enabled features. For instance, in order to use the "configs" feature (adding a parameter + "--configs" to the created class, allowing to loop over a list of config instances known to an analysis), + *require_cls* must be at least a :py:class:`ConfigTask` accepting "--config" (mind the singular form), whereas + *base_cls* must explicitly not. :param base_cls: Base class for this wrapper :param require_cls: :py:class:`~law.task.base.Task` class to be wrapped - :param enable: Enable these parameters to control the wrapped - :py:class:`~law.task.base.Task` class instance. - Currently allowed parameters are: "configs", "skip_configs", - "shifts", "skip_shifts", "datasets", "skip_datasets" - :param cls_name: Name of the wrapper instance. If :py:attr:`None`, defaults to the - name of the :py:class:`~law.task.base.WrapperTask` class + `"Wrapper"` - :param attributes: Add these attributes as class-level members of the - new :py:class:`~law.task.base.WrapperTask` class - :param docs: Manually set the documentation string `__doc__` of the new - :py:class:`~law.task.base.WrapperTask` class instance - :raises ValueError: If a parameter provided with `enable` is not in the list - of known parameters - :raises TypeError: If any parameter in `enable` is incompatible with the - :py:class:`~law.task.base.WrapperTask` class instance or the inheritance - structure of corresponding classes - :raises ValueError: when `configs` are enabled but not found in the analysis - config instance - :raises ValueError: when `shifts` are enabled but not found in the analysis - config instance - :raises ValueError: when `datasets` are enabled but not found in the analysis - config instance - :return: The new :py:class:`~law.task.base.WrapperTask` for the - :py:class:`~law.task.base.Task` class `required_cls` + :param enable: Enable these parameters to control the wrapped :py:class:`~law.task.base.Task` class instance. + Currently allowed parameters are: "configs", "skip_configs", "shifts", "skip_shifts", "datasets", + "skip_datasets" + :param cls_name: Name of the wrapper instance. If :py:attr:`None`, defaults to the name of the + :py:class:`~law.task.base.WrapperTask` class + `"Wrapper"` + :param attributes: Add these attributes as class-level members of the new :py:class:`~law.task.base.WrapperTask` + class + :param docs: Manually set the documentation string `__doc__` of the new :py:class:`~law.task.base.WrapperTask` class + instance + :raises ValueError: If a parameter provided with `enable` is not in the list of known parameters + :raises TypeError: If any parameter in `enable` is incompatible with the :py:class:`~law.task.base.WrapperTask` + class instance or the inheritance structure of corresponding classes + :raises ValueError: when `configs` are enabled but not found in the analysis config instance + :raises ValueError: when `shifts` are enabled but not found in the analysis config instance + :raises ValueError: when `datasets` are enabled but not found in the analysis config instance + :return: The new :py:class:`~law.task.base.WrapperTask` for the :py:class:`~law.task.base.Task` class `required_cls` """ # check known features known_features = [ @@ -1597,18 +2016,18 @@ class MyTask(DatasetTask): def check_class_compatibility(name, min_require_cls, max_base_cls): if not issubclass(require_cls, min_require_cls): raise TypeError( - f"when the '{name}' feature is enabled, require_cls must inherit from " - f"{min_require_cls}, but {require_cls} does not", + f"when the '{name}' feature is enabled, require_cls must inherit from {min_require_cls}, but " + f"{require_cls} does not", ) if issubclass(base_cls, min_require_cls): raise TypeError( - f"when the '{name}' feature is enabled, base_cls must not inherit from " - f"{min_require_cls}, but {base_cls} does", + f"when the '{name}' feature is enabled, base_cls must not inherit from {min_require_cls}, but " + f"{base_cls} does", ) if not issubclass(max_base_cls, base_cls): raise TypeError( - f"when the '{name}' feature is enabled, base_cls must be a super class of " - f"{max_base_cls}, but {base_cls} is not", + f"when the '{name}' feature is enabled, base_cls must be a super class of {max_base_cls}, but " + f"{base_cls} is not", ) # check classes @@ -1623,60 +2042,62 @@ def check_class_compatibility(name, min_require_cls, max_base_cls): class Wrapper(*base_classes, law.WrapperTask): exclude_params_repr_empty = set() + exclude_params_req_set = set() if has_configs: configs = law.CSVParameter( default=(default_config,), - description="names or name patterns of configs to use; can also be the key of a " - "mapping defined in the 'config_groups' auxiliary data of the analysis; " - f"default: {default_config}", + description="names or name patterns of configs to use; can also be the key of a mapping defined in the " + f"'config_groups' auxiliary data of the analysis; default: {default_config}", brace_expand=True, ) + exclude_params_req_set.add("configs") if has_skip_configs: skip_configs = law.CSVParameter( default=(), - description="names or name patterns of configs to skip after evaluating --configs; " - "can also be the key of a mapping defined in the 'config_groups' auxiliary data " - "of the analysis; empty default", + description="names or name patterns of configs to skip after evaluating --configs; can also be the key " + "of a mapping defined in the 'config_groups' auxiliary data of the analysis; empty default", brace_expand=True, ) exclude_params_repr_empty.add("skip_configs") + exclude_params_req_set.add("skip_configs") if has_datasets: datasets = law.CSVParameter( default=("*",), - description="names or name patterns of datasets to use; can also be the key of a " - "mapping defined in the 'dataset_groups' auxiliary data of the corresponding " - "config; default: ('*',)", + description="names or name patterns of datasets to use; can also be the key of a mapping defined in " + "the 'dataset_groups' auxiliary data of the corresponding config; default: ('*',)", brace_expand=True, ) + exclude_params_req_set.add("datasets") if has_skip_datasets: skip_datasets = law.CSVParameter( default=(), - description="names or name patterns of datasets to skip after evaluating " - "--datasets; can also be the key of a mapping defined in the 'dataset_groups' " - "auxiliary data of the corresponding config; empty default", + description="names or name patterns of datasets to skip after evaluating --datasets; can also be the " + "key of a mapping defined in the 'dataset_groups' auxiliary data of the corresponding config; empty " + "default", brace_expand=True, ) exclude_params_repr_empty.add("skip_datasets") + exclude_params_req_set.add("skip_datasets") if has_shifts: shifts = law.CSVParameter( default=("nominal",), - description="names or name patterns of shifts to use; can also be the key of a " - "mapping defined in the 'shift_groups' auxiliary data of the corresponding " - "config; default: ('nominal',)", + description="names or name patterns of shifts to use; can also be the key of a mapping defined in the " + "'shift_groups' auxiliary data of the corresponding config; default: ('nominal',)", brace_expand=True, ) + exclude_params_req_set.add("shifts") if has_skip_shifts: skip_shifts = law.CSVParameter( default=(), - description="names or name patterns of shifts to skip after evaluating --shifts; " - "can also be the key of a mapping defined in the 'shift_groups' auxiliary data " - "of the corresponding config; empty default", + description="names or name patterns of shifts to skip after evaluating --shifts; can also be the key " + "of a mapping defined in the 'shift_groups' auxiliary data of the corresponding config; empty default", brace_expand=True, ) exclude_params_repr_empty.add("skip_shifts") + exclude_params_req_set.add("skip_shifts") - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store wrapper flags @@ -1704,27 +2125,24 @@ def _build_wrapper_parameters(self): # get the target config instances if self.wrapper_has_configs: configs = self.find_config_objects( - self.configs, - self.analysis_inst, - od.Config, - self.analysis_inst.x("config_groups", {}), + names=self.configs, + container=self.analysis_inst, + object_cls=od.Config, + groups_str="config_groups", ) if not configs: - raise ValueError( - f"no configs found in analysis {self.analysis_inst} matching {self.configs}", - ) + raise ValueError(f"no configs found in analysis {self.analysis_inst} matching {self.configs}") if self.wrapper_has_skip_configs: skip_configs = self.find_config_objects( - self.skip_configs, - self.analysis_inst, - od.Config, - self.analysis_inst.x("config_groups", {}), + names=self.skip_configs, + container=self.analysis_inst, + object_cls=od.Config, + groups_str="config_groups", ) configs = [c for c in configs if c not in skip_configs] if not configs: raise ValueError( - f"no configs found in analysis {self.analysis_inst} after skipping " - f"{self.skip_configs}", + f"no configs found in analysis {self.analysis_inst} after skipping {self.skip_configs}", ) config_insts = list(map(self.analysis_inst.get_config, sorted(configs))) else: @@ -1740,28 +2158,23 @@ def _build_wrapper_parameters(self): # find all shifts if self.wrapper_has_shifts: shifts = self.find_config_objects( - self.shifts, - config_inst, - od.Shift, - config_inst.x("shift_groups", {}), + names=self.shifts, + container=config_inst, + object_cls=od.Shift, + groups_str="shift_groups", ) if not shifts: - raise ValueError( - f"no shifts found in config {config_inst} matching {self.shifts}", - ) + raise ValueError(f"no shifts found in config {config_inst} matching {self.shifts}") if self.wrapper_has_skip_shifts: skip_shifts = self.find_config_objects( - self.skip_shifts, - config_inst, - od.Shift, - config_inst.x("shift_groups", {}), + names=self.skip_shifts, + container=config_inst, + object_cls=od.Shift, + groups_str="shift_groups", ) shifts = [s for s in shifts if s not in skip_shifts] if not shifts: - raise ValueError( - f"no shifts found in config {config_inst} after skipping " - f"{self.skip_shifts}", - ) + raise ValueError(f"no shifts found in config {config_inst} after skipping {self.skip_shifts}") # move "nominal" to the front if present shifts = sorted(shifts) if "nominal" in shifts: @@ -1771,28 +2184,24 @@ def _build_wrapper_parameters(self): # find all datasets if self.wrapper_has_datasets: datasets = self.find_config_objects( - self.datasets, - config_inst, - od.Dataset, - config_inst.x("dataset_groups", {}), + names=self.datasets, + container=config_inst, + object_cls=od.Dataset, + groups_str="dataset_groups", ) if not datasets: - raise ValueError( - f"no datasets found in config {config_inst} matching " - f"{self.datasets}", - ) + raise ValueError(f"no datasets found in config {config_inst} matching {self.datasets}") if self.wrapper_has_skip_datasets: skip_datasets = self.find_config_objects( - self.skip_datasets, - config_inst, - od.Dataset, - config_inst.x("dataset_groups", {}), + names=self.skip_datasets, + container=config_inst, + object_cls=od.Dataset, + groups_str="dataset_groups", ) datasets = [d for d in datasets if d not in skip_datasets] if not datasets: raise ValueError( - f"no datasets found in config {config_inst} after skipping " - f"{self.skip_datasets}", + f"no datasets found in config {config_inst} after skipping {self.skip_datasets}", ) prod_sequences.append(sorted(datasets)) @@ -1801,14 +2210,7 @@ def _build_wrapper_parameters(self): return params - def requires(self) -> Requirements: - """Collect requirements defined by the underlying ``require_cls`` - of the :py:class:`~law.task.base.WrapperTask` depending on optional - additional parameters. - - :return: Requirements for the :py:class:`~law.task.base.WrapperTask` - instance. - """ + def requires(self) -> dict: # build all requirements based on the parameter space reqs = {} @@ -1840,7 +2242,7 @@ def update_wrapper_params(self, params): Wrapper.__module__ = module.__name__ # overwrite __name__ - Wrapper.__name__ = cls_name or require_cls.__name__ + "Wrapper" + Wrapper.__name__ = cls_name or f"{require_cls.__name__}Wrapper" # set docs if docs: diff --git a/columnflow/tasks/framework/decorators.py b/columnflow/tasks/framework/decorators.py index 91df3edd1..6c3ced53c 100644 --- a/columnflow/tasks/framework/decorators.py +++ b/columnflow/tasks/framework/decorators.py @@ -3,55 +3,109 @@ """ import law -from typing import Any, Callable + +from columnflow import env_is_local +from columnflow.types import Any, Callable @law.decorator.factory(accept_generator=True) -def view_output_plots( +def only_local_env( fn: Callable, opts: Any, task: law.Task, *args: Any, **kwargs: Any, ) -> tuple[Callable, Callable, Callable]: + """ only_local_env() + A decorator that ensures that the task's decorated method is only executed in the local environment, and not by + (e.g.) remote jobs. + + :param fn: The decorated function. + :param opts: Options for the decorator. + :param task: The task instance. + :param args: Arguments to be passed to the function call. + :param kwargs: Keyword arguments to be passed to the function call. + :return: A tuple containing the before_call, call, and after_call functions. """ - Decorator to view output plots. + def before_call() -> None: + return None + + def call(state: Any) -> Any: + if not env_is_local: + raise RuntimeError(f"{task.task_family}.{fn.__name__}() can only be executed locally") + return fn(task, *args, **kwargs) - This decorator is used to view the output plots of a task. It checks if the task has a view command, - collects all the paths of the output files, and then opens each file using the view command. + def after_call(state: Any) -> None: + return None - :param fn: The function to be decorated. + return before_call, call, after_call + + +@law.decorator.factory(callback=None, accept_generator=True) +def on_failure( + fn: Callable, + opts: Any, + task: law.Task, + *args: Any, + **kwargs: Any, +) -> tuple[Callable, Callable, Callable]: + """ callback(callback=None) + A decorator that is configured with a callback function that is invoked if the decorated method raises an exception. + The task instances is passed to the callback function as an argument. + + :param fn: The decorated function. :param opts: Options for the decorator. :param task: The task instance. - :param args: Variable length argument list. - :param kwargs: Arbitrary keyword arguments. + :param args: Arguments to be passed to the function call. + :param kwargs: Keyword arguments to be passed to the function call. :return: A tuple containing the before_call, call, and after_call functions. """ - def before_call() -> None: - """ - Function to be called before the decorated function. - - :return: None - """ return None def call(state: Any) -> Any: - """ - The decorated function. + try: + return fn(task, *args, **kwargs) + except: + if callable(opts["callback"]): + try: + opts["callback"](task) + except Exception as e: + task.logger.error(f"failure callback raised an exception: {e}") + raise - :param state: The state of the task. - :return: The result of the decorated function. - """ + def after_call(state: Any) -> None: + return None + + return before_call, call, after_call + + +@law.decorator.factory(accept_generator=True) +def view_output_plots( + fn: Callable, + opts: Any, + task: law.Task, + *args: Any, + **kwargs: Any, +) -> tuple[Callable, Callable, Callable]: + """ view_output_plots() + This decorator is used to view the output plots of a task. It checks if the task has a view command, collects all + the paths of the output files, and then opens each file using the view command. + + :param fn: The decorated function. + :param opts: Options for the decorator. + :param task: The task instance. + :param args: Arguments to be passed to the function call. + :param kwargs: Keyword arguments to be passed to the function call. + :return: A tuple containing the before_call, call, and after_call functions. + """ + def before_call() -> None: + return None + + def call(state: Any) -> Any: return fn(task, *args, **kwargs) def after_call(state: Any) -> None: - """ - Function to be called after the decorated function. - - :param state: The state of the task. - :return: None - """ view_cmd = getattr(task, "view_cmd", None) if not view_cmd or view_cmd == law.NO_STR: return diff --git a/columnflow/tasks/framework/histograms.py b/columnflow/tasks/framework/histograms.py index 86bb50673..a0ecdaa2e 100644 --- a/columnflow/tasks/framework/histograms.py +++ b/columnflow/tasks/framework/histograms.py @@ -11,30 +11,31 @@ from columnflow.tasks.framework.base import Requirements, ShiftTask from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, WeightProducerMixin, - VariablesMixin, DatasetsProcessesMixin, CategoriesMixin, - ShiftSourcesMixin, + CalibratorClassesMixin, SelectorClassMixin, ReducerClassMixin, ProducerClassesMixin, MLModelsMixin, + HistProducerClassMixin, VariablesMixin, DatasetsProcessesMixin, CategoriesMixin, ShiftSourcesMixin, ) from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms from columnflow.util import dev_sandbox, maybe_import -ak = maybe_import("awkward") hist = maybe_import("hist") class HistogramsUserBase( + CalibratorClassesMixin, + SelectorClassMixin, + ReducerClassMixin, + ProducerClassesMixin, + HistProducerClassMixin, + MLModelsMixin, DatasetsProcessesMixin, CategoriesMixin, VariablesMixin, - MLModelsMixin, - WeightProducerMixin, - ProducersMixin, - SelectorStepsMixin, - CalibratorsMixin, ): + single_config = True + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) - def store_parts(self): + def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() parts.insert_before("version", "datasets", f"datasets_{self.datasets_repr}") return parts @@ -112,19 +113,19 @@ def flatten_nested_list(nested_list): # axis selections h = h[{ "process": [ - hist.loc(p.id) + hist.loc(p.name) for p in sub_process_insts - if p.id in h.axes["process"] + if p.name in h.axes["process"] ], "category": [ - hist.loc(c.id) + hist.loc(c.name) for c in leaf_category_insts - if c.id in h.axes["category"] + if c.name in h.axes["category"] ], "shift": [ - hist.loc(s.id) + hist.loc(s.name) for s in shift_insts - if s.id in h.axes["shift"] + if s.name in h.axes["shift"] ], }] @@ -136,9 +137,11 @@ def flatten_nested_list(nested_list): class HistogramsUserSingleShiftBase( - HistogramsUserBase, ShiftTask, + HistogramsUserBase, ): + # use the MergeHistograms task to trigger upstream TaskArrayFunction initialization + resolution_task_cls = MergeHistograms # upstream requirements reqs = Requirements( @@ -163,9 +166,12 @@ def requires(self): class HistogramsUserMultiShiftBase( - HistogramsUserBase, ShiftSourcesMixin, + HistogramsUserBase, ): + # use the MergeHistograms task to trigger upstream TaskArrayFunction initialization + resolution_task_cls = MergeHistograms + # upstream requirements reqs = Requirements( MergeShiftedHistograms=MergeShiftedHistograms, diff --git a/columnflow/tasks/framework/inference.py b/columnflow/tasks/framework/inference.py new file mode 100644 index 000000000..e742ad0d0 --- /dev/null +++ b/columnflow/tasks/framework/inference.py @@ -0,0 +1,244 @@ +# coding: utf-8 + +""" +Base tasks for writing serialized statistical inference models. +""" + +from __future__ import annotations + +import law +import order as od + +from columnflow.tasks.framework.base import Requirements +from columnflow.tasks.framework.mixins import ( + CalibratorClassesMixin, SelectorClassMixin, ReducerClassMixin, ProducerClassesMixin, HistProducerClassMixin, + InferenceModelMixin, HistHookMixin, MLModelsMixin, +) +from columnflow.tasks.framework.remote import RemoteWorkflow +from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms +from columnflow.util import dev_sandbox, DotDict, maybe_import +from columnflow.config_util import get_datasets_from_process + +hist = maybe_import("hist") + + +class SerializeInferenceModelBase( + CalibratorClassesMixin, + SelectorClassMixin, + ReducerClassMixin, + ProducerClassesMixin, + MLModelsMixin, + HistProducerClassMixin, + InferenceModelMixin, + HistHookMixin, + law.LocalWorkflow, + RemoteWorkflow, +): + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + + # support multiple configs + single_config = False + + # upstream requirements + reqs = Requirements( + RemoteWorkflow.reqs, + MergeHistograms=MergeHistograms, + MergeShiftedHistograms=MergeShiftedHistograms, + ) + + @classmethod + def get_mc_datasets(cls, config_inst: od.Config, proc_obj: DotDict) -> list[str]: + """ + Helper to find mc datasets. + + :param config_inst: The config instance. + :param proc_obj: process object from an InferenceModel + :return: List of dataset names corresponding to the process *proc_obj*. + """ + # the config instance should be specified in the config data of the proc_obj + if not (config_data := proc_obj.config_data.get(config_inst.name)): + return [] + + # when datasets are defined on the process object itself, interpret them as patterns + if config_data.mc_datasets: + return [ + dataset.name + for dataset in config_inst.datasets + if ( + dataset.is_mc and + law.util.multi_match(dataset.name, config_data.mc_datasets, mode=any) + ) + ] + + # if the proc object is dynamic, it is calculated and the fly (e.g. via a hist hook) + # and doesn't have any additional requirements + if proc_obj.is_dynamic: + return [] + + # otherwise, check the config + return [ + dataset_inst.name + for dataset_inst in get_datasets_from_process(config_inst, config_data.process) + ] + + @classmethod + def get_data_datasets(cls, config_inst: od.Config, cat_obj: DotDict) -> list[str]: + """ + Helper to find data datasets. + + :param config_inst: The config instance. + :param cat_obj: category object from an InferenceModel + :return: List of dataset names corresponding to the category *cat_obj*. + """ + # the config instance should be specified in the config data of the proc_obj + if not (config_data := cat_obj.config_data.get(config_inst.name)): + return [] + + if not config_data.data_datasets: + return [] + + return [ + dataset.name + for dataset in config_inst.datasets + if ( + dataset.is_data and + law.util.multi_match(dataset.name, config_data.data_datasets, mode=any) + ) + ] + + def create_branch_map(self): + return list(self.inference_model_inst.categories) + + def _requires_cat_obj(self, cat_obj: DotDict, **req_kwargs): + reqs = {} + for config_inst in self.config_insts: + if not (config_data := cat_obj.config_data.get(config_inst.name)): + continue + + # add merged shifted histograms for mc + reqs[config_inst.name] = { + proc_obj.name: { + dataset: self.reqs.MergeShiftedHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=dataset, + shift_sources=tuple( + param_obj.config_data[config_inst.name].shift_source + for param_obj in proc_obj.parameters + if ( + config_inst.name in param_obj.config_data and + self.inference_model_inst.require_shapes_for_parameter(param_obj) + ) + ), + variables=(config_data.variable,), + **req_kwargs, + ) + for dataset in self.get_mc_datasets(config_inst, proc_obj) + } + for proc_obj in cat_obj.processes + if config_inst.name in proc_obj.config_data and not proc_obj.is_dynamic + } + # add merged histograms for data, but only if + # - data in that category is not faked from mc, or + # - at least one process object is dynamic (that usually means data-driven) + if ( + (not cat_obj.data_from_processes or any(proc_obj.is_dynamic for proc_obj in cat_obj.processes)) and + (data_datasets := self.get_data_datasets(config_inst, cat_obj)) + ): + reqs[config_inst.name]["data"] = { + dataset: self.reqs.MergeHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=dataset, + variables=(config_data.variable,), + **req_kwargs, + ) + for dataset in data_datasets + } + + return reqs + + def workflow_requires(self): + reqs = super().workflow_requires() + + reqs["merged_hists"] = hist_reqs = {} + for cat_obj in self.branch_map.values(): + cat_reqs = self._requires_cat_obj(cat_obj) + for config_name, proc_reqs in cat_reqs.items(): + hist_reqs.setdefault(config_name, {}) + for proc_name, dataset_reqs in proc_reqs.items(): + hist_reqs[config_name].setdefault(proc_name, {}) + for dataset_name, task in dataset_reqs.items(): + hist_reqs[config_name][proc_name].setdefault(dataset_name, set()).add(task) + + return reqs + + def requires(self): + cat_obj = self.branch_data + return self._requires_cat_obj(cat_obj, branch=-1, workflow="local") + + def load_process_hists( + self, + inputs: dict, + cat_obj: DotDict, + config_inst: od.Config, + ) -> dict[od.Process, hist.Hist]: + # loop over all configs required by the datacard category and gather histograms + config_data = cat_obj.config_data.get(config_inst.name) + + # collect histograms per config process + hists: dict[od.Process, hist.Hist] = {} + with self.publish_step( + f"extracting {config_data.variable} in {config_data.category} for config {config_inst.name}...", + ): + for proc_obj_name, inp in inputs[config_inst.name].items(): + if proc_obj_name == "data": + process_inst = config_inst.get_process("data") + else: + proc_obj = self.inference_model_inst.get_process(proc_obj_name, category=cat_obj.name) + process_inst = config_inst.get_process(proc_obj.config_data[config_inst.name].process) + sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] + + # loop over per-dataset inputs and extract histograms containing the process + h_proc = None + for dataset_name, _inp in inp.items(): + dataset_inst = config_inst.get_dataset(dataset_name) + + # skip when the dataset is already known to not contain any sub process + if not any(map(dataset_inst.has_process, sub_process_insts)): + self.logger.warning( + f"dataset '{dataset_name}' does not contain process '{process_inst.name}' or any of " + "its subprocesses which indicates a misconfiguration in the inference model " + f"'{self.inference_model}'", + ) + continue + + # open the histogram and work on a copy + h = _inp["collection"][0]["hists"][config_data.variable].load(formatter="pickle").copy() + + # axis selections + h = h[{ + "process": [ + hist.loc(p.name) + for p in sub_process_insts + if p.name in h.axes["process"] + ], + }] + + # axis reductions + h = h[{"process": sum}] + + # add the histogram for this dataset + if h_proc is None: + h_proc = h + else: + h_proc += h + + # there must be a histogram + if h_proc is None: + raise Exception(f"no histograms found for process '{process_inst.name}'") + + # save histograms mapped to processes + hists[process_inst] = h_proc + + return hists diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index 0de908b80..37535b17a 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -14,214 +14,215 @@ import law import order as od -from columnflow.types import Sequence, Any, Iterable, Union -from columnflow.tasks.framework.base import AnalysisTask, ConfigTask, RESOLVE_DEFAULT -from columnflow.tasks.framework.parameters import SettingsParameter +from columnflow.types import Any, Iterable, Sequence +from columnflow.tasks.framework.base import ConfigTask, DatasetTask, TaskShifts, RESOLVE_DEFAULT +from columnflow.tasks.framework.parameters import SettingsParameter, DerivableInstParameter, DerivableInstsParameter from columnflow.calibration import Calibrator from columnflow.selection import Selector +from columnflow.reduction import Reducer from columnflow.production import Producer -from columnflow.weight import WeightProducer +from columnflow.histogramming import HistProducer from columnflow.ml import MLModel from columnflow.inference import InferenceModel -from columnflow.columnar_util import Route, ColumnCollection, ChunkedIOHandler -from columnflow.util import maybe_import, DotDict +from columnflow.columnar_util import Route, ColumnCollection, ChunkedIOHandler, TaskArrayFunction +from columnflow.util import maybe_import, DotDict, get_docs_url, get_code_url +from columnflow.types import Callable ak = maybe_import("awkward") logger = law.logger.get_logger(__name__) +logger_dev = law.logger.get_logger(f"{__name__}-dev") -class CalibratorMixin(ConfigTask): - """ - Mixin to include a single :py:class:`~columnflow.calibration.Calibrator` into tasks. +class ArrayFunctionClassMixin(ConfigTask): - Inheriting from this mixin will give access to instantiate and access a - :py:class:`~columnflow.calibration.Calibrator` instance with name *calibrator*, - which is an input parameter for this task. - """ - calibrator = luigi.Parameter( - default=RESOLVE_DEFAULT, - description="the name of the calibrator to be applied; default: value of the " - "'default_calibrator' config", - ) - calibrator.__annotations__ = " ".join(""" - the name of the calibrator to be applied; default: value of the - 'default_calibrator' config""".split()) - - # decides whether the task itself runs the calibrator and implements its shifts - register_calibrator_sandbox = False - register_calibrator_shifts = False + def array_function_cls_repr(self, array_function_name: str) -> str: + """ + Central definition of how to obtain representation of array function from the name. - @classmethod - def get_calibrator_inst(cls, calibrator: str, kwargs=None) -> Calibrator: + :param array_function: name of the array function (NOTE: change to class?) + :return: sring representation of the array function """ - Initialize :py:class:`~columnflow.calibration.Calibrator` instance. + return str(array_function_name) + - Extracts relevant *kwargs* for this calibrator instance using the - :py:meth:`~columnflow.tasks.framework.base.AnalaysisTask.get_calibrator_kwargs` - method. - After this process, the previously initialized instance of a - :py:class:`~columnflow.calibration.Calibrator` with the name - *calibrator* is initialized using the - :py:meth:`~columnflow.util.DerivableMeta.get_cls` method with the - relevant keyword arguments. +class ArrayFunctionInstanceMixin(DatasetTask): - :param calibrator: Name of the calibrator instance - :param kwargs: Any set keyword argument that is potentially relevant for - this :py:class:`~columnflow.calibration.Calibrator` instance - :raises RuntimeError: if requested :py:class:`~columnflow.calibration.Calibrator` instance - is not :py:attr:`~columnflow.calibration.Calibrator.exposed` - :return: The initialized :py:class:`~columnflow.calibration.Calibrator` - instance. + def _array_function_post_init(self, **kwargs) -> None: """ - calibrator_cls: Calibrator = Calibrator.get_cls(calibrator) - if not calibrator_cls.exposed: - raise RuntimeError(f"cannot use unexposed calibrator '{calibrator}' in {cls.__name__}") + Post-initialization method for all known task array functions. + """ + return None - inst_dict = cls.get_calibrator_kwargs(**kwargs) if kwargs else None - return calibrator_cls(inst_dict=inst_dict) + def array_function_inst_repr(self, array_function_inst: TaskArrayFunction) -> None: + return str(array_function_inst) - @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: - """ - Resolve parameter values *params* relevant for the - :py:class:`CalibratorMixin` and all classes it inherits from. - Loads the ``config_inst`` and loads the parameter ``"calibrator"``. - In case the parameter is not found, defaults to ``"default_calibrator"``. - Finally, this function adds the keyword ``"calibrator_inst"``, which - contains the :py:class:`~columnflow.calibration.Calibrator` instance - obtained using :py:meth:`~.CalibratorMixin.get_calibrator_inst` method. +class CalibratorClassMixin(ArrayFunctionClassMixin): + """ + Mixin to include and access single :py:class:`~columnflow.calibration.Calibrator` class. + """ - :param params: Dictionary with parameters provided by the user at - commandline level. - :return: Dictionary of parameters that now includes new value for - ``"calibrator_inst"``. - """ - params = super().resolve_param_values(params) + calibrator = luigi.Parameter( + default=RESOLVE_DEFAULT, + description="the name of the calibrator to be applied; default: value of the 'default_calibrator' analysis aux", + ) + + @classmethod + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) - config_inst = params.get("config_inst") - if config_inst: - # add the default calibrator when empty + # resolve the default class if necessary + if (container := cls._get_config_container(params)): params["calibrator"] = cls.resolve_config_default( - params, - params.get("calibrator"), - container=config_inst, + param=params.get("calibrator"), + task_params=params, + container=container, default_str="default_calibrator", - multiple=False, + multi_strategy="same", ) - params["calibrator_inst"] = cls.get_calibrator_inst(params["calibrator"], params) return params @classmethod - def get_known_shifts(cls, config_inst: od.Config, params: dict[str, Any]) -> tuple[set[str], set[str]]: - """ - Adds set of shifts that the current ``calibrator_inst`` registers to the - set of known ``shifts`` and ``upstream_shifts``. + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + # prefer --calibrator set on task-level via cli + kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrator"} + return super().req_params(inst, **kwargs) - First, the set of ``shifts`` and ``upstream_shifts`` are obtained from - the *config_inst* and the current set of parameters *params* using the - ``get_known_shifts`` methods of all classes that :py:class:`CalibratorMixin` - inherits from. - Afterwards, check if the current ``calibrator_inst`` registers shifts. - If :py:attr:`~CalibratorMixin.register_calibrator_shifts` is ``True``, - add them to the current set of ``shifts``. Otherwise, add the - shifts obtained from the ``calibrator_inst`` to ``upstream_shifts``. + @property + def calibrator_repr(self) -> str: + """ + Return a string representation of the calibrator class. + """ + return self.build_repr(self.array_function_cls_repr(self.calibrator)) - :param config_inst: Config instance for the current task. - :param params: Dictionary containing the current set of parameters provided - by the user at commandline level - :return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``. + def store_parts(self) -> law.util.InsertableDict: """ - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "calibrator", f"calib__{self.calibrator_repr}") + return parts - # get the calibrator, update it and add its shifts - calibrator_inst = params.get("calibrator_inst") - if calibrator_inst: - if cls.register_calibrator_shifts: - shifts |= calibrator_inst.all_shifts - else: - upstream_shifts |= calibrator_inst.all_shifts + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: CalibratorClassMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) - return shifts, upstream_shifts + # add the calibrator name + calibrator = ( + inst_or_params.get("calibrator") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "calibrator", None) + ) + if calibrator not in (law.NO_STR, None, ""): + keys["calibrator"] = f"calib_{calibrator}" + + return keys + + +class CalibratorMixin(ArrayFunctionInstanceMixin, CalibratorClassMixin): + """ + Mixin to include and access a single :py:class:`~columnflow.calibration.Calibrator` instance. + """ + + calibrator_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) + + exclude_params_index = {"calibrator_inst"} + exclude_params_repr = {"calibrator_inst"} + exclude_params_sandbox = {"calibrator_inst"} + exclude_params_remote_workflow = {"calibrator_inst"} + + # decides whether the task itself invokes the calibrator + invokes_calibrator = False @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + def get_calibrator_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) + + @classmethod + def build_calibrator_inst( + cls, + calibrator: str, + params: dict[str, Any] | None = None, + ) -> Calibrator: """ - Returns the required parameters for the task. - It prefers `--calibrator` set on task-level via command line. + Instantiate and return the :py:class:`~columnflow.calibration.Calibrator` instance. - :param inst: The current task instance. - :param kwargs: Additional keyword arguments. - :return: Dictionary of required parameters. + :param calibrator: Name of the calibrator class to instantiate. + :param params: Arguments forwarded to the calibrator constructor. + :raises RuntimeError: If the calibrator class is not :py:attr:`~columnflow.calibration.Calibrator.exposed`. + :return: The calibrator instance. """ - # prefer --calibrator set on task-level via cli - kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrator"} + calibrator_cls = Calibrator.get_cls(calibrator) + if not calibrator_cls.exposed: + raise RuntimeError(f"cannot use unexposed calibrator '{calibrator}' in {cls.__name__}") - return super().req_params(inst, **kwargs) + inst_dict = cls.get_calibrator_dict(params) if params else None + return calibrator_cls(inst_dict=inst_dict) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # add the calibrator instance + if not params.get("calibrator_inst"): + params["calibrator_inst"] = cls.build_calibrator_inst(params["calibrator"], params) - # cache for calibrator inst - self._calibrator_inst = None + params = super().resolve_instances(params, shifts) - @property - def calibrator_inst(self) -> Calibrator: - """ - Access current :py:class:`~columnflow.calibration.Calibrator` instance. + return params - This method loads the current :py:class:`~columnflow.calibration.Calibrator` - *calibrator_inst* from the cache or initializes it. - If the calibrator requests a specific ``sandbox``, set this sandbox as - the environment for the current :py:class:`~law.task.base.Task`. + @classmethod + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: + """ + Updates the set of known *shifts* implemented by *this* and upstream tasks. - :return: Current :py:class:`~columnflow.calibration.Calibrator` instance + :param config_inst: Config instance. + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. """ - if self._calibrator_inst is None: - self._calibrator_inst = self.get_calibrator_inst(self.calibrator, {"task": self}) + # get the calibrator, update it and add its shifts + calibrator_shifts = params["calibrator_inst"].all_shifts + (shifts.local if cls.invokes_calibrator else shifts.upstream).update(calibrator_shifts) - # overwrite the sandbox when set - if self.register_calibrator_sandbox: - sandbox = self._calibrator_inst.get_sandbox() - if sandbox: - self.sandbox = sandbox - # rebuild the sandbox inst when already initialized - if self._sandbox_initialized: - self._initialize_sandbox(force=True) + super().get_known_shifts(params, shifts) - return self._calibrator_inst + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - @property - def calibrator_repr(self): - """ - Return a string representation of the calibrator. - """ - return str(self.calibrator_inst) + # overwrite the sandbox when set + if self.invokes_calibrator and (sandbox := self.calibrator_inst.get_sandbox()): + self.reset_sandbox(sandbox) - def store_parts(self) -> law.util.InsertableDict[str, str]: - """ - Create parts to create the output path to store intermediary results - for the current :py:class:`~law.task.base.Task`. + def _array_function_post_init(self, **kwargs) -> None: + self.calibrator_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) - This method calls :py:meth:`store_parts` of the ``super`` class and inserts - `{"calibrator": "calib__{self.calibrator}"}` before keyword ``version``. - For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`. + def teardown_calibrator_inst(self) -> None: + if self.calibrator_inst: + self.calibrator_inst.run_teardown(task=self) - :return: Updated parts to create output path to store intermediary results. + @property + def calibrator_repr(self) -> str: """ - parts = super().store_parts() - parts.insert_before("version", "calibrator", f"calib__{self.calibrator_repr}") - return parts + Return a string representation of the calibrator instance. + """ + return self.build_repr(self.array_function_inst_repr(self.calibrator_inst)) def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: """ Finds the columns to keep based on the *collection*. - If the collection is `ALL_FROM_CALIBRATOR`, it includes the columns produced by the calibrator. - :param collection: The collection of columns. :return: Set of columns to keep. """ @@ -232,223 +233,176 @@ def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: return columns + +class CalibratorClassesMixin(ArrayFunctionClassMixin): + """ + Mixin to include and access multiple :py:class:`~columnflow.calibration.Calibrator` classes. + """ + + calibrators = law.CSVParameter( + default=(RESOLVE_DEFAULT,), + description="comma-separated names of calibrators to be applied; default: value of the 'default_calibrator' " + "analysis aux", + brace_expand=True, + parse_empty=True, + ) + + @classmethod + def resolve_param_values_pre_init( + cls, + params: law.util.InsertableDict[str, Any], + ) -> law.util.InsertableDict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + # resolve the default classes if necessary + if (container := cls._get_config_container(params)): + params["calibrators"] = cls.resolve_config_default_and_groups( + param=params.get("calibrators"), + task_params=params, + container=container, + default_str="default_calibrator", + groups_str="calibrator_groups", + multi_strategy="same", + ) + + return params + + @classmethod + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + # prefer --calibrators set on task-level via cli + kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrators"} + return super().req_params(inst, **kwargs) + + @property + def calibrators_repr(self) -> str: + """ + Return a string representation of the calibrators. + """ + if not self.calibrators: + return "none" + return self.build_repr(list(map(self.array_function_cls_repr, self.calibrators))) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "calibrators", f"calib__{self.calibrators_repr}") + return parts + @classmethod def get_config_lookup_keys( cls, - inst_or_params: CalibratorMixin | dict[str, Any], + inst_or_params: CalibratorClassesMixin | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) - get = ( - inst_or_params.get + # add the calibrator names + calibrators = ( + inst_or_params.get("calibrators") if isinstance(inst_or_params, dict) - else lambda attr: (getattr(inst_or_params, attr, None)) + else getattr(inst_or_params, "calibrators", None) ) - - # add the calibrator name - calibrator = get("calibrator") - if calibrator not in {law.NO_STR, None, ""}: - keys["calibrator"] = f"calib_{calibrator}" + if calibrators not in {law.NO_STR, None, "", ()}: + keys["calibrators"] = [f"calib_{calibrator}" for calibrator in calibrators] return keys -class CalibratorsMixin(ConfigTask): +class CalibratorsMixin(ArrayFunctionInstanceMixin, CalibratorClassesMixin): """ Mixin to include multiple :py:class:`~columnflow.calibration.Calibrator` instances into tasks. - - Inheriting from this mixin will allow a task to instantiate and access a set of - :py:class:`~columnflow.calibration.Calibrator` instances with names *calibrators*, - which is a comma-separated list of calibrator names and is an input parameter for this task. """ - calibrators = law.CSVParameter( - default=(RESOLVE_DEFAULT,), - description="comma-separated names of calibrators to be applied; default: value of the " - "'default_calibrator' config", - brace_expand=True, - parse_empty=True, + calibrator_insts = DerivableInstsParameter( + default=(), + visibility=luigi.parameter.ParameterVisibility.PRIVATE, ) - # decides whether the task itself runs the calibrators and implements their shifts - register_calibrators_shifts = False + exclude_params_index = {"calibrator_insts"} + exclude_params_repr = {"calibrator_insts"} + exclude_params_sandbox = {"calibrator_insts"} + exclude_params_remote_workflow = {"calibrator_insts"} @classmethod - def get_calibrator_insts(cls, calibrators: Iterable[str], kwargs=None) -> list[Calibrator]: - """ - Get all requested *calibrators*. + def get_calibrator_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) - :py:class:`~columnflow.calibration.Calibrator` instances are either - initalized or loaded from cache. + @classmethod + def build_calibrator_insts( + cls, + calibrators: Iterable[str], + params: dict[str, Any] | None = None, + ) -> list[Calibrator]: + """ + Instantiate and return multiple :py:class:`~columnflow.calibration.Calibrator` instances. - :param calibrators: Names of Calibrators to load - :param kwargs: Additional keyword arguments to forward to individual - :py:class:`~columnflow.calibration.Calibrator` instances - :raises RuntimeError: if requested calibrators are not - :py:attr:`~columnflow.calibration.Calibrator.exposed` - :return: List of :py:class:`~columnflow.calibration.Calibrator` instances. + :param calibrators: Name of the calibrator class to instantiate. + :param params: Arguments forwarded to the calibrator constructors. + :raises RuntimeError: If any calibrator class is not :py:attr:`~columnflow.calibration.Calibrator.exposed`. + :return: The list of calibrator instances. """ - inst_dict = cls.get_calibrator_kwargs(**kwargs) if kwargs else None + inst_dict = cls.get_calibrator_dict(params) if params else None insts = [] for calibrator in calibrators: calibrator_cls = Calibrator.get_cls(calibrator) if not calibrator_cls.exposed: - raise RuntimeError( - f"cannot use unexposed calibrator '{calibrator}' in {cls.__name__}", - ) + raise RuntimeError(f"cannot use unexposed calibrator '{calibrator}' in {cls.__name__}") insts.append(calibrator_cls(inst_dict=inst_dict)) return insts @classmethod - def resolve_param_values( - cls, - params: law.util.InsertableDict[str, Any], - ) -> law.util.InsertableDict[str, Any]: - """ - Resolve values *params* and check against possible default values and - calibrator groups. - - Check the values in *params* against the default value ``"default_calibrator"`` - and possible group definitions ``"calibrator_groups"`` in the current config inst. - For more information, see - :py:meth:`~columnflow.tasks.framework.base.ConfigTask.resolve_config_default_and_groups`. + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # add the calibrator instances + if not params.get("calibrator_insts"): + params["calibrator_insts"] = cls.build_calibrator_insts(params["calibrators"], params) - :param params: Parameter values to resolve - :return: Dictionary of parameters that contains the list requested - :py:class:`~columnflow.calibration.Calibrator` instances under the - keyword ``"calibrator_insts"``. See :py:meth:`~.CalibratorsMixin.get_calibrator_insts` - for more information. - """ - params = super().resolve_param_values(params) - - config_inst = params.get("config_inst") - if config_inst: - params["calibrators"] = cls.resolve_config_default_and_groups( - params, - params.get("calibrators"), - container=config_inst, - default_str="default_calibrator", - groups_str="calibrator_groups", - ) - params["calibrator_insts"] = cls.get_calibrator_insts(params["calibrators"], params) + params = super().resolve_instances(params, shifts) return params @classmethod def get_known_shifts( cls, - config_inst: od.Config, params: dict[str, Any], - ) -> tuple[set[str], set[str]]: + shifts: TaskShifts, + ) -> None: """ - Adds set of all shifts that the list of ``calibrator_insts`` register to the - set of known ``shifts`` and ``upstream_shifts``. + Updates the set of known *shifts* implemented by *this* and upstream tasks. - First, the set of ``shifts`` and ``upstream_shifts`` are obtained from - the *config_inst* and the current set of parameters *params* using the - ``get_known_shifts`` methods of all classes that :py:class:`CalibratorsMixin` - inherits from. - Afterwards, loop through the list of :py:class:`~columnflow.calibration.Calibrator` - and check if they register shifts. - If :py:attr:`~CalibratorsMixin.register_calibrators_shifts` is ``True``, - add them to the current set of ``shifts``. Otherwise, add the - shifts to ``upstream_shifts``. - - :param config_inst: Config instance for the current task. - :param params: Dictionary containing the current set of parameters provided - by the user at commandline level - :return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``. + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. """ - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) - # get the calibrators, update them and add their shifts - for calibrator_inst in params.get("calibrator_insts") or []: - if cls.register_calibrators_shifts: - shifts |= calibrator_inst.all_shifts - else: - upstream_shifts |= calibrator_inst.all_shifts - - return shifts, upstream_shifts - - @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: - """ - Returns the required parameters for the task. - - It prefers ``--calibrators`` set on task-level via command line. + for calibrator_inst in params["calibrator_insts"]: + shifts.upstream |= calibrator_inst.all_shifts - :param inst: The current task instance. - :param kwargs: Additional keyword arguments. - :return: Dictionary of required parameters. - """ + super().get_known_shifts(params, shifts) - # prefer --calibrators set on task-level via cli - kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrators"} - - return super().req_params(inst, **kwargs) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # cache for calibrator insts - self._calibrator_insts = None - - @property - def calibrator_insts(self) -> list[Calibrator]: - """ - Access current list of :py:class:`~columnflow.calibration.Calibrator` instances. - - Loads the current :py:class:`~columnflow.calibration.Calibrator` *calibrator_insts* from - the cache or initializes it. - - :return: Current list :py:class:`~columnflow.calibration.Calibrator` instances - """ - if self._calibrator_insts is None: - self._calibrator_insts = self.get_calibrator_insts(self.calibrators, {"task": self}) - return self._calibrator_insts + def _array_function_post_init(self, **kwargs) -> None: + for calibrator_inst in self.calibrator_insts or []: + calibrator_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) @property def calibrators_repr(self) -> str: """ Return a string representation of the calibrators. """ - calibs_repr = "none" - if self.calibrators: - calibs_repr = "__".join([str(calib) for calib in self.calibrator_insts[:5]]) - if len(self.calibrators) > 5: - calibs_repr += f"__{law.util.create_hash([str(calib) for calib in self.calibrator_insts[5:]])}" - return calibs_repr - - def store_parts(self): - """ - Create parts to create the output path to store intermediary results - for the current :py:class:`~law.task.base.Task`. - - Calls :py:meth:`store_parts` of the ``super`` class and inserts - `{"calibrator": "calib__{HASH}"}` before keyword ``version``. - Here, ``HASH`` is the joint string of the first five calibrator names - + a hash created with :py:meth:`law.util.create_hash` based on - the list of calibrators, starting at its 5th element (i.e. ``self.calibrators[5:]``) - For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`. - - :return: Updated parts to create output path to store intermediary results. - """ - parts = super().store_parts() - parts.insert_before("version", "calibrators", f"calib__{self.calibrators_repr}") - return parts + if not self.calibrators: + return "none" + return self.build_repr(list(map(self.array_function_inst_repr, self.calibrator_insts))) def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: """ Finds the columns to keep based on the *collection*. - If the collection is ``ALL_FROM_CALIBRATORS``, it includes the columns produced by the calibrators. - :param collection: The collection of columns. :return: Set of columns to keep. """ - columns: set[Route] = super().find_keep_columns(collection) + columns = super().find_keep_columns(collection) if collection == ColumnCollection.ALL_FROM_CALIBRATORS: columns |= set.union(*( @@ -459,578 +413,693 @@ def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: return columns -class SelectorMixin(ConfigTask): +class SelectorClassMixin(ArrayFunctionClassMixin): """ - Mixin to include a single :py:class:`~columnflow.selection.Selector` - instances into tasks. - - Inheriting from this mixin will allow a task to instantiate and access a - :py:class:`~columnflow.selection.Selector` instance with name *selector*, - which is an input parameter for this task. + Mixin to include and access single :py:class:`~columnflow.selection.Selector` class. """ + selector = luigi.Parameter( default=RESOLVE_DEFAULT, description="the name of the selector to be applied; default: value of the " - "'default_selector' config", + "'default_selector' analysis aux", + ) + selector_steps = law.CSVParameter( + default=(RESOLVE_DEFAULT,), + description="a subset of steps of the selector to apply; uses all steps when empty; " + "default: empty", + brace_expand=True, + parse_empty=True, ) - # decides whether the task itself runs the selector and implements its shifts - register_selector_sandbox = False - register_selector_shifts = False + selector_steps_order_sensitive = False + + exclude_params_repr_empty = {"selector_steps"} @classmethod - def get_selector_inst( - cls, - selector: str, - kwargs=None, - ) -> Selector: + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + if (container := cls._get_config_container(params)): + # resolve the default class if necessary + params["selector"] = cls.resolve_config_default( + param=params.get("selector"), + task_params=params, + container=container, + default_str="default_selector", + multi_strategy="same", + ) + + # apply selector_steps_groups and default_selector_steps from config + if "selector_steps" in params: + params["selector_steps"] = cls.resolve_config_default_and_groups( + param=params.get("selector_steps"), + task_params=params, + container=container, + default_str="default_selector_steps", + groups_str="selector_step_groups", + multi_strategy="same", + ) + + # sort selector steps when the order does not matter + if params.get("selector_steps") and not cls.selector_steps_order_sensitive: + params["selector_steps"] = tuple(sorted(params["selector_steps"])) + + return params + + @classmethod + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + # prefer --selector and --selector-steps set on task-level via cli + kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | { + "selector", + "selector_steps", + } + return super().req_params(inst, **kwargs) + + @property + def selector_repr(self) -> str: + """ + Return a string representation of the selector class. + """ + sel_repr = self.build_repr(self.array_function_cls_repr(self.selector)) + steps = self.selector_steps + if steps and not self.selector_steps_order_sensitive: + steps = sorted(steps) + if steps: + sel_repr += "__steps_" + self.build_repr(steps, sep="_") + return sel_repr + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. """ - Get requested *selector*. + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "selector", f"sel__{self.selector_repr}") + return parts + + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: SelectorClassMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the selector name + selector = ( + inst_or_params.get("selector") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "selector", None) + ) + if selector not in (law.NO_STR, None, ""): + keys["selector"] = f"sel_{selector}" - :py:class:`~columnflow.selection.Selector` instance is either - initalized or loaded from cache. + return keys + + +class SelectorMixin(ArrayFunctionInstanceMixin, SelectorClassMixin): + """ + Mixin to include and access a single :py:class:`~columnflow.selection.Selector` instance. + """ + + selector_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) - :param selector: Name of :py:class:`~columnflow.selection.Selector` to load - :param kwargs: Additional keyword arguments to forward to the - :py:class:`~columnflow.selection.Selector` instance - :return: :py:class:`~columnflow.selection.Selector` instance. + exclude_params_index = {"selector_inst"} + exclude_params_repr = {"selector_inst"} + exclude_params_sandbox = {"selector_inst"} + exclude_params_remote_workflow = {"selector_inst"} + + # decides whether the task itself invokes the selector + invokes_selector = False + + @classmethod + def get_selector_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) + + @classmethod + def build_selector_inst(cls, selector: str, params: dict[str, Any] | None = None) -> Selector: + """ + Instantiate and return the :py:class:`~columnflow.selection.Selector` instance. + + :param selector: Name of the selector class to instantiate. + :param params: Arguments forwarded to the selector constructor. + :raises RuntimeError: If the selector class is not :py:attr:`~columnflow.selection.Selector.exposed`. + :return: The selector instance. """ selector_cls = Selector.get_cls(selector) if not selector_cls.exposed: raise RuntimeError(f"cannot use unexposed selector '{selector}' in {cls.__name__}") - inst_dict = cls.get_selector_kwargs(**kwargs) if kwargs else None + inst_dict = cls.get_selector_dict(params) if params else None return selector_cls(inst_dict=inst_dict) @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict: - """ - Resolve values *params* and check against possible default values and - selector groups. + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # add the selector instance + if not params.get("selector_inst"): + params["selector_inst"] = cls.build_selector_inst(params["selector"], params) - Check the values in *params* against the default value ``"default_selector"`` - in the current config inst. For more information, see - :py:meth:`~columnflow.tasks.framework.base.AnalysisTask.resolve_config_default`. - - :param params: Parameter values to resolve - :return: Dictionary of parameters that contains the requested - :py:class:`~columnflow.selection.Selector` instance under the - keyword ``"selector_inst"``. - """ - params = super().resolve_param_values(params) - - # add the default selector when empty - config_inst = params.get("config_inst") - if config_inst: - params["selector"] = cls.resolve_config_default( - params, - params.get("selector"), - container=config_inst, - default_str="default_selector", - multiple=False, - ) - params["selector_inst"] = cls.get_selector_inst(params["selector"], params) + params = super().resolve_instances(params, shifts) return params @classmethod def get_known_shifts( cls, - config_inst: od.Config, params: dict[str, Any], - ) -> tuple[set[str], set[str]]: + shifts: TaskShifts, + ) -> None: """ - Adds set of shifts that the current ``selector_inst`` registers to the - set of known ``shifts`` and ``upstream_shifts``. - - First, the set of ``shifts`` and ``upstream_shifts`` are obtained from - the *config_inst* and the current set of parameters *params* using the - ``get_known_shifts`` methods of all classes that :py:class:`SelectorMixin` - inherits from. - Afterwards, check if the current ``selector_inst`` registers shifts. - If :py:attr:`~SelectorMixin.register_selector_shifts` is ``True``, - add them to the current set of ``shifts``. Otherwise, add the - shifts obtained from the ``selector_inst`` to ``upstream_shifts``. + Updates the set of known *shifts* implemented by *this* and upstream tasks. - :param config_inst: Config instance for the current task. - :param params: Dictionary containing the current set of parameters provided - by the user at commandline level - :return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``. + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. """ - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) - # get the selector, update it and add its shifts - selector_inst = params.get("selector_inst") - if selector_inst: - if cls.register_selector_shifts: - shifts |= selector_inst.all_shifts - else: - upstream_shifts |= selector_inst.all_shifts + selector_shifts = params["selector_inst"].all_shifts + (shifts.local if cls.invokes_selector else shifts.upstream).update(selector_shifts) - return shifts, upstream_shifts + super().get_known_shifts(params, shifts) - @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # overwrite the sandbox when set + if self.invokes_selector and (sandbox := self.selector_inst.get_sandbox()): + self.reset_sandbox(sandbox) + + def _array_function_post_init(self, **kwargs) -> None: + self.selector_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) + + def teardown_selector_inst(self) -> None: + if self.selector_inst: + self.selector_inst.run_teardown(task=self) + + @property + def selector_repr(self) -> str: + """ + Return a string representation of the selector instance. """ - Get the required parameters for the task, preferring the ``--selector`` set on task-level via CLI. + sel_repr = self.build_repr(self.array_function_inst_repr(self.selector_inst)) + # add representation of steps only if this class does not invoke the selector itself + if not self.invokes_selector: + steps = self.selector_steps + if steps and not self.selector_steps_order_sensitive: + steps = sorted(steps) + if steps: + sel_repr += "__steps_" + self.build_repr(steps, sep="_") - This method first checks if the --selector parameter is set at the task-level via the command line. - If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. - The method then calls the 'req_params' method of the superclass with the updated kwargs. + return sel_repr - :param inst: The current task instance. - :param kwargs: Additional keyword arguments that may contain parameters for the task. - :return: A dictionary of parameters required for the task. + def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: """ - # prefer --selector set on task-level via cli - kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"selector"} + Finds the columns to keep based on the *collection*. - return super().req_params(inst, **kwargs) + :param collection: The collection of columns. + :return: Set of columns to keep. + """ + columns = super().find_keep_columns(collection) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + if collection == ColumnCollection.ALL_FROM_SELECTOR: + columns |= self.selector_inst.produced_columns - # cache for selector inst - self._selector_inst = None + return columns - @property - def selector_inst(self): - """ - Access current :py:class:`~columnflow.selection.Selector` instance. - Loads the current :py:class:`~columnflow.selection.Selector` *selector_inst* from - the cache or initializes it. - If the selector requests a specific ``sandbox``, set this sandbox as - the environment for the current :py:class:`~law.task.base.Task`. +class ReducerClassMixin(ArrayFunctionClassMixin): + """ + Mixin to include and access single :py:class:`~columnflow.reduction.Reducer` class. + """ - :return: Current :py:class:`~columnflow.selection.Selector` instance - """ - if self._selector_inst is None: - self._selector_inst = self.get_selector_inst(self.selector, {"task": self}) + reducer = luigi.Parameter( + default=RESOLVE_DEFAULT, + description="the name of the reducer to be applied; default: value of the 'default_reducer' analysis aux", + ) + + @classmethod + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + # resolve the default class if necessary + if (container := cls._get_config_container(params)): + params["reducer"] = cls.resolve_config_default( + param=params.get("reducer"), + task_params=params, + container=container, + default_str="default_reducer", + multi_strategy="same", + ) - # overwrite the sandbox when set - if self.register_selector_sandbox: - sandbox = self._selector_inst.get_sandbox() - if sandbox: - self.sandbox = sandbox - # rebuild the sandbox inst when already initialized - if self._sandbox_initialized: - self._initialize_sandbox(force=True) + # !! to be removed in a future release + if not params["reducer"]: + # fallback to cf's default and trigger a verbose warning + params["reducer"] = "cf_default" + docs_url = get_docs_url("user_guide", "02_03_transition.html") + code_url = get_code_url("columnflow", "reduction", "default.py") + logger.warning_once( + "reducer_undefined", + "the resolution of the '--reducer' parameter resulted in an empty value, most likely caused by a " + f"missing auxiliary field 'default_reducer' in your configuration; see {docs_url} for more " + f"information; using '{params['reducer']}' ({code_url}) as a fallback", + ) - return self._selector_inst + return params + + @classmethod + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + # prefer --reducer set on task-level via cli + kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"reducer"} + return super().req_params(inst, **kwargs) @property - def selector_repr(self): + def reducer_repr(self) -> str: """ - Return a string representation of the selector. + Return a string representation of the reducer class. """ - return str(self.selector_inst) + return self.build_repr(self.array_function_cls_repr(self.reducer)) - def store_parts(self): + def store_parts(self) -> law.util.InsertableDict: """ - Create parts to create the output path to store intermediary results - for the current :py:class:`~law.task.base.Task`. - - Calls :py:meth:`store_parts` of the ``super`` class and inserts - `{"selector": "sel__{SELECTOR_NAME}"}` before keyword ``version``. - Here, ``SELECTOR_NAME`` is the name of the current ``selector_inst``. - - :return: Updated parts to create output path to store intermediary results. + :return: Dictionary with parts that will be translated into an output directory path. """ parts = super().store_parts() - parts.insert_before("version", "selector", f"sel__{self.selector_repr}") + parts.insert_after(self.config_store_anchor, "reducer", f"red__{self.reducer_repr}") return parts - def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: - columns = super().find_keep_columns(collection) - - if collection == ColumnCollection.ALL_FROM_SELECTOR: - columns |= self.selector_inst.produced_columns - - return columns - @classmethod def get_config_lookup_keys( cls, - inst_or_params: SelectorMixin | dict[str, Any], + inst_or_params: ReducerClassMixin | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) - get = ( - inst_or_params.get + # add the reducer name + reducer = ( + inst_or_params.get("reducer") if isinstance(inst_or_params, dict) - else lambda attr: (getattr(inst_or_params, attr, None)) + else getattr(inst_or_params, "reducer", None) ) - - # add the selector name - selector = get("selector") - if selector not in {law.NO_STR, None, ""}: - keys["selector"] = f"sel_{selector}" + if reducer not in (law.NO_STR, None, ""): + keys["reducer"] = f"red_{reducer}" return keys -class SelectorStepsMixin(SelectorMixin): +class ReducerMixin(ArrayFunctionInstanceMixin, ReducerClassMixin): """ - Mixin to include multiple selector steps into tasks. - - Inheriting from this mixin will allow a task to access selector steps, which can be a - comma-separated list of selector step names and is an input parameter for this task. + Mixin to include and access a single :py:class:`~columnflow.reduction.Reducer` instance. """ - selector_steps = law.CSVParameter( - default=(), - description="a subset of steps of the selector to apply; uses all steps when empty; " - "default: empty", - brace_expand=True, - parse_empty=True, + reducer_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, ) - exclude_params_repr_empty = {"selector_steps"} + exclude_params_index = {"reducer_inst"} + exclude_params_repr = {"reducer_inst"} + exclude_params_sandbox = {"reducer_inst"} + exclude_params_remote_workflow = {"reducer_inst"} - selector_steps_order_sensitive = False + # decides whether the task itself invokes the reducer + invokes_reducer = False @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: - """ - Resolve values *params* and check against possible default values and - selector step groups. + def get_reducer_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) - Check the values in *params* against the default value ``"default_selector_steps"`` - and the group ``"selector_step_groups"`` in the current config inst. - For more information, see - :py:meth:`~columnflow.tasks.framework.base.AnalysisTask.resolve_config_default`. - If :py:attr:`SelectorStepsMixin.selector_steps_order_sensitive` is ``True``, - :py:func:`sort ` the selector steps. + @classmethod + def build_reducer_inst( + cls, + reducer: str, + params: dict[str, Any] | None = None, + ) -> Reducer: + """ + Instantiate and return the :py:class:`~columnflow.reduction.Reducer` instance. - :param params: Parameter values to resolve - :return: Dictionary of parameters that contains the requested - selector steps under the keyword ``"selector_steps"``. + :param reducer: Name of the reducer class to instantiate. + :param params: Arguments forwarded to the reducer constructor. + :raises RuntimeError: If the reducer class is not :py:attr:`~columnflow.reduction.Reducer.exposed`. + :return: The reducer instance. """ - params = super().resolve_param_values(params) + reducer_cls = Reducer.get_cls(reducer) + if not reducer_cls.exposed: + raise RuntimeError(f"cannot use unexposed reducer '{reducer}' in {cls.__name__}") - # apply selector_steps_groups and default_selector_steps from config - config_inst = params.get("config_inst") - if config_inst: - params["selector_steps"] = cls.resolve_config_default_and_groups( - params, - params.get("selector_steps"), - container=config_inst, - default_str="default_selector_steps", - groups_str="selector_step_groups", - ) + inst_dict = cls.get_reducer_dict(params) if params else None + return reducer_cls(inst_dict=inst_dict) - # sort selector steps when the order does not matter - if not cls.selector_steps_order_sensitive and "selector_steps" in params: - params["selector_steps"] = tuple(sorted(params["selector_steps"])) + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # add the reducer instance + if not params.get("reducer_inst"): + params["reducer_inst"] = cls.build_reducer_inst(params["reducer"], params) + + params = super().resolve_instances(params, shifts) return params @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: """ - Get the required parameters for the task, preferring the --selector-steps set on task-level via CLI. + Updates the set of known *shifts* implemented by *this* and upstream tasks. - This method first checks if the --selector-steps parameter is set at the task-level via the command line. - If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. - The method then calls the 'req_params' method of the superclass with the updated kwargs. + :param config_inst: Config instance. + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. + """ + # get the reducer, update it and add its shifts + reducer_shifts = params["reducer_inst"].all_shifts + (shifts.local if cls.invokes_reducer else shifts.upstream).update(reducer_shifts) - :param inst: The current task instance. - :param kwargs: Additional keyword arguments that may contain parameters for the task. - :return: A dictionary of parameters required for the task. + super().get_known_shifts(params, shifts) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # overwrite the sandbox when set + if self.invokes_reducer and (sandbox := self.reducer_inst.get_sandbox()): + self.reset_sandbox(sandbox) + + def _array_function_post_init(self, **kwargs) -> None: + self.reducer_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) + + def teardown_reducer_inst(self) -> None: + if self.reducer_inst: + self.reducer_inst.run_teardown(task=self) + + @property + def reducer_repr(self) -> str: """ - # prefer --selector-steps set on task-level via cli - kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"selector_steps"} + Return a string representation of the reducer instance. + """ + return self.build_repr(self.array_function_inst_repr(self.reducer_inst)) - return super().req_params(inst, **kwargs) + def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: + """ + Finds the columns to keep based on the *collection*. - def store_parts(self) -> law.util.InsertableDict: + :param collection: The collection of columns. + :return: Set of columns to keep. """ - Create parts to create the output path to store intermediary results - for the current :py:class:`~law.task.base.Task`. + columns = super().find_keep_columns(collection) + + if collection == ColumnCollection.ALL_FROM_REDUCER: + columns |= self.reducer_inst.produced_columns + + return columns + - Calls :py:meth:`store_parts` of the ``super`` class and inserts - `{"selector": "__steps__LIST_OF_STEPS"}`, where ``LIST_OF_STEPS`` is the - sorted list of selector steps. - For more information, see e.g. - :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`. +class ProducerClassMixin(ArrayFunctionClassMixin): + """ + Mixin to include and access single :py:class:`~columnflow.production.Producer` class. + """ - :return: Updated parts to create output path to store intermediary results. + producer = luigi.Parameter( + default=RESOLVE_DEFAULT, + description="the name of the producer to be applied; default: value of the 'default_producer' analysis aux", + ) + + @classmethod + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + # resolve the default class if necessary + if (container := cls._get_config_container(params)): + params["producer"] = cls.resolve_config_default( + param=params.get("producer"), + task_params=params, + container=container, + default_str="default_producer", + multi_strategy="same", + ) + + return params + + @classmethod + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + # prefer --producer set on task-level via cli + kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producer"} + return super().req_params(inst, **kwargs) + + @property + def producer_repr(self) -> str: + """ + Return a string representation of the producer class. + """ + return self.build_repr(self.array_function_cls_repr(self.producer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. """ parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "producer", f"prod__{self.producer_repr}") + return parts - steps = self.selector_steps - if not self.selector_steps_order_sensitive: - steps = sorted(steps) - if steps: - parts["selector"] += "__steps_" + "_".join(steps) + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: ProducerClassMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the producer name + producer = ( + inst_or_params.get("producer") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "producer", None) + ) + if producer not in (law.NO_STR, None, ""): + keys["producer"] = f"prod_{producer}" - return parts + return keys -class ProducerMixin(ConfigTask): +class ProducerMixin(ArrayFunctionInstanceMixin, ProducerClassMixin): """ - Mixin to include a single :py:class:`~columnflow.production.Producer` into tasks. - - Inheriting from this mixin will give access to instantiate and access a - :py:class:`~columnflow.production.Producer` instance with name *producer*, - which is an input parameter for this task. + Mixin to include and access a single :py:class:`~columnflow.production.Producer` instance. """ - producer = luigi.Parameter( - default=RESOLVE_DEFAULT, - description="the name of the producer to be applied; default: value of the " - "'default_producer' config", + producer_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, ) - # decides whether the task itself runs the producer and implements its shifts - register_producer_sandbox = False - register_producer_shifts = False + exclude_params_index = {"producer_inst"} + exclude_params_repr = {"producer_inst"} + exclude_params_sandbox = {"producer_inst"} + exclude_params_remote_workflow = {"producer_inst"} + + # decides whether the task itself invokes the producer + invokes_producer = False @classmethod - def get_producer_inst(cls, producer: str, kwargs=None) -> Producer: - """ - Initialize :py:class:`~columnflow.production.Producer` instance. + def get_producer_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) - Extracts relevant *kwargs* for this producer instance using the - :py:meth:`~columnflow.tasks.framework.base.AnalaysisTask.get_producer_kwargs` - method. - After this process, the previously initialized instance of a - :py:class:`~columnflow.production.Producer` with the name - *producer* is initialized using the - :py:meth:`~columnflow.util.DerivableMeta.get_cls` method with the - relevant keyword arguments. + @classmethod + def build_producer_inst( + cls, + producer: str, + params: dict[str, Any] | None = None, + ) -> Producer: + """ + Instantiate and return the :py:class:`~columnflow.production.Producer` instance. - :param producer: Name of the :py:class:`~columnflow.production.Producer` - instance - :param kwargs: Any set keyword argument that is potentially relevant for - this :py:class:`~columnflow.production.Producer` instance - :raises RuntimeError: if requested :py:class:`~columnflow.production.Producer` instance - is not :py:attr:`~columnflow.production.Producer.exposed` - :return: The initialized :py:class:`~columnflow.production.Producer` - instance. + :param producer: Name of the producer class to instantiate. + :param params: Arguments forwarded to the producer constructor. + :raises RuntimeError: If the producer class is not + :py:attr:`~columnflow.production.Producer.exposed`. + :return: The producer instance. """ - producer_cls: Producer = Producer.get_cls(producer) + producer_cls = Producer.get_cls(producer) if not producer_cls.exposed: raise RuntimeError(f"cannot use unexposed producer '{producer}' in {cls.__name__}") - inst_dict = cls.get_producer_kwargs(**kwargs) if kwargs else None + inst_dict = cls.get_producer_dict(params) if params else None return producer_cls(inst_dict=inst_dict) @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: - """ - Resolve parameter values *params* relevant for the - :py:class:`ProducerMixin` and all classes it inherits from. - - Loads the ``config_inst`` and loads the parameter ``"producer"``. - In case the parameter is not found, defaults to ``"default_producer"``. - Finally, this function adds the keyword ``"producer_inst"``, which - contains the :py:class:`~columnflow.production.Producer` instance - obtained using :py:meth:`~.ProducerMixin.get_producer_inst` method. - - :param params: Dictionary with parameters provided by the user at - commandline level. - :return: Dictionary of parameters that now includes new value for - ``"producer_inst"``. - """ - params = super().resolve_param_values(params) + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # add the producer instance + if not params.get("producer_inst"): + params["producer_inst"] = cls.build_producer_inst(params["producer"], params) - # add the default producer when empty - config_inst = params.get("config_inst") - if config_inst: - params["producer"] = cls.resolve_config_default( - params, - params.get("producer"), - container=config_inst, - default_str="default_producer", - multiple=False, - ) - params["producer_inst"] = cls.get_producer_inst(params["producer"], params) + params = super().resolve_instances(params, shifts) return params @classmethod - def get_known_shifts(cls, config_inst: od.Config, params: dict[str, Any]) -> tuple[set[str], set[str]]: + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: """ - Adds set of shifts that the current ``producer_inst`` registers to the - set of known ``shifts`` and ``upstream_shifts``. + Updates the set of known *shifts* implemented by *this* and upstream tasks. - First, the set of ``shifts`` and ``upstream_shifts`` are obtained from - the *config_inst* and the current set of parameters *params* using the - ``get_known_shifts`` methods of all classes that :py:class:`ProducerMixin` - inherits from. - Afterwards, check if the current ``producer_inst`` registers shifts. - If :py:attr:`~ProducerMixin.register_producer_shifts` is ``True``, - add them to the current set of ``shifts``. Otherwise, add the - shifts obtained from the ``producer_inst`` to ``upstream_shifts``. - - :param config_inst: Config instance for the current task. - :param params: Dictionary containing the current set of parameters provided - by the user at commandline level - :return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``. + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. """ - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) - # get the producer, update it and add its shifts - producer_inst = params.get("producer_inst") - if producer_inst: - if cls.register_producer_shifts: - shifts |= producer_inst.all_shifts - else: - upstream_shifts |= producer_inst.all_shifts - - return shifts, upstream_shifts - - @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: - """ - Get the required parameters for the task, preferring the ``--producer`` set on task-level via CLI. + producer_shifts = params["producer_inst"].all_shifts + (shifts.local if cls.invokes_producer else shifts.upstream).update(producer_shifts) - This method first checks if the ``--producer`` parameter is set at the task-level via the command line. - If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. - The method then calls the 'req_params' method of the superclass with the updated kwargs. + super().get_known_shifts(params, shifts) - :param inst: The current task instance. - :param kwargs: Additional keyword arguments that may contain parameters for the task. - :return: A dictionary of parameters required for the task. - """ - # prefer --producer set on task-level via cli - kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producer"} - - return super().req_params(inst, **kwargs) - - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - # cache for producer inst - self._producer_inst = None - - @property - def producer_inst(self) -> Producer: - """ - Access current :py:class:`~columnflow.production.Producer` instance. - - Loads the current :py:class:`~columnflow.production.Producer` *producer_inst* from - the cache or initializes it. - If the producer requests a specific ``sandbox``, set this sandbox as - the environment for the current :py:class:`~law.task.base.Task`. - - :return: Current :py:class:`~columnflow.production.Producer` instance - """ - if self._producer_inst is None: - self._producer_inst = self.get_producer_inst(self.producer, {"task": self}) + # overwrite the sandbox when set + if self.invokes_producer and (sandbox := self.producer_inst.get_sandbox()): + self.reset_sandbox(sandbox) - # overwrite the sandbox when set - if self.register_producer_sandbox: - sandbox = self._producer_inst.get_sandbox() - if sandbox: - self.sandbox = sandbox - # rebuild the sandbox inst when already initialized - if self._sandbox_initialized: - self._initialize_sandbox(force=True) + def _array_function_post_init(self, **kwargs) -> None: + self.producer_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) - return self._producer_inst + def teardown_producer_inst(self) -> None: + if self.producer_inst: + self.producer_inst.run_teardown(task=self) @property def producer_repr(self) -> str: """ - Return a string representation of the producer. - """ - return str(self.producer_inst) if self.producer != law.NO_STR else "none" - - def store_parts(self) -> law.util.InsertableDict[str, str]: - """ - Create parts to create the output path to store intermediary results - for the current :py:class:`~law.task.base.Task`. - - Calls :py:meth:`store_parts` of the ``super`` class and inserts - `{"producer": "prod__{self.producer}"}` before keyword ``version``. - For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`. - - :return: Updated parts to create output path to store intermediary results. + Return a string representation of the producer instance. """ - parts = super().store_parts() - producer = f"prod__{self.producer_repr}" - parts.insert_before("version", "producer", producer) - return parts + return self.build_repr(self.array_function_inst_repr(self.producer_inst)) def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: """ Finds the columns to keep based on the *collection*. - This method first calls the 'find_keep_columns' method of the superclass with the given *collection*. - If the *collection* is equal to ``ALL_FROM_PRODUCER``, it adds the - columns produced by the producer instance to the set of columns. - :param collection: The collection of columns. - :return: A set of columns to keep. + :return: Set of columns to keep. """ columns = super().find_keep_columns(collection) - if collection == ColumnCollection.ALL_FROM_PRODUCER: + if collection == ColumnCollection.ALL_FROM_CALIBRATOR: columns |= self.producer_inst.produced_columns return columns + +class ProducerClassesMixin(ArrayFunctionClassMixin): + """ + Mixin to include and access multiple :py:class:`~columnflow.production.Producer` classes. + """ + + producers = law.CSVParameter( + default=(RESOLVE_DEFAULT,), + description="comma-separated names of producers to be applied; default: value of the 'default_producer' " + "analysis aux", + brace_expand=True, + parse_empty=True, + ) + + @classmethod + def resolve_param_values_pre_init( + cls, + params: law.util.InsertableDict[str, Any], + ) -> law.util.InsertableDict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + # resolve the default classes if necessary + if (container := cls._get_config_container(params)): + params["producers"] = cls.resolve_config_default_and_groups( + param=params.get("producers"), + task_params=params, + container=container, + default_str="default_producer", + groups_str="producer_groups", + multi_strategy="same", + ) + + return params + + @classmethod + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + # prefer --producers set on task-level via cli + kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producers"} + return super().req_params(inst, **kwargs) + + @property + def producers_repr(self) -> str: + """ + Return a string representation of the producers. + """ + if not self.producers: + return "none" + return self.build_repr(list(map(self.array_function_cls_repr, self.producers))) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "producers", f"prod__{self.producers_repr}") + return parts + @classmethod def get_config_lookup_keys( cls, - inst_or_params: ProducerMixin | dict[str, Any], + inst_or_params: ProducerClassesMixin | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) - get = ( - inst_or_params.get + # add the producer names + producers = ( + inst_or_params.get("producers") if isinstance(inst_or_params, dict) - else lambda attr: (getattr(inst_or_params, attr, None)) + else getattr(inst_or_params, "producers", None) ) - - # add the producer name - producer = get("producer") - if producer not in {law.NO_STR, None, ""}: - keys["producer"] = f"prod_{producer}" + if producers not in {law.NO_STR, None, "", ()}: + keys["producers"] = [f"prod_{producer}" for producer in producers] return keys -class ProducersMixin(ConfigTask): +class ProducersMixin(ArrayFunctionInstanceMixin, ProducerClassesMixin): """ Mixin to include multiple :py:class:`~columnflow.production.Producer` instances into tasks. - - Inheriting from this mixin will allow a task to instantiate and access a set of - :py:class:`~columnflow.production.Producer` instances with names *producers*, - which is a comma-separated list of producer names and is an input parameter for this task. """ - producers = law.CSVParameter( - default=(RESOLVE_DEFAULT,), - description="comma-separated names of producers to be applied; default: value of the " - "'default_producer' config", - brace_expand=True, - parse_empty=True, + producer_insts = DerivableInstsParameter( + default=(), + visibility=luigi.parameter.ParameterVisibility.PRIVATE, ) - # decides whether the task itself runs the producers and implements their shifts - register_producers_shifts = False + exclude_params_index = {"producer_insts"} + exclude_params_repr = {"producer_insts"} + exclude_params_sandbox = {"producer_insts"} + exclude_params_remote_workflow = {"producer_insts"} @classmethod - def get_producer_insts(cls, producers: Iterable[str], kwargs=None) -> list[Producer]: - """ - Get all requested *producers*. + def get_producer_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) - :py:class:`~columnflow.production.Producer` instances are either - initalized or loaded from cache. + @classmethod + def build_producer_insts( + cls, + producers: Iterable[str], + params: dict[str, Any] | None = None, + ) -> list[Producer]: + """ + Instantiate and return multiple :py:class:`~columnflow.production.Producer` instances. - :param producers: Names of :py:class:`~columnflow.production.Producer` - instances to load - :param kwargs: Additional keyword arguments to forward to individual - :py:class:`~columnflow.production.Producer` instances - :raises RuntimeError: if requested producers are not - :py:attr:`~columnflow.production.Producer.exposed` - :return: List of :py:class:`~columnflow.production.Producer` instances. + :param producers: Name of the producer class to instantiate. + :param params: Arguments forwarded to the producer constructors. + :raises RuntimeError: If any producer class is not :py:attr:`~columnflow.production.Producer.exposed`. + :return: The list of producer instances. """ - inst_dict = cls.get_producer_kwargs(**kwargs) if kwargs else None + inst_dict = cls.get_producer_dict(params) if params else None insts = [] for producer in producers: @@ -1042,153 +1111,57 @@ def get_producer_insts(cls, producers: Iterable[str], kwargs=None) -> list[Produ return insts @classmethod - def resolve_param_values( - cls, - params: law.util.InsertableDict[str, Any], - ) -> law.util.InsertableDict[str, Any]: - """ - Resolve values *params* and check against possible default values and - producer groups. + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # add the producer instances + if not params.get("producer_insts"): + params["producer_insts"] = cls.build_producer_insts(params["producers"], params) - Check the values in *params* against the default value ``"default_producer"`` - and possible group definitions ``"producer_groups"`` in the current config inst. - For more information, see - :py:meth:`~columnflow.tasks.framework.base.ConfigTask.resolve_config_default_and_groups`. - - :param params: Parameter values to resolve - :return: Dictionary of parameters that contains the list requested - :py:class:`~columnflow.production.Producer` instances under the - keyword ``"producer_insts"``. See :py:meth:`~.ProducersMixin.get_producer_insts` - for more information. - """ - params = super().resolve_param_values(params) - - config_inst = params.get("config_inst") - if config_inst: - params["producers"] = cls.resolve_config_default_and_groups( - params, - params.get("producers"), - container=config_inst, - default_str="default_producer", - groups_str="producer_groups", - ) - params["producer_insts"] = cls.get_producer_insts(params["producers"], params) + params = super().resolve_instances(params, shifts) return params @classmethod - def get_known_shifts(cls, config_inst: od.Config, params: dict[str, Any]) -> tuple[set[str], set[str]]: + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: """ - Adds set of all shifts that the list of ``producer_insts`` register to the - set of known ``shifts`` and ``upstream_shifts``. - - First, the set of ``shifts`` and ``upstream_shifts`` are obtained from - the *config_inst* and the current set of parameters *params* using the - ``get_known_shifts`` methods of all classes that :py:class:`ProducersMixin` - inherits from. - Afterwards, loop through the list of :py:class:`~columnflow.production.Producer` - and check if they register shifts. - If :py:attr:`~ProducersMixin.register_producers_shifts` is ``True``, - add them to the current set of ``shifts``. Otherwise, add the - shifts to ``upstream_shifts``. + Updates the set of known *shifts* implemented by *this* and upstream tasks. - :param config_inst: Config instance for the current task. - :param params: Dictionary containing the current set of parameters provided - by the user at commandline level - :return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``. + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. """ - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) - # get the producers, update them and add their shifts - for producer_inst in params.get("producer_insts") or []: - if cls.register_producers_shifts: - shifts |= producer_inst.all_shifts - else: - upstream_shifts |= producer_inst.all_shifts - - return shifts, upstream_shifts - - @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: - """ - Get the required parameters for the task, preferring the --producers set on task-level via CLI. - - This method first checks if the --producers parameter is set at the task-level via the command line. - If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. - The method then calls the 'req_params' method of the superclass with the updated kwargs. - - :param inst: The current task instance. - :param kwargs: Additional keyword arguments that may contain parameters for the task. - :return: A dictionary of parameters required for the task. - """ - # prefer --producers set on task-level via cli - kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producers"} - - return super().req_params(inst, **kwargs) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # cache for producer insts - self._producer_insts = None - - @property - def producer_insts(self) -> list[Producer]: - """ - Access current list of :py:class:`~columnflow.production.Producer` instances. + for producer_inst in params["producer_insts"]: + shifts.upstream |= producer_inst.all_shifts - Loads the current :py:class:`~columnflow.production.Producer` *producer_insts* from - the cache or initializes it. + super().get_known_shifts(params, shifts) - :return: Current list :py:class:`~columnflow.production.Producer` instances - """ - if self._producer_insts is None: - self._producer_insts = self.get_producer_insts(self.producers, {"task": self}) - return self._producer_insts + def _array_function_post_init(self, **kwargs) -> None: + for producer_inst in self.producer_insts or []: + producer_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) @property def producers_repr(self) -> str: - """Return a string representation of the producers.""" - prods_repr = "none" - if self.producers: - prods_repr = "__".join([str(prod) for prod in self.producer_insts[:5]]) - if len(self.producers) > 5: - prods_repr += f"__{law.util.create_hash([str(prod) for prod in self.producer_insts[5:]])}" - return prods_repr - - def store_parts(self): """ - Create parts to create the output path to store intermediary results - for the current :py:class:`~law.task.base.Task`. - - Calls :py:meth:`store_parts` of the ``super`` class and inserts - `{"producers": "prod__{HASH}"}` before keyword ``version``. - Here, ``HASH`` is the joint string of the first five producer names - + a hash created with :py:meth:`law.util.create_hash` based on - the list of producers, starting at its 5th element (i.e. ``self.producers[5:]``) - For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`. - - :return: Updated parts to create output path to store intermediary results. + Return a string representation of the producers. """ - parts = super().store_parts() - parts.insert_before("version", "producers", f"prod__{self.producers_repr}") - - return parts + if not self.producers: + return "none" + return self.build_repr(list(map(self.array_function_inst_repr, self.producer_insts))) def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: """ Finds the columns to keep based on the *collection*. - This method first calls the 'find_keep_columns' method of the superclass with the given *collection*. - If the *collection* is equal to ``ALL_FROM_PRODUCERS``, it adds the - columns produced by all producer instances to the set of columns. - :param collection: The collection of columns. - :return: A set of columns to keep. + :return: Set of columns to keep. """ columns = super().find_keep_columns(collection) - if collection == ColumnCollection.ALL_FROM_PRODUCERS: + if collection == ColumnCollection.ALL_FROM_CALIBRATORS: columns |= set.union(*( producer_inst.produced_columns for producer_inst in self.producer_insts @@ -1197,13 +1170,12 @@ def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: return columns -class MLModelMixinBase(AnalysisTask): +class MLModelMixinBase(ConfigTask): """ Base mixin to include a machine learning application into tasks. - Inheriting from this mixin will allow a task to instantiate and access a - :py:class:`~columnflow.ml.MLModel` instance with name *ml_model*, which is an input parameter - for this task. + Inheriting from this mixin will allow a task to instantiate and access a :py:class:`~columnflow.ml.MLModel` instance + with name *ml_model*, which is an input parameter for this task. """ ml_model = luigi.Parameter( @@ -1213,15 +1185,23 @@ class MLModelMixinBase(AnalysisTask): default=DotDict(), description="settings passed to the init function of the ML model", ) + ml_model_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) + exclude_params_index = {"ml_model_inst"} + exclude_params_repr = {"ml_model_inst"} + exclude_params_sandbox = {"ml_model_inst"} + exclude_params_remote_workflow = {"ml_model_inst"} exclude_params_repr_empty = {"ml_model"} @property - def ml_model_repr(self): + def ml_model_repr(self) -> str: """ Returns a string representation of the ML model instance. """ - return str(self.ml_model_inst) + return self.build_repr(str(self.ml_model_inst)) @classmethod def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: @@ -1229,9 +1209,9 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: Get the required parameters for the task, preferring the ``--ml-model`` set on task-level via CLI. - This method first checks if the ``--ml-model`` parameter is set at the task-level via the command line. - If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. - The method then calls the 'req_params' method of the superclass with the updated kwargs. + This method first checks if the ``--ml-model`` parameter is set at the task-level via the command line. If it + is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. The method then + calls the 'req_params' method of the superclass with the updated kwargs. :param inst: The current task instance. :param kwargs: Additional keyword arguments that may contain parameters for the task. @@ -1253,9 +1233,8 @@ def get_ml_model_inst( """ Get requested *ml_model* instance. - This method retrieves the requested *ml_model* instance. - If *requested_configs* are provided, they are used for the training of - the ML application. + This method retrieves the requested *ml_model* instance. If *requested_configs* are provided, they are used for + the training of the ML application. :param ml_model: Name of :py:class:`~columnflow.ml.MLModel` to load. :param analysis_inst: Forward this analysis inst to the init function of new MLModel sub class. @@ -1264,7 +1243,6 @@ def get_ml_model_inst( :return: :py:class:`~columnflow.ml.MLModel` instance. """ ml_model_inst: MLModel = MLModel.get_cls(ml_model)(analysis_inst, **kwargs) - if requested_configs: configs = ml_model_inst.training_configs(list(requested_configs)) if configs: @@ -1300,410 +1278,157 @@ def events_used_in_training( ) -class MLModelTrainingMixin(MLModelMixinBase): +class MLModelTrainingMixin( + MLModelMixinBase, + CalibratorClassesMixin, + SelectorClassMixin, + ReducerClassMixin, + ProducerClassesMixin, +): """ A mixin class for training machine learning models. - - This class provides parameters for configuring the training of machine learning models. """ - configs = law.CSVParameter( - default=(), - description="comma-separated names of analysis config to use; should only contain a single " - "name in case the ml model is bound to a single config; when empty, the ml model is " - "expected to fully define the configs it uses; empty default", - brace_expand=True, - parse_empty=True, - ) - calibrators = law.MultiCSVParameter( - default=(), - description="multiple comma-separated sequences of names of calibrators to apply, " - "separated by ':'; each sequence corresponds to a config in --configs; when empty, the " - "'default_calibrator' setting of each config is used if set, or the model is expected to " - "fully define the calibrators it requires upstream; empty default", - brace_expand=True, - parse_empty=True, - ) - selectors = law.CSVParameter( - default=(), - description="comma-separated names of selectors to apply; each selector corresponds to a " - "config in --configs; when empty, the 'default_selector' setting of each config is used if " - "set, or the ml model is expected to fully define the selector it uses requires upstream; " - "empty default", - brace_expand=True, - parse_empty=True, - ) - producers = law.MultiCSVParameter( - default=(), - description="multiple comma-separated sequences of names of producers to apply, " - "separated by ':'; each sequence corresponds to a config in --configs; when empty, the " - "'default_producer' setting of each config is used if set, or ml model is expected to " - "fully define the producers it requires upstream; empty default", - brace_expand=True, - parse_empty=True, - ) - - @classmethod - def resolve_calibrators( - cls, - ml_model_inst: MLModel, - params: dict[str, Any], - ) -> tuple[tuple[str]]: - """ - Resolve the calibrators for the given ML model instance. - - This method retrieves the calibrators from the parameters *params* and - broadcasts them to the configs if necessary. - It also resolves `calibrator_groups` and `default_calibrator` from the config(s) associated - with this ML model instance, and validates the number of sequences. - Finally, it checks the retrieved calibrators against - the training calibrators of the model using - :py:meth:`~columnflow.ml.MLModel.training_calibrators` and instantiates them if necessary. - - :param ml_model_inst: The ML model instance. - :param params: A dictionary of parameters that may contain the calibrators. - :return: A tuple of tuples containing the resolved calibrators. - :raises Exception: If the number of calibrator sequences does not match - the number of configs used by the ML model. - """ - calibrators: Union[tuple[str], tuple[tuple[str]]] = params.get("calibrators") or ((),) - - # broadcast to configs - n_configs = len(ml_model_inst.config_insts) - if len(calibrators) == 1 and n_configs != 1: - calibrators = tuple(calibrators * n_configs) - - # apply calibrators_groups and default_calibrator from the config - calibrators = tuple( - ConfigTask.resolve_config_default_and_groups( - params, - calibrators[i], - container=config_inst, - default_str="default_calibrator", - groups_str="calibrator_groups", - ) - for i, config_inst in enumerate(ml_model_inst.config_insts) - ) - - # validate number of sequences - if len(calibrators) != n_configs: - raise Exception( - f"MLModel '{ml_model_inst.cls_name}' uses {n_configs} configs but received " - f"{len(calibrators)} calibrator sequences", - ) - - # final check by model - calibrators = tuple( - tuple(ml_model_inst.training_calibrators(config_inst, list(_calibrators))) - for config_inst, _calibrators in zip(ml_model_inst.config_insts, calibrators) - ) - - # instantiate them once - for config_inst, _calibrators in zip(ml_model_inst.config_insts, calibrators): - init_kwargs = law.util.merge_dicts(params, {"config_inst": config_inst}) - for calibrator in _calibrators: - CalibratorMixin.get_calibrator_inst(calibrator, kwargs=init_kwargs) - - return calibrators - - @classmethod - def resolve_selectors( - cls, - ml_model_inst: MLModel, - params: dict[str, Any], - ) -> tuple[str]: - """ - Resolve the selectors for the given ML model instance. - - This method retrieves the selectors from the parameters *params* and - broadcasts them to the configs if necessary. - It also resolves `default_selector` from the config(s) associated - with this ML model instance, validates the number of sequences. - Finally, it checks the retrieved selectors against the training selectors - of the model, using - :py:meth:`~columnflow.ml.MLModel.training_selector`, and instantiates them. - - :param ml_model_inst: The ML model instance. - :param params: A dictionary of parameters that may contain the selectors. - :return: A tuple containing the resolved selectors. - :raises Exception: If the number of selector sequences does not match - the number of configs used by the ML model. - """ - selectors = params.get("selectors") or (None,) - - # broadcast to configs - n_configs = len(ml_model_inst.config_insts) - if len(selectors) == 1 and n_configs != 1: - selectors = tuple(selectors * n_configs) - - # use config defaults - selectors = tuple( - ConfigTask.resolve_config_default( - params, - selectors[i], - container=config_inst, - default_str="default_selector", - multiple=False, - ) - for i, config_inst in enumerate(ml_model_inst.config_insts) - ) - - # validate sequence length - if len(selectors) != n_configs: - raise Exception( - f"MLModel '{ml_model_inst.cls_name}' uses {n_configs} configs but received " - f"{len(selectors)} selectors", - ) - - # final check by model - selectors = tuple( - ml_model_inst.training_selector(config_inst, selector) - for config_inst, selector in zip(ml_model_inst.config_insts, selectors) - ) - - # instantiate them once - for config_inst, selector in zip(ml_model_inst.config_insts, selectors): - init_kwargs = law.util.merge_dicts(params, {"config_inst": config_inst}) - SelectorMixin.get_selector_inst(selector, kwargs=init_kwargs) - - return selectors + single_config = False @classmethod - def resolve_producers( - cls, - ml_model_inst: MLModel, - params: dict[str, Any], - ) -> tuple[tuple[str]]: - """ - Resolve the producers for the given ML model instance. - - This method retrieves the producers from the parameters *params* and - broadcasts them to the configs if necessary. - It also resolves `producer_groups` and `default_producer` from the config(s) associated - with this ML model instance, validates the number of sequences. - Finally, it checks the retrieved producers against the training producers - of the model, using - :py:meth:`~columnflow.ml.MLModel.training_producers`, and instantiates them. - - :param ml_model_inst: The ML model instance. - :param params: A dictionary of parameters that may contain the producers. - :return: A tuple of tuples containing the resolved producers. - :raises Exception: If the number of producer sequences does not match - the number of configs used by the ML model. - """ - producers = params.get("producers") or ((),) - - # broadcast to configs - n_configs = len(ml_model_inst.config_insts) - if len(producers) == 1 and n_configs != 1: - producers = tuple(producers * n_configs) - - # apply producers_groups and default_producer from the config - producers = tuple( - ConfigTask.resolve_config_default_and_groups( - params, - producers[i], - container=config_inst, - default_str="default_producer", - groups_str="producer_groups", - ) - for i, config_inst in enumerate(ml_model_inst.config_insts) - ) - - # validate number of sequences - if len(producers) != n_configs: - raise Exception( - f"MLModel '{ml_model_inst.cls_name}' uses {n_configs} configs but received " - f"{len(producers)} producer sequences", - ) - - # final check by model - producers = tuple( - tuple(ml_model_inst.training_producers(config_inst, list(_producers))) - for config_inst, _producers in zip(ml_model_inst.config_insts, producers) - ) + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # NOTE: we can only build TAF insts from the MLModel after ml_model_inst is set + if not cls.resolution_task_cls: + raise ValueError(f"resolution_task_cls must be set for multi-config task {cls.task_family}") + + cls.get_known_shifts(params, shifts) + + ml_model_inst = params["ml_model_inst"] + for config_inst, dataset_insts in ml_model_inst.used_datasets.items(): + for dataset_inst in dataset_insts: + # NOTE: we need to copy here, because otherwise taf inits will only be triggered once + _params = { + **params, + "config_inst": config_inst, + "config": config_inst.name, + "dataset": dataset_inst.name, + } + logger_dev.debug( + f"building taf insts for {ml_model_inst.cls_name} {config_inst.name}, {dataset_inst.name}", + ) + cls.resolution_task_cls.resolve_instances(_params, shifts) + cls.resolution_task_cls.get_known_shifts(_params, shifts) - # instantiate them once - for config_inst, _producers in zip(ml_model_inst.config_insts, producers): - init_kwargs = law.util.merge_dicts(params, {"config_inst": config_inst}) - for producer in _producers: - ProducerMixin.get_producer_inst(producer, kwargs=init_kwargs) + params["known_shifts"] = shifts - return producers + return params @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: """ Resolve the parameter values for the given parameters. - This method retrieves the parameters and resolves the ML model instance, configs, - calibrators, selectors, and producers. It also calls the model's setup hook. + This method retrieves the parameters and resolves the ML model instance and the configs. It also calls the + model's setup hook. :param params: A dictionary of parameters that may contain the analysis instance and ML model. :return: A dictionary containing the resolved parameters. - :raises Exception: If the ML model instance received configs to define training configs, - but did not define any. + :raises Exception: If the ML model instance received configs to define training configs, but did not define any. """ - params = super().resolve_param_values(params) + # NOTE: we need to resolve ml_model_inst before CSPs because the ml_model_inst itself defines + # used CSPs and datasets + params = super().resolve_param_values_pre_init(params) - if "analysis_inst" in params and "ml_model" in params: - analysis_inst = params["analysis_inst"] + if "analysis_inst" not in params or "ml_model" not in params: + raise ValueError("analysis_inst and ml_model need to be set to resolve the ml_model_inst") - # NOTE: we could try to implement resolving the default ml_model here - ml_model_inst = cls.get_ml_model_inst( - params["ml_model"], - analysis_inst, - parameters=params["ml_model_settings"], - ) - params["ml_model_inst"] = ml_model_inst - - # resolve configs - _configs = params.get("configs", ()) - params["configs"] = tuple(ml_model_inst.training_configs(list(_configs))) - if not params["configs"]: - raise Exception( - f"MLModel '{ml_model_inst.cls_name}' received configs '{_configs}' to define " - "training configs, but did not define any", - ) - ml_model_inst._set_configs(params["configs"]) + analysis_inst = params["analysis_inst"] - # resolve calibrators - params["calibrators"] = cls.resolve_calibrators(ml_model_inst, params) + # NOTE: we could try to implement resolving the default ml_model here + # NOTE: why not implement the config resoluting in get_ml_model_inst instead? + ml_model_inst = cls.get_ml_model_inst( + params["ml_model"], + analysis_inst, + parameters=params["ml_model_settings"], + ) + params["ml_model_inst"] = ml_model_inst - # resolve selectors - params["selectors"] = cls.resolve_selectors(ml_model_inst, params) + # resolve configs + _configs = params.get("configs", ()) + params["configs"] = tuple(ml_model_inst.training_configs(list(_configs))) + if not params["configs"]: + raise Exception( + f"MLModel '{ml_model_inst.cls_name}' received configs '{_configs}' to define training configs, but did " + "not define any", + ) + ml_model_inst._set_configs(params["configs"]) - # resolve producers - params["producers"] = cls.resolve_producers(ml_model_inst, params) + # call the model's setup hook + ml_model_inst._setup() - # call the model's setup hook - ml_model_inst._setup() + # resolve CSPs based on the MLModel + params["calibrators"] = law.util.make_tuple( + ml_model_inst.training_calibrators(analysis_inst, params["calibrators"]), + ) + params["selector"] = ml_model_inst.training_selector(analysis_inst, params["selector"]) + params["producers"] = law.util.make_tuple( + ml_model_inst.training_producers(analysis_inst, params["producers"]), + ) return params - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # get the ML model instance - self.ml_model_inst = self.get_ml_model_inst( - self.ml_model, - self.analysis_inst, - configs=list(self.configs), - parameters=self.ml_model_settings, - ) - def store_parts(self) -> law.util.InsertableDict[str, str]: """ - Generate a dictionary of store parts for the current instance. - - This method extends the base method to include additional parts related to machine learning - model configurations, calibrators, selectors, producers (CSP), and the ML model instance itself. - If the list of either of the CSPs is empty, the corresponding part is set to ``"none"``, - otherwise, the first two elements of the list are joined with ``"__"``. - If the list of either of the CSPs contains more than two elements, the part is extended - with the number of elements and a hash of the remaining elements, which is - created with :py:meth:`law.util.create_hash`. - The parts are represented as strings and are used to create unique identifiers for the - instance's output. + Generate a dictionary of store parts for the current instance. This method extends the base method to include + the ML model parameter. :return: An InsertableDict containing the store parts. """ parts = super().store_parts() - # since MLTraining is no CalibratorsMixin, SelectorMixin, ProducerMixin, ConfigTask, - # all these parts are missing in the `store_parts` - - configs_repr = "__".join(self.configs[:5]) - - if len(self.configs) > 5: - configs_repr += f"_{law.util.create_hash(self.configs[5:])}" - - parts.insert_after("task_family", "configs", configs_repr) - - for label, fct_names in [ - ("calib", self.calibrators), - ("sel", tuple((sel,) for sel in self.selectors)), - ("prod", self.producers), - ]: - if not fct_names or not any(fct_names): - fct_names = ["none"] - elif len(set(fct_names)) == 1: - # when functions are the same per config, only use them once - fct_names = fct_names[0] - n_fct_per_config = str(len(fct_names)) - else: - # when functions differ between configs, flatten - n_fct_per_config = "".join(str(len(x)) for x in fct_names) - fct_names = tuple(fct_name for fct_names_cfg in fct_names for fct_name in fct_names_cfg) - - part = "__".join(fct_names[:2]) - - if len(fct_names) > 2: - part += f"_{n_fct_per_config}_{law.util.create_hash(fct_names[2:])}" - - parts.insert_before("version", label, f"{label}__{part}") - if self.ml_model_inst: parts.insert_before("version", "ml_model", f"ml__{self.ml_model_repr}") return parts -class MLModelMixin(ConfigTask, MLModelMixinBase): +class MLModelMixin(MLModelMixinBase): + """ + A mixin for tasks that require a single machine learning model, e.g. for evaluation. + """ ml_model = luigi.Parameter( default=RESOLVE_DEFAULT, description="the name of the ML model to be applied; default: value of the " - "'default_ml_model' config", + "'default_ml_model' analysis aux", ) allow_empty_ml_model = True exclude_params_repr_empty = {"ml_model"} - @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: - params = super().resolve_param_values(params) - - # add the default ml model when empty - if "analysis_inst" in params and "config_inst" in params: - analysis_inst = params["analysis_inst"] - config_inst = params["config_inst"] - - params["ml_model"] = cls.resolve_config_default( - params, - params.get("ml_model"), - container=config_inst, - default_str="default_ml_model", - multiple=False, - ) - - # initialize it once to trigger its set_config hook which might, in turn, - # add objects to the config itself - if params.get("ml_model") not in (None, law.NO_STR): + @classmethod + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + # # add the default ml model when empty + params["ml_model"] = cls.resolve_config_default( + param=params.get("ml_model"), + task_params=params, + container=params["analysis_inst"], + default_str="default_ml_model", + multi_strategy="same", + ) + + # when both config_inst and ml_model are set, initialize the ml_model_inst + if all(params.get(x) not in {None, law.NO_STR} for x in ("config_inst", "ml_model")): + if not params.get("ml_model_inst"): params["ml_model_inst"] = cls.get_ml_model_inst( params["ml_model"], - analysis_inst, - requested_configs=[config_inst], - parameters=params["ml_model_settings"], + params["analysis_inst"], + requested_configs=[params["config_inst"]], ) - elif not cls.allow_empty_ml_model: - raise Exception(f"no ml_model configured for {cls.task_family}") + elif not cls.allow_empty_ml_model: + raise Exception(f"no ml_model configured for {cls.task_family}") return params - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # get the ML model instance - self.ml_model_inst = None - if self.ml_model != law.NO_STR: - self.ml_model_inst = self.get_ml_model_inst( - self.ml_model, - self.analysis_inst, - requested_configs=[self.config_inst], - parameters=self.ml_model_settings, - ) - def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() @@ -1721,8 +1446,55 @@ def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: return columns -class MLModelDataMixin(MLModelMixin): +class PreparationProducerMixin(ArrayFunctionInstanceMixin, MLModelMixin): + + preparation_producer_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) + + exclude_params_index = {"preparation_producer_inst"} + exclude_params_repr = {"preparation_producer_inst"} + exclude_params_sandbox = {"preparation_producer_inst"} + exclude_params_remote_workflow = {"preparation_producer_inst"} + + @classmethod + def invokes_preparation_producer(cls, params) -> bool: + return False + + @classmethod + def get_producer_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) + + build_producer_inst = ProducerMixin.build_producer_inst + + def _array_function_post_init(self, **kwargs) -> None: + if self.preparation_producer_inst: + self.preparation_producer_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) + + def teardown_preparation_producer_inst(self) -> None: + if self.preparation_producer_inst: + self.preparation_producer_inst.run_teardown(task=self) + + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + ml_model_inst = params["ml_model_inst"] + + if cls.invokes_preparation_producer(params): + preparation_producer = ml_model_inst.preparation_producer(params["analysis_inst"]) + # add the producer instance + if preparation_producer and not params.get("preparation_producer_inst"): + params["preparation_producer_inst"] = cls.build_producer_inst(preparation_producer, params) + + params = super().resolve_instances(params, shifts) + + return params + + +class MLModelDataMixin(PreparationProducerMixin): + single_config = True allow_empty_ml_model = False def store_parts(self) -> law.util.InsertableDict: @@ -1740,36 +1512,36 @@ class MLModelsMixin(ConfigTask): ml_models = law.CSVParameter( default=(RESOLVE_DEFAULT,), - description="comma-separated names of ML models to be applied; default: value of the " - "'default_ml_model' config", + description="comma-separated names of ML models to be applied; default: value of the 'default_ml_model' config", brace_expand=True, parse_empty=True, ) - allow_empty_ml_models = True - exclude_params_repr_empty = {"ml_models"} + allow_empty_ml_models = True + @property - def ml_models_repr(self): - """Returns a string representation of the ML models.""" - ml_models_repr = "__".join([str(model_inst) for model_inst in self.ml_model_insts]) - return ml_models_repr + def ml_models_repr(self) -> str: + """ + Returns a string representation of the ML models. + """ + return self.build_repr(tuple(map(str, self.ml_model_insts))) @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: - params = super().resolve_param_values(params) + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + # NOTE: at the moment, the ml_models will be initialized before CSPs are initialized + params = super().resolve_param_values_pre_init(params) - analysis_inst = params.get("analysis_inst") - config_inst = params.get("config_inst") - if analysis_inst and config_inst: + if (container := cls._get_config_container(params)): # apply ml_model_groups and default_ml_model from the config params["ml_models"] = cls.resolve_config_default_and_groups( - params, - params.get("ml_models"), - container=config_inst, + param=params.get("ml_models"), + task_params=params, + container=container, default_str="default_ml_model", groups_str="ml_model_groups", + multi_strategy="same", ) # special case: initialize them once to trigger their set_config hook @@ -1777,8 +1549,8 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params["ml_model_insts"] = [ MLModelMixinBase.get_ml_model_inst( ml_model, - analysis_inst, - requested_configs=[config_inst], + params["analysis_inst"], + requested_configs=[params["config"]] if cls.has_single_config() else params["configs"], ) for ml_model in params["ml_models"] ] @@ -1794,18 +1566,24 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict: return super().req_params(inst, **kwargs) + @property + def ml_model_insts(self) -> list[MLModel]: + if self._ml_model_insts is None: + self._ml_model_insts = [ + MLModelMixinBase.get_ml_model_inst( + ml_model, + self.analysis_inst, + requested_configs=[self.config] if self.single_config else self.config, + ) + for ml_model in self.ml_models + ] + return self._ml_model_insts + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # get the ML model instances - self.ml_model_insts = [ - MLModelMixinBase.get_ml_model_inst( - ml_model, - self.analysis_inst, - requested_configs=[self.config_inst], - ) - for ml_model in self.ml_models - ] + # cache for ml model insts + self._ml_model_insts = None def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() @@ -1827,35 +1605,207 @@ def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: return columns -class InferenceModelMixin(ConfigTask): +class HistProducerClassMixin(ArrayFunctionClassMixin): + """ + Mixin to include and access single :py:class:`~columnflow.histogramming.HistProducer` class. + """ + + hist_producer = luigi.Parameter( + default=RESOLVE_DEFAULT, + description="the name of the hist producer to be applied; default: value of the 'default_hist_producer' config", + ) + + @classmethod + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + # resolve the default class if necessary + if (container := cls._get_config_container(params)): + params["hist_producer"] = cls.resolve_config_default( + param=params.get("hist_producer"), + task_params=params, + container=container, + default_str="default_hist_producer", + multi_strategy="same", + ) + + # !! to be removed in a future release + if not params["hist_producer"]: + # fallback to cf's default and trigger a verbose warning + params["hist_producer"] = "cf_default" + docs_url = get_docs_url("user_guide", "02_03_transition.html") + code_url = get_code_url("columnflow", "histogramming", "default.py") + logger.warning_once( + "hist_producer_undefined", + "the resolution of the '--hist-producer' parameter resulted in an empty value, most likely caused " + f"by a missing auxiliary field 'default_hist_producer' in your configuration; see {docs_url} for " + f"more information; using '{params['hist_producer']}' ({code_url}) as a fallback", + ) + + return params + + @classmethod + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: + # prefer --hist-producer set on task-level via cli + kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"hist_producer"} + return super().req_params(inst, **kwargs) + + @property + def hist_producer_repr(self) -> str: + """ + Return a string representation of the hist producer class. + """ + return self.build_repr(self.array_function_cls_repr(self.hist_producer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "hist_producer", f"hist__{self.hist_producer_repr}") + return parts + + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: HistProducerClassMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the hist producer name + producer = ( + inst_or_params.get("hist_producer") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "hist_producer", None) + ) + if producer not in (law.NO_STR, None, ""): + keys["hist_producer"] = f"hist_{producer}" + + return keys + + +class HistProducerMixin(ArrayFunctionInstanceMixin, HistProducerClassMixin): + """ + Mixin to include and access a single :py:class:`~columnflow.histogramming.HistProducer` instance. + """ + + hist_producer_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) + + exclude_params_index = {"hist_producer_inst"} + exclude_params_repr = {"hist_producer_inst"} + exclude_params_sandbox = {"hist_producer_inst"} + exclude_params_remote_workflow = {"hist_producer_inst"} + + # decides whether the task itself invokes the hist_producer + invokes_hist_producer = False + + @classmethod + def get_hist_producer_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) + + @classmethod + def build_hist_producer_inst( + cls, + hist_producer: str, + params: dict[str, Any] | None = None, + ) -> Producer: + """ + Instantiate and return the :py:class:`~columnflow.histogramming.HistProducer` instance. + + :param producer: Name of the hist producer class to instantiate. + :param params: Arguments forwarded to the hist producer constructor. + :raises RuntimeError: If the hist producer class is not + :py:attr:`~columnflow.histogramming.HistProducer.exposed`. + :return: The hist producer instance. + """ + hist_producer_cls = HistProducer.get_cls(hist_producer) + if not hist_producer_cls.exposed: + raise RuntimeError(f"cannot use unexposed hist_producer '{hist_producer}' in {cls.__name__}") + + inst_dict = cls.get_hist_producer_dict(params) if params else None + return hist_producer_cls(inst_dict=inst_dict) + + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + # add the hist producer instance + if not params.get("hist_producer_inst"): + params["hist_producer_inst"] = cls.build_hist_producer_inst( + params["hist_producer"], + params, + ) + + params = super().resolve_instances(params, shifts) + + return params + + @classmethod + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: + """ + Updates the set of known *shifts* implemented by *this* and upstream tasks. + + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. + """ + # get the hist producer, update it and add its shifts + hist_producer_shifts = params["hist_producer_inst"].all_shifts + (shifts.local if cls.invokes_hist_producer else shifts.upstream).update(hist_producer_shifts) + + super().get_known_shifts(params, shifts) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # overwrite the sandbox when set + if self.invokes_hist_producer and (sandbox := self.hist_producer_inst.get_sandbox()): + self.reset_sandbox(sandbox) + + def _array_function_post_init(self, **kwargs) -> None: + self.hist_producer_inst.run_post_init(task=self, **kwargs) + super()._array_function_post_init(**kwargs) + + def teardown_hist_producer_inst(self) -> None: + if self.hist_producer_inst: + self.hist_producer_inst.run_teardown(task=self) + + @property + def hist_producer_repr(self) -> str: + """ + Return a string representation of the hist producer instance. + """ + return self.build_repr(self.array_function_inst_repr(self.hist_producer_inst)) + + +class InferenceModelClassMixin(ConfigTask): inference_model = luigi.Parameter( default=RESOLVE_DEFAULT, - description="the name of the inference model to be used; default: value of the " - "'default_inference_model' config", + description="the name of the inference model to be used; default: value of the 'default_inference_model' " + "config", ) @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: - params = super().resolve_param_values(params) + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) # add the default inference model when empty - config_inst = params.get("config_inst") - if config_inst: + if (container := cls._get_config_container(params)): params["inference_model"] = cls.resolve_config_default( - params, - params.get("inference_model"), - container=config_inst, + param=params.get("inference_model"), + task_params=params, + container=container, default_str="default_inference_model", - multiple=False, + multi_strategy="same", ) return params - @classmethod - def get_inference_model_inst(cls, inference_model: str, config_inst: od.Config) -> InferenceModel: - return InferenceModel.get_cls(inference_model)(config_inst) - @classmethod def req_params(cls, inst: law.Task, **kwargs) -> dict: # prefer --inference-model set on task-level via cli @@ -1863,30 +1813,106 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict: return super().req_params(inst, **kwargs) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # get the inference model instance - self.inference_model_inst = self.get_inference_model_inst(self.inference_model, self.config_inst) - @property def inference_model_repr(self): return str(self.inference_model) def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ parts = super().store_parts() if self.inference_model != law.NO_STR: - parts.insert_before("version", "inf_model", f"inf__{self.inference_model_repr}") + parts.insert_after(self.config_store_anchor, "inf_model", f"inf__{self.inference_model_repr}") return parts +class InferenceModelMixin(InferenceModelClassMixin): + + inference_model_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) + + exclude_params_index = {"inference_model_inst"} + exclude_params_repr = {"inference_model_inst"} + exclude_params_sandbox = {"inference_model_inst"} + exclude_params_remote_workflow = {"inference_model_inst"} + + @classmethod + def build_inference_model_inst( + cls, + inference_model: str, + config_insts: list[od.Config], + **kwargs, + ) -> InferenceModel: + """ + Instantiate and return the :py:class:`~columnflow.inference.InferenceModel` instance. + + :param inference_model: Name of the inference model class to instantiate. + :param config_insts: List of configuration objects that are passed to the inference model constructor. + :param kwargs: Additional keywork arguments forwarded to the inference model constructor. + :return: The inference model instance. + """ + inference_model_cls = InferenceModel.get_cls(inference_model) + return inference_model_cls(config_insts, **kwargs) + + @classmethod + def resolve_param_values_post_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_post_init(params) + + # add the inference model instance + if not params.get("inference_model_inst") and params.get("inference_model"): + if cls.has_single_config(): + if (config_inst := params.get("config_inst")): + params["inference_model_inst"] = cls.build_inference_model_inst( + params["inference_model"], + [config_inst], + ) + elif (config_insts := params.get("config_insts")): + params["inference_model_inst"] = cls.build_inference_model_inst( + params["inference_model"], + config_insts, + ) + + return params + + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + if not cls.resolution_task_cls: + raise ValueError(f"resolution_task_cls must be set for multi-config task {cls.task_family}") + + cls.get_known_shifts(params, shifts) + + # we loop over all configs/datasets, but return initial params + inference_model_cls = InferenceModel.get_cls(params["inference_model"]) + for i, config_inst in enumerate(params["config_insts"]): + datasets = inference_model_cls.used_datasets(config_inst) + + for dataset in datasets: + # NOTE: we need to copy here, because otherwise taf inits will only be triggered once + _params = { + **params, + "config_inst": config_inst, + "config": config_inst.name, + "dataset": dataset, + } + logger_dev.debug(f"building taf insts for {config_inst.name}, {dataset}") + cls.resolution_task_cls.resolve_instances(_params, shifts) + cls.resolution_task_cls.get_known_shifts(_params, shifts) + + params["known_shifts"] = shifts + + return params + + class CategoriesMixin(ConfigTask): categories = law.CSVParameter( - default=(), - description="comma-separated category names or patterns to select; can also be the key of " - "a mapping defined in 'category_groups' auxiliary data of the config; when empty, uses the " - "auxiliary data enty 'default_categories' when set; empty default", + default=(RESOLVE_DEFAULT,), + description="comma-separated category names or patterns to select; can also be the key of a mapping defined in " + "'category_groups' auxiliary data of the config; when empty, uses the auxiliary data enty 'default_categories' " + "when set; empty default", brace_expand=True, parse_empty=True, ) @@ -1895,31 +1921,36 @@ class CategoriesMixin(ConfigTask): allow_empty_categories = False @classmethod - def resolve_param_values(cls, params): - params = super().resolve_param_values(params) - - if "config_inst" not in params: + def resolve_param_values_post_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_post_init(params) + if "analysis_inst" not in params or "config_insts" not in params: return params - config_inst = params["config_inst"] # resolve categories - if "categories" in params: - # when empty, use the config default - if not params["categories"] and config_inst.x("default_categories", ()): - params["categories"] = tuple(config_inst.x.default_categories) - - # when still empty and default categories are defined, use them instead - if not params["categories"] and cls.default_categories: - params["categories"] = tuple(cls.default_categories) - - # resolve them - categories = cls.find_config_objects( - params["categories"], - config_inst, - od.Category, - config_inst.x("category_groups", {}), - deep=True, - ) + if (categories := params.get("categories", law.no_value)) != law.no_value: + # when empty, use the ones defined on class level + if categories in ((), (RESOLVE_DEFAULT,)) and cls.default_categories: + categories = tuple(cls.default_categories) + + # additional resolution and expansion requires a config + if (container := cls._get_config_container(params)): + # when still empty, get the config default + categories = cls.resolve_config_default( + param=params.get("categories"), + task_params=params, + container=container, + default_str="default_categories", + multi_strategy="union", + ) + # resolve them + categories = cls.find_config_objects( + names=categories, + container=container, + object_cls=od.Category, + groups_str="category_groups", + deep=True, + multi_strategy="intersection", + ) # complain when no categories were found if not categories and not cls.allow_empty_categories: @@ -1930,20 +1961,19 @@ def resolve_param_values(cls, params): return params @property - def categories_repr(self): + def categories_repr(self) -> str: if len(self.categories) == 1: - return self.categories[0] - - return f"{len(self.categories)}_{law.util.create_hash(sorted(self.categories))}" + return self.build_repr(self.categories[0]) + return self.build_repr(self.categories, prepend_count=True) class VariablesMixin(ConfigTask): variables = law.CSVParameter( - default=(), - description="comma-separated variable names or patterns to select; can also be the key of " - "a mapping defined in the 'variable_group' auxiliary data of the config; when empty, uses " - "all variables of the config; empty default", + default=(RESOLVE_DEFAULT,), + description="comma-separated variable names or patterns to select; can also be the key of a mapping defined in " + "the 'variable_group' auxiliary data of the config; when empty, uses all variables of the config; empty " + "default", brace_expand=True, parse_empty=True, ) @@ -1953,64 +1983,49 @@ class VariablesMixin(ConfigTask): allow_missing_variables = False @classmethod - def resolve_param_values(cls, params): - params = super().resolve_param_values(params) + def resolve_param_values_post_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_post_init(params) - if "config_inst" not in params: + if "analysis_inst" not in params or "config_insts" not in params: return params - config_inst = params["config_inst"] # resolve variables - if "variables" in params: - # when empty, use the config default - if not params["variables"] and config_inst.x("default_variables", ()): - params["variables"] = tuple(config_inst.x.default_variables) - - # when still empty and default variables are defined, use them instead - if not params["variables"] and cls.default_variables: - params["variables"] = tuple(cls.default_variables) - - # resolve them - if params["variables"]: - # first, split into single- and multi-dimensional variables - single_vars = [] - multi_var_parts = [] - for variable in params["variables"]: - parts = cls.split_multi_variable(variable) - if len(parts) == 1: - single_vars.append(variable) - else: - multi_var_parts.append(parts) - - # resolve single variables - variables = cls.find_config_objects( - single_vars, - config_inst, - od.Variable, - config_inst.x("variable_groups", {}), - strict=not cls.allow_missing_variables, + if (variables := params.get("variables", law.no_value)) != law.no_value: + # when empty, use the ones defined on class level + if variables in {(), (RESOLVE_DEFAULT,)} and cls.default_variables: + variables = tuple(cls.default_variables) + + # additional resolution and expansion requires a config + if (container := cls._get_config_container(params)): + # when still empty, get the config default + variables = cls.resolve_config_default_and_groups( + param=params.get("variables"), + task_params=params, + container=container, + default_str="default_variables", + groups_str="variable_groups", + multi_strategy="union", ) - - # for each multi-variable, resolve each part separately and create the full - # combinatorics of all possibly pattern-resolved parts - for parts in multi_var_parts: + # since there can be multi-dimensional variables, resolve each part separately + resolved_variables = set() + for variable in variables: resolved_parts = [ cls.find_config_objects( - part, - config_inst, - od.Variable, - config_inst.x("variable_groups", {}), - strict=not cls.allow_missing_variables, + names=part, + container=container, + object_cls=od.Variable, + groups_str="variable_groups", + multi_strategy="intersection", ) - for part in parts + for part in cls.split_multi_variable(variable) ] - variables.extend([ - cls.join_multi_variable(_parts) - for _parts in itertools.product(*resolved_parts) - ]) - else: - # fallback to using all known variables - variables = config_inst.variables.names() + # build combinatrics + resolved_variables.update(map(cls.join_multi_variable, itertools.product(*resolved_parts))) + variables = resolved_variables + + # when still empty, fallback to using all known variables + if not variables: + variables = sorted(set.intersection(*(set(c.variables.names()) for c in law.util.make_list(container)))) # complain when no variables were found if not variables and not cls.allow_empty_variables: @@ -2036,7 +2051,7 @@ def join_multi_variable(cls, variables: Sequence[str]) -> str: """ return "-".join(map(str, variables)) - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # if enabled, split names of multi-dimensional parameters into tuples @@ -2046,29 +2061,45 @@ def __init__(self, *args, **kwargs): } @property - def variables_repr(self): + def variables_repr(self) -> str: if len(self.variables) == 1: - return self.variables[0] - - return f"{len(self.variables)}_{law.util.create_hash(sorted(self.variables))}" + return self.build_repr(self.variables[0]) + return self.build_repr(sorted(self.variables), prepend_count=True) class DatasetsProcessesMixin(ConfigTask): datasets = law.CSVParameter( default=(), - description="comma-separated dataset names or patters to select; can also be the key of a " - "mapping defined in the 'dataset_groups' auxiliary data of the config; when empty, uses " - "all datasets registered in the config that contain any of the selected --processes; empty " - "default", + description="comma-separated dataset names or patters to select; can also be the key of a mapping defined in " + "the 'dataset_groups' auxiliary data of the config; when empty, uses all datasets registered in the config " + "that contain any of the selected --processes; empty default", + brace_expand=True, + parse_empty=True, + ) + datasets_multi = law.MultiCSVParameter( + default=(), + description="multiple comma-separated dataset names or patters to select per config object, each separated by " + "a colon; when only one sequence is passed, it is applied to all configs; values can also be the key of a " + "mapping defined in " "the 'dataset_groups' auxiliary data of the specific config; when empty, uses all " + "datasets registered in the config that contain any of the selected --processes; empty default", brace_expand=True, parse_empty=True, ) processes = law.CSVParameter( default=(), - description="comma-separated process names or patterns for filtering processes; can also " - "be the key of a mapping defined in the 'process_groups' auxiliary data of the config; " - "uses all processes of the config when empty; empty default", + description="comma-separated process names or patterns for filtering processes; can also be the key of a " + "mapping defined in the 'process_groups' auxiliary data of the config; uses all processes of the config when " + "empty; empty default", + brace_expand=True, + parse_empty=True, + ) + processes_multi = law.MultiCSVParameter( + default=(), + description="multiple comma-separated process names or patters for filtering processing per config object, " + "each separated by a colon; when only one sequence is passed, it is applied to all configs; values can also be " + "the key of a mapping defined in the 'process_groups' auxiliary data of the specific config; uses all " + "processes of the config when empty; empty default", brace_expand=True, parse_empty=True, ) @@ -2077,96 +2108,158 @@ class DatasetsProcessesMixin(ConfigTask): allow_empty_processes = False @classmethod - def resolve_param_values(cls, params): - params = super().resolve_param_values(params) + def modify_task_attributes(cls) -> None: + super().modify_task_attributes() + # single/multi config adjustments in case the switch has been specified + if isinstance(cls.single_config, bool) and getattr(cls, "datasets_multi", None) is not None: + if not cls.has_single_config(): + cls.datasets = cls.datasets_multi + cls.processes = cls.processes_multi + cls.datasets_multi = None + cls.processes_multi = None - if "config_inst" not in params: - return params - config_inst = params["config_inst"] - - # resolve processes - if "processes" in params: - if params["processes"]: - processes = cls.find_config_objects( - params["processes"], - config_inst, - od.Process, - config_inst.x("process_groups", {}), - deep=True, - ) - else: - processes = config_inst.processes.names() - - # complain when no processes were found - if not processes and not cls.allow_empty_processes: - raise ValueError(f"no processes found matching {params['processes']}") - - params["processes"] = tuple(processes) - params["process_insts"] = [config_inst.get_process(p) for p in params["processes"]] - - # resolve datasets - if "datasets" in params: - if params["datasets"]: - datasets = cls.find_config_objects( - params["datasets"], - config_inst, - od.Dataset, - config_inst.x("dataset_groups", {}), - ) - elif "processes" in params: - # pick all datasets that contain any of the requested (sub) processes - sub_process_insts = sum(( - [proc for proc, _, _ in process_inst.walk_processes(include_self=True)] - for process_inst in map(config_inst.get_process, params["processes"]) - ), []) - datasets = [ - dataset_inst.name - for dataset_inst in config_inst.datasets - if any(map(dataset_inst.has_process, sub_process_insts)) - ] + @classmethod + def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_pre_init(params) + + # helper to resolve processes and datasets for one config + def resolve(config_inst: od.Config, processes: Any, datasets: Any) -> tuple[list[str], list[str]]: + if processes != law.no_value: + processes_orig = processes + if processes: + processes = cls.find_config_objects( + names=processes, + container=config_inst, + object_cls=od.Process, + groups_str="process_groups", + deep=True, + ) + else: + processes = config_inst.processes.names() + if not processes and not cls.allow_empty_processes: + raise ValueError(f"no processes found matching {processes_orig}") + if datasets != law.no_value: + datasets_orig = datasets + if datasets: + datasets = cls.find_config_objects( + names=datasets, + container=config_inst, + object_cls=od.Dataset, + groups_str="dataset_groups", + ) + elif processes and processes != law.no_value: + # pick all datasets that contain any of the requested (sub)processes + sub_process_insts = sum(( + [proc for proc, _, _ in process_inst.walk_processes(include_self=True)] + for process_inst in map(config_inst.get_process, processes) + ), []) + datasets = [ + dataset_inst.name for dataset_inst in config_inst.datasets + if any(map(dataset_inst.has_process, sub_process_insts)) + ] + if not datasets and not cls.allow_empty_datasets: + raise ValueError(f"no datasets found matching {datasets_orig}") + + return (processes, datasets) + + # get processes and datasets + single_config = cls.has_single_config() + processes = (params.get("processes", law.no_value),) if single_config else params.get("processes", ()) + datasets = (params.get("datasets", law.no_value),) if single_config else params.get("datasets", ()) + + # "broadcast" to match number of configs + config_insts = params.get("config_insts") + processes = cls.broadcast_to_configs(processes, "processes", len(config_insts)) + datasets = cls.broadcast_to_configs(datasets, "datasets", len(config_insts)) + + # perform resolution per config + multi_processes = [] + multi_datasets = [] + for config_inst, _processes, _datasets in zip(config_insts, processes, datasets): + _processes, _datasets = resolve(config_inst, _processes, _datasets) + multi_processes.append(tuple(_processes) if _processes != law.no_value else None) + multi_datasets.append(tuple(_datasets) if _datasets != law.no_value else None) + + # store params + params["processes"] = multi_processes[0] if single_config else tuple(multi_processes) + params["datasets"] = multi_datasets[0] if single_config else tuple(multi_datasets) + + # store instances + params["process_insts"] = { + config_inst: list(map(config_inst.get_process, processes)) + for config_inst, processes in zip(config_insts, multi_processes) + } + params["dataset_insts"] = { + config_inst: list(map(config_inst.get_dataset, datasets)) + for config_inst, datasets in zip(config_insts, multi_datasets) + } + return params + + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: + if not cls.resolution_task_cls: + raise ValueError(f"resolution_task_cls must be set for multi-config task {cls.task_family}") - # complain when no datasets were found - if not datasets and not cls.allow_empty_datasets: - raise ValueError(f"no datasets found matching {params['datasets']}") + cls.get_known_shifts(params, shifts) - params["datasets"] = tuple(datasets) - params["dataset_insts"] = [config_inst.get_dataset(d) for d in params["datasets"]] + # we loop over all configs/datasets, but return initial params + for i, config_inst in enumerate(params["config_insts"]): + if cls.has_single_config(): + datasets = params["datasets"] + else: + datasets = params["datasets"][i] + + for dataset in datasets: + # NOTE: we need to copy here, because otherwise taf inits will only be triggered once + _params = { + **params, + "config_inst": config_inst, + "config": config_inst.name, + "dataset": dataset, + } + logger_dev.debug(f"building taf insts for {config_inst.name}, {dataset}") + cls.resolution_task_cls.resolve_instances(_params, shifts) + cls.resolution_task_cls.get_known_shifts(_params, shifts) + + params["known_shifts"] = shifts return params @classmethod - def get_known_shifts(cls, config_inst, params): - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) + def get_known_shifts( + cls, + params: dict[str, Any], + shifts: TaskShifts, + ) -> None: + """ + Updates the set of known *shifts* implemented by *this* and upstream tasks. + :param params: Dictionary of task parameters. + :param shifts: TaskShifts object to adjust. + """ # add shifts of all datasets to upstream ones - for dataset_inst in params.get("dataset_insts") or []: - if dataset_inst.is_mc: - upstream_shifts |= set(dataset_inst.info.keys()) + for config_inst, dataset_insts in params["dataset_insts"].items(): + for dataset_inst in dataset_insts: + if dataset_inst.is_mc: + shifts.upstream |= set(dataset_inst.info.keys()) - return shifts, upstream_shifts + super().get_known_shifts(params, shifts) @property - def datasets_repr(self): - if len(self.datasets) == 1: - return self.datasets[0] - - return f"{len(self.datasets)}_{law.util.create_hash(sorted(self.datasets))}" + def datasets_repr(self) -> str: + return self._multi_sequence_repr(self.datasets, sort=True) @property - def processes_repr(self): - if len(self.processes) == 1: - return self.processes[0] - - return f"{len(self.processes)}_{law.util.create_hash(self.processes)}" + def processes_repr(self) -> str: + return self._multi_sequence_repr(self.processes, sort=True) class ShiftSourcesMixin(ConfigTask): shift_sources = law.CSVParameter( default=(), - description="comma-separated shift source names (without direction) or patterns to select; " - "can also be the key of a mapping defined in the 'shift_group' auxiliary data of the " - "config; default: ()", + description="comma-separated shift source names (without direction) or patterns to select; can also be the key " + "of a mapping defined in the 'shift_group' auxiliary data of the config; default: ()", brace_expand=True, parse_empty=True, ) @@ -2174,29 +2267,41 @@ class ShiftSourcesMixin(ConfigTask): allow_empty_shift_sources = False @classmethod - def resolve_param_values(cls, params): - params = super().resolve_param_values(params) - - if "config_inst" not in params: - return params - config_inst = params["config_inst"] + def resolve_param_values_post_init(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values_post_init(params) # resolve shift sources - if "shift_sources" in params: - # convert to full shift first to do the object finding + if (container := cls._get_config_container(params)) and "shift_sources" in params: shifts = cls.find_config_objects( - cls.expand_shift_sources(params["shift_sources"]), - config_inst, - od.Shift, - config_inst.x("shift_groups", {}), + names=cls.expand_shift_sources(params["shift_sources"]), + container=container, + object_cls=od.Shift, + groups_str="shift_groups", + multi_strategy="union", # or "intersection"? ) - # complain when no shifts were found - if not shifts and not cls.allow_empty_shift_sources: + # convert back to sources and validate + sources = [] + if shifts: + sources = cls.reduce_shifts(shifts) + + # # reduce shifts based on known shifts + if "known_shifts" not in params: + raise ValueError("known_shifts must be set before resolving shift sources") + sources = [ + source for source in sources + if ( + f"{source}_up" in params["known_shifts"].upstream and + f"{source}_down" in params["known_shifts"].upstream + ) + ] + + # complain when no sources were found + if not sources and not cls.allow_empty_shift_sources: raise ValueError(f"no shifts found matching {params['shift_sources']}") - # convert back to sources - params["shift_sources"] = tuple(cls.reduce_shifts(shifts)) + # store them + params["shift_sources"] = tuple(sources) return params @@ -2208,123 +2313,37 @@ def expand_shift_sources(cls, sources: Sequence[str] | set[str]) -> list[str]: def reduce_shifts(cls, shifts: Sequence[str] | set[str]) -> list[str]: return list(set(od.Shift.split_name(shift)[0] for shift in shifts)) - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.shifts = self.expand_shift_sources(self.shift_sources) @property - def shift_sources_repr(self): + def shift_sources_repr(self) -> str: + if not self.shift_sources: + return "none" if len(self.shift_sources) == 1: - return self.shift_sources[0] - - return f"{len(self.shift_sources)}_{law.util.create_hash(sorted(self.shift_sources))}" - - -class WeightProducerMixin(ConfigTask): - - weight_producer = luigi.Parameter( - default=RESOLVE_DEFAULT, - description="the name of the weight producer to be used; default: value of the " - "'default_weight_producer' config", - ) - - # decides whether the task itself runs the weight producer and implements its shifts - register_weight_producer_sandbox = False - register_weight_producer_shifts = False - - @classmethod - def get_weight_producer_inst( - cls, - weight_producer: str, - kwargs: dict | None = None, - ) -> WeightProducer: - weight_producer_cls = WeightProducer.get_cls(weight_producer) - if not weight_producer_cls.exposed: - raise RuntimeError( - f"cannot use unexposed weight producer '{weight_producer}' in {cls.__name__}", - ) - - inst_dict = cls.get_weight_producer_kwargs(**kwargs) if kwargs else None - return weight_producer_cls(inst_dict=inst_dict) - - @classmethod - def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: - params = super().resolve_param_values(params) - - config_inst = params.get("config_inst") - if config_inst: - # add the default weight producer when empty - params["weight_producer"] = cls.resolve_config_default( - params, - params.get("weight_producer"), - container=config_inst, - default_str="default_weight_producer", - multiple=False, - ) - if params["weight_producer"] is None: - raise Exception(f"no weight producer configured for task {cls.task_family}") - params["weight_producer_inst"] = cls.get_weight_producer_inst( - params["weight_producer"], - params, - ) - - return params - - @classmethod - def get_known_shifts( - cls, - config_inst: od.Config, - params: dict[str, Any], - ) -> tuple[set[str], set[str]]: - shifts, upstream_shifts = super().get_known_shifts(config_inst, params) - - # get the weight producer, update it and add its shifts - weight_producer_inst = params.get("weight_producer_inst") - if weight_producer_inst: - if cls.register_weight_producer_shifts: - shifts |= weight_producer_inst.all_shifts - else: - upstream_shifts |= weight_producer_inst.all_shifts - - return shifts, upstream_shifts - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + return self.build_repr(self.shift_sources[0]) + return self.build_repr(sorted(self.shift_sources), prepend_count=True) - # cache for weight producer inst - self._weight_producer_inst = None - - @property - def weight_producer_inst(self) -> WeightProducer: - if self._weight_producer_inst is None: - self._weight_producer_inst = self.get_weight_producer_inst( - self.weight_producer, - {"task": self}, - ) + def store_parts(self) -> law.util.InsertableDict: + parts = super().store_parts() + parts.insert_before("calibrators", "shift_sources", f"shifts__{self.shift_sources_repr}") + return parts - # overwrite the sandbox when set - if self.register_weight_producer_sandbox: - sandbox = self._weight_producer_inst.get_sandbox() - if sandbox: - self.sandbox = sandbox - # rebuild the sandbox inst when already initialized - if self._sandbox_initialized: - self._initialize_sandbox(force=True) - return self._weight_producer_inst +class DatasetShiftSourcesMixin(ShiftSourcesMixin, DatasetTask): - @property - def weight_producer_repr(self) -> str: - return str(self.weight_producer_inst) + # disable the shift parameter + shift = None + effective_shift = None + allow_empty_shift = True - def store_parts(self: WeightProducerMixin) -> law.util.InsertableDict[str, str]: - parts = super().store_parts() - parts.insert_before("version", "weightprod", f"weight__{self.weight_producer_repr}") - return parts + # allow empty sources, i.e., using only nominal + allow_empty_shift_sources = True -class ChunkedIOMixin(AnalysisTask): +class ChunkedIOMixin(ConfigTask): check_finite_output = luigi.BoolParameter( default=False, @@ -2359,10 +2378,7 @@ def raise_if_not_finite(cls, ak_array: ak.Array) -> None: for route in get_ak_routes(ak_array): if ak.any(~np.isfinite(ak.flatten(route.apply(ak_array), axis=None))): - raise ValueError( - f"found one or more non-finite values in column '{route.column}' " - f"of array {ak_array}", - ) + raise ValueError(f"found one or more non-finite values in column '{route.column}' of array {ak_array}") @classmethod def raise_if_overlapping(cls, ak_arrays: Sequence[ak.Array]) -> None: @@ -2442,29 +2458,45 @@ class HistHookMixin(ConfigTask): hist_hooks = law.CSVParameter( default=(), - description="names of functions in the config's auxiliary dictionary 'hist_hooks' that are " - "invoked before plotting to update a potentially nested dictionary of histograms; " - "default: empty", + description="names of functions in the config's auxiliary dictionary 'hist_hooks' that are invoked before " + "plotting to update a potentially nested dictionary of histograms; default: empty", ) - def invoke_hist_hooks(self, hists: dict) -> dict: + def _get_hist_hook(self, name: str) -> Callable: + func = None + if not self.has_single_config(): + # only check the analysis + func = self.analysis_inst.x("hist_hooks", {}).get(name) + elif not (func := self.config_inst.x("hist_hooks", {}).get(name)): + # check the config, fallback to the analysis + func = self.analysis_inst.x("hist_hooks", {}).get(name) + + if not func: + raise KeyError( + f"hist hook '{name}' not found in 'hist_hooks' for {self.config_mode()} config task {self!r}", + ) + + return func + + def invoke_hist_hooks( + self, + hists: dict[od.Config, dict[od.Process, Any]], + ) -> dict[od.Config, dict[od.Process, Any]]: """ - Invoke hooks to update histograms before plotting. + Invoke hooks to modify histograms before further processing such as plotting. """ if not self.hist_hooks: return hists + # apply hooks in order for hook in self.hist_hooks: - if hook in (None, "", law.NO_STR): + if hook in {None, "", law.NO_STR}: continue - # get the hook from the config instance - hooks = self.config_inst.x("hist_hooks", {}) - if hook not in hooks: - raise KeyError( - f"hist hook '{hook}' not found in 'hist_hooks' auxiliary entry of config", - ) - func = hooks[hook] + # get the hook + func = self._get_hist_hook(hook) + + # validate it if not callable(func): raise TypeError(f"hist hook '{hook}' is not callable: {func}") @@ -2479,10 +2511,13 @@ def hist_hooks_repr(self) -> str: """ Return a string representation of the hist hooks. """ - hooks = [hook for hook in self.hist_hooks if hook not in (None, "", law.NO_STR)] + # prepare names + names = [name for name in self.hist_hooks if name not in {None, "", law.NO_STR}] - hooks_repr = "__".join(hooks[:5]) - if len(hooks) > 5: - hooks_repr += f"__{law.util.create_hash(hooks[5:])}" + # lookup the functions for an alternative store_name + names = [ + getattr(self._get_hist_hook(name), "store_name", name) + for name in names + ] - return hooks_repr + return self.build_repr(names) diff --git a/columnflow/tasks/framework/parameters.py b/columnflow/tasks/framework/parameters.py index e3a0da7ee..b12d7ba20 100644 --- a/columnflow/tasks/framework/parameters.py +++ b/columnflow/tasks/framework/parameters.py @@ -11,8 +11,8 @@ import luigi import law -from columnflow.util import try_float, try_complex, DotDict -from columnflow.types import Iterable +from columnflow.util import try_float, try_complex, DotDict, Derivable +from columnflow.types import Iterable, Any user_parameter_inst = luigi.Parameter( @@ -37,6 +37,37 @@ ) +class DerivableInstParameter(luigi.Parameter): + """ + Parameter that can be used to pass the instance of a :py:class:`Derivable` subclass. + + This class does not implement parameter value parsing. + """ + + @classmethod + def _serialize(cls, x: Any) -> str: + if isinstance(x, Derivable): + return x.cls_name + return str(x) + + def serialize(self, x: Any) -> str: + return self._serialize(x) + + +class DerivableInstsParameter(luigi.Parameter): + """ + Parameter that can be used to pass multiple instances of a :py:class:`Derivable` subclass. + + This class does not implement parameter value parsing. + """ + + def serialize(self, x: Any) -> str: + """""" + if isinstance(x, (list, tuple)): + return ",".join(DerivableInstParameter._serialize(v) for v in x) + return str(x) + + class SettingsParameter(law.CSVParameter): """ Parameter that parses the input of a CSVParameter into a dictionary diff --git a/columnflow/tasks/framework/plotting.py b/columnflow/tasks/framework/plotting.py index a63fef391..f3b3c43c4 100644 --- a/columnflow/tasks/framework/plotting.py +++ b/columnflow/tasks/framework/plotting.py @@ -11,7 +11,7 @@ from columnflow.types import Any, Callable from columnflow.tasks.framework.base import ConfigTask, RESOLVE_DEFAULT -from columnflow.tasks.framework.mixins import DatasetsProcessesMixin, VariablesMixin +from columnflow.tasks.framework.mixins import VariablesMixin, DatasetsProcessesMixin from columnflow.tasks.framework.parameters import SettingsParameter, MultiSettingsParameter from columnflow.util import DotDict, dict_add_strict, ipython_shell @@ -85,11 +85,13 @@ class PlotBase(ConfigTask): def resolve_param_values(cls, params): params = super().resolve_param_values(params) - if "config_inst" not in params: + if "config_insts" not in params: return params - config_inst = params["config_inst"] + config_inst = params["config_insts"][0] # resolve general_settings + # NOTE: we currently assume that general_settings defaults and groups are the same for all + # config instances if "general_settings" in params: settings = params["general_settings"] # when empty and default general_settings are defined, use them instead @@ -254,6 +256,8 @@ def update_plot_kwargs(self, kwargs: dict) -> dict: if value is None: kwargs.pop(key) + config_inst = self.config_insts[0] + # set items of general_settings in kwargs if corresponding key is not yet present general_settings = kwargs.get("general_settings", {}) for key, value in general_settings.items(): @@ -262,9 +266,9 @@ def update_plot_kwargs(self, kwargs: dict) -> dict: # resolve custom_style_config custom_style_config = kwargs.get("custom_style_config", None) if custom_style_config == RESOLVE_DEFAULT: - custom_style_config = self.config_inst.x("default_custom_style_config", RESOLVE_DEFAULT) + custom_style_config = config_inst.x("default_custom_style_config", RESOLVE_DEFAULT) - groups = self.config_inst.x("custom_style_config_groups", {}) + groups = config_inst.x("custom_style_config_groups", {}) if isinstance(custom_style_config, str) and custom_style_config in groups.keys(): custom_style_config = groups[custom_style_config] @@ -280,7 +284,7 @@ def update_plot_kwargs(self, kwargs: dict) -> dict: # resolve blinding_threshold blinding_threshold = kwargs.get("blinding_threshold", None) if blinding_threshold is None: - blinding_threshold = self.config_inst.x("default_blinding_threshold", None) + blinding_threshold = config_inst.x("default_blinding_threshold", None) kwargs["blinding_threshold"] = blinding_threshold return kwargs @@ -315,10 +319,10 @@ class PlotBase1D(PlotBase): description="when True, each process is normalized on it's integral in the upper panel; " "default: None", ) - hide_errors = law.OptionalBoolParameter( + hide_stat_errors = law.OptionalBoolParameter( default=None, significant=False, - description="when True, no error bars/bands on histograms are drawn; default: None", + description="when True, no error bands for statistical uncertainty histograms are drawn; default: None", ) def get_plot_parameters(self) -> DotDict: @@ -328,7 +332,7 @@ def get_plot_parameters(self) -> DotDict: dict_add_strict(params, "density", self.density) dict_add_strict(params, "yscale", None if self.yscale == law.NO_STR else self.yscale) dict_add_strict(params, "shape_norm", self.shape_norm) - dict_add_strict(params, "hide_errors", self.hide_errors) + dict_add_strict(params, "hide_stat_errors", self.hide_stat_errors) return params @@ -413,8 +417,9 @@ def get_plot_parameters(self) -> DotDict: class ProcessPlotSettingMixin( - DatasetsProcessesMixin, + # TODO: could add back DatasetsProcessesMixin PlotBase, + DatasetsProcessesMixin, ): """ Mixin class for tasks creating plots where contributions of different processes are shown. @@ -434,11 +439,13 @@ class ProcessPlotSettingMixin( def resolve_param_values(cls, params): params = super().resolve_param_values(params) - if "config_inst" not in params: + if "config_insts" not in params: return params - config_inst = params["config_inst"] + config_inst = params["config_insts"][0] # resolve process_settings + # NOTE: we currently assume that process_settings defaults and groups are the same for all + # config instances if "process_settings" in params: settings = params["process_settings"] # when empty and default process_settings are defined, use them instead @@ -467,8 +474,8 @@ def get_plot_parameters(self) -> DotDict: class VariablePlotSettingMixin( - VariablesMixin, - PlotBase, + PlotBase, + VariablesMixin, ): """ Mixin class for tasks creating plots for multiple variables. @@ -488,11 +495,13 @@ class VariablePlotSettingMixin( def resolve_param_values(cls, params): params = super().resolve_param_values(params) - if "config_inst" not in params: + if "config_insts" not in params: return params - config_inst = params["config_inst"] + config_inst = params["config_insts"][0] # resolve variable_settings + # NOTE: we currently assume that variable_settings defaults and groups are the same for all + # config instances if "variable_settings" in params: settings = params["variable_settings"] # when empty and default variable_settings are defined, use them instead diff --git a/columnflow/tasks/framework/remote.py b/columnflow/tasks/framework/remote.py index 47a720ed0..bac3affdb 100644 --- a/columnflow/tasks/framework/remote.py +++ b/columnflow/tasks/framework/remote.py @@ -17,7 +17,9 @@ from columnflow import flavor as cf_flavor from columnflow.tasks.framework.base import Requirements, AnalysisTask from columnflow.tasks.framework.parameters import user_parameter_inst +from columnflow.tasks.framework.decorators import only_local_env from columnflow.util import UNSET, real_path +from columnflow.types import Any class BundleRepo(AnalysisTask, law.git.BundleGitRepository, law.tasks.TransferLocalFile): @@ -61,6 +63,7 @@ def get_file_pattern(self): def output(self): return law.tasks.TransferLocalFile.output(self) + @only_local_env @law.decorator.notify @law.decorator.log @law.decorator.safe_output @@ -92,6 +95,7 @@ def get_file_pattern(self): path = os.path.expandvars(os.path.expanduser(self.single_output().path)) return self.get_replicated_path(path, i=None if self.replicas <= 0 else r"[^\.]+") + @only_local_env @law.decorator.notify @law.decorator.log @law.decorator.safe_output @@ -171,6 +175,7 @@ def output(self): # note: invoking self.env will already trigger installing the sandbox return law.LocalFileTarget(self.env["CF_SANDBOX_FLAG_FILE"]) + @only_local_env def run(self): # no need to run anything as the sandboxing mechanism handles the installation return @@ -232,6 +237,7 @@ def get_file_pattern(self): path = os.path.expandvars(os.path.expanduser(self.single_output().path)) return self.get_replicated_path(path, i=None if self.replicas <= 0 else r"[^\.]+") + @only_local_env @law.decorator.notify @law.decorator.log @law.decorator.safe_output @@ -308,6 +314,7 @@ def get_file_pattern(self): path = os.path.expandvars(os.path.expanduser(self.single_output().path)) return self.get_replicated_path(path, i=None if self.replicas <= 0 else r"[^\.]+") + @only_local_env @law.decorator.notify @law.decorator.log def run(self): @@ -364,6 +371,24 @@ def __init__(self, *args, **kwargs) -> None: # container to store scheduler message handlers self._scheduler_message_handlers: dict[str, SchedulerMessageHandler] = {} + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: RemoteWorkflowMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the pilot flag + pilot = ( + inst_or_params.get("pilot") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "pilot", None) + ) + if pilot not in (law.NO_STR, None, ""): + keys["pilot"] = f"pilot_{pilot}" + + return keys + def add_bundle_requirements( self, reqs: dict[str, AnalysisTask], diff --git a/columnflow/tasks/framework/remote_bootstrap.sh b/columnflow/tasks/framework/remote_bootstrap.sh index 0e271765c..4d17e4c8f 100755 --- a/columnflow/tasks/framework/remote_bootstrap.sh +++ b/columnflow/tasks/framework/remote_bootstrap.sh @@ -35,10 +35,12 @@ bootstrap_htcondor_standalone() { # fix for missing voms/x509 variables in the lcg setup of the naf if [[ "${CF_HTCONDOR_FLAVOR}" = naf* ]]; then + echo "setting up X509 variables for ${CF_HTCONDOR_FLAVOR}" export X509_CERT_DIR="/cvmfs/grid.cern.ch/etc/grid-security/certificates" export X509_VOMS_DIR="/cvmfs/grid.cern.ch/etc/grid-security/vomsdir" export X509_VOMSES="/cvmfs/grid.cern.ch/etc/grid-security/vomses" export VOMS_USERCONF="/cvmfs/grid.cern.ch/etc/grid-security/vomses" + export CAPATH="/cvmfs/grid.cern.ch/etc/grid-security/certificates" fi # fallback to a default path when the externally given software base is empty or inaccessible diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index 9552719d9..57c3654b6 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -9,30 +9,38 @@ import luigi import law -from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory +from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, VariablesMixin, - ShiftSourcesMixin, WeightProducerMixin, ChunkedIOMixin, + CalibratorClassesMixin, CalibratorsMixin, SelectorClassMixin, SelectorMixin, ReducerClassMixin, ReducerMixin, + ProducerClassesMixin, ProducersMixin, VariablesMixin, DatasetShiftSourcesMixin, HistProducerClassMixin, + HistProducerMixin, ChunkedIOMixin, MLModelsMixin, ) from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.framework.parameters import last_edge_inclusive_inst +from columnflow.tasks.framework.decorators import on_failure from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.tasks.ml import MLEvaluation from columnflow.util import dev_sandbox -from columnflow.hist_util import create_hist_from_variables -class CreateHistograms( - VariablesMixin, - WeightProducerMixin, - MLModelsMixin, - ProducersMixin, +class _CreateHistograms( ReducedEventsUser, + ProducersMixin, + MLModelsMixin, + HistProducerMixin, ChunkedIOMixin, + VariablesMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`CreateHistograms`. + """ + + +class CreateHistograms(_CreateHistograms): + last_edge_inclusive = last_edge_inclusive_inst sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -45,21 +53,30 @@ class CreateHistograms( MLEvaluation=MLEvaluation, ) - # strategy for handling missing source columns when adding aliases on event chunks missing_column_alias_strategy = "original" - - # names of columns that contain category ids - # (might become a parameter at some point) category_id_columns = {"category_ids"} - - # register sandbox and shifts found in the chosen weight producer to this task - register_weight_producer_sandbox = True - register_weight_producer_shifts = True + invokes_hist_producer = True @law.util.classproperty def mandatory_columns(cls) -> set[str]: return set(cls.category_id_columns) | {"process_id"} + @classmethod + def check_histogram_compatibility(cls, h) -> None: + # expected axis names and types + import hist + expected = { + "category": hist.axis.StrCategory, + "shift": hist.axis.StrCategory, + "process": hist.axis.StrCategory, + } + axes = {ax.name: ax for ax in h.axes} + for axis_name, axis_type in expected.items(): + if not (ax := axes.get(axis_name)): + raise Exception(f"missing axis '{axis_name}' in histogram: {h}") + if not isinstance(ax, axis_type): + raise ValueError(f"axis '{axis_name}' must have type '{axis_type}', found '{type(ax)}'") + def workflow_requires(self): reqs = super().workflow_requires() @@ -69,7 +86,11 @@ def workflow_requires(self): if not self.pilot: if self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -79,8 +100,10 @@ def workflow_requires(self): for ml_model_inst in self.ml_model_insts ] - # add weight_producer dependent requirements - reqs["weight_producer"] = law.util.make_unique(law.util.flatten(self.weight_producer_inst.run_requires())) + # add hist_producer dependent requirements + reqs["hist_producer"] = law.util.make_unique(law.util.flatten( + self.hist_producer_inst.run_requires(task=self), + )) return reqs @@ -89,7 +112,11 @@ def requires(self): if self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -99,8 +126,10 @@ def requires(self): for ml_model_inst in self.ml_model_insts ] - # add weight_producer dependent requirements - reqs["weight_producer"] = law.util.make_unique(law.util.flatten(self.weight_producer_inst.run_requires())) + # add hist_producer dependent requirements + reqs["hist_producer"] = law.util.make_unique(law.util.flatten( + self.hist_producer_inst.run_requires(task=self), + )) return reqs @@ -114,23 +143,34 @@ def output(self): @law.decorator.log @law.decorator.localize(input=True, output=False) @law.decorator.safe_output + @on_failure(callback=lambda task: task.teardown_hist_producer_inst()) def run(self): import numpy as np import awkward as ak from columnflow.columnar_util import ( Route, update_ak_array, add_ak_aliases, has_ak_column, attach_coffea_behavior, ) - from columnflow.hist_util import fill_hist # prepare inputs inputs = self.input() + # get IDs and names of all leaf categories + leaf_category_map = { + cat.id: cat.name + for cat in self.config_inst.get_leaf_categories() + } + # declare output: dict of histograms histograms = {} - # run the weight_producer setup - producer_reqs = self.weight_producer_inst.run_requires() - reader_targets = self.weight_producer_inst.run_setup(producer_reqs, luigi.task.getpaths(producer_reqs)) + # run the hist_producer setup + self._array_function_post_init() + hist_producer_reqs = self.hist_producer_inst.run_requires(task=self) + reader_targets = self.hist_producer_inst.run_setup( + task=self, + reqs=hist_producer_reqs, + inputs=luigi.task.getpaths(hist_producer_reqs), + ) # create a temp dir for saving intermediate files tmp_dir = law.LocalDirectoryTarget(is_tmp=True) @@ -142,7 +182,7 @@ def run(self): # define columns that need to be read read_columns = {Route("process_id")} read_columns |= set(map(Route, self.category_id_columns)) - read_columns |= set(self.weight_producer_inst.used_columns) + read_columns |= set(self.hist_producer_inst.used_columns) read_columns |= set(map(Route, aliases.values())) read_columns |= { Route(inp) @@ -153,17 +193,15 @@ def run(self): for inp in (( {variable_inst.expression} if isinstance(variable_inst.expression, str) - # for variable_inst with custom expressions, read columns declared via aux key - else set(variable_inst.x("inputs", [])) - ) | ( - # for variable_inst with selection, read columns declared via aux key - set(variable_inst.x("inputs", [])) - if variable_inst.selection != "1" else set() + ) | set( + # read requested input columns if defined + variable_inst.x("inputs", []), )) } - # empty float array to use when input files have no entries + # empty arrays to use when input files have no entries + empty_i32 = ak.Array(np.array([], dtype=np.int32)) empty_f32 = ak.Array(np.array([], dtype=np.float32)) # iterate over chunks of events and diffs @@ -174,15 +212,12 @@ def run(self): file_targets.extend([inp["mlcolumns"] for inp in inputs["ml"]]) # prepare inputs for localization - with law.localize_file_targets( - [*file_targets, *reader_targets.values()], - mode="r", - ) as inps: + with law.localize_file_targets([*file_targets, *reader_targets.values()], mode="r") as inps: for (events, *columns), pos in self.iter_chunked_io( [inp.abspath for inp in inps], source_type=len(file_targets) * ["awkward_parquet"] + [None] * len(reader_targets), read_columns=(len(file_targets) + len(reader_targets)) * [read_columns], - chunk_size=self.weight_producer_inst.get_min_chunk_size(), + chunk_size=self.hist_producer_inst.get_min_chunk_size(), ): # optional check for overlapping inputs if self.check_overlapping_inputs: @@ -199,14 +234,23 @@ def run(self): missing_strategy=self.missing_column_alias_strategy, ) - # attach coffea behavior aiding functional variable expressions + # invoke the hist producer, potentially updating columns and creating the event weight events = attach_coffea_behavior(events) + events, weight = self.hist_producer_inst(events, task=self) - # build the full event weight - if hasattr(self.weight_producer_inst, "skip_func") and not self.weight_producer_inst.skip_func(): - events, weight = self.weight_producer_inst(events) - else: - weight = ak.Array(np.ones(len(events), dtype=np.float32)) + # merge category ids and check that they are defined as leaf categories + category_ids = ak.concatenate( + [Route(c).apply(events) for c in self.category_id_columns], + axis=-1, + ) + unique_category_ids = np.unique(ak.flatten(category_ids)) + if any(cat_id not in leaf_category_map for cat_id in unique_category_ids): + undefined_category_ids = list(map(str, set(unique_category_ids) - set(leaf_category_map))) + raise ValueError( + f"category_ids column contains ids {','.join(undefined_category_ids)} that are either not " + "known to the config at all, or not as leaf categories (i.e., they have child categories); " + "please ensure that category_ids only contains ids of known leaf categories", + ) # define and fill histograms, taking into account multiple axes for var_key, var_names in self.variable_tuples.items(): @@ -215,37 +259,28 @@ def run(self): if var_key not in histograms: # create the histogram in the first chunk - histograms[var_key] = create_hist_from_variables( - *variable_insts, - int_cat_axes=("category", "process", "shift"), - ) + histograms[var_key] = self.hist_producer_inst.run_create_hist(variable_insts, task=self) # mask events and weights when selection expressions are found masked_events = events masked_weights = weight + masked_category_ids = category_ids for variable_inst in variable_insts: sel = variable_inst.selection if sel == "1": continue if not callable(sel): - raise ValueError( - f"invalid selection '{sel}', for now only callables are supported", - ) + raise ValueError(f"invalid selection '{sel}', for now only callables are supported") mask = sel(masked_events) masked_events = masked_events[mask] masked_weights = masked_weights[mask] - - # merge category ids - category_ids = ak.concatenate( - [Route(c).apply(masked_events) for c in self.category_id_columns], - axis=-1, - ) + masked_category_ids = masked_category_ids[mask] # broadcast arrays so that each event can be filled for all its categories fill_data = { - "category": category_ids, + "category": masked_category_ids, "process": masked_events.process_id, - "shift": np.ones(len(masked_events), dtype=np.int32) * self.global_shift_inst.id, + "shift": self.global_shift_inst.id, "weight": masked_weights, } for variable_inst in variable_insts: @@ -255,17 +290,27 @@ def run(self): route = Route(expr) def expr(events, *args, **kwargs): if len(events) == 0 and not has_ak_column(events, route): - return empty_f32 + return empty_i32 if variable_inst.discrete_x else empty_f32 return route.apply(events, null_value=variable_inst.null_value) # apply it fill_data[variable_inst.name] = expr(masked_events) - # fill it - fill_hist( - histograms[var_key], - fill_data, - last_edge_inclusive=self.last_edge_inclusive, - ) + # let the hist producer fill it + self.hist_producer_inst.run_fill_hist(histograms[var_key], fill_data, task=self) + + # post-process the histograms + for var_key in self.variable_tuples.keys(): + histograms[var_key] = self.hist_producer_inst.run_post_process_hist(histograms[var_key], task=self) + + # check the format after post-processing if no merged preprocessing will take place + if ( + not self.hist_producer_inst.skip_compatibility_check and + not callable(self.hist_producer_inst.post_process_merged_hist_func) + ): + self.check_histogram_compatibility(histograms[var_key]) + + # teardown the hist producer + self.teardown_hist_producer_inst() # merge output files self.output()["hists"].dump(histograms, formatter="pickle") @@ -278,7 +323,6 @@ def expr(events, *args, **kwargs): add_default_to_description=True, ) - CreateHistogramsWrapper = wrapper_factory( base_cls=AnalysisTask, require_cls=CreateHistograms, @@ -286,17 +330,24 @@ def expr(events, *args, **kwargs): ) -class MergeHistograms( - VariablesMixin, - WeightProducerMixin, - MLModelsMixin, - ProducersMixin, - SelectorStepsMixin, +class _MergeHistograms( CalibratorsMixin, - DatasetTask, + SelectorMixin, + ReducerMixin, + ProducersMixin, + MLModelsMixin, + HistProducerMixin, + VariablesMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`MergeHistograms`. + """ + + +class MergeHistograms(_MergeHistograms): + only_missing = luigi.BoolParameter( default=False, description="when True, identify missing variables first and only require histograms of " @@ -389,8 +440,18 @@ def run(self): for variable_name in self.iter_progress(variable_names, len(variable_names), reach=(50, 100)): self.publish_message(f"merging histograms for '{variable_name}'") + # merge them variable_hists = [h[variable_name] for h in hists] merged = sum(variable_hists[1:], variable_hists[0].copy()) + + # post-process the merged histogram + merged = self.hist_producer_inst.run_post_process_merged_hist(merged, task=self) + + # ensure the format is compatible + if not self.hist_producer_inst.skip_compatibility_check: + CreateHistograms.check_histogram_compatibility(merged) + + # write the output outputs["hists"][variable_name].dump(merged, formatter="pickle") # optionally remove inputs @@ -405,27 +466,29 @@ def run(self): ) -class MergeShiftedHistograms( - VariablesMixin, - ShiftSourcesMixin, - WeightProducerMixin, +class _MergeShiftedHistograms( + DatasetShiftSourcesMixin, + CalibratorClassesMixin, + SelectorClassMixin, + ReducerClassMixin, + ProducerClassesMixin, MLModelsMixin, - ProducersMixin, - SelectorStepsMixin, - CalibratorsMixin, - DatasetTask, + HistProducerClassMixin, + VariablesMixin, law.LocalWorkflow, RemoteWorkflow, ): - sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + """ + Base classes for :py:class:`MergeShiftedHistograms`. + """ + + +class MergeShiftedHistograms(_MergeShiftedHistograms): - # disable the shift parameter - shift = None - effective_shift = None - allow_empty_shift = True + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) - # allow only running on nominal - allow_empty_shift_sources = True + # use the MergeHistograms task to trigger upstream TaskArrayFunction initialization + resolution_task_cls = MergeHistograms # upstream requirements reqs = Requirements( @@ -434,7 +497,7 @@ class MergeShiftedHistograms( ) def create_branch_map(self): - # create a dummy branch map so that this task could as a job + # create a dummy branch map so that this task can run as a job return {0: None} def workflow_requires(self): @@ -452,16 +515,13 @@ def requires(self): for shift in ["nominal"] + self.shifts } - def store_parts(self) -> law.util.InsertableDict: - parts = super().store_parts() - parts.insert_after("dataset", "shift_sources", f"shifts_{self.shift_sources_repr}") - return parts - def output(self): - return {"hists": law.SiblingFileCollection({ - variable_name: self.target(f"shifted_hist__{variable_name}.pickle") - for variable_name in self.variables - })} + return { + "hists": law.SiblingFileCollection({ + variable_name: self.target(f"hists__{variable_name}.pickle") + for variable_name in self.variables + }), + } @law.decorator.notify @law.decorator.log diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 0d2f9ce74..4a39abe3b 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -10,21 +10,26 @@ import law import luigi +from columnflow.types import Any from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory from columnflow.tasks.framework.mixins import ( CalibratorsMixin, SelectorMixin, + ReducerMixin, ProducersMixin, + CalibratorClassesMixin, + SelectorClassMixin, + ProducerClassesMixin, MLModelDataMixin, MLModelTrainingMixin, MLModelMixin, + PreparationProducerMixin, ChunkedIOMixin, CategoriesMixin, - SelectorStepsMixin, ) from columnflow.tasks.framework.plotting import ProcessPlotSettingMixin, PlotBase from columnflow.tasks.framework.remote import RemoteWorkflow -from columnflow.tasks.framework.decorators import view_output_plots +from columnflow.tasks.framework.decorators import view_output_plots, on_failure from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.util import dev_sandbox, safe_div, DotDict, maybe_import @@ -35,10 +40,10 @@ class PrepareMLEvents( - MLModelDataMixin, + ReducedEventsUser, ProducersMixin, + MLModelDataMixin, ChunkedIOMixin, - ReducedEventsUser, law.LocalWorkflow, RemoteWorkflow, ): @@ -46,6 +51,10 @@ class PrepareMLEvents( allow_empty_ml_model = False + @classmethod + def invokes_preparation_producer(cls, params) -> bool: + return True + # upstream requirements reqs = Requirements( ReducedEventsUser.reqs, @@ -59,9 +68,6 @@ class PrepareMLEvents( def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # cache for producer inst - self._preparation_producer_inst = law.no_value - # complain when this task is run for events that are not needed for training if not self.events_used_in_training( self.config_inst, @@ -75,30 +81,6 @@ def __init__(self, *args, **kwargs): f"{self.__class__.__name__}", ) - @property - def preparation_producer_inst(self): - if self._preparation_producer_inst is not law.no_value: - # producer has already been cached - return self._preparation_producer_inst - - producer = self.ml_model_inst.preparation_producer(self.config_inst) - - if not producer: - # set producer inst to None when no producer is requested - self._preparation_producer_inst = None - return self._preparation_producer_inst - self._preparation_producer_inst = self.get_producer_insts([producer], {"task": self})[0] - - # overwrite the sandbox when set - sandbox = self._preparation_producer_inst.get_sandbox() - if sandbox: - self.sandbox = sandbox - # rebuild the sandbox inst when already initialized - if self._sandbox_initialized: - self._initialize_sandbox(force=True) - - return self._preparation_producer_inst - def workflow_requires(self): reqs = super().workflow_requires() @@ -107,12 +89,16 @@ def workflow_requires(self): # add producer dependent requirements if self.preparation_producer_inst: - reqs["preparation_producer"] = self.preparation_producer_inst.run_requires() + reqs["preparation_producer"] = self.preparation_producer_inst.run_requires(task=self) # add producers to requirements if not self.pilot and self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -123,11 +109,15 @@ def requires(self): reqs = {"events": self.reqs.ProvideReducedEvents.req(self)} if self.preparation_producer_inst: - reqs["preparation_producer"] = self.preparation_producer_inst.run_requires() + reqs["preparation_producer"] = self.preparation_producer_inst.run_requires(task=self) if self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -152,6 +142,7 @@ def output(self): @law.decorator.log @law.decorator.localize @law.decorator.safe_output + @on_failure(callback=lambda task: task.teardown_preaparation_producer_inst()) def run(self): from columnflow.columnar_util import ( Route, RouteFilter, sorted_ak_to_parquet, update_ak_array, add_ak_aliases, @@ -168,8 +159,9 @@ def run(self): reader_targets = {} if self.preparation_producer_inst: reader_targets = self.preparation_producer_inst.run_setup( - reqs["preparation_producer"], - inputs["preparation_producer"], + task=self, + reqs=reqs["preparation_producer"], + inputs=inputs["preparation_producer"], ) # create a temp dir for saving intermediate files @@ -181,7 +173,7 @@ def run(self): # define columns that will to be written write_columns = set.union(*self.ml_model_inst.used_columns.values()) - route_filter = RouteFilter(write_columns) + route_filter = RouteFilter(keep=write_columns) # define columns that need to be read read_columns = {Route("deterministic_seed")} @@ -231,6 +223,7 @@ def run(self): if len(events) and self.preparation_producer_inst: events = self.preparation_producer_inst( events, + task=self, stats=stats, fold_indices=events.fold_indices, ml_model_inst=self.ml_model_inst, @@ -256,23 +249,26 @@ def run(self): output_chunks[f][pos.index] = chunk self.chunked_io.queue(sorted_ak_to_parquet, (fold_events, chunk.abspath)) - # merge output files of all folds - for _output_chunks, output in zip(output_chunks, outputs["mlevents"].targets): - sorted_chunks = [_output_chunks[key] for key in sorted(_output_chunks)] - law.pyarrow.merge_parquet_task( - self, sorted_chunks, output, local=True, writer_opts=self.get_parquet_writer_opts(), - ) + # teardown the optional producer + self.teardown_preparation_producer_inst() + + # merge output files of all folds + for _output_chunks, output in zip(output_chunks, outputs["mlevents"].targets): + sorted_chunks = [_output_chunks[key] for key in sorted(_output_chunks)] + law.pyarrow.merge_parquet_task( + self, sorted_chunks, output, local=True, writer_opts=self.get_parquet_writer_opts(), + ) - # save stats - if not getattr(stats, "num_fold_events", None): - stats["num_fold_events"] = num_fold_events - outputs["stats"].dump(stats, indent=4, formatter="json") + # save stats + if not getattr(stats, "num_fold_events", None): + stats["num_fold_events"] = num_fold_events + outputs["stats"].dump(stats, indent=4, formatter="json") - # some logs - self.publish_message(f"total events: {n_events}") - for f, n in num_fold_events.items(): - r = 100 * safe_div(n, n_events) - self.publish_message(f"fold {' ' if f < 10 else ''}{f}: {n} ({r:.2f}%)") + # some logs + self.publish_message(f"total events: {n_events}") + for f, n in num_fold_events.items(): + r = 100 * safe_div(n, n_events) + self.publish_message(f"fold {' ' if f < 10 else ''}{f}: {n} ({r:.2f}%)") # overwrite class defaults @@ -296,17 +292,18 @@ def run(self): class MergeMLStats( - MLModelDataMixin, - ProducersMixin, - SelectorMixin, CalibratorsMixin, - DatasetTask, + SelectorMixin, + ReducerMixin, + ProducersMixin, + MLModelDataMixin, law.LocalWorkflow, RemoteWorkflow, ): # upstream requirements reqs = Requirements( + RemoteWorkflow.reqs, PrepareMLEvents=PrepareMLEvents, ) @@ -365,11 +362,11 @@ def merge_counts(cls, dst: dict, src: dict) -> dict: class MergeMLEvents( - MLModelDataMixin, - ProducersMixin, - SelectorMixin, CalibratorsMixin, - DatasetTask, + SelectorMixin, + ReducerMixin, + ProducersMixin, + MLModelDataMixin, law.tasks.ForestMerge, RemoteWorkflow, ): @@ -461,7 +458,10 @@ class MLTraining( law.LocalWorkflow, RemoteWorkflow, ): + # use the MergeMLEvents task to trigger upstream TaskArrayFunction initialization + resolution_task_cls = MergeMLEvents + single_config = False allow_empty_ml_model = False # upstream requirements @@ -498,21 +498,13 @@ def workflow_requires(self): self, config=config_inst.name, dataset=dataset_inst.name, - calibrators=_calibrators, - selector=_selector, - producers=_producers, fold=fold, tree_index=-1) for fold in range(self.ml_model_inst.folds) ] for dataset_inst in dataset_insts } - for (config_inst, dataset_insts), _calibrators, _selector, _producers in zip( - self.ml_model_inst.used_datasets.items(), - self.calibrators, - self.selectors, - self.producers, - ) + for config_inst, dataset_insts in self.ml_model_inst.used_datasets.items() } reqs["stats"] = { config_inst.name: { @@ -520,18 +512,10 @@ def workflow_requires(self): self, config=config_inst.name, dataset=dataset_inst.name, - calibrators=_calibrators, - selector=_selector, - producers=_producers, ) for dataset_inst in dataset_insts } - for (config_inst, dataset_insts), _calibrators, _selector, _producers in zip( - self.ml_model_inst.used_datasets.items(), - self.calibrators, - self.selectors, - self.producers, - ) + for config_inst, dataset_insts in self.ml_model_inst.used_datasets.items() } # ml model requirements @@ -550,9 +534,6 @@ def requires(self): self, config=config_inst.name, dataset=dataset_inst.name, - calibrators=_calibrators, - selector=_selector, - producers=_producers, fold=f, ) for f in range(self.ml_model_inst.folds) @@ -560,13 +541,9 @@ def requires(self): ] for dataset_inst in dataset_insts } - for (config_inst, dataset_insts), _calibrators, _selector, _producers in zip( - self.ml_model_inst.used_datasets.items(), - self.calibrators, - self.selectors, - self.producers, - ) + for config_inst, dataset_insts in self.ml_model_inst.used_datasets.items() } + # TODO: stats reqs missing here # ml model requirements reqs["model"] = self.ml_model_inst.requires(self) @@ -589,13 +566,14 @@ def run(self): class MLEvaluation( - MLModelMixin, + ReducedEventsUser, ProducersMixin, + PreparationProducerMixin, ChunkedIOMixin, - ReducedEventsUser, law.LocalWorkflow, RemoteWorkflow, ): + sandbox = None allow_empty_ml_model = False @@ -614,41 +592,32 @@ class MLEvaluation( def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # cache for producer inst - self._preparation_producer_inst = law.no_value - # set the sandbox self.sandbox = self.ml_model_inst.sandbox(self) + # TODO: potentially reset - @property - def preparation_producer_inst(self): - if self._preparation_producer_inst is not law.no_value: - # producer has already been cached - return self._preparation_producer_inst - - producer = None - if self.ml_model_inst.preparation_producer_in_ml_evaluation: - # only consider preparation_producer in MLEvaluation if requested by model - producer = self.ml_model_inst.preparation_producer(self.config_inst) - - if not producer: - # set producer inst to None when no producer is requested - self._preparation_producer_inst = None - return self._preparation_producer_inst - - self._preparation_producer_inst = self.get_producer_insts([producer], {"task": self})[0] - - # check that preparation_producer does not clash with ml_model_inst sandbox - if ( - self._preparation_producer_inst.sandbox and - self.sandbox != self._preparation_producer_inst.sandbox - ): - raise Exception( - f"Task {self.__class__.__name__} got different sandboxes from the MLModel ({self.sandbox}) " - f"than from the preparation_producer ({self._preparation_producer_inst.sandbox})", - ) + @classmethod + def invokes_preparation_producer(cls, params) -> bool: + # check if the preparation producer is used in the ML model + return bool(params["ml_model_inst"].preparation_producer_in_ml_evaluation) - return self._preparation_producer_inst + @classmethod + def resolve_param_values_pre_init( + cls, + params: law.util.InsertableDict[str, Any], + ) -> law.util.InsertableDict[str, Any]: + # resolve producers used in MLEvaluation based on the MLModel instance + params = super().resolve_param_values_pre_init(params) + params["producers"] = law.util.make_tuple(params["ml_model_inst"].evaluation_producers( + params["analysis_inst"], + params["producers"], + )) + if "producer_insts" in params: + params["producer_insts"] = law.util.make_tuple( + producer_inst for producer_inst in params["producer_insts"] + if producer_inst.cls_name in params["producers"] + ) + return params def workflow_requires(self): reqs = super().workflow_requires() @@ -656,20 +625,21 @@ def workflow_requires(self): reqs["models"] = self.reqs.MLTraining.req_different_branching( self, configs=(self.config_inst.name,), - calibrators=(self.calibrators,), - selectors=(self.selector,), - producers=(self.producers,), ) reqs["events"] = self.reqs.ProvideReducedEvents.req(self) # add producer dependent requirements if self.preparation_producer_inst: - reqs["preparation_producer"] = self.preparation_producer_inst.run_requires() + reqs["preparation_producer"] = self.preparation_producer_inst.run_requires(task=self) if not self.pilot and self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -681,19 +651,20 @@ def requires(self): "models": self.reqs.MLTraining.req_different_branching( self, configs=(self.config_inst.name,), - calibrators=(self.calibrators,), - selectors=(self.selector,), - producers=(self.producers,), branch=-1, ), "events": self.reqs.ProvideReducedEvents.req(self, _exclude=self.exclude_params_branch), } if self.preparation_producer_inst: - reqs["preparation_producer"] = self.preparation_producer_inst.run_requires() + reqs["preparation_producer"] = self.preparation_producer_inst.run_requires(task=self) if self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -710,6 +681,7 @@ def output(self): @law.decorator.log @law.decorator.localize @law.decorator.safe_output + @on_failure(callback=lambda task: task.teardown_preparation_producer_inst()) def run(self): from columnflow.columnar_util import ( Route, RouteFilter, sorted_ak_to_parquet, update_ak_array, add_ak_aliases, @@ -730,8 +702,9 @@ def run(self): reader_targets = {} if self.preparation_producer_inst: reader_targets = self.preparation_producer_inst.run_setup( - reqs["preparation_producer"], - inputs["preparation_producer"], + task=self, + reqs=reqs["preparation_producer"], + inputs=inputs["preparation_producer"], ) # open all model files @@ -759,7 +732,7 @@ def run(self): # define columns that will be written write_columns = set.union(*self.ml_model_inst.produced_columns.values()) - route_filter = RouteFilter(write_columns) + route_filter = RouteFilter(keep=write_columns) # iterate over chunks of events and columns file_targets = [inputs["events"]["events"]] @@ -800,6 +773,7 @@ def run(self): if len(events) and self.preparation_producer_inst: events = self.preparation_producer_inst( events, + task=self, stats=stats, fold_indices=events.fold_indices, ml_model_inst=self.ml_model_inst, @@ -826,6 +800,9 @@ def run(self): output_chunks[pos.index] = chunk self.chunked_io.queue(sorted_ak_to_parquet, (events, chunk.abspath)) + # teardown the optional producer + self.teardown_preparation_producer_inst() + # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( @@ -854,11 +831,11 @@ def run(self): class MergeMLEvaluation( - MLModelMixin, - ProducersMixin, - SelectorMixin, - CalibratorsMixin, DatasetTask, + CalibratorClassesMixin, + SelectorClassMixin, + ProducerClassesMixin, + MLModelMixin, law.tasks.ForestMerge, RemoteWorkflow, ): @@ -914,11 +891,11 @@ def merge(self, inputs, output): class PlotMLResultsBase( ProcessPlotSettingMixin, - CategoriesMixin, + CalibratorClassesMixin, + SelectorClassMixin, + ProducerClassesMixin, MLModelMixin, - ProducersMixin, - SelectorStepsMixin, - CalibratorsMixin, + CategoriesMixin, law.LocalWorkflow, RemoteWorkflow, ): diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index b922684a8..0c4cb18f9 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -4,17 +4,20 @@ Tasks to plot different types of histograms. """ -from collections import OrderedDict +import itertools +from collections import OrderedDict, defaultdict from abc import abstractmethod +from columnflow.types import Any + import law import luigi import order as od from columnflow.tasks.framework.base import Requirements, ShiftTask from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, WeightProducerMixin, - CategoriesMixin, ShiftSourcesMixin, HistHookMixin, + CalibratorClassesMixin, SelectorClassMixin, ReducerClassMixin, ProducerClassesMixin, HistProducerClassMixin, + CategoriesMixin, ShiftSourcesMixin, HistHookMixin, MLModelsMixin, ) from columnflow.tasks.framework.plotting import ( PlotBase, PlotBase1D, PlotBase2D, ProcessPlotSettingMixin, VariablePlotSettingMixin, @@ -23,32 +26,37 @@ from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms from columnflow.util import DotDict, dev_sandbox, dict_add_strict +from columnflow.hist_util import add_missing_shifts +from columnflow.config_util import get_shift_from_configs -class PlotVariablesBase( - HistHookMixin, - VariablePlotSettingMixin, - ProcessPlotSettingMixin, - CategoriesMixin, +class _PlotVariablesBase( + CalibratorClassesMixin, + SelectorClassMixin, + ReducerClassMixin, + ProducerClassesMixin, MLModelsMixin, - WeightProducerMixin, - ProducersMixin, - SelectorStepsMixin, - CalibratorsMixin, + HistProducerClassMixin, + CategoriesMixin, + ProcessPlotSettingMixin, + VariablePlotSettingMixin, + HistHookMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`PlotVariablesBase`. + """ + + +class PlotVariablesBase(_PlotVariablesBase): + single_config = False + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) exclude_index = True - # upstream requirements - reqs = Requirements( - RemoteWorkflow.reqs, - MergeHistograms=MergeHistograms, - ) - - def store_parts(self): + def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() parts.insert_before("version", "datasets", f"datasets_{self.datasets_repr}") return parts @@ -62,113 +70,198 @@ def create_branch_map(self): def workflow_requires(self): reqs = super().workflow_requires() - reqs["merged_hists"] = self.requires_from_branch() - return reqs @abstractmethod def get_plot_shifts(self): return - @law.decorator.notify + @property + def config_inst(self): + return self.config_insts[0] + + def get_config_process_map(self) -> tuple[dict[od.Config, dict[od.Process, dict[str, Any]]], dict[str, set[str]]]: + """ + Function that maps the config and process instances to the datasets and shifts they are supposed to be plotted + with. The mapping from processes to datasets is done by checking the dataset instances for the presence of the + process instances. The mapping from processes to shifts is done by checking the upstream requirements for the + presence of a shift in the requires method of the task. + + :return: A 2-tuple with a dictionary mapping config instances to dictionaries mapping process instances to + dictionaries containing the dataset-process mapping and the shifts to be considered, and a dictionary + mapping process names to the shifts to be considered. + """ + reqs = self.requires() + + config_process_map = {config_inst: {} for config_inst in self.config_insts} + process_shift_map = defaultdict(set) + + for i, config_inst in enumerate(self.config_insts): + process_insts = [config_inst.get_process(p) for p in self.processes[i]] + dataset_insts = [config_inst.get_dataset(d) for d in self.datasets[i]] + + requested_shifts_per_dataset: dict[od.Dataset, list[od.Shift]] = {} + for dataset_inst in dataset_insts: + _req = reqs[config_inst.name][dataset_inst.name] + if hasattr(_req, "shift") and _req.shift: + # when a shift is found, use it + requested_shifts = [_req.shift] + else: + # when no shift is found, check upstream requirements + requested_shifts = [sub_req.shift for sub_req in _req.requires().values()] + + requested_shifts_per_dataset[dataset_inst] = requested_shifts + + for process_inst in process_insts: + sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] + dataset_proc_name_map = {} + for dataset_inst in dataset_insts: + matched_proc_names = [p.name for p in sub_process_insts if dataset_inst.has_process(p.name)] + if matched_proc_names: + dataset_proc_name_map[dataset_inst] = matched_proc_names + + if not dataset_proc_name_map: + # no datasets found for this process + continue + + process_info = { + "dataset_proc_name_map": dataset_proc_name_map, + "config_shifts": { + shift + for dataset_inst in dataset_proc_name_map.keys() + for shift in requested_shifts_per_dataset[dataset_inst] + }, + } + process_shift_map[process_inst.name].update(process_info["config_shifts"]) + config_process_map[config_inst][process_inst] = process_info + + # assign the combination of all shifts to each config-process pair + for config_inst, process_info_dict in config_process_map.items(): + for process_inst, process_info in process_info_dict.items(): + if process_inst.name in process_shift_map: + config_process_map[config_inst][process_inst]["shifts"] = process_shift_map[process_inst.name] + + return config_process_map, process_shift_map + @law.decorator.log @view_output_plots def run(self): import hist - # get the shifts to extract and plot - plot_shifts = law.util.make_list(self.get_plot_shifts()) - - # copy process instances once so that their auxiliary data fields can be used as a storage - # for process-specific plot parameters later on in plot scripts without affecting the - # original instances - fake_root = od.Process( - name=f"{hex(id(object()))[2:]}", - id="+", - processes=list(map(self.config_inst.get_process, self.processes)), - ).copy() - process_insts = list(fake_root.processes) - fake_root.processes.clear() - # prepare other config objects variable_tuple = self.variable_tuples[self.branch_data.variable] variable_insts = [ self.config_inst.get_variable(var_name) for var_name in variable_tuple ] - category_inst = self.config_inst.get_category(self.branch_data.category) - leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] - sub_process_insts = { - process_inst: [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] - for process_inst in process_insts - } + plot_shifts = self.get_plot_shifts() + plot_shift_names = set(shift_inst.name for shift_inst in plot_shifts) + + # get assignment of processes to datasets and shifts + config_process_map, process_shift_map = self.get_config_process_map() # histogram data per process copy - hists = {} - - with self.publish_step(f"plotting {self.branch_data.variable} in {category_inst.name}"): - for dataset, inp in self.input().items(): - dataset_inst = self.config_inst.get_dataset(dataset) - h_in = inp["collection"][0]["hists"].targets[self.branch_data.variable].load(formatter="pickle") - - # loop and extract one histogram per process - for process_inst in process_insts: - # skip when the dataset is already known to not contain any sub process - if not any( - dataset_inst.has_process(sub_process_inst.name) - for sub_process_inst in sub_process_insts[process_inst] - ): - continue - - # select processes and reduce axis - h = h_in.copy() - h = h[{ - "process": [ - hist.loc(p.id) - for p in sub_process_insts[process_inst] - if p.id in h.axes["process"] - ], - }] - h = h[{"process": sum}] - - # add the histogram - if process_inst in hists: - hists[process_inst] += h - else: - hists[process_inst] = h - - # there should be hists to plot - if not hists: - raise Exception( - "no histograms found to plot; possible reasons:\n" - " - requested variable requires columns that were missing during histogramming\n" - " - selected --processes did not match any value on the process axis of the input histogram", - ) + hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} + with self.publish_step(f"plotting {self.branch_data.variable} in {self.branch_data.category}"): + for i, (config, dataset_dict) in enumerate(self.input().items()): + config_inst = self.config_insts[i] + category_inst = config_inst.get_category(self.branch_data.category) + leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] + + hists_config = {} + + for dataset, inp in dataset_dict.items(): + dataset_inst = config_inst.get_dataset(dataset) + h_in = inp["collection"][0]["hists"].targets[self.branch_data.variable].load(formatter="pickle") + + # loop and extract one histogram per process + for process_inst, process_info in config_process_map[config_inst].items(): + if dataset_inst not in process_info["dataset_proc_name_map"].keys(): + continue + + # select processes and reduce axis + h = h_in.copy() + h = h[{ + "process": [ + hist.loc(proc_name) + for proc_name in process_info["dataset_proc_name_map"][dataset_inst] + if proc_name in h.axes["process"] + ], + }] + h = h[{"process": sum}] + + # create expected shift bins and fill them with the nominal histogram + expected_shifts = plot_shift_names & process_shift_map[process_inst.name] + add_missing_shifts(h, expected_shifts, str_axis="shift", nominal_bin="nominal") + + # add the histogram + if process_inst in hists_config: + hists_config[process_inst] += h + else: + hists_config[process_inst] = h + + # after merging all processes, sort the histograms by process order and store them + hists[config_inst] = { + proc_inst: hists_config[proc_inst] + for proc_inst in sorted( + hists_config.keys(), key=list(config_process_map[config_inst].keys()).index, + ) + } + + # there should be hists to plot + if not hists: + raise Exception( + "no histograms found to plot; possible reasons:\n" + " - requested variable requires columns that were missing during histogramming\n" + " - selected --processes did not match any value on the process axis of the input histogram", + ) # update histograms using custom hooks hists = self.invoke_hist_hooks(hists) - # add new processes to the end of the list - for process_inst in hists: - if process_inst not in process_insts: - process_insts.append(process_inst) - - # axis selections and reductions, including sorting by process order + # merge configs + if len(self.config_insts) != 1: + process_memory = {} + merged_hists = {} + for _hists in hists.values(): + for process_inst, h in _hists.items(): + if process_inst.id in merged_hists: + merged_hists[process_inst.id] += h + else: + merged_hists[process_inst.id] = h + process_memory[process_inst.id] = process_inst + + process_insts = list(process_memory.values()) + hists = {process_memory[process_id]: h for process_id, h in merged_hists.items()} + else: + hists = hists[self.config_inst] + process_insts = list(hists.keys()) + + # axis selections and reductions _hists = OrderedDict() - for process_inst in sorted(hists, key=process_insts.index): + for process_inst in hists.keys(): h = hists[process_inst] + # determine expected shifts from the intersection of requested shifts and those known for the process + process_shifts = ( + process_shift_map[process_inst.name] + if process_inst.name in process_shift_map + else {"nominal"} + ) + expected_shifts = plot_shift_names & process_shifts + if not expected_shifts: + raise Exception(f"no shifts to plot found for process {process_inst.name}") # selections h = h[{ "category": [ - hist.loc(c.id) + hist.loc(c.name) for c in leaf_category_insts - if c.id in h.axes["category"] + if c.name in h.axes["category"] ], "shift": [ - hist.loc(s.id) - for s in plot_shifts - if s.id in h.axes["shift"] + hist.loc(s_name) + for s_name in expected_shifts + if s_name in h.axes["shift"] ], }] # reductions @@ -177,15 +270,32 @@ def run(self): _hists[process_inst] = h hists = _hists - # call the plot function - fig, _ = self.call_plot_func( - self.plot_function, - hists=hists, - config_inst=self.config_inst, - category_inst=category_inst.copy_shallow(), - variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts], - **self.get_plot_parameters(), - ) + # copy process instances once so that their auxiliary data fields can be used as a storage + # for process-specific plot parameters later on in plot scripts without affecting the + # original instances + fake_root = od.Process( + name=f"{hex(id(object()))[2:]}", + id="+", + processes=list(hists.keys()), + ).copy() + process_insts = list(fake_root.processes) + fake_root.processes.clear() + hists = dict(zip(process_insts, hists.values())) + + # temporarily use a merged luminostiy value, assigned to the first config + config_inst = self.config_insts[0] + lumi = sum([_config_inst.x.luminosity for _config_inst in self.config_insts]) + with law.util.patch_object(config_inst.x, "luminosity", lumi): + # call the plot function + fig, _ = self.call_plot_func( + self.plot_function, + hists=hists, + config_inst=config_inst, + category_inst=category_inst.copy_shallow(), + variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts], + shift_insts=plot_shifts, + **self.get_plot_parameters(), + ) # save the plot for outp in self.output()["plots"]: @@ -193,12 +303,13 @@ def run(self): class PlotVariablesBaseSingleShift( - PlotVariablesBase, ShiftTask, + PlotVariablesBase, ): + # use the MergeHistograms task to trigger upstream TaskArrayFunction initialization + resolution_task_cls = MergeHistograms exclude_index = True - # upstream requirements reqs = Requirements( PlotVariablesBase.reqs, MergeHistograms=MergeHistograms, @@ -216,16 +327,23 @@ def workflow_requires(self): return reqs def requires(self): - return { - d: self.reqs.MergeHistograms.req( - self, - dataset=d, - branch=-1, - _exclude={"branches"}, - _prefer_cli={"variables"}, - ) - for d in self.datasets - } + req = {} + + for i, config_inst in enumerate(self.config_insts): + sub_datasets = self.datasets[i] + req[config_inst.name] = {} + for d in sub_datasets: + if d in config_inst.datasets.names(): + req[config_inst.name][d] = self.reqs.MergeHistograms.req( + self, + config=config_inst.name, + shift=self.global_shift_insts[config_inst].name, + dataset=d, + branch=-1, + _exclude={"branches"}, + _prefer_cli={"variables"}, + ) + return req def plot_parts(self) -> law.util.InsertableDict: parts = super().plot_parts() @@ -245,14 +363,14 @@ def output(self): "plots": [self.target(name) for name in self.get_plot_names("plot")], } - def store_parts(self): + def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() if "shift" in parts: parts.insert_before("datasets", "shift", parts.pop("shift")) return parts def get_plot_shifts(self): - return [self.global_shift_inst] + return [get_shift_from_configs(self.config_insts, self.shift)] class PlotVariables1D( @@ -260,11 +378,31 @@ class PlotVariables1D( PlotBase1D, ): plot_function = PlotBase.plot_function.copy( - default="columnflow.plotting.plot_functions_1d.plot_variable_per_process", + default="columnflow.plotting.plot_functions_1d.plot_variable_stack", add_default_to_description=True, ) +class PlotVariablesPerConfig1D( + PlotVariables1D, + law.WrapperTask, +): + # force this one to be a local workflow + workflow = "local" + output_collection_cls = law.NestedSiblingFileCollection + + def requires(self): + return { + config: PlotVariables1D.req( + self, + datasets=(self.datasets[i],), + processes=(self.processes[i],), + configs=(config,), + ) + for i, config in enumerate(self.configs) + } + + class PlotVariables2D( PlotVariablesBaseSingleShift, PlotBase2D, @@ -275,9 +413,29 @@ class PlotVariables2D( ) -class PlotVariablesPerProcess2D( +class PlotVariablesPerConfig2D( + PlotVariables1D, law.WrapperTask, +): + # force this one to be a local workflow + workflow = "local" + output_collection_cls = law.NestedSiblingFileCollection + + def requires(self): + return { + config: PlotVariablesPerConfig2D.req( + self, + datasets=(self.datasets[i],), + processes=(self.processes[i],), + configs=(config,), + ) + for i, config in enumerate(self.configs) + } + + +class PlotVariablesPerProcess2D( PlotVariables2D, + law.WrapperTask, ): # force this one to be a local workflow workflow = "local" @@ -290,8 +448,8 @@ def requires(self): class PlotVariablesBaseMultiShifts( - PlotVariablesBase, ShiftSourcesMixin, + PlotVariablesBase, ): legend_title = luigi.Parameter( default=law.NO_STR, @@ -300,45 +458,67 @@ class PlotVariablesBaseMultiShifts( "the plot, the process_inst label is used; empty default", ) + # whether this task creates a single plot combining all shifts or one plot per shift + combine_shifts = True + + # use the MergeHistograms task to trigger upstream TaskArrayFunction initialization + resolution_task_cls = MergeHistograms + exclude_index = True # upstream requirements reqs = Requirements( PlotVariablesBase.reqs, + MergeHistograms=MergeHistograms, MergeShiftedHistograms=MergeShiftedHistograms, ) - def create_branch_map(self): - return [ - DotDict({"category": cat_name, "variable": var_name, "shift_source": source}) - for var_name in sorted(self.variables) - for cat_name in sorted(self.categories) - for source in sorted(self.shift_sources) - ] + def create_branch_map(self) -> list[DotDict]: + seqs = [self.categories, self.variables] + keys = ["category", "variable"] + if not self.combine_shifts: + seqs.append(self.shift_sources) + keys.append("shift_source") + return [DotDict(zip(keys, vals)) for vals in itertools.product(*seqs)] def requires(self): - # TODO: for data, request MergeHistograms - return { - d: self.reqs.MergeShiftedHistograms.req( - self, - dataset=d, - branch=-1, - _exclude={"branches"}, - _prefer_cli={"variables"}, - ) - for d in self.datasets - } + req_cls = lambda dataset_name: ( + self.reqs.MergeShiftedHistograms + if self.config_inst.get_dataset(dataset_name).is_mc + else self.reqs.MergeHistograms + ) + + req = {} + for i, config_inst in enumerate(self.config_insts): + req[config_inst.name] = {} + for dataset_name in self.datasets[i]: + if dataset_name in config_inst.datasets: + req[config_inst.name][dataset_name] = req_cls(dataset_name).req( + self, + config=config_inst.name, + dataset=dataset_name, + branch=-1, + _exclude={"branches"}, + _prefer_cli={"variables"}, + ) + return req def plot_parts(self) -> law.util.InsertableDict: parts = super().plot_parts() parts["processes"] = f"proc_{self.processes_repr}" - parts["shift_source"] = f"unc_{self.branch_data.shift_source}" parts["category"] = f"cat_{self.branch_data.category}" parts["variable"] = f"var_{self.branch_data.variable}" - hooks_repr = self.hist_hooks_repr - if hooks_repr: + # shift source or sources + parts["shift_source"] = ( + f"shifts_{self.shift_sources_repr}" + if self.combine_shifts + else f"shift_{self.branch_data.shift_source}" + ) + + # hooks + if (hooks_repr := self.hist_hooks_repr): parts["hook"] = f"hooks_{hooks_repr}" return parts @@ -348,19 +528,20 @@ def output(self): "plots": [self.target(name) for name in self.get_plot_names("plot")], } - def store_parts(self): - parts = super().store_parts() - parts.insert_before("datasets", "shifts", f"shifts_{self.shift_sources_repr}") - return parts - def get_plot_shifts(self): - return [ - self.config_inst.get_shift(s) for s in [ - "nominal", - f"{self.branch_data.shift_source}_up", - f"{self.branch_data.shift_source}_down", - ] - ] + # only to be called by branch tasks + if self.is_workflow(): + raise Exception("calls to get_plots_shifts are forbidden for workflow tasks") + + # gather sources, and expand to up/down shifts + sources = self.shift_sources if self.combine_shifts else [self.branch_data.shift_source] + shifts = [] + for source in sources: + shifts.append(get_shift_from_configs(self.config_insts, f"{source}_{od.Shift.UP}")) + shifts.append(get_shift_from_configs(self.config_insts, f"{source}_{od.Shift.DOWN}")) + + # add nominal + return [self.config_inst.get_shift("nominal"), *shifts] def get_plot_parameters(self): # convert parameters to usable values during plotting @@ -373,22 +554,54 @@ class PlotShiftedVariables1D( PlotBase1D, PlotVariablesBaseMultiShifts, ): + plot_function = PlotBase.plot_function.copy( + default="columnflow.plotting.plot_functions_1d.plot_variable_stack", + add_default_to_description=True, + ) + + +class PlotShiftedVariablesPerShift1D( + PlotBase1D, + PlotVariablesBaseMultiShifts, +): + # this tasks creates one plot per shift + combine_shifts = False plot_function = PlotBase.plot_function.copy( default="columnflow.plotting.plot_functions_1d.plot_shifted_variable", add_default_to_description=True, ) -class PlotShiftedVariablesPerProcess1D(law.WrapperTask): +class PlotShiftedVariablesPerConfig1D( + law.WrapperTask, + PlotShiftedVariables1D, +): + # force this one to be a local workflow + workflow = "local" + output_collection_cls = law.NestedSiblingFileCollection + + def requires(self): + return { + config: PlotShiftedVariables1D.req( + self, + datasets=(self.datasets[i],), + processes=(self.processes[i],), + configs=(config,), + ) + for i, config in enumerate(self.configs) + } + + +class PlotShiftedVariablesPerShiftAndProcess1D(law.WrapperTask): # upstream requirements reqs = Requirements( - PlotShiftedVariables1D.reqs, - PlotShiftedVariables1D=PlotShiftedVariables1D, + PlotShiftedVariablesPerShift1D.reqs, + PlotShiftedVariablesPerShift1D=PlotShiftedVariablesPerShift1D, ) def requires(self): return { - process: self.reqs.PlotShiftedVariables1D.req(self, processes=(process,)) + process: self.reqs.PlotShiftedVariablesPerShift1D.req(self, processes=(process,)) for process in self.processes } diff --git a/columnflow/tasks/production.py b/columnflow/tasks/production.py index 18707f8d1..cee2d1311 100644 --- a/columnflow/tasks/production.py +++ b/columnflow/tasks/production.py @@ -12,17 +12,25 @@ from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory from columnflow.tasks.framework.mixins import ProducerMixin, ChunkedIOMixin from columnflow.tasks.framework.remote import RemoteWorkflow +from columnflow.tasks.framework.decorators import on_failure from columnflow.tasks.reduction import ReducedEventsUser from columnflow.util import dev_sandbox -class ProduceColumns( +class _ProduceColumns( + ReducedEventsUser, ProducerMixin, ChunkedIOMixin, - ReducedEventsUser, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`ProduceColumns`. + """ + + +class ProduceColumns(_ProduceColumns): + # default sandbox, might be overwritten by producer function sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -32,11 +40,7 @@ class ProduceColumns( RemoteWorkflow.reqs, ) - # register sandbox and shifts found in the chosen producer to this task - register_producer_sandbox = True - register_producer_shifts = True - - # strategy for handling missing source columns when adding aliases on event chunks + invokes_producer = True missing_column_alias_strategy = "original" def workflow_requires(self): @@ -46,14 +50,18 @@ def workflow_requires(self): reqs["events"] = self.reqs.ProvideReducedEvents.req(self) # add producer dependent requirements - reqs["producer"] = law.util.make_unique(law.util.flatten(self.producer_inst.run_requires())) + reqs["producer"] = law.util.make_unique(law.util.flatten( + self.producer_inst.run_requires(task=self), + )) return reqs def requires(self): return { "events": self.reqs.ProvideReducedEvents.req(self), - "producer": law.util.make_unique(law.util.flatten(self.producer_inst.run_requires())), + "producer": law.util.make_unique(law.util.flatten( + self.producer_inst.run_requires(task=self), + )), } workflow_condition = ReducedEventsUser.workflow_condition.copy() @@ -72,10 +80,11 @@ def output(self): @law.decorator.log @law.decorator.localize(input=False) @law.decorator.safe_output + @on_failure(callback=lambda task: task.teardown_producer_inst()) def run(self): from columnflow.columnar_util import ( Route, RouteFilter, mandatory_coffea_columns, update_ak_array, add_ak_aliases, - sorted_ak_to_parquet, + sorted_ak_to_parquet, attach_coffea_behavior, ) # prepare inputs and outputs @@ -84,8 +93,13 @@ def run(self): output_chunks = {} # run the producer setup - producer_reqs = self.producer_inst.run_requires() - reader_targets = self.producer_inst.run_setup(producer_reqs, luigi.task.getpaths(producer_reqs)) + self._array_function_post_init() + producer_reqs = self.producer_inst.run_requires(task=self) + reader_targets = self.producer_inst.run_setup( + task=self, + reqs=producer_reqs, + inputs=luigi.task.getpaths(producer_reqs), + ) n_ext = len(reader_targets) # create a temp dir for saving intermediate files @@ -102,7 +116,7 @@ def run(self): # define columns that will be written write_columns = self.producer_inst.produced_columns - route_filter = RouteFilter(write_columns) + route_filter = RouteFilter(keep=write_columns) # prepare inputs for localization with law.localize_file_targets( @@ -133,7 +147,8 @@ def run(self): # invoke the producer if len(events): - events = self.producer_inst(events) + events = attach_coffea_behavior(events) + events = self.producer_inst(events, task=self) # remove columns events = route_filter(events) @@ -147,6 +162,9 @@ def run(self): output_chunks[pos.index] = chunk self.chunked_io.queue(sorted_ak_to_parquet, (events, chunk.abspath)) + # teardown the producer + self.teardown_producer_inst() + # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 0ae088788..4e6c8d87a 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -7,20 +7,19 @@ from __future__ import annotations import math -import functools -from collections import OrderedDict, defaultdict +from collections import OrderedDict import law import luigi -from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory -from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, ChunkedIOMixin, -) +from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory +from columnflow.tasks.framework.mixins import CalibratorsMixin, SelectorMixin, ReducerMixin, ChunkedIOMixin from columnflow.tasks.framework.remote import RemoteWorkflow +from columnflow.tasks.framework.decorators import on_failure from columnflow.tasks.external import GetDatasetLFNs from columnflow.tasks.selection import CalibrateEvents, SelectEvents from columnflow.util import maybe_import, ensure_proxy, dev_sandbox, safe_div +from columnflow.types import Any ak = maybe_import("awkward") @@ -29,14 +28,21 @@ default_keep_reduced_events = law.config.get_expanded("analysis", "default_keep_reduced_events") -class ReduceEvents( - SelectorStepsMixin, +class _ReduceEvents( CalibratorsMixin, + SelectorMixin, + ReducerMixin, ChunkedIOMixin, - DatasetTask, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`ReduceEvents`. + """ + + +class ReduceEvents(_ReduceEvents): + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) # upstream requirements @@ -50,8 +56,7 @@ class ReduceEvents( # strategy for handling missing source columns when adding aliases on event chunks missing_column_alias_strategy = "original" - # register shifts found in the chosen selector upstream - register_selector_shifts = True + invokes_reducer = True def workflow_requires(self): reqs = super().workflow_requires() @@ -60,11 +65,19 @@ def workflow_requires(self): if not self.pilot: reqs["calibrations"] = [ - self.reqs.CalibrateEvents.req(self, calibrator=calibrator_inst.cls_name) + self.reqs.CalibrateEvents.req( + self, + calibrator=calibrator_inst.cls_name, + calibrator_inst=calibrator_inst, + ) for calibrator_inst in self.calibrator_insts if calibrator_inst.produced_columns ] reqs["selection"] = self.reqs.SelectEvents.req(self) + # reducer dependent requirements + reqs["reducer"] = law.util.make_unique(law.util.flatten( + self.reducer_inst.run_requires(task=self), + )) else: # pass-through pilot workflow requirements of upstream task t = self.reqs.SelectEvents.req(self) @@ -76,11 +89,18 @@ def requires(self): return { "lfns": self.reqs.GetDatasetLFNs.req(self), "calibrations": [ - self.reqs.CalibrateEvents.req(self, calibrator=calibrator_inst.cls_name) + self.reqs.CalibrateEvents.req( + self, + calibrator=calibrator_inst.cls_name, + calibrator_inst=calibrator_inst, + ) for calibrator_inst in self.calibrator_insts if calibrator_inst.produced_columns ], "selection": self.reqs.SelectEvents.req(self), + "reducer": law.util.make_unique(law.util.flatten( + self.reducer_inst.run_requires(task=self), + )), } def output(self): @@ -91,12 +111,12 @@ def output(self): @ensure_proxy @law.decorator.localize(input=False) @law.decorator.safe_output + @on_failure(callback=lambda task: task.teardown_reducer_inst()) def run(self): from columnflow.columnar_util import ( Route, RouteFilter, mandatory_coffea_columns, update_ak_array, add_ak_aliases, - sorted_ak_to_parquet, + sorted_ak_to_parquet, attach_coffea_behavior, ) - from columnflow.selection.util import create_collections_from_masks # prepare inputs and outputs inputs = self.input() @@ -104,6 +124,24 @@ def run(self): output = self.output() output_chunks = {} + # for evaluating new object collections to write based on the "objects" field of the selection result data, + # create a mapping of src_col to dst_col's using only file meta data + self.collection_map: dict[str, list[str]] = {} + sel_meta = inputs["selection"]["results"].load(formatter="dask_awkward") + if "objects" in sel_meta.fields: + for src_col in sel_meta.objects.fields: + self.collection_map[src_col] = list(sel_meta.objects[src_col].fields) + del sel_meta + + # run the reducer setup + self._array_function_post_init() + reducer_reqs = self.reducer_inst.run_requires(task=self) + reader_targets = self.reducer_inst.run_setup( + task=self, + reqs=reducer_reqs, + inputs=luigi.task.getpaths(reducer_reqs), + ) + # create a temp dir for saving intermediate files tmp_dir = law.LocalDirectoryTarget(is_tmp=True) tmp_dir.touch() @@ -111,54 +149,29 @@ def run(self): # get shift dependent aliases aliases = self.local_shift_inst.x("column_aliases", {}) - # define columns that will be written + # define columns that will be written based on the reducer's produced columns, + # but taking into account those that should be skipped (e.g. if not all routes added by a collection are needed) write_columns: set[Route] = set() - skip_columns: set[str] = set() - for c in self.config_inst.x.keep_columns.get(self.task_family, ["*"]): + skip_columns: set[Route] = set() + for c in self.reducer_inst.produced_columns: for r in self._expand_keep_column(c): if r.has_tag("skip"): - skip_columns.add(r.column) + skip_columns.add(r) else: write_columns.add(r) - write_columns = { - r for r in write_columns - if not law.util.multi_match(r.column, skip_columns, mode=any) - } - route_filter = RouteFilter(write_columns) - - # map routes to write to their top level column - write_columns_groups = defaultdict(set) - for route in write_columns: - if len(route) > 1: - write_columns_groups[route[0]].add(route) + route_filter = RouteFilter(keep=write_columns, remove=skip_columns) # define columns that need to be read - read_columns = write_columns | set(mandatory_coffea_columns) | set(aliases.values()) - read_columns = {Route(c) for c in read_columns} - - # define columns to read for the differently structured selection masks - read_sel_columns = set() - # open either selector steps of the full event selection mask - read_sel_columns.add(Route("steps.*" if self.selector_steps else "event")) - # add object masks, depending on the columns to write - # (as object masks are dynamic and deeply nested, preload the meta info to access fields) - sel_results = inputs["selection"]["results"].load(formatter="dask_awkward") - if "objects" in sel_results.fields: - for src_field in sel_results.objects.fields: - for dst_field in sel_results.objects[src_field].fields: - # nothing to do in case the top level column does not need to be loaded - if not law.util.multi_match(dst_field, write_columns_groups.keys()): - continue - # register the object masks - read_sel_columns.add(Route(f"objects.{src_field}.{dst_field}")) - # in case new collections are created and configured to be written, make sure - # that the corresponding columns of the source collection are loaded - if src_field != dst_field: - read_columns |= { - src_field + route[1:] - for route in write_columns_groups[dst_field] - } - del sel_results + read_columns = set(map(Route, mandatory_coffea_columns)) + read_columns |= self.reducer_inst.used_columns + read_columns |= set(map(Route, set(aliases.values()))) + + # columns starting with "steps." and "objects." are implicitly treated as pointing to the selection result data + read_sel_columns = {Route("event")} + for r in list(read_columns): + if r.column.startswith(("steps.", "objects.")): + read_sel_columns.add(r) + read_columns.remove(r) # event counters n_all = 0 @@ -173,6 +186,7 @@ def run(self): input_targets.extend([inp["columns"] for inp in inputs["calibrations"]]) if self.selector_inst.produced_columns: input_targets.append(inputs["selection"]["columns"]) + input_targets.extend(reader_targets.values()) # prepare inputs for localization with law.localize_file_targets(input_targets, mode="r") as inps: @@ -197,45 +211,30 @@ def run(self): missing_strategy=self.missing_column_alias_strategy, ) - # build the event mask - if self.selector_steps: - # check if all steps are present - missing_steps = set(self.selector_steps) - set(sel.steps.fields) - if missing_steps: - raise Exception( - f"selector steps {','.join(missing_steps)} are not produced by " - f"selector '{self.selector}'", - ) - event_mask = functools.reduce( - (lambda a, b: a & b), - (sel["steps", step] for step in self.selector_steps), - ) - else: - event_mask = sel.event if "event" in sel.fields else Ellipsis - - # apply the mask - n_all += len(events) - events = events[event_mask] - n_reduced += len(events) - - # loop through all object selection, go through their masks - # and create new collections if required - if "objects" in sel.fields: - # apply the event mask - events = create_collections_from_masks(events, sel.objects[event_mask]) + # invoke the reducer + if len(events): + n_all += len(events) + events = attach_coffea_behavior(events) + events = self.reducer_inst(events, selection=sel, task=self) + n_reduced += len(events) # remove columns events = route_filter(events) + # optional check for finite values + if self.check_finite_output: + self.raise_if_not_finite(events) + # save as parquet via a thread in the same pool chunk = tmp_dir.child(f"file_{lfn_index}_{pos.index}.parquet", type="f") output_chunks[pos.index] = chunk self.chunked_io.queue(sorted_ak_to_parquet, (ak.to_packed(events), chunk.abspath)) + # teardown the reducer + self.teardown_reducer_inst() + # some logs - self.publish_message( - f"reduced {n_all:_} to {n_reduced:_} events ({safe_div(n_reduced, n_all) * 100:.2f}%)", - ) + self.publish_message(f"reduced {n_all:_} to {n_reduced:_} events ({safe_div(n_reduced, n_all) * 100:.2f}%)") # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] @@ -245,6 +244,12 @@ def run(self): # overwrite class defaults +check_finite_tasks = law.config.get_expanded("analysis", "check_finite_output", [], split_csv=True) +ReduceEvents.check_finite_output = ChunkedIOMixin.check_finite_output.copy( + default=ReduceEvents.task_family in check_finite_tasks, + add_default_to_description=True, +) + check_overlap_tasks = law.config.get_expanded("analysis", "check_overlapping_inputs", [], split_csv=True) ReduceEvents.check_overlapping_inputs = ChunkedIOMixin.check_overlapping_inputs.copy( default=ReduceEvents.task_family in check_overlap_tasks, @@ -258,13 +263,19 @@ def run(self): ) -class MergeReductionStats( - SelectorStepsMixin, +class _MergeReductionStats( CalibratorsMixin, - DatasetTask, + SelectorMixin, + ReducerMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`MergeReductionStats`. + """ + + +class MergeReductionStats(_MergeReductionStats): n_inputs = luigi.IntParameter( default=10, @@ -288,12 +299,12 @@ class MergeReductionStats( ) @classmethod - def resolve_param_values(cls, params): + def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().resolve_param_values(params) # check for the default merged size if "merged_size" in params: - if params["merged_size"] in (None, law.NO_FLOAT): + if params["merged_size"] in {None, law.NO_FLOAT}: merged_size = 512.0 if "config_inst" in params: merged_size = params["config_inst"].x("reduced_file_size", merged_size) @@ -412,13 +423,19 @@ def get_avg_std(values): ) -class MergeReducedEvents( - SelectorStepsMixin, +class _MergeReducedEvents( CalibratorsMixin, - DatasetTask, + SelectorMixin, + ReducerMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`MergeReducedEvents`. + """ + + +class MergeReducedEvents(_MergeReducedEvents): keep_reduced_events = luigi.BoolParameter( default=default_keep_reduced_events, @@ -499,13 +516,19 @@ def run(self): ) -class ProvideReducedEvents( - SelectorStepsMixin, +class _ProvideReducedEvents( CalibratorsMixin, - DatasetTask, + SelectorMixin, + ReducerMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`ProvideReducedEvents`. + """ + + +class ProvideReducedEvents(_ProvideReducedEvents): skip_merging = luigi.BoolParameter( default=False, @@ -638,9 +661,9 @@ def run(self): class ReducedEventsUser( - SelectorStepsMixin, CalibratorsMixin, - DatasetTask, + SelectorMixin, + ReducerMixin, law.BaseWorkflow, ): # upstream requirements diff --git a/columnflow/tasks/selection.py b/columnflow/tasks/selection.py index 2b5c719f3..314f1f3aa 100644 --- a/columnflow/tasks/selection.py +++ b/columnflow/tasks/selection.py @@ -10,13 +10,17 @@ import luigi import law -from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory -from columnflow.tasks.framework.mixins import CalibratorsMixin, SelectorMixin, ChunkedIOMixin +from columnflow.types import Any + +from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory +from columnflow.tasks.framework.mixins import CalibratorsMixin, SelectorMixin, ChunkedIOMixin, ProducerMixin from columnflow.tasks.framework.remote import RemoteWorkflow +from columnflow.tasks.framework.decorators import on_failure from columnflow.tasks.external import GetDatasetLFNs from columnflow.tasks.calibration import CalibrateEvents -from columnflow.production import Producer from columnflow.util import maybe_import, ensure_proxy, dev_sandbox, safe_div, DotDict +from columnflow.tasks.framework.parameters import DerivableInstParameter + np = maybe_import("numpy") ak = maybe_import("awkward") @@ -31,14 +35,23 @@ ) -class SelectEvents( - SelectorMixin, +class _SelectEvents( CalibratorsMixin, + SelectorMixin, ChunkedIOMixin, - DatasetTask, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`SelectEvents`. + """ + + +class SelectEvents(_SelectEvents): + + # disable selector steps + selector_steps = None + # default sandbox, might be overwritten by selector function sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -49,14 +62,8 @@ class SelectEvents( CalibrateEvents=CalibrateEvents, ) - # register sandbox and shifts found in the chosen selector to this task - register_selector_sandbox = True - register_selector_shifts = True - - # strategy for handling missing source columns when adding aliases on event chunks + invokes_selector = True missing_column_alias_strategy = "original" - - # whether histogram outputs should be created create_selection_hists = default_create_selection_hists def workflow_requires(self): @@ -66,7 +73,11 @@ def workflow_requires(self): if not self.pilot: reqs["calibrations"] = [ - self.reqs.CalibrateEvents.req(self, calibrator=calibrator_inst.cls_name) + self.reqs.CalibrateEvents.req( + self, + calibrator=calibrator_inst.cls_name, + calibrator_inst=calibrator_inst, + ) for calibrator_inst in self.calibrator_insts if calibrator_inst.produced_columns ] @@ -76,7 +87,9 @@ def workflow_requires(self): reqs = law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) # add selector dependent requirements - reqs["selector"] = law.util.make_unique(law.util.flatten(self.selector_inst.run_requires())) + reqs["selector"] = law.util.make_unique(law.util.flatten( + self.selector_inst.run_requires(task=self), + )) return reqs @@ -84,14 +97,20 @@ def requires(self): reqs = { "lfns": self.reqs.GetDatasetLFNs.req(self), "calibrations": [ - self.reqs.CalibrateEvents.req(self, calibrator=calibrator_inst.cls_name) + self.reqs.CalibrateEvents.req( + self, + calibrator=calibrator_inst.cls_name, + calibrator_inst=calibrator_inst, + ) for calibrator_inst in self.calibrator_insts if calibrator_inst.produced_columns ], } # add selector dependent requirements - reqs["selector"] = law.util.make_unique(law.util.flatten(self.selector_inst.run_requires())) + reqs["selector"] = law.util.make_unique(law.util.flatten( + self.selector_inst.run_requires(task=self), + )) return reqs @@ -116,6 +135,7 @@ def output(self): @ensure_proxy @law.decorator.localize(input=False) @law.decorator.safe_output + @on_failure(callback=lambda task: task.teardown_selector_inst()) def run(self): from columnflow.tasks.histograms import CreateHistograms from columnflow.columnar_util import ( @@ -133,8 +153,13 @@ def run(self): hists = DotDict() # run the selector setup - selector_reqs = self.selector_inst.run_requires() - reader_targets = self.selector_inst.run_setup(selector_reqs, luigi.task.getpaths(selector_reqs)) + self._array_function_post_init() + selector_reqs = self.selector_inst.run_requires(task=self) + reader_targets = self.selector_inst.run_setup( + task=self, + reqs=selector_reqs, + inputs=luigi.task.getpaths(selector_reqs), + ) n_ext = len(reader_targets) # show an early warning in case the selector does not produce some mandatory columns @@ -161,7 +186,7 @@ def run(self): # define columns that will be written write_columns = set(map(Route, mandatory_coffea_columns)) write_columns |= self.selector_inst.produced_columns - route_filter = RouteFilter(write_columns) + route_filter = RouteFilter(keep=write_columns) # let the lfn_task prepare the nano file (basically determine a good pfn) [(lfn_index, input_file)] = lfn_task.iter_nano_files(self) @@ -199,7 +224,7 @@ def run(self): ) # invoke the selection function - events, results = self.selector_inst(events, stats, hists=hists) + events, results = self.selector_inst(events, task=self, stats=stats, hists=hists) # complain when there is no event mask if results.event is None: @@ -233,6 +258,9 @@ def run(self): column_chunks[(lfn_index, pos.index)] = chunk self.chunked_io.queue(sorted_ak_to_parquet, (events, chunk.abspath)) + # teardown the selector + self.teardown_selector_inst() + # merge the result files sorted_chunks = [result_chunks[key] for key in sorted(result_chunks)] writer_opts_masks = self.get_parquet_writer_opts(repeating_values=True) @@ -286,13 +314,19 @@ def run(self): ) -class MergeSelectionStats( - SelectorMixin, +class _MergeSelectionStats( CalibratorsMixin, - DatasetTask, + SelectorMixin, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`MergeSelectionStats`. + """ + + +class MergeSelectionStats(_MergeSelectionStats): + # default sandbox, might be overwritten by selector function (needed to load hist objects) sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -365,13 +399,19 @@ def merge_counts(cls, dst: dict, src: dict) -> dict: ) -class MergeSelectionMasks( - SelectorMixin, +class _MergeSelectionMasks( CalibratorsMixin, - DatasetTask, + SelectorMixin, law.tasks.ForestMerge, RemoteWorkflow, ): + """ + Base classes for :py:class:`MergeSelectionMasks`. + """ + + +class MergeSelectionMasks(_MergeSelectionMasks): + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) # recursively merge 8 files into one @@ -383,15 +423,31 @@ class MergeSelectionMasks( SelectEvents=SelectEvents, ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + norm_weights_producer = "normalization_weights" + norm_weight_producer_inst = DerivableInstParameter( + default=None, + visibility=luigi.parameter.ParameterVisibility.PRIVATE, + ) - # store the normalization weight producer for MC - self.norm_weight_producer = None - if self.dataset_inst.is_mc: - self.norm_weight_producer = Producer.get_cls("normalization_weights")( - inst_dict=self.get_producer_kwargs(self), - ) + exclude_params_index = {"norm_weight_producer_inst"} + exclude_params_repr = {"norm_weight_producer_inst"} + exclude_params_sandbox = {"norm_weight_producer_inst"} + exclude_params_remote_workflow = {"norm_weight_producer_inst"} + + @classmethod + def get_producer_dict(cls, params: dict[str, Any]) -> dict[str, Any]: + return cls.get_array_function_dict(params) + + build_producer_inst = ProducerMixin.build_producer_inst + + @classmethod + def resolve_instances(cls, params: dict[str, Any], shifts) -> dict[str, Any]: + if not params.get("norm_weight_producer_inst"): + params["norm_weight_producer_inst"] = cls.build_producer_inst(cls.norm_weights_producer, params) + + params = super().resolve_instances(params, shifts) + + return params def create_branch_map(self): # DatasetTask implements a custom branch map, but we want to use the one in ForestMerge @@ -401,7 +457,7 @@ def merge_workflow_requires(self): reqs = {"selection": self.reqs.SelectEvents.req_different_branching(self)} if self.dataset_inst.is_mc: - reqs["normalization"] = self.norm_weight_producer.run_requires() + reqs["normalization"] = self.norm_weight_producer_inst.run_requires(task=self) return reqs @@ -414,7 +470,7 @@ def merge_requires(self, start_branch, end_branch): } if self.dataset_inst.is_mc: - reqs["normalization"] = self.norm_weight_producer.run_requires() + reqs["normalization"] = self.norm_weight_producer_inst.run_requires(task=self) return reqs @@ -446,13 +502,14 @@ def zip_results_and_columns(self, inputs, tmp_dir): Route, RouteFilter, sorted_ak_to_parquet, mandatory_coffea_columns, ) - chunks = [] - # setup the normalization weights producer if self.dataset_inst.is_mc: - self.norm_weight_producer.run_setup( - self.requires()["forest_merge"]["normalization"], - self.input()["forest_merge"]["normalization"], + self._array_function_post_init() + self.norm_weight_producer_inst.run_post_init(task=self) + self.norm_weight_producer_inst.run_setup( + task=self, + reqs=self.requires()["forest_merge"]["normalization"], + inputs=self.input()["forest_merge"]["normalization"], ) # define columns that will be written @@ -471,15 +528,16 @@ def zip_results_and_columns(self, inputs, tmp_dir): # add some mandatory columns write_columns |= set(map(Route, mandatory_coffea_columns)) write_columns |= set(map(Route, {"category_ids", "process_id", "normalization_weight"})) - route_filter = RouteFilter(write_columns) + route_filter = RouteFilter(keep=write_columns) + chunks = [] for inp in inputs: events = inp["columns"].load(formatter="awkward", cache=False) steps = inp["results"].load(formatter="awkward", cache=False).steps # add normalization weight if self.dataset_inst.is_mc: - events = self.norm_weight_producer(events) + events = self.norm_weight_producer_inst(events, task=self) # remove columns events = route_filter(events) @@ -491,6 +549,10 @@ def zip_results_and_columns(self, inputs, tmp_dir): chunks.append(chunk) sorted_ak_to_parquet(out, chunk.abspath) + # teardown the normalization weights producer + if self.dataset_inst.is_mc: + self.norm_weight_producer_inst.run_teardown(task=self) + return chunks diff --git a/columnflow/tasks/union.py b/columnflow/tasks/union.py index e0fece50b..5b52d22b3 100644 --- a/columnflow/tasks/union.py +++ b/columnflow/tasks/union.py @@ -16,14 +16,21 @@ from columnflow.util import dev_sandbox -class UniteColumns( - MLModelsMixin, +class _UniteColumns( + ReducedEventsUser, ProducersMixin, + MLModelsMixin, ChunkedIOMixin, - ReducedEventsUser, law.LocalWorkflow, RemoteWorkflow, ): + """ + Base classes for :py:class:`UniteColumns`. + """ + + +class UniteColumns(_UniteColumns): + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) file_type = luigi.ChoiceParameter( @@ -49,7 +56,11 @@ def workflow_requires(self): if not self.pilot: if self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -66,7 +77,11 @@ def requires(self): if self.producer_insts: reqs["producers"] = [ - self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + self.reqs.ProduceColumns.req( + self, + producer=producer_inst.cls_name, + producer_inst=producer_inst, + ) for producer_inst in self.producer_insts if producer_inst.produced_columns ] @@ -105,18 +120,14 @@ def run(self): # define columns that will be written write_columns: set[Route] = set() - skip_columns: set[str] = set() + skip_columns: set[Route] = set() for c in self.config_inst.x.keep_columns.get(self.task_family, ["*"]): for r in self._expand_keep_column(c): if r.has_tag("skip"): - skip_columns.add(r.column) + skip_columns.add(r) else: write_columns.add(r) - write_columns = { - r for r in write_columns - if not law.util.multi_match(r.column, skip_columns, mode=any) - } - route_filter = RouteFilter(write_columns) + route_filter = RouteFilter(keep=write_columns, remove=skip_columns) # define columns that need to be read read_columns = write_columns | set(mandatory_coffea_columns) @@ -128,6 +139,7 @@ def run(self): files.extend([inp["columns"].abspath for inp in inputs["producers"]]) if self.ml_model_insts: files.extend([inp["mlcolumns"].abspath for inp in inputs["ml"]]) + for (events, *columns), pos in self.iter_chunked_io( files, source_type=len(files) * ["awkward_parquet"], @@ -178,7 +190,6 @@ def run(self): add_default_to_description=True, ) - UniteColumnsWrapper = wrapper_factory( base_cls=AnalysisTask, require_cls=UniteColumns, diff --git a/columnflow/tasks/yields.py b/columnflow/tasks/yields.py index e7d26ca57..55aec3c98 100644 --- a/columnflow/tasks/yields.py +++ b/columnflow/tasks/yields.py @@ -13,25 +13,34 @@ from columnflow.tasks.framework.base import Requirements from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, ProducersMixin, - DatasetsProcessesMixin, CategoriesMixin, WeightProducerMixin, + CalibratorClassesMixin, SelectorClassMixin, ReducerClassMixin, ProducerClassesMixin, HistProducerClassMixin, + DatasetsProcessesMixin, CategoriesMixin, ) from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.histograms import MergeHistograms from columnflow.util import dev_sandbox, try_int -class CreateYieldTable( +class _CreateYieldTable( + CalibratorClassesMixin, + SelectorClassMixin, + ReducerClassMixin, + ProducerClassesMixin, + HistProducerClassMixin, DatasetsProcessesMixin, CategoriesMixin, - WeightProducerMixin, - ProducersMixin, - SelectorStepsMixin, - CalibratorsMixin, law.LocalWorkflow, RemoteWorkflow, ): - sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + """ + Base classes for :py:class:`CreateYieldTable`. + """ + + single_config = True + resolution_task_cls = MergeHistograms + + +class CreateYieldTable(_CreateYieldTable): table_format = luigi.Parameter( default="fancy_grid", @@ -63,15 +72,17 @@ class CreateYieldTable( description="Adds a suffix to the output name of the yields table; empty default", ) + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + # upstream requirements reqs = Requirements( RemoteWorkflow.reqs, MergeHistograms=MergeHistograms, ) - # dummy branch map def create_branch_map(self): - return [0] + # dummy branch map + return {0: None} def requires(self): return { @@ -157,9 +168,9 @@ def run(self): # axis selections h = h[{ "process": [ - hist.loc(p.id) + hist.loc(p.name) for p in sub_process_insts[process_inst] - if p.id in h.axes["process"] + if p.name in h.axes["process"] ], }] @@ -192,9 +203,9 @@ def run(self): leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] h_cat = h[{"category": [ - hist.loc(c.id) + hist.loc(c.name) for c in leaf_category_insts - if c.id in h.axes["category"] + if c.name in h.axes["category"] ]}] h_cat = h_cat[{"category": sum}] diff --git a/columnflow/types.py b/columnflow/types.py index fc3714139..cfe437207 100644 --- a/columnflow/types.py +++ b/columnflow/types.py @@ -26,7 +26,7 @@ Type, ) -from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType # noqa +from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType, TypeAlias # noqa #: Generic type variable, more stringent than Any. diff --git a/columnflow/util.py b/columnflow/util.py index 1df09623d..ef796eb1a 100644 --- a/columnflow/util.py +++ b/columnflow/util.py @@ -9,6 +9,8 @@ __all__ = [] import os +import io +import re import abc import uuid import queue @@ -16,20 +18,20 @@ import subprocess import importlib import fnmatch -import re import inspect +import pprint import multiprocessing import multiprocessing.pool from functools import wraps from collections import OrderedDict import law -from law.util import InsertableDict # noqa import luigi -from columnflow import env_is_dev, env_is_remote +from columnflow import env_is_dev, env_is_remote, docs_url, github_url from columnflow.types import Callable, Any, Sequence, Union, ModuleType + #: Placeholder for an unset value. UNSET = object() @@ -161,6 +163,53 @@ def ipython_shell( return InteractiveShellEmbed.instance(config=config, display_banner=banner) +def prettify(obj: Any, **kwargs) -> str: + """ + Prettifies the string repserentation of an object *obj* and returns it. + + :param obj: Object to prettify. + :param kwargs: Optional arguments passed to :py:meth:`pprint.pprint`. + :return: Prettified string representation. + """ + s = io.StringIO() + pprint.pprint(obj, stream=s, **kwargs) + s.seek(0) + return s.read() + + +def get_docs_url(*parts: str, anchor: str | None = None) -> str: + """ + Returns a URL pointing to the documentation of a particular page defined by *parts*. When an *anchor* is defined, + it is appended to the URL. + """ + url = "/".join([docs_url, *(str(part).strip("/") for part in parts)]) + if anchor: + url += f"#{anchor}" + return url + + +def get_github_url(*parts: str) -> str: + """ + Returns a URL pointing to the repository on github including additional URL fragments *parts*. + """ + url = "/".join([github_url, *(str(part).strip("/") for part in parts)]) + return url + + +def get_release_url(tag: str) -> str: + """ + Returns a URL pointing to the release notes of a particular tag. + """ + return get_github_url("releases", "tag", f"v{tag.lstrip('/v')}") + + +def get_code_url(*parts: str, branch: str = "master") -> str: + """ + Returns a URL pointing to specific code on the github repository, defined by *parts* and the corresponding *branch*. + """ + return get_github_url("blob", branch, *parts) + + def create_random_name() -> str: """ Returns a random string based on UUID v4. @@ -465,10 +514,9 @@ def maybe_int(i: Any) -> Any: def is_pattern(s: str) -> bool: """ - Returns *True* if a string *s* contains pattern characters such as "*" or "?", and *False* - otherwise. + Returns *True* if a string *s* contains pattern characters such as "*" or "?", and *False* otherwise. """ - return "*" in s or "?" in s + return "*" in s or "?" in s or s.startswith("!") def is_regex(s: str) -> bool: @@ -485,6 +533,9 @@ def pattern_matcher(pattern: Sequence[str] | str, mode: Callable = any) -> Calla or just a plain string and returns a function that can be used to test of a string matches that pattern. + Patterns starting with "^" and ending with "$" are considered regular expressions, and otherwise fnmatch patterns. + In the latter case, when the pattern starts with a "!", the match is inverted. + When *pattern* is a sequence, all its patterns are compared the same way and the result is the combination given a *mode* which typically should be *any* or *all*. @@ -500,6 +551,10 @@ def pattern_matcher(pattern: Sequence[str] | str, mode: Callable = any) -> Calla matcher("foox") # -> False matcher("foo1") # -> True + matcher = pattern_matcher("!foo*") + matcher("foo123") # -> False + matcher("bar123") # -> True + matcher = pattern_matcher(("foo*", "*bar"), mode=any) matcher("foo123") # -> True matcher("123bar") # -> True @@ -524,6 +579,9 @@ def pattern_matcher(pattern: Sequence[str] | str, mode: Callable = any) -> Calla # identify fnmatch patterns if is_pattern(pattern): + negate = pattern.startswith("!") + if negate: + return lambda s: not fnmatch.fnmatch(s, pattern[1:]) return lambda s: fnmatch.fnmatch(s, pattern) # fallback to string comparison diff --git a/columnflow/weight/__init__.py b/columnflow/weight/__init__.py deleted file mode 100644 index 4d3a9af83..000000000 --- a/columnflow/weight/__init__.py +++ /dev/null @@ -1,119 +0,0 @@ -# coding: utf-8 - -""" -Tools for producing new columns to be used as event or object weights. -""" - -from __future__ import annotations - -import inspect - -from columnflow.types import Callable -from columnflow.util import DerivableMeta -from columnflow.columnar_util import TaskArrayFunction - - -class WeightProducer(TaskArrayFunction): - """ - Base class for all weight producers, i.e., functions that produce and return a single column - that is meant to be used as a per-event or per-object weight. - """ - - exposed = True - - @classmethod - def weight_producer( - cls, - func: Callable | None = None, - bases: tuple = (), - mc_only: bool = False, - data_only: bool = False, - **kwargs, - ) -> DerivableMeta | Callable: - """ - Decorator for creating a new :py:class:`WeightProducer` subclass with additional, optional - *bases* and attaching the decorated function to it as :py:meth:`~WeightProducer.call_func`. - - When *mc_only* (*data_only*) is *True*, the weight producer is skipped and not considered by - other calibrators, selectors and producers in case they are evaluated on a - :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` - (``is_data``) attribute is *False*. - - When *nominal_only* is *True* or *shifts_only* is set, the producer is skipped and not - considered by other calibrators, selectors and producers in case they are evaluated on a - :py:class:`order.Shift` (using the :py:attr:`global_shift_inst` attribute) whose name does - not match. - - All additional *kwargs* are added as class members of the new subclasses. - - :param func: Function to be wrapped and integrated into new :py:class:`WeightProducer` - class. - :param bases: Additional bases for the new :py:class:`WeightProducer`. - :param mc_only: Boolean flag indicating that this :py:class:`WeightProducer` should only run - on Monte Carlo simulation and skipped for real data. - :param data_only: Boolean flag indicating that this :py:class:`WeightProducer` should only - run on real data and skipped for Monte Carlo simulation. - :return: New :py:class:`WeightProducer` subclass. - """ - def decorator(func: Callable) -> DerivableMeta: - # create the class dict - cls_dict = { - **kwargs, - "call_func": func, - "mc_only": mc_only, - "data_only": data_only, - } - - # get the module name - frame = inspect.stack()[1] - module = inspect.getmodule(frame[0]) - - # get the producer name - cls_name = cls_dict.pop("cls_name", func.__name__) - - # hook to update the class dict during class derivation - def update_cls_dict(cls_name, cls_dict, get_attr): - mc_only = get_attr("mc_only") - data_only = get_attr("data_only") - - # optionally add skip function - if mc_only and data_only: - raise Exception( - f"weight producer {cls_name} received both mc_only and data_only", - ) - - if mc_only or data_only: - if cls_dict.get("skip_func"): - raise Exception( - f"weight producer {cls_name} received custom skip_func, but either " - "mc_only or data_only are set", - ) - - if "skip_func" not in cls_dict: - def skip_func(self): - # check mc_only and data_only - if getattr(self, "dataset_inst", None): - if mc_only and not self.dataset_inst.is_mc: - return True - if data_only and not self.dataset_inst.is_data: - return True - - # in all other cases, do not skip - return False - - cls_dict["skip_func"] = skip_func - - return cls_dict - - cls_dict["update_cls_dict"] = update_cls_dict - - # create the subclass - subclass = cls.derive(cls_name, bases=bases, cls_dict=cls_dict, module=module) - - return subclass - - return decorator(func) if func else decorator - - -# shorthand -weight_producer = WeightProducer.weight_producer diff --git a/columnflow/weight/all_weights.py b/columnflow/weight/all_weights.py deleted file mode 100644 index 64c3c9bbd..000000000 --- a/columnflow/weight/all_weights.py +++ /dev/null @@ -1,82 +0,0 @@ -# coding: utf-8 - -""" -Exemplary event weight producer. -""" - -from columnflow.weight import WeightProducer, weight_producer -from columnflow.util import maybe_import -from columnflow.columnar_util import has_ak_column, Route - -np = maybe_import("numpy") -ak = maybe_import("awkward") - - -@weight_producer( - # only run on mc - mc_only=True, -) -def all_weights(self: WeightProducer, events: ak.Array, **kwargs) -> ak.Array: - """ - WeightProducer that combines all event weights from the *event_weights* aux entry from either - the config or the dataset. The weights are multiplied together to form the full event weight. - - The expected structure of the *event_weights* aux entry is a dictionary with the weight column - name as key and a list of shift sources as values. The shift sources are used to declare the - shifts that the produced event weight depends on. Example: - - .. code-block:: python - - from columnflow.config_util import get_shifts_from_sources - # add weights and their corresponding shifts for all datasets - cfg.x.event_weights = { - "normalization_weight": [], - "muon_weight": get_shifts_from_sources(config, "mu_sf"), - "btag_weight": get_shifts_from_sources(config, "btag_hf", "btag_lf"), - } - for dataset_inst in cfg.datasets: - # add dataset-specific weights and their corresponding shifts - dataset.x.event_weights = {} - if not dataset_inst.has_tag("skip_pdf"): - dataset_inst.x.event_weights["pdf_weight"] = get_shifts_from_sources(config, "pdf") - """ - # build the full event weight - weight = ak.Array(np.ones(len(events))) - if self.dataset_inst.is_mc and len(events): - # multiply weights from global config `event_weights` aux entry - for column in self.config_inst.x.event_weights: - weight = weight * Route(column).apply(events) - - # multiply weights from dataset-specific `event_weights` aux entry - for column in self.dataset_inst.x("event_weights", []): - if has_ak_column(events, column): - weight = weight * Route(column).apply(events) - else: - self.logger.warning_once( - f"missing_dataset_weight_{column}", - f"weight '{column}' for dataset {self.dataset_inst.name} not found", - ) - return events, weight - - -@all_weights.init -def all_weights_init(self: WeightProducer) -> None: - if not getattr(self, "dataset_inst", None): - return - - weight_columns = set() - - # add used weight columns and declare shifts that the produced event weight depends on - if self.config_inst.has_aux("event_weights"): - weight_columns |= {Route(column) for column in self.config_inst.x.event_weights} - for shift_insts in self.config_inst.x.event_weights.values(): - self.shifts |= {shift_inst.name for shift_inst in shift_insts} - - # optionally also for weights defined by a dataset - if self.dataset_inst.has_aux("event_weights"): - weight_columns |= {Route(column) for column in self.dataset_inst.x("event_weights", [])} - for shift_insts in self.dataset_inst.x.event_weights.values(): - self.shifts |= {shift_inst.name for shift_inst in shift_insts} - - # add weight columns to uses - self.uses |= weight_columns diff --git a/columnflow/weight/empty.py b/columnflow/weight/empty.py deleted file mode 100644 index 178b063fb..000000000 --- a/columnflow/weight/empty.py +++ /dev/null @@ -1,17 +0,0 @@ -# coding: utf-8 - -""" -Empty event weight producer. -""" - -from columnflow.weight import WeightProducer, weight_producer -from columnflow.util import maybe_import - -np = maybe_import("numpy") -ak = maybe_import("awkward") - - -@weight_producer -def empty(self: WeightProducer, events: ak.Array, **kwargs) -> ak.Array: - # simply return ones - return events, ak.Array(np.ones(len(events), dtype=np.float32)) diff --git a/docs/Makefile b/docs/Makefile index 7da3bb088..f0d518bea 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,6 +1,6 @@ # Makefile for Sphinx documentation -SPHINXOPTS = +SPHINXOPTS = -v SPHINXBUILD = sphinx-build BUILDDIR = _build diff --git a/docs/api/calibration/cms/index.rst b/docs/api/calibration/cms/index.rst index c91151f5e..6ef82426b 100644 --- a/docs/api/calibration/cms/index.rst +++ b/docs/api/calibration/cms/index.rst @@ -1,16 +1,17 @@ ``cms`` ======= -.. automodule:: columnflow.calibration.cms + .. currentmodule:: columnflow.calibration.cms - + +.. automodule:: columnflow.calibration.cms + Summary ------- .. autosummary:: -.. toctree:: +.. toctree:: :maxdepth: 1 - - jets_coffea + jets met diff --git a/docs/api/calibration/cms/jets_coffea.rst b/docs/api/calibration/cms/jets_coffea.rst deleted file mode 100644 index ad47f2414..000000000 --- a/docs/api/calibration/cms/jets_coffea.rst +++ /dev/null @@ -1,44 +0,0 @@ -``jets_coffea`` -========================================== -.. automodule:: columnflow.calibration.cms.jets_coffea -.. currentmodule:: columnflow.calibration.cms.jets_coffea - -Summary -------- - -.. autosummary:: - get_basenames - get_lookup_provider - jets_coffea - jets_coffea_init - jec_coffea - jec_coffea_init - jec_coffea_requires - jec_coffea_setup - jer_coffea - jer_coffea_init - jer_coffea_requires - jer_coffea_setup - - -Functions ---------- - -.. autofunction:: get_basenames - -.. autofunction:: get_lookup_provider - -Calibrators ------------ - -.. autoclass:: jets_coffea - :members: - :undoc-members: - -.. autoclass:: jec_coffea - :members: - :undoc-members: - -.. autoclass:: jer_coffea - :members: - :undoc-members: \ No newline at end of file diff --git a/docs/api/histogramming/default.rst b/docs/api/histogramming/default.rst new file mode 100644 index 000000000..4d7f81812 --- /dev/null +++ b/docs/api/histogramming/default.rst @@ -0,0 +1,9 @@ +``default`` +=========== + +.. currentmodule:: columnflow.histogramming.default + +.. automodule:: columnflow.histogramming.default + :autosummary: + :members: + :undoc-members: diff --git a/docs/api/histogramming/index.rst b/docs/api/histogramming/index.rst new file mode 100644 index 000000000..99d516f57 --- /dev/null +++ b/docs/api/histogramming/index.rst @@ -0,0 +1,14 @@ +``columnflow.histogramming`` +============================ + +.. currentmodule:: columnflow.histogramming + +.. automodule:: columnflow.histogramming + :autosummary: + :members: + :undoc-members: + +.. toctree:: + :maxdepth: 1 + + default diff --git a/docs/api/index.rst b/docs/api/index.rst index c97c79191..5a25b050a 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -6,16 +6,15 @@ API Reference calibration/index categorization/index - categorization/index + histogramming/index inference/index ml/index plotting/index production/index + reduction/index selection/index tasks/index - weights/index columnar_util config_util - types util - + types diff --git a/docs/api/reduction/default.rst b/docs/api/reduction/default.rst new file mode 100644 index 000000000..5f37ebf42 --- /dev/null +++ b/docs/api/reduction/default.rst @@ -0,0 +1,9 @@ +``default`` +=========== + +.. currentmodule:: columnflow.reduction.default + +.. automodule:: columnflow.reduction.default + :autosummary: + :members: + :undoc-members: diff --git a/docs/api/reduction/index.rst b/docs/api/reduction/index.rst new file mode 100644 index 000000000..5eca70d9f --- /dev/null +++ b/docs/api/reduction/index.rst @@ -0,0 +1,14 @@ +``columnflow.reduction`` +======================== + +.. currentmodule:: columnflow.reduction + +.. automodule:: columnflow.reduction + :autosummary: + :members: + :undoc-members: + +.. toctree:: + :maxdepth: 1 + + default diff --git a/docs/api/selection/index.rst b/docs/api/selection/index.rst index 4c4be0f04..55e7a03ee 100644 --- a/docs/api/selection/index.rst +++ b/docs/api/selection/index.rst @@ -11,7 +11,5 @@ :maxdepth: 1 empty - matching stats - util - cms/index \ No newline at end of file + cms/index diff --git a/docs/api/selection/matching.rst b/docs/api/selection/matching.rst deleted file mode 100644 index 30071fa7d..000000000 --- a/docs/api/selection/matching.rst +++ /dev/null @@ -1,9 +0,0 @@ -``matching`` -============ - -.. currentmodule:: columnflow.selection.matching -.. automodule:: columnflow.selection.matching - :autosummary: - :members: - :undoc-members: - diff --git a/docs/api/selection/util.rst b/docs/api/selection/util.rst deleted file mode 100644 index 12b1066c1..000000000 --- a/docs/api/selection/util.rst +++ /dev/null @@ -1,9 +0,0 @@ -``util`` -======== - -.. currentmodule:: columnflow.selection.util -.. automodule:: columnflow.selection.util - :autosummary: - :members: - :undoc-members: - diff --git a/docs/api/types.rst b/docs/api/types.rst index 4bf7f71f9..5d5209041 100644 --- a/docs/api/types.rst +++ b/docs/api/types.rst @@ -1,7 +1,8 @@ ``columnflow.types`` -========================== +==================== .. currentmodule:: columnflow.types + .. automodule:: columnflow.types :members: :undoc-members: diff --git a/docs/api/weights/all_weights.rst b/docs/api/weights/all_weights.rst deleted file mode 100644 index 6423e0617..000000000 --- a/docs/api/weights/all_weights.rst +++ /dev/null @@ -1,9 +0,0 @@ -``all_weights`` -=============== - -.. currentmodule:: columnflow.weight.all_weights -.. automodule:: columnflow.weight.all_weights - :autosummary: - :members: - :undoc-members: - diff --git a/docs/api/weights/empty.rst b/docs/api/weights/empty.rst deleted file mode 100644 index c92b16765..000000000 --- a/docs/api/weights/empty.rst +++ /dev/null @@ -1,9 +0,0 @@ -``empty`` -========= - -.. currentmodule:: columnflow.weight.empty -.. automodule:: columnflow.weight.empty - :autosummary: - :members: - :undoc-members: - diff --git a/docs/api/weights/index.rst b/docs/api/weights/index.rst deleted file mode 100644 index 60fca994e..000000000 --- a/docs/api/weights/index.rst +++ /dev/null @@ -1,14 +0,0 @@ -``columnflow.weight`` -===================== - -.. currentmodule:: columnflow.weight -.. automodule:: columnflow.weight - :autosummary: - :members: - :undoc-members: - -.. toctree:: - :maxdepth: 1 - - all_weights - empty diff --git a/docs/conf.py b/docs/conf.py index b958286ba..315ba4623 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,9 +11,10 @@ sys.path.insert(0, os.path.join(projdir, "modules", "law")) sys.path.insert(0, os.path.join(projdir, "modules", "order")) sys.path.insert(0, projdir) -# os.environ["LAW_CONFIG_FILE"] = os.path.join(projdir, "docs", "law.cfg") + os.environ["LAW_CONFIG_FILE"] = os.path.join(projdir, "law.cfg") +import luigi import columnflow as cf project = "columnflow" @@ -66,15 +67,15 @@ }) extensions = [ - "sphinx_design", - "sphinx_copybutton", "sphinx.ext.intersphinx", "sphinx.ext.autodoc", - "sphinx_autodoc_typehints", "sphinx.ext.viewcode", "sphinx.ext.autosectionlabel", - "sphinxcontrib.mermaid", + "sphinx_design", + "sphinx_copybutton", + "sphinx_autodoc_typehints", "sphinx_lfs_content", + "sphinxcontrib.mermaid", "autodocsumm", "myst_parser", "pydomain_patch", @@ -94,8 +95,7 @@ intersphinx_aliases = { # alias for a class that was imported at its package level - ("py:class", "awkward.highlevel.Array"): - ("py:class", "ak.Array"), + ("py:class", "awkward.highlevel.Array"): ("py:class", "ak.Array"), } intersphinx_mapping = { @@ -112,8 +112,6 @@ "scinum": ("https://scinum.readthedocs.io/en/stable/", None), } -import luigi - def process_docstring(app, what, name, obj, options, lines): if isinstance(obj, luigi.parameter.Parameter): diff --git a/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_st__cat_incl__var_cf_jet1_pt.svg b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_st__cat_incl__var_cf_jet1_pt.svg new file mode 100644 index 000000000..cfde1f8d5 --- /dev/null +++ b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_st__cat_incl__var_cf_jet1_pt.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73fb556f8ad709dac829b80129f91dc5dd51119549d0d632a6e4fd7a8432cd25 +size 210242 diff --git a/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_tt__cat_incl__var_cf_jet1_pt.svg b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_tt__cat_incl__var_cf_jet1_pt.svg new file mode 100644 index 000000000..bdfb2eb43 --- /dev/null +++ b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_tt__cat_incl__var_cf_jet1_pt.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86c5c00f4a094ed50fdd31d58b90b3d896058a1f4419014f3993131bbff00cfb +size 202776 diff --git a/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step0_Initial__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step0_Initial__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg new file mode 100644 index 000000000..14b3f4ed2 --- /dev/null +++ b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step0_Initial__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a7e2fcff6a0613c161648c8c6516cf0795d03d81a91f73c5cd7ca254abc421b +size 248365 diff --git a/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step1_jet__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step1_jet__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg new file mode 100644 index 000000000..04a6e273a --- /dev/null +++ b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step1_jet__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d382de76aec4c8a67a4461e40fed9e9539c0563f2a0a1a9a3643adfeb9d6387 +size 242995 diff --git a/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step2_muon__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step2_muon__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg new file mode 100644 index 000000000..e82842c4d --- /dev/null +++ b/docs/plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step2_muon__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbd6d7cf3f4fa21c9966c6ea50ee177c9701ff2e912cbd7af63a5f2f7434ffaa +size 250007 diff --git a/docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_2j.svg b/docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_2j.svg new file mode 100644 index 000000000..d6e3688f3 --- /dev/null +++ b/docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_2j.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34dafc8075593ff87f6c6bf0405f2ad61cbbd3c803e10794aaed6c8a9189b471 +size 81445 diff --git a/docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_incl.svg b/docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_incl.svg new file mode 100644 index 000000000..bcfdea27f --- /dev/null +++ b/docs/plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_incl.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b5bc97efe206a5cb701c47084fd2516c85ec3ad0300f920793e64da1b448e5f +size 82094 diff --git a/docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_jet1_pt.svg b/docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_jet1_pt.svg new file mode 100644 index 000000000..18b0b3971 --- /dev/null +++ b/docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_jet1_pt.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:633307589b4d591bcc61187b5151681a6dcbd61bfc69c9e8882e423ebbda391c +size 214961 diff --git a/docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_n_jet.svg b/docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_n_jet.svg new file mode 100644 index 000000000..00dfc718a --- /dev/null +++ b/docs/plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_n_jet.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdcf0f2fce69c55b40fad8043e900ff55fbec5a078da8be4dcebe54515a8b920 +size 135064 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c1.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c1.svg new file mode 100644 index 000000000..de73735b7 --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c1.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a08862597f9cbc4b5d2d856a22dc446c5af087d1624a4455d40ac2c7a433562 +size 109288 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c1.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c1.svg new file mode 100644 index 000000000..85f01b2e3 --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c1.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe5e738db8a28cd5b5bf7bef6d3ccac05d6969b5240c1b2c4feb192f3833c0d3 +size 81395 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_2j__var_n_jet.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_2j__var_n_jet.svg new file mode 100644 index 000000000..3fc072225 --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_2j__var_n_jet.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18b913c59ed746a0fa44e4d757d1bafa70a15310505487f28aa039d9bed815b7 +size 155256 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg new file mode 100644 index 000000000..37c30658a --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b001d593c28235d89235c34377ab6ec7bf321db66a0b29d70459df22c4de30d2 +size 157627 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__4601e8554b__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__4601e8554b__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg new file mode 100644 index 000000000..37c30658a --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__4601e8554b__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b001d593c28235d89235c34377ab6ec7bf321db66a0b29d70459df22c4de30d2 +size 157627 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c3.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c3.svg new file mode 100644 index 000000000..02c6ca693 --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c3.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9ead543626fa2f47a7d61d9db2fae878c911d4dd89ba8dbc79d57dfdcc8058f +size 105864 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c3.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c3.svg new file mode 100644 index 000000000..2b7d29e28 --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c3.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1032fa21375905110bb671ed5be9c1c7c544bc494a859a5a71ced93b1c0f6a52 +size 94052 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c2.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c2.svg new file mode 100644 index 000000000..f002829fc --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c2.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b434012dade58ee791ce641dbeda592e66344021ba1d43748d34fae9d2538607 +size 108692 diff --git a/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c2.svg b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c2.svg new file mode 100644 index 000000000..90ceb657e --- /dev/null +++ b/docs/plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c2.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b10b921b172e8bd05b970bb41a53bac452ea0b68324e4367cd7c2a063dd0a694 +size 103256 diff --git a/docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt-n_jet.svg b/docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt-n_jet.svg new file mode 100644 index 000000000..2e4ff35c2 --- /dev/null +++ b/docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt-n_jet.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cadccf88f3e91950ed180d3b40dd5b2893a10086e054a4eb22c87ae895dcbcb6 +size 178280 diff --git a/docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_n_jet-jet1_pt.svg b/docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_n_jet-jet1_pt.svg new file mode 100644 index 000000000..ff0478bca --- /dev/null +++ b/docs/plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_n_jet-jet1_pt.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:418d0c0f667ee18d2cf1f1ebd67b0ecab2c926c9edffe886121274f9dce301bd +size 178610 diff --git a/docs/requirements.txt b/docs/requirements.txt index 284c87f4f..a35fdf012 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,15 +1,15 @@ -# version 7 +# version 8 # documentation packages -sphinx~=6.2.1 -sphinx-design~=0.5.0 +sphinx~=7.4.7 +sphinx-design~=0.6.1 sphinx-copybutton~=0.5.2 -sphinx-autodoc-typehints~=1.22,<1.23 -sphinx-book-theme~=1.0.1 -sphinx-lfs-content~=1.1.3,!=1.1.5 -autodocsumm~=0.2.11 -myst-parser~=2.0.0 -sphinxcontrib-mermaid~=0.9.2 +sphinx-autodoc-typehints~=2.3.0 +sphinx-book-theme~=1.1.4 +sphinx-lfs-content~=1.1.8 +sphinxcontrib-mermaid~=1.0.0 +autodocsumm~=0.2.14 +myst-parser~=3.0.1 # prod packages -r ../sandboxes/cf.txt diff --git a/docs/user_guide/02_03_transition.md b/docs/user_guide/02_03_transition.md new file mode 100644 index 000000000..87bd843fa --- /dev/null +++ b/docs/user_guide/02_03_transition.md @@ -0,0 +1,291 @@ +# v0.2 → v0.3 Transition + +This document describes changes on columnflow introduced in version 0.3.0 that may affect existing code as well as already created output files. +These changes were made in a refactoring campaign (see [release v0.3](https://github.com/columnflow/columnflow/releases/tag/v0.3.0)) that was necessary to generalize some decisions made in an earlier stage of the project, and to ultimately support more analysis use cases that require a high degree of flexibility in many aspects of the framework. + +The changes are grouped into the following categories: + +- [Restructured Task Array Functions](#restructured-task-array-functions) +- [Multi-config Tasks](#multi-config-tasks) +- [Reducers](#reducers) +- [Histogram Producers](#histogram-producers) +- [Inference Model Updates](#inference-model-updates) +- [Changed Plotting Task Names](#changed-plotting-task-names) + +## Restructured Task Array Functions + +The internals of task array functions (TAF) like calibrators, selectors and producers received a major overhaul. +Not all changes affect user code but some might. + +Most notably, TAFs **no longer** have the attributes `task`, `global_shift_inst`, `local_shift_inst`. +Instead, some of the configurable functions now receive a `task` argument through which task information and attributes like shifts can be accessed. +In turn, the attributes `analysis_inst`, `config_inst` and `dataset_inst` are guarenteed to **always be available**, and there is no longer the need to dynamically check their existence. + +This change reflects the new state separation imposed by the order in which underlying, customizable functions (or *hooks*) are called. +A full overview of these hooks and arguments they received are listed in the [task array functions documentation](./task_array_functions.md). +In short, there are three types of hooks: + +1. `pre_init`, `init`, `post_init`: Initialization hooks meant to dynamically update used and produced columns and TAF dependencies. `post_init` is the first hook to receive the `task` argument. +2. `requires`, `setup`, `teardown`: Methods to define custom task requirements, setting up attributes of the task array function before event processing, and to clean up and free resources afterwards. +3. `__call__`: The main callable that is invoked for each event chunk. + +`pre_init`, `post_init` and `teardown` have been newly introduced. +See the [task array function interface](./task_array_functions.md#taf-interface) for a full descrption of all hooks and the arguments they receive. + +(Note that, as before, while the hooks to register custom functions are named as shown above, the functions stored internally have an additional suffix and are named `_func`.) + +### Example + +The example below shows a simple producer that calculates the invariant mass of the two leading jets per event. +The `task` argument is now passed to the function, and the `task.logger` can be used to log messages in the scope of the task. + +```python +import law +import awkward as ak +from columnflow.production import Producer, producer +from columnflow.columnar_util import set_ak_column + +@producer( + uses={"Jet.{pt,eta,phi,mass}"}, + produces={"di_jet_mass"}, +) +def di_jet_mass(self: Producer, events: ak.Array, task: law.Task) -> ak.Array: + # issue a warning in case less than 2 jets are present + if ak.any(ak.num(events.Jet, axis=1) < 2): + task.logger.warning("encountered event with less than 2 jets") + + di_jet = events.Jet[:, :2].sum(axis=1) + events = set_ak_column(events, "di_jet_mass", di_jet.mass, value_type="float32") + + return events +``` + +### Update Instructions + +1. Checkout the [TAF interface](./task_array_functions.md#taf-interface) to learn about the arguments that the hooks receive. In particular, the `task` argument is now passed to all hooks after (and including) `post_init`. +2. Make sure to no longer use the TAF attribites `self.task`, `self.global_shift_inst`, and `self.local_shift_inst`. Access them through `task` argument instead. +3. Depending on whether your custom TAF required access to these attributes, for instance in the `init` hook, you need to move your code to a different hook such as `post_init`. +4. If your TAF blocked specific resources, such as a large object, ML model, etc. loaded during `setup`, think about releasing these resources in the `teardown` hook. +5. Also, all TAF instances are chached from now on, given the combination of `self.analysis_inst`, `self.config_inst` and `self.dataset_inst`. + +## Multi-config Tasks + +Most of the tasks provided by columnflow operate on a single analysis configuration (usually representing self-contained data taking periods or *eras*). +Examples are `cf.CalibrateEvents` and `cf.SelectEvents`, or `cf.ProduceColumns` and `cf.CreateHistograms` which do the heavy lifting in terms of event processing. + +However, some tasks require access to data of multiple eras at a time, and therefore, access to multiple analysis configurations. +We refer to these tasks as **multi-config tasks**. + +In version 0.3, the following tasks are multi-config tasks: + +- Most plotting tasks: tasks like `cf.PlotVariables1D` need to be able to draw events simulated for / recorded in multiple eras into the same plot. +- `cf.MLTraining`: For many ML training applications it is reasonable to train on data from multiple eras, given that detector conditions are not too different. It is now possible to request data from multiple eras to be loaded for a single training. +- `cf.CreateDatacards` (CMS-specific): The inference model interface as well as the datacard export routines now support entries for multiple configurations. See the [changes to the inference model interface](#inference-model-updates) below for details. + +### Update Instructions + +All instructions only apply to the CLI usage of tasks. + +1. Tasks listed above no longer have a `--config` parameter. However, they now have a `--configs` parameter that accepts multiple configuration names as a comma-separate sequece. In order to achieve the single-config behavior, just pass the name of a single configuration here. +2. Specific other parameters of multi-config tasks changed as well. Most notably, the `--datasets` and `--processes` parameters, which previously allowed for defining sequences of dataset and process names on the command line, now accept muliple comma-separated sequences. The number of sequences should be exactly one (applies to all configurations) or match the number of configurations given in `--configs` (one-to-one assignment). Sequences should be separater by colons. + - Example: `law.run cf.PlotVariables1D --configs 22pre,22post --datasets tt_sl,st_tw:tt_sl,st_s` + +## Reducers + +Reducers are a new type of task array function that are invoked by the `cf.ReduceEvents` task. +They control how results of the event selection - event and object masks - are applied to the full event data. +See the [types of task array functions](./task_array_functions.md#taf-types) and the detailed [documentation on reducers](./building_blocks/reducers.md) for details. + +The reducer's job is + +- to apply the event selection mask (booleans) to select only a subset of events, +- to apply object selection masks (booleans or integer indices) to create new collections of objects (e.g. specific jets, or leptons), and +- to drop columns are not needed by any of the downstream tasks. + +These three steps were previously part of the default implementation of the `cf.ReduceEvents` tasks but are now fully configurable though custom reducers. +For compatibility with existing analyses, a default reducer called `cf_default` is provided by columnflow that implements exactly the previous behavior. +In doing so, it even relies on the auxiliary entry `keep_columns` in the configuration to determine which columns should be kept after reduction. + +### Example + +The following example creates a custom reducer that invokes columnflow's default reduction behavior and additionally creates a new column. + +```python +from columnflow.reduction import Reducer, reducer +from columnflow.reduction.default import cf_default +from columnflow.util import maybe_import +from columnflow.columnar_util import set_ak_column + +ak = maybe_import("awkward") + +@reducer( + uses={cf_default, "Jet.hadronFlavour"}, + produces={cf_default, "Jet.from_b_hadron"}, +) +def example(self: Reducer, events: ak.Array, selection: ak.Array, **kwargs) -> ak.Array: + # run cf's default reduction which handles event selection and collection creation + events = self[cf_default](events, selection, **kwargs) + + # compute and store additional columns after the default reduction + # (so only on a subset of the events and objects which might be computationally lighter) + col = abs(events.Jet.hadronFlavour) == 5 + events = set_ak_column(events, "Jet.from_b_hadron", col, value_type=bool) + + return events +``` + +### Update Instructions + +1. In general, there is no need to update your code. However, you will notice that output paths of all tasks downstream of (and including) `cf.ReduceEvents` will have an additional fragment like `.../red__cf_default/...` to reflect the choice of the reducer. +2. The reduction behavior that was previously part of the `cf.ReduceEvents` task is now encapsulated by a [default reducer](https://github.com/columnflow/columnflow/blob/refactor/taf_init/columnflow/reduction/default.py) called `cf_default`. To extend or alter its behavior, create your own implementation either from scratch or by inheriting from it and only overwriting some of its hooks. +3. Invoke your reducer by adding `--reducer MY_REDUCER_CLASS` on the command line or by adding an auxiliary entry `default_reducer` to your configuration. +4. If you decide to control the set of columns that should be available after reduction solely through your reducer, and no longer through the `keep_columns` auxiliary entry in your configuration, you can do so by redefining the `produces` set of your reducer. + +## Histogram Producers + +In release v0.2 and before, the amount of control users had over the creation of histograms within `cf.CreateHistograms` was limited to the selection of variables to use (through the `--variables` parameter) and the definition of event weights to be used during histogram filling. +The latter was configured by specifying a so-called weight producer (through the `--weight-producer` parameter), was referred to the name of a task array function. + +As of v0.3, we generalized this concept and renamed it to **histogram producers**. +Use `--hist-producer` in the command line to specify the histogram producer you intend to use. +See the full [histogram producer documentation](./building_blocks/hist_producers.md) for more info. + +In short, histogram producers [continue to be task array functions](./task_array_functions.md#histogram-producers), however, they provide additional hooks to control different aspects of the histogramming process: + +- `create_hist(self, variables: list[od.Variable], task: law.Task) -> hist.Histogram`: Given a list of variables, creates and returns a new histogram, with arbitrary axes, binning and weight storage. +- `fill_hist(self, h: hist.Histogram, data: dict[str, Any], task: law.Task) -> None`: Provided columnar data to fill (with fields `"category"`, `"process"`, `"shift"` (a string) and `"weight"`), controls the way this data is filled into the histogram. +- `post_process_hist(self, h: hist.Histogram, task: law.Task) -> hist.Histogram`: After all data was filled in `cf.CreateHistogram`, allows to change the histogram before it is saved to disk. +- `post_process_merged_hist(self, h: hist.Histogram, task: law.Task) -> hist.Histogram`: Invoked by `cf.MergeHistograms`, allows to change the merged histogram before it is saved for subsequent processing. + +The only requirement that columnflow imposes on histograms for plotting and export as part of statistical models is the existence of categorical (string) axes `"category"`, `"process"` and `"shift"` **after** merging. + +The main callable of a histogram producer continues to be responsible for returning (and potentially preprocessing) the event chunk to histogram, as well as a float array representing event weights in a 2-tuple, consistent with the previous behavior of weight producers. + +**Note** that, unlike for most other task array functions, columnflow provides a default histogram producer named {py:class}`~columnflow.histogramming.default.cf_default`. +It handles the histogram definition and filling in a backwards-compatible way, as well as a post-processing step that converts the category and shift axes from categorical integer to string types (for consistency across configuration objects when used in multi-config tasks). +It is recommended to extend this default histogram producer in case you only need to change a single aspect of the histogramming process with respect to the default behavior. +See the example below for how to do this. + +### Example + +```python +from columnflow.histogramming import HistProducer +from columnflow.histogramming.default import cf_default +from columnflow.util import maybe_import + +ak = maybe_import("awkward") + +@cf_default.hist_producer( + uses={"{normalization,pileup,btag}_weight"} +) +def example(self: HistProducer, events: ak.Array, **kwargs) -> ak.Array: + """ + Example histogram producer that inherits from columnflow's default and + changes the event weight only. + """ + # compute the event weight + weight = events.normalization_weight * events.pileup_weight * events.btag_weight + + return events, weight +``` + +### Update Instructions + +1. In case you used a weight producer before, convert it to a {py:meth}`~columnflow.histogramming.HistProducer`. There should be no change necessary for the main event callable. +2. On the command line, use `--hist-producer` instead of `--weight-producer`. +3. Note that the `weight__` prefix in the weight producer related fragment of output paths of all tasks downstream of (and including) `cf.CreateHistograms` were changed to `hist__` accordingly. +4. If you do not intend to alter the default histogram definition, filling and post-processing, make sure to inherit from {py:class}`~columnflow.histogramming.default.cf_default` as shown in the example above. + +## Inference Model Updates + +As stated [above](#multi-config-tasks), multi-config tasks allow for the inclusion of multiple analysis configurations in a single task to be able to access event data that spans multiple eras. +This is particularly useful for tasks that export statistical models like `cf.CreateDatacards` (CMS-specific), and all other tasks that inherit from the generalized `SerializeInferenceModelBase` task. + +To support this new feature, the underlying {py:class}`~columnflow.inference.InferenceModel`, i.e., the container object able to configure statistical models for your analysis, was updated. +Pointers to analysis-specific objects in category and process defintions are now to be stored per configuration (see example below). +This info is picked up by (e.g.) `cf.CreateDatacards` to pull in information and data from multiple data taking eras to potentially fill their event data into the same inference category. + +As for all multi-config tasks, pass a sequence of configuration names to the `--configs` parameter on the command line. + +### Example + +The following example demonstrates how to define an inference model that ... + +```python +from columnflow.inference import InferenceModel, inference_model + +@inference_model +def example_model(self: InferenceModel) -> None: + """ + Initialization method for the inference model. + Use instance methods to define categories, processes and parameters. + """ + # add a category + self.add_category( + "example_category", + # add config dependent settings + config_data={ + config_inst.name: self.category_config_spec( + # name of the analysis category in the config + category=f"{ch}__{cat}__os__iso", + # name of the variable + variable="jet1_pt", + # names (or patterns) of datasets with real data in the config + data_datasets=["data_*"], + ) + for config_inst in self.config_insts + }, + # additional category settings + mc_stats=10.0, + flow_strategy=FlowStrategy.move, + ) + + # add processes + self.add_process( + name="TT", + # add config dependent settings + config_data={ + config_inst.name: self.process_config_spec( + # name of the (parent) process in the config + process="tt", + # names of MC datasets in the config + mc_datasets=["tt_sl_powheg", "tt_dl_...", ...], + ), + }, + # additional process settings + is_signal=False, + ) + # more processes here + ... +``` + +### Update Instructions + +1. In definitions of categories, processes and parameters within your inference model, make sure that all pointers that refer for analysis-specific objects are stored in a dictionary with keys being configuration names. +2. These dictionaries are stored in fields named `config_data`. +3. Use the provided factory functions to create these dictionary structures to invoke some additional value validation: + - for categories: {py:meth}`~columnflow.inference.InferenceModel.category_config_spec` + - for processes: {py:meth}`~columnflow.inference.InferenceModel.process_config_spec` + - for parameters: {py:meth}`~columnflow.inference.InferenceModel.parameter_config_spec` + +## Changed Plotting Task Names + +The visualization of systematic uncertainties is updated as of v0.3. +A new plot method was introduced to show not only the effect of the statistical uncertainty (due to the limited amount of simulated events) as a grey, hatched area, but also that of systematic uncertainties as a differently colored band. + +The task that invokes this plot method by default is `cf.PlotShiftedVariables1D`. +See the full task graph in [our wiki](https://github.com/columnflow/columnflow/wiki#default-task-graph) to see its dependencies to other tasks. + +**Note** that this task is not new, but it has been changed to include the systematic uncertainty bands. +In version v0.2 and below, this task was used to plot the effect of a single up or down variation of a single shift. +This behavior is now covered by a task called `cf.PlotShiftedVariablesPerShift1D`. + +### Update Instructions + +1. If you are interested in creating plots showing the effect of one **or multiple** shifts in the same graph, use the `cf.PlotShiftedVariables1D` task. +2. If you want to plot the effect of a single up or down variation of a single shift, use the `cf.PlotShiftedVariablesPerShift1D` task (formerly known as `cf.PlotShiftedVariables1D`) + +## Miscellaneous smaller updates + +- The `SelectorStepsMixin` was removed and its functionality was moved into the standard {py:class}`~columnflow.tasks.framework.mixins.SelectorClassMixin` and {py:class}`~columnflow.tasks.framework.mixins.SelectorMixin` classes. +- `columnflow.util.InsertableDict` was removed in favor of `law.util.InsertableDict`. diff --git a/docs/user_guide/building_blocks/hist_producers.md b/docs/user_guide/building_blocks/hist_producers.md new file mode 100644 index 000000000..aee9eed08 --- /dev/null +++ b/docs/user_guide/building_blocks/hist_producers.md @@ -0,0 +1,5 @@ +# Histogram Producers + +## Introduction + +## Usage diff --git a/docs/user_guide/building_blocks/index.rst b/docs/user_guide/building_blocks/index.rst index c8da16b1e..bcb545032 100644 --- a/docs/user_guide/building_blocks/index.rst +++ b/docs/user_guide/building_blocks/index.rst @@ -6,6 +6,8 @@ Building Blocks calibrators.md selectors.md + reducers.md producers.md categories.md + hist_producers.md config_objects.md diff --git a/docs/user_guide/building_blocks/reducers.md b/docs/user_guide/building_blocks/reducers.md new file mode 100644 index 000000000..eeb270431 --- /dev/null +++ b/docs/user_guide/building_blocks/reducers.md @@ -0,0 +1,5 @@ +# Event and object reduction + +## Introduction + +## Usage diff --git a/docs/user_guide/debugging.md b/docs/user_guide/debugging.md index 7c18e23e1..ad8f58d3a 100644 --- a/docs/user_guide/debugging.md +++ b/docs/user_guide/debugging.md @@ -13,14 +13,14 @@ In this section, debugging tools already implemented in columnflow to inspect th ### Debugging outputs of supported extensions (ROOT, Parquet, JSON and Pickle) -columnflow comes equipped with the command ```cf_inspect```, which is available in the columnflow environment after sourcing the ```setup.sh``` file. +columnflow comes equipped with the command `cf_inspect`, which is available in the columnflow environment after sourcing the `setup.sh` file. The command takes the pathes of one or more files (space seperated) you want to inspect as arguments and enters an IPython shell in a development sandbox, where analysis tools and packages (awkward, coffea etc.) as well as the columnflow API are available. -The input files are loaded into memory and can be accessed via the variables ```objetcs```. -Make sure to avoid pathes, which are accesed via file protocol (e.g. ```file://```, ```davs://```, etc.), as the file loader may not support this. +The input files are loaded into memory and can be accessed via the variables `objetcs`. +Make sure to avoid pathes, which are accesed via file protocol (e.g. `file://`, `davs://`, etc.), as the file loader may not support this. :::{dropdown} Where to find the path of my outputs? -After running a task, the path of the output can be yielded by appending ```--print-output ``` to the command, where `````` is the task index in the workflow tree. +After running a task, the path of the output can be yielded by appending `--print-output ` to the command, where `` is the task index in the workflow tree. For the main task of the command the index is 0. For example, for the output path of the Selection task, the command would be @@ -32,34 +32,34 @@ law run cf.SelectEvents --version v1 --{other options} --print-output 0 ### Debugging histograms -Histograms can be inspected with the ```cf_inspect``` command as well. +Histograms can be inspected with the `cf_inspect` command as well. However columnflow provides a more specialized tool for this purpose in the form of task. This can be the better choice if you want to inspect multiple histograms of multiple variables, since this can offer a more structured way to do so. -The task can be accessed by calling ```law run cf.InspectHistograms --{options}``` for the nominal shift or ```law run cf.InspectShiftedHistograms --shift-sources {source} --{options}``` for nominal and shifted histograms respectively. -In both cases the ```{options}``` are the same parameters, which would be passed to create the histogram in the task ```law run cf.CreateHistograms --{options}```. +The task can be accessed by calling `law run cf.InspectHistograms --{options}` for the nominal shift or `law run cf.InspectShiftedHistograms --shift-sources {source} --{options}` for nominal and shifted histograms respectively. +In both cases the `{options}` are the same parameters, which would be passed to create the histogram in the task `law run cf.CreateHistograms --{options}`. The task also enters an IPython shell in a development sandbox, where the following variables are available: -- ```self```: the task instance, from which the config, analysis and datasets instances can be accessed -- ```hists```: a dictionary of the histograms, where the keys are the variables names and the values are the histograms -- ```dataset```: name of the dataset -- ```variable```: name of the last histogrammed variable in the ```--variables``` option -- ```h_in```: histogram of ```variable``` +- `self`: the task instance, from which the config, analysis and datasets instances can be accessed +- `hists`: a dictionary of the histograms, where the keys are the variables names and the values are the histograms +- `dataset`: name of the dataset +- `variable`: name of the last histogrammed variable in the `--variables` option +- `h_in`: histogram of `variable` -An advatage of this debugging methode comapred to the ```cf_inspect``` command is that the histogram files are not required to exist, since the task (like all other tasks) will set up the workflow to produce missing requirements. +An advatage of this debugging methode comapred to the `cf_inspect` command is that the histogram files are not required to exist, since the task (like all other tasks) will set up the workflow to produce missing requirements. ## FAQ ### Troubleshooting: -- "I have changed something in the code and called the corresponding ```law run``` bash command, but the task isn't starting/the task started is further down the task tree." +- "I have changed something in the code and called the corresponding `law run` bash command, but the task isn't starting/the task started is further down the task tree." -A: Do not forget to remove the corresponding intermediate output(s), for example with ```--remove-output``` (see {doc}`law`), or start a new version with ```--version``` if you do explicitely want to conserve the previous output before the change in the code. +A: Do not forget to remove the corresponding intermediate output(s), for example with `--remove-output` (see {doc}`law`), or start a new version with `--version` if you do explicitely want to conserve the previous output before the change in the code. - "Where do I find the outputs of my tasks?" A: When you run "source setup.sh {name_of_the_setup}" in columnflow for the first time, you choose the storage locations. You can find the storage locations again by opening the ".setups/{name_of_the_setup}.sh" file in the analysis repository. -You may also use law functions to ease the search, namely with the ```--print-output``` command, see in the {doc}`law` section. +You may also use law functions to ease the search, namely with the `--print-output` command, see in the {doc}`law` section. - "I get an error telling me that some columns could not be found/produced. What can I do?" diff --git a/docs/user_guide/index.rst b/docs/user_guide/index.rst index f07bf11b1..12dabe2c9 100644 --- a/docs/user_guide/index.rst +++ b/docs/user_guide/index.rst @@ -5,6 +5,8 @@ User Guide :maxdepth: 3 structure.md + 02_03_transition.md + task_array_functions.md building_blocks/index sandbox.md ml.md @@ -12,7 +14,7 @@ User Guide plotting.md examples.md debugging.md - law.md best_practices.md + law.md special_usecases.md cms_specializations.md diff --git a/docs/user_guide/plotting.md b/docs/user_guide/plotting.md index d6cda37de..930a8be71 100644 --- a/docs/user_guide/plotting.md +++ b/docs/user_guide/plotting.md @@ -19,11 +19,11 @@ law run cf.PlotVariables1D --version v1 \ This will run the full analysis chain for the given processes (data, tt, st) and should create plots looking like this: ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_incl__var_n_jet.* :width: 100% ::: -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_2j__var_n_jet.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__12dfac316a__plot__proc_3_7727a49dc2__cat_2j__var_n_jet.* :width: 100% ::: :::: @@ -31,7 +31,7 @@ This will run the full analysis chain for the given processes (data, tt, st) and :::{dropdown} Where do I find that plot? You can add ```--print-output 0``` to every task call, which will print the full filename of all outputs of the requested task. Alternatively, you can add ```--fetch-output 0,a``` to directly copy all outputs of this task into the directory you are currently in. -Finally, there is the ```--view-cmd``` parameter you can add to directly display the plot during the runtime of the task, e.g. via ```--view-cmd evince-previewer``` or ```--view-cmd imgcat```. +Finally, there is the ```--view-cmd``` parameter you can add to directly display the plot during the runtime of the task, e.g. via ```--view-cmd evince``` or ```--view-cmd imgcat```. ::: The ```PlotVariables1D``` task is located at the bottom of our [task graph](https://github.com/columnflow/columnflow/wiki#default-task-graph), which means that all tasks leading to ```PlotVariables1D``` will be run for all datasets corresponding to the ```--processes``` we requested using the {py:class}`~columnflow.calibration.Calibrator`s, {py:class}`~columnflow.selection.Selector`, and {py:class}`~columnflow.production.Producer`s (often referred to as CSPs) as requested. @@ -90,11 +90,11 @@ law run cf.PlotVariables1D --version v1 --processes tt,st --variables n_jet,jet1 to produce the following plot: ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c1.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c1.* :width: 100% ::: -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c1.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__0191de868f__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c1.* :width: 100% ::: :::: @@ -130,11 +130,11 @@ law run cf.PlotVariables1D --version v1 --processes tt,st --variables n_jet,jet1 ``` ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c2.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c2.* :width: 100% ::: -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c2.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__c80529af83__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c2.* :width: 100% ::: :::: @@ -186,11 +186,11 @@ law run cf.PlotVariables1D --version v1 --processes tt,st --variables n_jet,jet1 ``` ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c3.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt__c3.* :width: 100% ::: -:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c3.pdf +:::{figure} ../plots/cf.PlotVariables1D_tpl_config_analy__1__be60d3bca7__plot__proc_2_a2211e799f__cat_incl__var_n_jet__c3.* :width: 100% ::: :::: @@ -206,11 +206,11 @@ law run cf.PlotVariables2D --version v1 \ ``` ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt-n_jet.pdf +:::{figure} ../plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_jet1_pt-n_jet.* :width: 100% ::: -:::{figure} ../plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_n_jet-jet1_pt.pdf +:::{figure} ../plots/cf.PlotVariables2D_tpl_config_analy__1__b27b994979__plot__proc_2_a2211e799f__cat_incl__var_n_jet-jet1_pt.* :width: 100% ::: :::: @@ -233,11 +233,11 @@ law run cf.PlotCutflow --version v1 \ ``` ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_incl.pdf +:::{figure} ../plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_incl.* :width: 100% ::: -:::{figure} ../plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_2j.pdf +:::{figure} ../plots/cf.PlotCutflow_tpl_config_analy__1__12a17bf79c__cutflow__cat_2j.* :width: 100% ::: :::: @@ -269,15 +269,15 @@ law run cf.PlotCutflowVariables1D --version v1 \ ``` ::::{grid} 1 1 3 3 -:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step0_Initial__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.pdf +:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step0_Initial__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.* :width: 100% ::: -:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step1_jet__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.pdf +:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step1_jet__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.* :width: 100% ::: -:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step2_muon__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.pdf +:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__d8a37d3da9__plot__step2_muon__proc_2_a2211e799f__cat_incl__var_cf_jet1_pt.* :width: 100% ::: :::: @@ -293,11 +293,11 @@ law run cf.PlotCutflowVariables1D --version v1 \ ``` ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_st__cat_incl__var_cf_jet1_pt.pdf +:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_st__cat_incl__var_cf_jet1_pt.* :width: 100% ::: -:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_tt__cat_incl__var_cf_jet1_pt.pdf +:::{figure} ../plots/cf.PlotCutflowVariables1D_tpl_config_analy__1__c3947accbb__plot__proc_tt__cat_incl__var_cf_jet1_pt.* :width: 100% ::: :::: @@ -328,11 +328,11 @@ law run cf.PlotShiftedVariables1D --version v1 \ and produces the following plot: ::::{grid} 1 1 2 2 -:::{figure} ../plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_jet1_pt.pdf +:::{figure} ../plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_jet1_pt.* :width: 100% ::: -:::{figure} ../plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_n_jet.pdf +:::{figure} ../plots/cf.PlotShiftedVariables1D_tpl_config_analy__1__42b45aba89__plot__proc_2_a2211e799f__unc_mu__cat_incl__var_n_jet.* :width: 100% ::: :::: @@ -346,7 +346,7 @@ All plotting tasks also include a ```--view-cmd``` parameter that allows directl ```shell law run cf.PlotVariables1D --version v1 \ - --processes tt,st --variables n_jet --view-cmd evince-previewer + --processes tt,st --variables n_jet --view-cmd evince ``` (custom_plot_function)= @@ -367,7 +367,7 @@ An example on how to implement such a plotting function is shown in the followin ## Applying a selection to a variable In some cases, you might want to apply a selection to a variable before plotting it. -Instead of creating a new column with the selection applied, columnflow provides the possibility to apply a selection to a variable directly when histograming it. +Instead of creating a new column with the selection applied, columnflow provides the possibility to apply a selection to a variable directly when histogramming it. For this purpose, the `selection` parameter can be added in the variable definition in the config. This may look as follows: diff --git a/docs/user_guide/structure.md b/docs/user_guide/structure.md index 402f8cf0f..4037316c6 100644 --- a/docs/user_guide/structure.md +++ b/docs/user_guide/structure.md @@ -15,10 +15,10 @@ Fields like "Jet" exist too, they contain columns (the parameters of the field, As most of the information is conserved in the form of columns, it would be very inefficient (and might not even fit in the memory) to use all columns and all events from a dataset at once for each task. Therefore, in order to reduce the impact on the memory: -- a chunking of the datasets is implemented using [dask](https://www.dask.org/): +- A chunking of the datasets is implemented using [dask](https://www.dask.org/): not all events from a dataset are inputed in a task at once, but only chunked in groups of events. (100 000 events max per group is default as of 05.2023, default is set in the law.cfg file). -- the user needs to define for each {py:class}`~columnflow.production.Producer`, {py:class}`~columnflow.calibration.Calibrator` and {py:class}`~columnflow.selection.Selector` which columns are to be loaded (this happens by defining the ```uses``` set in the header of the decorator of the class) and which new columns/fields are to be saved in parquet files after the respective task (this happens by defining the ```produces``` set in the header of the decorator of the class). +- The user needs to define for each {py:class}`~columnflow.calibration.Calibrator`, {py:class}`~columnflow.selection.Selector`, {py:class}`~columnflow.reduction.Reducer`, {py:class}`~columnflow.production.Producer`, and {py:class}`~columnflow.histogramming.HistProducer` which columns are to be loaded (this happens by defining the `uses` set in the header of the decorator of the class) and which new columns/fields are to be saved in parquet files after the respective task (this happens by defining the `produces` set in the header of the decorator of the class). The exact implementation for this feature is further detailed in {doc}`building_blocks/selectors` and {doc}`building_blocks/producers`. ## Tasks in columnflow @@ -37,10 +37,10 @@ For an overview of the tasks that are available with columnflow, please see the ## Important note on required parameters It should also be added that there are additional parameters specific for the tasks in columnflow, required by the fact that columnflow's purpose is for HEP analysis. -These are the ```--analysis``` and ```--config``` parameters, which defaults can be set in the law.cfg. +These are the `--analysis` and `--config` parameters, which defaults can be set in the law.cfg. These two parameters respectively define the config file for the different analyses to be used (where the different analyses and their parameters should be defined) and the name of the config file for the specific analysis to be used. -Similarly, the ```--version``` parameter, which purpose is explained in the {doc}`law` section of this documentation, is required to start a task. +Similarly, the `--version` parameter, which purpose is explained in the {doc}`law` section of this documentation, is required to start a task. ## Important modules and configs diff --git a/docs/user_guide/task_array_functions.md b/docs/user_guide/task_array_functions.md new file mode 100644 index 000000000..ebfa3b89a --- /dev/null +++ b/docs/user_guide/task_array_functions.md @@ -0,0 +1,230 @@ +# Task Array Functions + +Besides [configuration files and objects](./building_blocks/config_objects.md), task array functions constitute the core of columnflow. +They connect array functions - consider them small-ish, reusable code snippets that perform vectorized operations of chunks of events data encoded in awkward arrays - with columnflow's default task structure. +Throughout the documentation, you will sometimes see them abbreviated as `TAFs`. + +## Introduction + +A streamlined view of this task structure and the dependencies between them can be seen in the figure below (more complex parts of the graph, e.g. those related to cutflows or machine learning, are hidden for the purpose of clarity here; see [our wiki](https://github.com/columnflow/columnflow/wiki#default-task-graph) for an update-to-date version of the current graph). + +- Each box denotes a specific task. +- Arrows indicate dependencies between them, usually in the form of persistently stored data. +- Orange boxes at the bottom are placeholders for tasks in the later stage of the graph, usually in the context of creating results. +- **Purple boxes highlight tasks that allow for the inclusion of user-defined code within task array functions**. The command-line parameters to control these functions are added in parentheses below task names. + +The five major task array functions are (from top to bottom): calibrators, selectors, reducers, producers, and histogram producers. +They purpose and behavior is explained below, and in more detail in the [columnflow building blocks](./building_blocks/index.rst). + +```{mermaid} +graph TD + classDef PH stroke: #fe8e01, stroke-width: 3px, fill: #ffc78f + classDef TA stroke: #8833bb, stroke-width: 3px + + GetDatasetLFNs(GetDatasetLFNs) + CalibrateEvents("CalibrateEvents
(--calibrators)") + SelectEvents("SelectEvents
(--selector)") + ReduceEvents("ReduceEvents
(--reducer)") + MergeReductionStats(MergeReductionStats) + MergeReducedEvents(MergedReducedEvents) + ProduceColumns("ProduceColumns
(--producers)") + CreateHistograms("CreateHistograms
(--hist-producer)") + MergeHistograms(MergeHistograms) + MergeShiftedHistograms(MergeShiftedHistograms) + Inference(Inference models ...) + Plots(Variable plots ...) + ShiftedPlots(Shifted variable plots ...) + + class CalibrateEvents TA + class SelectEvents TA + class ReduceEvents TA + class ProduceColumns TA + class CreateHistograms TA + class Plots PH + class ShiftedPlots PH + class Inference PH + + %% top part + GetDatasetLFNs -- lfns --> SelectEvents + GetDatasetLFNs -- lfns --> CalibrateEvents + CalibrateEvents -. cols .-> SelectEvents + SelectEvents -- masks --> ReduceEvents + SelectEvents -. cols .-> ReduceEvents + CalibrateEvents -. cols .-> ReduceEvents + GetDatasetLFNs -- lfns --> ReduceEvents + + %% merging 1 right + ReduceEvents -- sizes --> MergeReductionStats + MergeReductionStats -- factors --> MergeReducedEvents + ReduceEvents -- events --> MergeReducedEvents + + %% additional columns + MergeReducedEvents -- events --> CreateHistograms + MergeReducedEvents -- events --> ProduceColumns + ProduceColumns -. cols .-> CreateHistograms + + %% merging and results + CreateHistograms -- hists --> MergeHistograms + MergeHistograms -- hists ---> Plots + MergeHistograms -- mc hists --> MergeShiftedHistograms + MergeShiftedHistograms -- mc hists --> ShiftedPlots + MergeHistograms -- data hists --> ShiftedPlots + MergeShiftedHistograms -- mc hists --> Inference + MergeHistograms -- data hists --> Inference + + subgraph "        Column merging" + MergeReductionStats + MergeReducedEvents + end + + subgraph "Histogram merging    " + MergeHistograms + MergeShiftedHistograms + end +``` + +## TAF Types + +### Calibrator + +- Meant to apply calibrations to event data that are stored as additional columns; these columns can then be used in subsequent tasks as if they were part of the original data +- *Examples*: energy calibration (jets, taus) or object corrections (MET) +- *Quantity*: zero, one or more +- *Output length*: same as input +- *Parameter*: `--calibrators [NAME,[NAME,...]]` + +### Selector + +- Meant to perform both event and object selection; they must produce event and per-object masks that can be used downstream, as well as event and selection statistics (either in `json` or `hist` format) to be used for normalization later; they can also produce additional columns for use in subsequent tasks +- *Examples*: the usual event selection +- *Quantity*: exactly one +- *Output length*: same as input +- *Parameter*: `--selector NAME` + +### Reducer + +- It receives all event data, as well as columns produced during calbration and selection plus the event and object selection masks to perform the reduction step; a default implementation exists that should be sufficient for most use cases; if additional columns should be produced, the default reducer can be extended +- *Examples*: `cf_default`, i.e., columnflow's default event and object reduction, as well as collection creation +- *Quantity*: exactly one +- *Output length*: Any length, but obviously usually shorter than the input +- *Parameter*: `--reducer NAME(cf_default)` + +### Producer + +- The go-to mechanism for creating and storing additional variables needed by the analysis, after events were selected and reduced +- *Examples*: creation of additional variables needed by the analysis +- *Quantity*: zero, one or more +- *Output length*: same as input +- *Parameter*: `--producers [NAME,[NAME,...]]` + +### Histogram producer + +- More versatile than other TAFs as it allows defining some late event data adjustments and event weight, as well as controls the creation, filling and post-processing of histograms before they are saved to disk +- *Examples*: calcalation of event weights, plus the usual histogram filling procedure +- *Quantity*: exactly one +- *Output length*: does not apply, since histograms are created +- *Parameter*: `--hist-producer NAME(cf_default)` + +## Simple Example + +This simple example shows a producer that adds a new column to the events chunk. +Here, it calculates the supercluster eta of electrons, based in the original eta and the delta to the supercluster. +It is obviously a contrived example, but it shows the basic concept of + +- declaring which columns need to be read from disk via the `uses` set, +- declaring which columns are produced by the TAF and should be saved to disk via the `produces` set, and +- that the events chunk is never modified in place but potentially copied (**without** copying the underlying data though!). + +```python +from columnflow.production import producer +from columnflow.util import maybe_import +from columnflow.columnar_util import set_ak_column + +ak = maybe_import("awkward") + +@producer( + uses={"Electron.{pt,phi,eta,deltaEtaSC}"}, + produces={"Electron.superclusterEta"}, +) +def electron_sc_eta(self, events: ak.Array, **kwargs) -> ak.Array: + sc_eta = events.Electron.eta + events.Electron.deltaEtaSC + events = set_ak_column(events, "Electron.superclusterEta", sc_eta) + return events +``` + +TAF can be nested to reuse common functionality, i.e., one TAF can call other TAFs with all information about used and produced columns passed along. +For more information, see [columnflow building blocks](./building_blocks/index.rst). + +## TAF Interface + +The full interface can be described as a collection of functions that are invoked in specific places by tasks in columnflow, and that can be implemented by the user with very high granularity. + +These functions (or *hooks*) are registered using the decorator syntax below. +However, as they are classes under the hood, you can also define them as such. + +Upon creation, the `analysis_inst`, `config_inst` and `dataset_inst` objects of a task are passed as members to each TAF instance and form **their state**. +They can be accessed as usual to retrieve information about the context in which they are called (e.g. for a specific config, MC or real data, etc.). + +Hooks are called thereafter in various places: + +- `pre_init(self)`: Called before dependency creation, can be used to control `deps_kwargs` that are passed to dependent TAFs. +- `init(self)`: Initialization of the TAF, can control dynamic registration of used and produced columns or dependencies, as well as systemtic shifts. +- `skip(self)`: Whether this TAF should be skipped altogether. +- `post_init(self, task)`: Can control dynamic registration of used and produced columns, but no additional TAF dependencies. +- `requires(self, task, reqs)`: Allows adding extra task requirements to `reqs` that will be resolved before the tasks commences. +- `setup(self, task, reqs, inputs, reader_targets)`: Allows setting up objects needed for actual function calls, receiving requirements defined in `requires` as well as their produced outputs via `inputs`; +- `call(self, events, task, **kwargs)`: Actual events chunk processing, can be called multiple times for different chunks. +- `teardown(self, task)`: Called after processing, but potentially before chunk merging, allows reducing memory footprint by eagerly freeing up resources. + +## Full example + +The following example is not implementing a fake TAF, but extends the example above to show how different hooks can be registered and used. +Note that the decorators miss the `_func` suffix, but they register and bind methods internally that **contain** this suffix. + +```python +# same as above +@producer( + uses={"Electron.{eta,deltaEtaSC}"}, + produces={"Electron.superclusterEta"}, +) +def electron_sc_eta(self, events: ak.Array, **kwargs) -> ak.Array: + ... + return events + +# custom pre-init +@electron_sc_eta.pre_init +def electron_sc_eta_pre_init(self: Producer) -> None: + ... + +# custom init +@electron_sc_eta.init +def electron_sc_eta_init(self: Producer) -> None: + # e.g. update uses/produces + self.uses.add(...) + +# custom post-init +@electron_sc_eta.post_init +def electron_sc_eta_post_init(self: Producer, task: law.Task) -> None: + # first hook to access task for some late init steps after dependency tree was built + ... + +# custom requires +@electron_sc_eta.requires +def electron_sc_eta_requires(self: Producer, task: law.Task, reqs: dict[str, Any]) -> None: + # add extra requirements to reqs + ... + +# custom setup +@electron_sc_eta.setup +def electron_sc_eta_setup(self: Producer, task: law.Task, reqs: dict[str, Any], inputs: dict[str, Any], reader_targets: dict[str, Any]) -> None: + # setup objects needed for actual function calls and store them as members on *self* + # reqs refer to requirements declared above, and inputs point to their outputs + # reader_targets can optionally be updated to declare additional columnar input files + ... + +# custom teardown +@electron_sc_eta.teardown +def electron_sc_eta_teardown(self: Producer, task: law.Task) -> None: + # free up resources, usually by removing members from *self* + ... +``` diff --git a/law.cfg b/law.cfg index 07be64a74..183bad887 100644 --- a/law.cfg +++ b/law.cfg @@ -6,7 +6,7 @@ columnflow.tasks.calibration columnflow.tasks.selection columnflow.tasks.reduction columnflow.tasks.production -columnflow.tasks.ml +# columnflow.tasks.ml columnflow.tasks.union columnflow.tasks.histograms columnflow.tasks.plotting @@ -23,9 +23,10 @@ default_dataset: st_tchannel_t production_modules: columnflow.production.{categories,processes,normalization} calibration_modules: columnflow.calibration -selection_modules: columnflow.selection.{empty} +selection_modules: columnflow.selection.empty +reduction_modules: columnflow.reduction.default categorization_modules: columnflow.categorization -weight_production_modules: columnflow.weight.{empty,all_weights} +hist_production_modules: columnflow.histogramming.default ml_modules: columnflow.ml inference_modules: columnflow.inference @@ -65,8 +66,8 @@ chunked_io_debug: False # csv list of task families that inherit from ChunkedReaderMixin and whose output arrays should be # checked (raising an exception) for non-finite values before saving them to disk -# supported tasks are: cf.CalibrateEvents, cf.SelectEvents, cf.ProduceColumns, cf.PrepareMLEvents, -# cf.MLEvaluation, cf.UniteColumns +# supported tasks are: cf.CalibrateEvents, cf.SelectEvents, cf.ReduceEvents, cf.ProduceColumns, +# cf.PrepareMLEvents, cf.MLEvaluation, cf.UniteColumns check_finite_output: None # how to treat inexistent selector steps passed to cf.CreateCutflowHistograms: throw an error, @@ -82,6 +83,14 @@ check_overlapping_inputs: None # whether to log runtimes of array functions by default log_array_function_runtime: False +# settings to control string representation of objects that are usually encoded into output paths +; the maximum length of the string representation (a hash is added when longer) +repr_max_len: -1 +; the maximum number of objects to include in the string representation +repr_max_count: 3 +; lengths of hashes that are added to representations for determinism +repr_hash_len: 10 + [outputs] diff --git a/modules/law b/modules/law index e26f045a0..a02aeb3c2 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit e26f045a0ccb1d4ef91af3153b9a26ef8063aaff +Subproject commit a02aeb3c2cf7cb460e52f67490f10c50055c6606 diff --git a/sandboxes/_setup_cmssw.sh b/sandboxes/_setup_cmssw.sh index 373c67c0e..8e874b60c 100644 --- a/sandboxes/_setup_cmssw.sh +++ b/sandboxes/_setup_cmssw.sh @@ -173,7 +173,7 @@ setup_cmssw() { fi # create the pending_flag to express that the venv state might be changing - touch "${pending_flag_file}" + [ ! -f "${pending_flag_file}" ] && touch "${pending_flag_file}" clear_pending() { rm -f "${pending_flag_file}" } diff --git a/sandboxes/_setup_venv.sh b/sandboxes/_setup_venv.sh index f746551ed..31e8e8fd4 100644 --- a/sandboxes/_setup_venv.sh +++ b/sandboxes/_setup_venv.sh @@ -203,7 +203,7 @@ setup_venv() { fi # create the pending_flag to express that the venv state might be changing - touch "${pending_flag_file}" + [ ! -f "${pending_flag_file}" ] && touch "${pending_flag_file}" clear_pending() { rm -f "${pending_flag_file}" } diff --git a/sandboxes/cf.txt b/sandboxes/cf.txt index 7c7d7ff2b..46861d8c1 100644 --- a/sandboxes/cf.txt +++ b/sandboxes/cf.txt @@ -1,8 +1,8 @@ -# version 13 +# version 14 -#tenacity!=8.4.0 -luigi~=3.5.2 +luigi~=3.6.0 scinum~=2.2.0 -six~=1.16.0 +six~=1.17.0 pyyaml~=6.0.2 -typing_extensions~=4.12.2 +typing_extensions~=4.13.0 +tabulate~=0.9.0 diff --git a/sandboxes/columnar.txt b/sandboxes/columnar.txt index 8ad104e47..8d6293ecc 100644 --- a/sandboxes/columnar.txt +++ b/sandboxes/columnar.txt @@ -1,15 +1,14 @@ -# version 16 +# version 17 # exact versions for core array packages -awkward==2.7.1 -uproot==5.5.1 -pyarrow==18.0.0 -dask-awkward==2024.9.0 +awkward==2.8.1 +uproot==5.6.0 +pyarrow==19.0.1 +dask-awkward==2025.3.0 correctionlib==2.6.4 coffea==2024.11.0 # minimum versions for general packages -tabulate~=0.9.0 zstandard~=0.23.0 -lz4~=4.3.3 +lz4~=4.4.3 xxhash~=3.5.0 diff --git a/sandboxes/dev.txt b/sandboxes/dev.txt index cc5455448..73ab68008 100644 --- a/sandboxes/dev.txt +++ b/sandboxes/dev.txt @@ -1,12 +1,12 @@ -# version 10 +# version 11 # last version to support python 3.9 ipython~=8.18.1 -pytest~=8.3.3 +pytest~=8.3.5 pytest-cov~=6.0.0 -flake8~=7.1.1 +flake8~=7.1.2 flake8-commas~=4.0.0 flake8-quotes~=3.4.0 -pipdeptree~=2.23.4 -pymarkdownlnt~=0.9.25 -uniplot~=0.15.1 +pipdeptree~=2.26.0 +pymarkdownlnt~=0.9.29 +uniplot~=0.17.1 diff --git a/sandboxes/ml_tf.txt b/sandboxes/ml_tf.txt index 5454f6860..382f89151 100644 --- a/sandboxes/ml_tf.txt +++ b/sandboxes/ml_tf.txt @@ -1,8 +1,8 @@ -# version 10 +# version 11 # use packages from columnar sandbox as baseline -r columnar.txt # add numpy and tensorflow with exact version requirement numpy==1.26.4 -tensorflow==2.11.0 +tensorflow==2.19.0 diff --git a/setup.sh b/setup.sh index 0c1dd23b0..3e4a08ccd 100644 --- a/setup.sh +++ b/setup.sh @@ -581,7 +581,7 @@ cf_setup_software_stack() { local setup_name="${1}" local setup_is_default="false" [ "${setup_name}" = "default" ] && setup_is_default="true" - local pyv="3.9" + local pyv="${CF_PYTHON_VERSION:-3.9}" local conda_arch="${CF_CONDA_ARCH:-linux-64}" local ret @@ -770,11 +770,11 @@ cf_setup_git_hooks() { # Initializes lfs and custom githooks in the local checkout for both the columnflow # (sub)repository, as well as the analysis repository in case a directory bin/githooks is found. # - # Optional environments variables: - # CF_REMOTE_ENV - # When "1", no hooks are setup. - # CF_CI_ENV - # When "1", no hooks are setup. + # Required environments variables: + # CF_REMOTE_ENV (bool) + # When true, no hooks are setup. + # CF_CI_ENV (bool) + # When true, no hooks are setup. # do nothing when not local if ${CF_REMOTE_ENV} || ${CF_CI_ENV}; then diff --git a/tests/run_tests b/tests/run_tests index 963f3398d..4b0486a82 100755 --- a/tests/run_tests +++ b/tests/run_tests @@ -35,6 +35,9 @@ action() { # test_inference echo bash "${this_dir}/run_test" test_inference "${cf_dir}/sandboxes/venv_columnar${dev}.sh" + ret="$?" + [ "${gret}" = "0" ] && gret="${ret}" + # test_hist_util echo bash "${this_dir}/run_test" test_hist_util "${cf_dir}/sandboxes/venv_columnar${dev}.sh" @@ -47,6 +50,12 @@ action() { ret="$?" [ "${gret}" = "0" ] && gret="${ret}" + # test_base_tasks + echo + bash "${this_dir}/run_test" test_base_tasks + ret="$?" + [ "${gret}" = "0" ] && gret="${ret}" + # test_plotting echo bash "${this_dir}/run_test" test_plotting "${cf_dir}/sandboxes/venv_columnar${dev}.sh" diff --git a/tests/test_base_tasks.py b/tests/test_base_tasks.py new file mode 100644 index 000000000..030bb538c --- /dev/null +++ b/tests/test_base_tasks.py @@ -0,0 +1,391 @@ +# coding: utf-8 + + +__all__ = ["AnalysisTaskTests"] + +import unittest + +import order as od + +from columnflow.tasks.framework.base import AnalysisTask, RESOLVE_DEFAULT, ShiftTask, DatasetTask +from columnflow.tasks.framework.mixins import ( + VariablesMixin, CategoriesMixin, DatasetsProcessesMixin, +) + + +class AnalysisTaskTests(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.analysis_inst = ana = od.Analysis("analysis", 1) + self.config_inst1 = cfg1 = ana.add_config(name="config", id=1) + self.config_inst2 = cfg2 = ana.add_config(name="config2", id=2) + self.base_params = { + "analysis_inst": ana, + "config_inst": cfg1, + "config_insts": tuple(ana.configs), + } + # setup calibrators, selectors, producers + ana.x.default_calibrator = "calib" + ana.x.default_selector = "sel" + ana.x.default_producer = ["A", "B", "C"] + ana.x.producer_groups = { + "A": ["a", "b"], + "C": ["c", "d"], + } + + # setup categories and variables + cfg1.add_shift("nominal", id=0) + for i in range(5): + cfg1.add_category(name=f"cat{i}", id=i) + cfg1.add_variable(name=f"var{i}", id=i) + p = cfg1.add_process(name=f"proc{i}", id=i) + cfg1.add_shift(name=f"shift{i}_up", id=2 * i + 1) + cfg1.add_shift(name=f"shift{i}_down", id=2 * i + 2) + cfg1.add_dataset(name=f"ds{i}", processes=[p], id=i, info={ + "nominal": [], f"shift{i}_up": [], f"shift{i}_down": []}) + + cat1 = cfg1.get_category("cat1") + cat1.add_category(name="cat1_1", id=11) + cat1.add_category(name="cat1_2", id=12) + + cfg1.x.default_categories = ("cg1", "cat1_2", "not_existing") + cfg1.x.category_groups = { + "cg1": ["cat0", "cat1_1"], + "cg2": ["cat2", "cat3"], + } + + cfg1.x.default_variables = ("vg1", "vg_2d", "var4", "not_existing") + cfg1.x.variable_groups = { + "vg1": ["var0", "var1"], + "vg2": ["var2", "var3", "not_existing"], + "vg_2d": ["var0-var1", "var0-var2", "var1-var2", "var3"], + } + + # setup for MultiConfig + cfg2.add_shift("nominal", id=0) + for i in range(3, 7): + cfg2.add_category(name=f"cat{i}", id=i) + cfg2.add_variable(name=f"var{i}", id=i) + cfg2.add_shift(name=f"shift{i}_up", id=2 * i + 1) + cfg2.add_shift(name=f"shift{i}_down", id=2 * i + 2) + p = cfg2.add_process(name=f"proc{i}", id=i) + cfg2.add_dataset(name=f"ds{i}", processes=[p], id=i, info={ + "nominal": [], f"shift{i}_up": [], f"shift{i}_down": []}) + + # same proess group, different processes + cfg1.x.process_groups = { + "pg1": ("proc1", "proc2"), + } + cfg2.x.process_groups = { + "pg1": ("proc3", "proc4"), + } + + # same calibrator, different producer + cfg1.x.default_calibrator = ("calib",) + cfg2.x.default_calibrator = ("calib",) + cfg1.x.default_producer = ("A", "B", "C") + cfg2.x.default_producer = ("B", "C", "D") + + def test_resolve_config_default(self): + # single config + resolved_calibrator = AnalysisTask.resolve_config_default( + param=(RESOLVE_DEFAULT,), + task_params=self.base_params, + container=self.analysis_inst, + default_str="default_calibrator", + multi_strategy="first", + ) + self.assertEqual(resolved_calibrator, ("calib",)) + + resolved_selector = AnalysisTask.resolve_config_default( + param=RESOLVE_DEFAULT, + task_params=self.base_params, + container=self.analysis_inst, + default_str="default_selector", + multi_strategy="first", + ) + self.assertEqual(resolved_selector, "sel") + + resolved_selector_steps = AnalysisTask.resolve_config_default( + param=(RESOLVE_DEFAULT,), + task_params=self.base_params, + container=self.analysis_inst, + default_str="default_selector_steps", # does note exist --> should resolve to empty tuple + multi_strategy="first", + ) + self.assertEqual(resolved_selector_steps, ()) + + resolved_producer = AnalysisTask.resolve_config_default( + param=RESOLVE_DEFAULT, + task_params=self.base_params, + container=self.analysis_inst, + default_str="default_producer", + multi_strategy="first", + ) + self.assertEqual(resolved_producer, "A") + + resolved_producers = AnalysisTask.resolve_config_default( + param=(RESOLVE_DEFAULT,), + task_params=self.base_params, + container=self.analysis_inst, + default_str="default_producer", + multi_strategy="first", + ) + self.assertEqual(resolved_producers, ("A", "B", "C")) + + resolved_producer_groups = AnalysisTask.resolve_config_default_and_groups( + param=(RESOLVE_DEFAULT,), + task_params=self.base_params, + container=self.analysis_inst, + default_str="default_producer", + groups_str="producer_groups", + multi_strategy="first", + ) + self.assertEqual(resolved_producer_groups, ("b", "a", "B", "d", "c")) # TODO: order reversed + + # multi config + for multi_strategy, expected_producer in ( + ("all", {self.config_inst1: ("A", "B", "C"), self.config_inst2: ("B", "C", "D")}), + ("first", ("A", "B", "C")), + ("union", ("A", "B", "C", "D")), + ("intersection", ("B", "C")), + ): + resolved_producer = AnalysisTask.resolve_config_default( + param=(RESOLVE_DEFAULT,), + task_params=self.base_params, + container=tuple(self.analysis_inst.configs), + default_str="default_producer", + multi_strategy=multi_strategy, + ) + # TODO: remove set() when order is fixed + self.assertEqual(set(resolved_producer), set(expected_producer)) + + # "same" strategy + resolved_calibrator = AnalysisTask.resolve_config_default( + param=(RESOLVE_DEFAULT,), + task_params=self.base_params, + container=tuple(self.analysis_inst.configs), + default_str="default_calibrator", + multi_strategy="same", + ) + self.assertEqual(resolved_calibrator, ("calib",)) + with self.assertRaises(ValueError): + AnalysisTask.resolve_config_default( + param=(RESOLVE_DEFAULT,), + task_params=self.base_params, + container=tuple(self.analysis_inst.configs), + default_str="default_producer", + multi_strategy="same", + ) + + def test_find_config_objects(self): + config = AnalysisTask.find_config_objects( + names=self.config_inst1.name, + container=self.analysis_inst, + object_cls=od.Config, + ) + self.assertEqual(config, [self.config_inst1.name]) + configs = AnalysisTask.find_config_objects( + names=(*self.analysis_inst.configs.names(), "not_existing"), + container=self.analysis_inst, + object_cls=od.Config, + ) + self.assertEqual(configs, list(self.analysis_inst.configs.names())) + + variables = AnalysisTask.find_config_objects( + names=("var1", "var2", "var3", "not_existing"), + container=self.config_inst1, + object_cls=od.Variable, + ) + self.assertEqual(variables, ["var1", "var2", "var3"]) + + categories = AnalysisTask.find_config_objects( + names=("cat1", "cat1_1", "cat1_2", "cat2", "cat3", "not_existing"), + container=self.config_inst1, + object_cls=od.Category, + deep=True, + ) + self.assertEqual(categories, ["cat1", "cat1_1", "cat1_2", "cat2", "cat3"]) + + categories = AnalysisTask.find_config_objects( + names=("cat1", "cat1_1", "cat1_2", "cat2", "cat3", "not_existing"), + container=self.config_inst1, + object_cls=od.Category, + deep=False, + ) + self.assertEqual(categories, ["cat1", "cat2", "cat3"]) + + def test_resolve_categories(self): + # TODO: order of resolved categories is still messed up + # testing with single config + CategoriesMixin.single_config = True + + for input_categories, expected_categories in ( + ((RESOLVE_DEFAULT,), ("cat0", "cat1_1", "cat1_2")), + (("cg1", "cg2", "cat4", "not_existing"), ("cat0", "cat1_1", "cat2", "cat3", "cat4")), + ): + input_params = { + **self.base_params, + "categories": input_categories, + } + resolved_params = CategoriesMixin.modify_param_values(params=input_params) + # TODO: remove set() when order is fixed + self.assertEqual(set(resolved_params["categories"]), set(expected_categories)) + + def test_resolve_variables(self): + # testing with single config + VariablesMixin.single_config = True + + for input_variables, expected_variables in ( + ((RESOLVE_DEFAULT,), ("var0", "var1", "var0-var1", "var0-var2", "var1-var2", "var3", "var4")), + (("vg1", "vg2", "var4-var1", "var4-missing"), ("var0", "var1", "var2", "var3", "var4-var1")), + ): + input_params = { + **self.base_params, + "variables": input_variables, + } + resolved_params = VariablesMixin.modify_param_values(params=input_params) + # TODO: remove set() when order is fixed + self.assertEqual(set(resolved_params["variables"]), set(expected_variables)) + + def test_resolve_datasets_processes(self): + DatasetsProcessesMixin.single_config = False + DatasetsProcessesMixin.resolution_task_cls = DatasetTask + for input_processes, expected_processes in ( + ((("proc4",),), (("proc4",), ("proc4",))), + ((("proc4",), ("proc5")), (("proc4",), ("proc5",))), + ((("proc4", "proc5", "proc6"),), (("proc4",), ("proc4", "proc5", "proc6"))), + ((("proc1", "proc2"), ("proc5", "proc6")), (("proc1", "proc2"), ("proc5", "proc6"))), + ((("pg1"),), (("proc1", "proc2"), ("proc3", "proc4"))), + # default (empty tuple) is resolved to all processes + ((), tuple(tuple(proc.name for proc in cfg.processes) for cfg in self.analysis_inst.configs)), + ): + input_params = { + **self.base_params, + "processes": input_processes, + "datasets": (), + } + resolved_params = DatasetsProcessesMixin.modify_param_values(params=input_params) + self.assertEqual(resolved_params["processes"], expected_processes) + + # since there is a 1-to-1 mapping between processes and datasets, we can infer the datasets as well + self.assertEqual( + resolved_params["datasets"], + tuple(tuple(proc_name.replace("proc", "ds") for proc_name in inner) for inner in expected_processes), + ) + + def test_resolve_dataset(self): + DatasetTask.single_config = True + + base_params = { + **self.base_params, + "shift": "nominal", + } + + resolved_params = DatasetTask.modify_param_values(params={**base_params, "dataset": "ds0"}) + self.assertEqual(resolved_params["dataset"], "ds0") + + with self.assertRaises(ValueError): + DatasetTask.modify_param_values({**base_params, "dataset": "not_existing"}) + + def test_resolve_shifts(self): + # single config + + for input_shift, input_dataset, expected_shift in ( + ("nominal", "ds0", "nominal"), + ("shift0_up", "ds0", "shift0_up"), # implemented upstream from dataset "ds0" + ("shift1_up", "ds0", "nominal"), # not implemented upstream --> fallback to "nominal" + ("shift1_up", "ds1", "shift1_up"), # implemented upstream from dataset "ds0" + ): + input_params = { + **self.base_params, + "dataset": input_dataset, + "shift": input_shift, + } + resolved_params = DatasetTask.modify_param_values(params=input_params) + self.assertEqual(resolved_params["shift"], expected_shift) + + with self.assertRaises(ValueError): + DatasetTask.modify_param_values({ + **self.base_params, + "dataset": "ds0", + "shift": "not_existing", + }) + + def test_modify_shifts_multi_config(self): + # multi config + class ShiftTaskAllUpstream(ShiftTask): + """ + Exemplary shift declaration task that collects all known shifts + from all config instances as upstream shifts. + """ + single_config = False + @classmethod + def get_known_shifts( + cls, + params: dict, + shifts, + ) -> None: + super().get_known_shifts(params, shifts) + for config_inst in params.get("config_insts", {}): + shifts.upstream.update(config_inst.shifts.names()) + + class ShiftTaskAllLocal(ShiftTask): + """ + Exemplary shift declaration task that collects all known shifts + from all config instances as local shifts. + """ + single_config = False + @classmethod + def get_known_shifts( + cls, + params: dict, + shifts, + ) -> None: + super().get_known_shifts(params, shifts) + for config_inst in params.get("config_insts", {}): + shifts.local.update(config_inst.shifts.names()) + + for input_shift, expected_shift, expected_shift_cfg1, expected_shift_cfg2 in ( + ("nominal", "nominal", "nominal", "nominal"), + ("shift1_up", "shift1_up", "shift1_up", "nominal"), # known to cfg1 + ("shift3_up", "shift3_up", "shift3_up", "shift3_up"), # known to cfg1 and cfg2 + ("shift6_up", "shift6_up", "nominal", "shift6_up"), # known to cfg2 + ): + # upstream shifts (local shifts should always resolve to "nominal") + input_params = { + **self.base_params, + "shift": input_shift, + } + + expected_shift_insts = { + self.config_inst1: self.config_inst1.get_shift(expected_shift_cfg1), + self.config_inst2: self.config_inst2.get_shift(expected_shift_cfg2), + } + + resolved_params_upstream = ShiftTaskAllUpstream.modify_param_values(params=input_params) + self.assertEqual(resolved_params_upstream["local_shift"], "nominal") + self.assertEqual( + resolved_params_upstream["local_shift_insts"], + {cfg: cfg.get_shift("nominal") for cfg in self.analysis_inst.configs}, + ) + self.assertEqual(resolved_params_upstream["shift"], expected_shift) + self.assertEqual(resolved_params_upstream["global_shift_insts"], expected_shift_insts) + + # local shifts (upstream shifts should be identical to local shifts) + input_params = { + **self.base_params, + "shift": input_shift, + } + resolved_params_local = ShiftTaskAllLocal.modify_param_values(params=input_params) + self.assertEqual(resolved_params_local["local_shift"], expected_shift) + self.assertEqual(resolved_params_local["local_shift_insts"], expected_shift_insts) + self.assertEqual(resolved_params_local["shift"], expected_shift) + self.assertEqual(resolved_params_local["global_shift_insts"], expected_shift_insts) + + # resolving non-existing shifts should raise an error + for task in (ShiftTaskAllUpstream, ShiftTaskAllLocal): + with self.assertRaises(ValueError): + task.modify_param_values({**self.base_params, "shift": "not_existing"}) diff --git a/tests/test_hist_util.py b/tests/test_hist_util.py index 61e9c060f..9049a0d03 100644 --- a/tests/test_hist_util.py +++ b/tests/test_hist_util.py @@ -6,8 +6,10 @@ import unittest from columnflow.util import maybe_import -from columnflow.hist_util import add_hist_axis, create_hist_from_variables - +from columnflow.hist_util import ( + add_hist_axis, create_hist_from_variables, create_columnflow_hist, + translate_hist_intcat_to_strcat, +) import order as od np = maybe_import("numpy") @@ -93,14 +95,48 @@ def test_create_hist_from_variables(self): self.assertEqual(histogram, histogram_manually) - # test with default axes - histogram = create_hist_from_variables( - *self.variable_examples, - int_cat_axes=("category", "process", "shift"), - ) + # test with default categorical axes + histogram = create_columnflow_hist(*self.variable_examples) - expected_default_axes = ("category", "process", "shift") - for axis in expected_default_axes: + expected_default_axes = { + "category": hist.axis.IntCategory, + "process": hist.axis.IntCategory, + "shift": hist.axis.StrCategory, + } + for axis, axis_type in expected_default_axes.items(): self.assertIn(axis, histogram.axes.name) self.assertEqual(histogram.axes[axis].name, axis) - self.assertEqual(type(histogram.axes[axis]), hist.axis.IntCategory) + self.assertEqual(type(histogram.axes[axis]), axis_type) + + def test_translate_hist_intcat_to_strcat(self): + # Create a histogram with an integer category axis + h = hist.Hist( + hist.axis.IntCategory([1, 2, 3], name="category", label="Category Axis"), + storage=hist.storage.Double(), + ) + + # Fill the histogram with some data + h.fill(category=[1, 2, 2, 3, 3, 3]) + + # Define the mapping from integer to string categories + id_map = {1: "one", 2: "two", 3: "three"} + + # Call the function to translate the axis + translated_h = translate_hist_intcat_to_strcat(h, "category", id_map) + + # Validate the new histogram + # Check that the axis has been correctly translated + self.assertEqual(len(translated_h.axes), len(h.axes), "The number of axes should remain the same.") + str_axis = translated_h.axes[0] + self.assertIsInstance(str_axis, hist.axis.StrCategory, "The axis should be of type StrCategory.") + self.assertEqual(str_axis.name, "category", "The axis name should remain unchanged.") + self.assertEqual(str_axis.label, "Category Axis", "The axis label should remain unchanged.") + + # Check the string categories + expected_categories = ["one", "two", "three"] + self.assertEqual(list(str_axis), expected_categories, "The categories should match the expected string values.") + + # Check the data + self.assertEqual(translated_h.sum(), h.sum(), "The total sum of data should remain unchanged.") + for original, translated in zip(h.values(flow=True), translated_h.values(flow=True)): + self.assertEqual(original, translated, "The data values should remain unchanged.") diff --git a/tests/test_inference.py b/tests/test_inference.py index f4c3c016a..5f1bde4b4 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,7 +1,9 @@ +# coding: utf-8 + import unittest + from columnflow.inference import ( - InferenceModel, ParameterType, ParameterTransformation, ParameterTransformations, - FlowStrategy, + InferenceModel, ParameterType, ParameterTransformation, ParameterTransformations, FlowStrategy, ) from columnflow.util import DotDict @@ -11,6 +13,7 @@ class TestInferenceModel(unittest.TestCase): def test_process_spec(self): # Test data name = "test_process" + config_name = "test_config" config_process = "test_config_process" is_signal = True config_mc_datasets = ["dataset1", "dataset2"] @@ -18,153 +21,125 @@ def test_process_spec(self): is_dynamic = True # Expected result - expected_result = DotDict([ - ("name", "test_process"), - ("is_signal", True), - ("config_process", "test_config_process"), - ("config_mc_datasets", ["dataset1", "dataset2"]), - ("scale", 2.0), - ("parameters", []), - ("is_dynamic", True), - ]) + expected_result = DotDict( + name=name, + is_signal=is_signal, + config_data={ + config_name: DotDict( + process=config_process, + mc_datasets=config_mc_datasets, + ), + }, + scale=scale, + is_dynamic=is_dynamic, + parameters=[], + ) # Call the method result = InferenceModel.process_spec( name=name, - config_process=config_process, is_signal=is_signal, - config_mc_datasets=config_mc_datasets, + config_data={ + config_name: InferenceModel.process_config_spec( + process=config_process, + mc_datasets=config_mc_datasets, + ), + }, scale=scale, is_dynamic=is_dynamic, ) - # Assert the result - self.assertEqual(result, expected_result) + self.assertDictEqual(result, expected_result) def test_category_spec(self): # Test data name = "test_category" + config_name = "test_config" config_category = "test_config_category" config_variable = "test_config_variable" config_data_datasets = ["dataset1", "dataset2"] data_from_processes = ["process1", "process2"] mc_stats = (10, 0.1) empty_bin_value = 1e-4 + postfix = None + rate_precision = 5 # Expected result - expected_result = DotDict([ - ("name", "test_category"), - ("config_category", "test_config_category"), - ("config_variable", "test_config_variable"), - ("config_data_datasets", ["dataset1", "dataset2"]), - ("data_from_processes", ["process1", "process2"]), - ("flow_strategy", FlowStrategy.warn), - ("mc_stats", (10, 0.1)), - ("empty_bin_value", 1e-4), - ("processes", []), - ]) + expected_result = DotDict( + name=name, + config_data={ + config_name: DotDict( + category=config_category, + variable=config_variable, + data_datasets=config_data_datasets, + ), + }, + data_from_processes=data_from_processes, + flow_strategy=FlowStrategy.warn, + mc_stats=mc_stats, + postfix=postfix, + empty_bin_value=empty_bin_value, + rate_precision=rate_precision, + processes=[], + ) # Call the method result = InferenceModel.category_spec( name=name, - config_category=config_category, - config_variable=config_variable, - config_data_datasets=config_data_datasets, + config_data={ + config_name: InferenceModel.category_config_spec( + category=config_category, + variable=config_variable, + data_datasets=config_data_datasets, + ), + }, data_from_processes=data_from_processes, mc_stats=mc_stats, empty_bin_value=empty_bin_value, + postfix=postfix, + rate_precision=rate_precision, ) - # Assert the result - self.assertEqual(result, expected_result) + self.assertDictEqual(result, expected_result) def test_parameter_spec(self): # Test data name = "test_parameter" type = ParameterType.rate_gauss transformations = [ParameterTransformation.centralize, ParameterTransformation.symmetrize] + config_name = "test_config" config_shift_source = "test_shift_source" effect = 1.5 + effect_precision = 4 # Expected result - expected_result = DotDict([ - ("name", "test_parameter"), - ("type", ParameterType.rate_gauss), - ("transformations", ParameterTransformations(transformations)), - ("config_shift_source", "test_shift_source"), - ("effect", 1.5), - ]) - - # Call the method - result = InferenceModel.parameter_spec( - name=name, - type=type, - transformations=transformations, - config_shift_source=config_shift_source, - effect=effect, - ) - - # Assert the result - self.assertEqual(result, expected_result) - - def test_parameter_spec_with_default_transformations(self): - # Test data - name = "test_parameter" - type = ParameterType.rate_gauss - config_shift_source = "test_shift_source" - effect = 1.5 - - # Expected result - expected_result = DotDict([ - ("name", "test_parameter"), - ("type", ParameterType.rate_gauss), - ("transformations", ParameterTransformations([ParameterTransformation.none])), - ("config_shift_source", "test_shift_source"), - ("effect", 1.5), - ]) - - # Call the method - result = InferenceModel.parameter_spec( + expected_result = DotDict( name=name, - type=type, - config_shift_source=config_shift_source, + type=ParameterType.rate_gauss, + transformations=ParameterTransformations(transformations), + config_data={ + config_name: DotDict( + shift_source=config_shift_source, + ), + }, effect=effect, + effect_precision=effect_precision, ) - # Assert the result - self.assertEqual(result, expected_result) - - def test_parameter_spec_with_string_type_and_transformations(self): - # Test data - name = "test_parameter" - type = "rate_gauss" - transformations = ["centralize", "symmetrize"] - config_shift_source = "test_shift_source" - effect = 1.5 - - # Expected result - expected_result = DotDict([ - ("name", "test_parameter"), - ("type", ParameterType.rate_gauss), - ("transformations", ParameterTransformations([ - ParameterTransformation.centralize, - ParameterTransformation.symmetrize, - ])), - ("config_shift_source", "test_shift_source"), - ("effect", 1.5), - ]) - # Call the method result = InferenceModel.parameter_spec( name=name, type=type, transformations=transformations, - config_shift_source=config_shift_source, + config_data={ + config_name: InferenceModel.parameter_config_spec( + shift_source=config_shift_source, + ), + }, effect=effect, ) - # Assert the result - self.assertEqual(result, expected_result) + self.assertDictEqual(result, expected_result) def test_parameter_group_spec(self): # Test data @@ -172,10 +147,10 @@ def test_parameter_group_spec(self): parameter_names = ["param1", "param2", "param3"] # Expected result - expected_result = DotDict([ - ("name", "test_group"), - ("parameter_names", ["param1", "param2", "param3"]), - ]) + expected_result = DotDict( + name="test_group", + parameter_names=["param1", "param2", "param3"], + ) # Call the method result = InferenceModel.parameter_group_spec( @@ -184,36 +159,35 @@ def test_parameter_group_spec(self): ) # Assert the result - self.assertEqual(result, expected_result) + self.assertDictEqual(result, expected_result) def test_parameter_group_spec_with_no_parameter_names(self): # Test data name = "test_group" # Expected result - expected_result = DotDict([ - ("name", "test_group"), - ("parameter_names", []), - ]) + expected_result = DotDict( + name="test_group", + parameter_names=[], + ) # Call the method result = InferenceModel.parameter_group_spec( name=name, ) - # Assert the result - self.assertEqual(result, expected_result) + self.assertDictEqual(result, expected_result) def test_require_shapes_for_parameter_shape(self): # No shape is required if the parameter type is a rate types = [ParameterType.rate_gauss, ParameterType.rate_uniform, ParameterType.rate_unconstrained] for t in types: with self.subTest(t=t): - param_obj = DotDict.wrap({ - "type": t, - "transformations": ParameterTransformations([ParameterTransformation.effect_from_rate]), - "name": "test_param", - }) + param_obj = DotDict( + type=t, + transformations=ParameterTransformations([ParameterTransformation.effect_from_rate]), + name="test_param", + ) result = InferenceModel.require_shapes_for_parameter(param_obj) self.assertFalse(result) @@ -223,18 +197,14 @@ def test_require_shapes_for_parameter_shape(self): self.assertTrue(result) # No shape is required if the transformation is from a rate - param_obj = DotDict.wrap({ - "type": ParameterType.shape, - "transformations": ParameterTransformations([ParameterTransformation.effect_from_rate]), - "name": "test_param", - }) + param_obj = DotDict( + type=ParameterType.shape, + transformations=ParameterTransformations([ParameterTransformation.effect_from_rate]), + name="test_param", + ) result = InferenceModel.require_shapes_for_parameter(param_obj) self.assertFalse(result) param_obj.transformations = ParameterTransformations([ParameterTransformation.effect_from_shape]) result = InferenceModel.require_shapes_for_parameter(param_obj) self.assertTrue(result) - - -if __name__ == "__main__": - unittest.main() From 578d8b7fdd9a229e120d29300303d36181cc2302 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 28 May 2025 11:29:39 +0200 Subject: [PATCH 003/123] Add tmp dir checks, add cf_setup_post_install hook. --- analysis_templates/cms_minimal/setup.sh | 24 +---- setup.sh | 126 +++++++++++++++++++----- 2 files changed, 102 insertions(+), 48 deletions(-) diff --git a/analysis_templates/cms_minimal/setup.sh b/analysis_templates/cms_minimal/setup.sh index 8f55fc685..4e226914d 100644 --- a/analysis_templates/cms_minimal/setup.sh +++ b/analysis_templates/cms_minimal/setup.sh @@ -104,7 +104,6 @@ setup___cf_short_name_lc__() { cf_setup_common_variables || return "$?" - # # minimal local software setup # @@ -124,32 +123,15 @@ setup___cf_short_name_lc__() { fi # - # git hooks + # additional common cf setup steps # - if ! ${CF_REMOTE_ENV}; then - cf_setup_git_hooks || return "$?" - fi + cf_setup_post_install || return "$?" # - # law setup + # finalize # - export LAW_HOME="${LAW_HOME:-${__cf_short_name_uc___BASE}/.law}" - export LAW_CONFIG_FILE="${LAW_CONFIG_FILE:-${__cf_short_name_uc___BASE}/law.cfg}" - - if ! ${CF_REMOTE_ENV} && which law &> /dev/null; then - # source law's bash completion scipt - source "$( law completion )" "" - - # add completion to the claw command - complete -o bashdefault -o default -F _law_complete claw - - # silently index - law index -q - fi - - # finalize export __cf_short_name_uc___SETUP="true" } diff --git a/setup.sh b/setup.sh index 3e4a08ccd..4d314b4d3 100644 --- a/setup.sh +++ b/setup.sh @@ -161,7 +161,6 @@ setup_columnflow() { return "1" fi - # # prepare local variables # @@ -180,7 +179,6 @@ setup_columnflow() { setopt globdots fi - # # global variables # (CF = columnflow) @@ -210,50 +208,28 @@ setup_columnflow() { export CF_ORIG_PYTHON3PATH="${PYTHON3PATH}" export CF_ORIG_LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" - # # common variables # cf_setup_common_variables || return "$?" - # # minimal local software setup # cf_setup_software_stack "${CF_SETUP_NAME}" || return "$?" - # - # git hooks + # additional common cf setup steps # - # only in local env - if ${CF_LOCAL_ENV}; then - cf_setup_git_hooks || return "$?" - fi - + cf_setup_post_install || return "$?" # - # law setup + # finalize # - export LAW_HOME="${LAW_HOME:-${CF_BASE}/.law}" - export LAW_CONFIG_FILE="${LAW_CONFIG_FILE:-${CF_BASE}/law.cfg}" - - if ${CF_LOCAL_ENV} && which law &> /dev/null; then - # source law's bash completion scipt - source "$( law completion )" "" - - # add completion to the claw command - complete -o bashdefault -o default -F _law_complete claw - - # silently index - law index -q - fi - - # finalize export CF_SETUP="true" } @@ -766,6 +742,102 @@ EOF fi } +cf_setup_post_install() { + # Performs additional, central setup steps after variables are set and software is installed. These steps are meant + # to be common to all setups and that can be adjusted later on through updates of the columnflow module. + # + # Current steps: + # - setup git hooks + # - setup law + # - check size of the target tmp dir + # + # Required environment variables: + # CF_LOCAL_ENV + # Should be true or false, indicating if the setup is run in a local environment. + # CF_REPO_BASE + # The base directory of the analysis repository, which is used to determine the law home and config file. + + # + # git hooks + # + + # only in local env + if ${CF_LOCAL_ENV}; then + cf_setup_git_hooks || return "$?" + fi + + # + # law setup + # + + if [ ! -z "${CF_REPO_BASE}" ]; then + export LAW_HOME="${LAW_HOME:-${CF_REPO_BASE}/.law}" + export LAW_CONFIG_FILE="${LAW_CONFIG_FILE:-${CF_REPO_BASE}/law.cfg}" + + if ${CF_LOCAL_ENV} && which law &> /dev/null; then + # source law's bash completion scipt + source "$( law completion )" "" + + # add completion to the claw command + complete -o bashdefault -o default -F _law_complete claw + + # silently index + law index -q + fi + fi + + # + # check the tmp directory size + # + + if ${CF_LOCAL_ENV} && which law &> /dev/null; then + cf_check_tmp_dir + fi + + return "0" +} + +cf_check_tmp_dir() { + # Computes the size of all user-owned files in the target tmp directory and issues a warning when the size exceeds + # certain thresholds. If a variable CF_SKIP_TMP_CHECK is set to true, the check is skipped. + + # check if skipping + if [ ! -z "${CF_SKIP_TMP_CHECK}" ] && ${CF_SKIP_TMP_CHECK}; then + return "0" + fi + + # determine the tmp directory + local tmp_dir="$( law config target.tmp_dir )" + local ret="$?" + if [ "${ret}" != "0" ]; then + >&2 cf_color "red" "cf_check_tmp_dir: 'law config target.tmp_dir' failed with error code ${ret}" + return "${ret}" + elif [ -z "${tmp_dir}" ]; then + >&2 cf_color "red" "cf_check_tmp_dir: 'law config target.tmp_dir' must not be empty" + return "2" + elif [ ! -d "${tmp_dir}" ]; then + >&2 cf_color "red" "cf_check_tmp_dir: 'law config target.tmp_dir' is not a directory" + return "3" + fi + + # compute the size + local tmp_size="$( find "${tmp_dir}" -maxdepth 1 -name "*" -user "$( id -u )" -exec du -cb {} + | grep 'total$' | cut -d $'\t' -f 1 )" + + # warn above 1GB with color changing when above 2GB + local thresh1="1073741824" + local thresh2="2147483648" + if [ "${tmp_size}" -gt "${thresh1}" ]; then + local hsize="$( python -c "import law; print(law.util.human_bytes(${tmp_size}, fmt=True))" )" + local color="$( [ "${tmp_size}" -lt "${thresh2}" ] && echo "yellow" || echo "red" )" + echo + cf_color "${color}" "the combined size of your files in the directory ${tmp_dir} is $( cf_color "${color}_bright" "${hsize}" )" + cf_color "${color}" "please consider cleaning up using $( cf_color "${color}_bright" "'cf_remove_tmp [all]'" )" + echo + fi + + return "0" +} + cf_setup_git_hooks() { # Initializes lfs and custom githooks in the local checkout for both the columnflow # (sub)repository, as well as the analysis repository in case a directory bin/githooks is found. From 67f98daae88b283b4d955bf7704a4928297669ce Mon Sep 17 00:00:00 2001 From: maadcoen Date: Wed, 28 May 2025 14:52:57 +0200 Subject: [PATCH 004/123] don't consider empty axis as "missing" --- columnflow/tasks/histograms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index 57c3654b6..883661e05 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -72,7 +72,7 @@ def check_histogram_compatibility(cls, h) -> None: } axes = {ax.name: ax for ax in h.axes} for axis_name, axis_type in expected.items(): - if not (ax := axes.get(axis_name)): + if (ax := axes.get(axis_name)) is None: raise Exception(f"missing axis '{axis_name}' in histogram: {h}") if not isinstance(ax, axis_type): raise ValueError(f"axis '{axis_name}' must have type '{axis_type}', found '{type(ax)}'") From db5f46bf9d8fe6338e3d8d42fd5ff84f0ea5a625 Mon Sep 17 00:00:00 2001 From: maadcoen Date: Wed, 28 May 2025 15:04:19 +0200 Subject: [PATCH 005/123] correct json file extension --- columnflow/tasks/ml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 4a39abe3b..406746ff6 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -134,7 +134,7 @@ def output(self): self.target(f"mlevents_fold{f}of{k}_{self.branch}.parquet") for f in range(k) ]), - "stats": self.target(f"stats_{self.branch}.parquet"), + "stats": self.target(f"stats_{self.branch}.json"), } return outputs From 54191a5291f357cade878bf73247f7d00cd3ba04 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 28 May 2025 15:37:24 +0200 Subject: [PATCH 006/123] Hotfix category flattening. --- columnflow/production/categories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/production/categories.py b/columnflow/production/categories.py index 862b33273..e6dee2e92 100644 --- a/columnflow/production/categories.py +++ b/columnflow/production/categories.py @@ -72,7 +72,7 @@ def category_ids_init(self: Producer, **kwargs) -> None: continue # treat all selections as lists of categorizers - for sel in law.util.make_list(cat_inst.selection): + for sel in law.util.flatten(cat_inst.selection): if Categorizer.derived_by(sel): categorizer = sel elif Categorizer.has_cls(sel): From 5ab1fdc1ca4070ac3c220c122b2d4f2f0de2b8a5 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 30 May 2025 10:29:37 +0200 Subject: [PATCH 007/123] Improve tmp file check. --- setup.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/setup.sh b/setup.sh index 4d314b4d3..f8966732e 100644 --- a/setup.sh +++ b/setup.sh @@ -816,12 +816,16 @@ cf_check_tmp_dir() { >&2 cf_color "red" "cf_check_tmp_dir: 'law config target.tmp_dir' must not be empty" return "2" elif [ ! -d "${tmp_dir}" ]; then - >&2 cf_color "red" "cf_check_tmp_dir: 'law config target.tmp_dir' is not a directory" - return "3" + # nothing to do + return "0" fi - # compute the size - local tmp_size="$( find "${tmp_dir}" -maxdepth 1 -name "*" -user "$( id -u )" -exec du -cb {} + | grep 'total$' | cut -d $'\t' -f 1 )" + # compute the size, with a notification shown if it takes too long + ( sleep 5 && cf_color yellow "computing the size of your files in ${tmp_dir} ..." ) & + local msg_pid="$!" + local tmp_size="$( find "${tmp_dir}" -maxdepth 1 -user "$( id -u )" -exec du -cb {} + | grep 'total$' | cut -d $'\t' -f 1 | sort | head -n 1 )" + kill "${msg_pid}" 2> /dev/null + wait "${msg_pid}" 2> /dev/null # warn above 1GB with color changing when above 2GB local thresh1="1073741824" From e0d13576f03c0b36bef4dab94ed6bd451d129790 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 12 Jun 2025 10:22:10 +0200 Subject: [PATCH 008/123] Update law. --- modules/law | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/law b/modules/law index a02aeb3c2..fb21a9c28 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit a02aeb3c2cf7cb460e52f67490f10c50055c6606 +Subproject commit fb21a9c28337bbf24ad83d11d96a71b6f0a20d06 From 05416b8a25255f35330cd468cb5b20e850dc17de Mon Sep 17 00:00:00 2001 From: Lara Date: Thu, 12 Jun 2025 15:41:10 +0200 Subject: [PATCH 009/123] Make jet collection used in DY weights more flexible --- columnflow/production/cms/dy.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py index 52718f801..391cd4664 100644 --- a/columnflow/production/cms/dy.py +++ b/columnflow/production/cms/dy.py @@ -247,8 +247,6 @@ def dy_weights_setup( # MET information # -> only Run 3 (PuppiMET) is supported "PuppiMET.{pt,phi}", - # Number of jets (as a per-event scalar) - "Jet.{pt,phi,eta,mass}", # Gen-level boson information (full boson momentum) # -> gen_dilepton_vis.pt, gen_dilepton_vis.phi, gen_dilepton_all.pt, gen_dilepton_all.phi gen_dilepton.PRODUCES, @@ -257,6 +255,8 @@ def dy_weights_setup( "RecoilCorrMET.{pt,phi}", "RecoilCorrMET.{pt,phi}_{recoilresp,recoilres}_{up,down}", }, + # custom njet column to be used to derive corrections + njet_column=None, mc_only=True, # function to determine the recoil correction file from external files get_dy_recoil_file=(lambda self, external_files: external_files.dy_recoil_sf), @@ -331,12 +331,15 @@ def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array uperp = -u_x * full_unit_y + u_y * full_unit_x # Determine jet multiplicity for the event (jet selection as in original) - jet_selection = ( - ((events.Jet.pt > 30) & (np.abs(events.Jet.eta) < 2.5)) | - ((events.Jet.pt > 50) & (np.abs(events.Jet.eta) >= 2.5)) - ) - selected_jets = events.Jet[jet_selection] - njet = np.asarray(ak.num(selected_jets, axis=1), dtype=np.float32) + if self.njet_column: + njet = np.asarry(events[self.njet_column], dtype=np.float32) + else: + jet_selection = ( + ((events.Jet.pt > 30) & (np.abs(events.Jet.eta) < 2.5)) | + ((events.Jet.pt > 50) & (np.abs(events.Jet.eta) >= 2.5)) + ) + selected_jets = events.Jet[jet_selection] + njet = np.asarray(ak.num(selected_jets, axis=1), dtype=np.float32) # Apply nominal recoil correction on U components # (see here: https://cms-higgs-leprare.docs.cern.ch/htt-common/V_recoil/#example-snippet) @@ -418,6 +421,14 @@ def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array return events +@recoil_corrected_met.init +def recoil_corrected_met_init(self: Producer) -> None: + if self.njet_column: + self.uses.add(f"{self.njet_column}") + else: + self.uses.add("Jet.{pt,eta,phi,mass}") + + @recoil_corrected_met.requires def recoil_corrected_met_requires(self: Producer, task: law.Task, reqs: dict) -> None: # Ensure that external files are bundled. From b9044c4364344607fb044af0c69c2f8ea0f59463 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 13 Jun 2025 13:07:02 +0200 Subject: [PATCH 010/123] Hotfix missing xsecs for stitched weight producer. --- columnflow/production/normalization.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 56e4c0c82..25b492918 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -352,6 +352,23 @@ def normalization_weights_setup( for proc_id, br in branching_ratios.items(): sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(proc_id)] process_weight_table[0, proc_id] = lumi * inclusive_xsec * br / sum_weights + + # fill in cross sections of missing leaf processes + missing_proc_ids = set(proc.id for proc in inclusive_proc.get_leaf_processes()) - set(branching_ratios.keys()) + for proc_id in missing_proc_ids: + process_inst = inclusive_proc.get_process(proc_id) + if ( + self.config_inst.campaign.ecm in process_inst.xsecs and + str(proc_id) in merged_selection_stats["sum_mc_weight_per_process"] + ): + xsec = process_inst.get_xsec(self.config_inst.campaign.ecm).nominal + sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(proc_id)] + process_weight_table[0, process_inst.id] = lumi * xsec / sum_weights + logger.warning( + f"added cross section for missing leaf process {process_inst.name} ({proc_id}) from xsec entry", + ) + else: + logger.warning(f"no cross section found for missing leaf process {process_inst.name} ({proc_id})") else: # fill the process weight table with per-process cross sections for process_inst in process_insts: From 796b5db902e4d478ff078bb2f526830430d3961a Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 13 Jun 2025 15:55:31 +0200 Subject: [PATCH 011/123] Add dfs lookup pattern negation. --- columnflow/production/normalization.py | 2 -- columnflow/tasks/framework/base.py | 7 ++++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 25b492918..273b156fc 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -367,8 +367,6 @@ def normalization_weights_setup( logger.warning( f"added cross section for missing leaf process {process_inst.name} ({proc_id}) from xsec entry", ) - else: - logger.warning(f"no cross section found for missing leaf process {process_inst.name} ({proc_id})") else: # fill the process weight table with per-process cross sections for process_inst in process_insts: diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index b6128196f..de267eebf 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -383,6 +383,11 @@ def _dfs_key_lookup( while lookup: pattern, obj, keys_func = lookup.popleft() + # when pattern starts with a "!", it is a negation + negate = pattern.startswith("!") + if negate: + pattern = pattern[1:] + # create the copy of comparison keys on demand # (the original sequence is living once on the previous stack until now) _keys = keys_func() @@ -391,7 +396,7 @@ def _dfs_key_lookup( regex = is_regex(pattern) while _keys: key = _keys.popleft() - if law.util.multi_match(key, pattern, regex=regex): + if law.util.multi_match(key, pattern, regex=regex) != negate: # when obj is not a dict, we found the value if not isinstance(obj, dict): return obj From 9ad7489fb6ddfde0b2689c6fef28f30aa9b7c459 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 16 Jun 2025 07:47:09 +0200 Subject: [PATCH 012/123] Allow skipping parts of post setup. --- setup.sh | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/setup.sh b/setup.sh index f8966732e..4bc024729 100644 --- a/setup.sh +++ b/setup.sh @@ -756,13 +756,19 @@ cf_setup_post_install() { # Should be true or false, indicating if the setup is run in a local environment. # CF_REPO_BASE # The base directory of the analysis repository, which is used to determine the law home and config file. + # + # Optional environment variables: + # CF_SKIP_SETUP_GIT_HOOKS + # When set to true, the setup of git hooks is skipped. + # CF_SKIP_CHECK_TMP_DIR + # When set to true, the check of the size of the target tmp directory is skipped. # # git hooks # # only in local env - if ${CF_LOCAL_ENV}; then + if ! ${CF_SKIP_SETUP_GIT_HOOKS} && ${CF_LOCAL_ENV}; then cf_setup_git_hooks || return "$?" fi @@ -790,7 +796,7 @@ cf_setup_post_install() { # check the tmp directory size # - if ${CF_LOCAL_ENV} && which law &> /dev/null; then + if ! ${CF_SKIP_CHECK_TMP_DIR} && ${CF_LOCAL_ENV} && which law &> /dev/null; then cf_check_tmp_dir fi @@ -1106,6 +1112,8 @@ for flag_name in \ CF_REINSTALL_SOFTWARE \ CF_REINSTALL_HOOKS \ CF_SKIP_BANNER \ + CF_SKIP_SETUP_GIT_HOOKS \ + CF_SKIP_CHECK_TMP_DIR \ CF_ON_HTCONDOR \ CF_ON_SLURM \ CF_ON_GRID \ From 59dd5743494e37d75b3c4901c8d28456e7af8a51 Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Mon, 16 Jun 2025 08:33:11 +0200 Subject: [PATCH 013/123] allow evaluating multiple working points with single electron_weights Producer (#694) * allow evaluating multiple working points with single electron_weights Producer * Move import. --------- Co-authored-by: Marcel R. Co-authored-by: Marcel Rieger --- columnflow/production/cms/electron.py | 45 ++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/columnflow/production/cms/electron.py b/columnflow/production/cms/electron.py index e88d115e8..38e628a62 100644 --- a/columnflow/production/cms/electron.py +++ b/columnflow/production/cms/electron.py @@ -12,8 +12,8 @@ from columnflow.production import Producer, producer from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array -from columnflow.types import Any +from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array, EMPTY_FLOAT +from columnflow.types import Any, Callable np = maybe_import("numpy") ak = maybe_import("awkward") @@ -23,7 +23,7 @@ class ElectronSFConfig: correction: str campaign: str - working_point: str = "" + working_point: str | dict[str, Callable] = "" hlt_path: str = "" def __post_init__(self) -> None: @@ -88,6 +88,21 @@ def electron_weights( working_point="wp80iso", # for trigger weights use hlt_path instead ) + The *working_point* can also be a dictionary mapping working point names to functions + that return a boolean mask for the electrons. This is useful to compute scale factors for + multiple working points at once, e.g. for the electron reconstruction scale factors: + + .. code-block:: python + cfg.x.electron_sf_names = ElectronSFConfig( + correction="Electron-ID-SF", + campaign="2022Re-recoE+PromptFG", + working_point={ + "RecoBelow20": lambda variable_map: variable_map["pt"] < 20.0, + "Reco20to75": lambda variable_map: (variable_map["pt"] >= 20.0) & (variable_map["pt"] < 75.0), + "RecoAbove75": lambda variable_map: variable_map["pt"] >= 75.0, + }, + ) + *get_electron_config* can be adapted in a subclass in case it is stored differently in the config. @@ -121,8 +136,28 @@ def electron_weights( **variable_map, "ValType": syst, } - inputs = [variable_map_syst[inp.name] for inp in self.electron_sf_corrector.inputs] - sf_flat = self.electron_sf_corrector(*inputs) + if isinstance(variable_map["WorkingPoint"], str): + inputs = [variable_map_syst[inp.name] for inp in self.electron_sf_corrector.inputs] + sf_flat = self.electron_sf_corrector(*inputs) + elif isinstance(variable_map["WorkingPoint"], dict): + sf_flat = np.ones_like(pt, dtype=np.float32) * EMPTY_FLOAT + for working_point, mask_fn in variable_map_syst["WorkingPoint"].items(): + mask = mask_fn(variable_map) + variable_map_syst_wp = { + **variable_map_syst, + "WorkingPoint": working_point, + } + for key, value in variable_map_syst_wp.items(): + # apply mask to array-like values + if isinstance(value, np.ndarray) or isinstance(value, ak.Array): + variable_map_syst_wp[key] = value[mask] + # call the corrector with the masked inputs + inputs = [variable_map_syst_wp[inp.name] for inp in self.electron_sf_corrector.inputs] + sf_flat[mask] = self.electron_sf_corrector(*inputs) + if np.any(sf_flat == EMPTY_FLOAT): + raise ValueError("some electrons did not have a valid scale factor, check your inputs") + else: + raise ValueError(f"unsupported working point type {type(variable_map['WorkingPoint'])}") # add the correct layout to it sf = layout_ak_array(sf_flat, events.Electron.pt[electron_mask]) From d9df51735ce1a6232d7ecfb3616b33bcf0aea99d Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Mon, 16 Jun 2025 09:57:03 +0200 Subject: [PATCH 014/123] Fix task key lookup. (#697) --- analysis_templates/cms_minimal/law.cfg | 20 +++++----- columnflow/tasks/framework/base.py | 52 ++++++++++++++++++-------- columnflow/tasks/framework/mixins.py | 21 +++++++---- docs/user_guide/best_practices.md | 34 +++++++++-------- docs/user_guide/examples.md | 7 ++-- 5 files changed, 82 insertions(+), 52 deletions(-) diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index 8ce50143d..304c8eb65 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -98,8 +98,8 @@ lfn_sources: wlcg_fs_infn_redirector, wlcg_fs_global_redirector # output locations per task family # the key can consist of multple underscore-separated parts, that can each be patterns or regexes # these parts are used for the lookup from within tasks and can contain (e.g.) the analysis name, -# the config name, the task family, the dataset name, or the shift name -# (see AnalysisTask.get_config_lookup_keys() - and subclasses - for the exact order) +# the config name, the task family, the dataset name, or the shift name, for more info, see +# https://columnflow.readthedocs.io/en/latest/user_guide/best_practices.html#selecting-output-locations # values can have the following format: # for local targets : "local[, LOCAL_FS_NAME or STORE_PATH][, store_parts_modifier]" # for remote targets : "wlcg[, WLCG_FS_NAME][, store_parts_modifier]" @@ -108,8 +108,8 @@ lfn_sources: wlcg_fs_infn_redirector, wlcg_fs_global_redirector # the "store_parts_modifiers" can be the name of a function in the "store_parts_modifiers" aux dict # of the analysis instance, which is called with an output's store parts of an output to modify them # example: -; run3_2023__cf.CalibrateEvents__nomin*: local -; cf.CalibrateEvents: wlcg +; cfg_run3_2023__task_cf.CalibrateEvents__shift_nomin*: local +; task_cf.CalibrateEvents: wlcg [versions] @@ -117,13 +117,13 @@ lfn_sources: wlcg_fs_infn_redirector, wlcg_fs_global_redirector # default versions of specific tasks to pin # the key can consist of multple underscore-separated parts, that can each be patterns or regexes # these parts are used for the lookup from within tasks and can contain (e.g.) the analysis name, -# the config name, the task family, the dataset name, or the shift name -# (see AnalysisTask.get_config_lookup_keys() - and subclasses - for the exact order) +# the config name, the task family, the dataset name, or the shift name, for more info, see +# https://columnflow.readthedocs.io/en/latest/user_guide/best_practices.html#pinned-versions-in-the-analysis-config-or-law-cfg-file # note: # this lookup is skipped if the lookup based on the config instance's auxiliary data succeeded # example: -; run3_2023__cf.CalibrateEvents__nomin*: prod1 -; cf.CalibrateEvents: prod2 +; cfg_run3_2023__task_cf.CalibrateEvents__shift_nomin*: prod1 +; task_cf.CalibrateEvents: prod2 [resources] @@ -135,8 +135,8 @@ lfn_sources: wlcg_fs_infn_redirector, wlcg_fs_global_redirector # by the respective parameter instance at runtime # same as for [versions], the order of options is important as it defines the resolution order # example: -; run3_2023__cf.CalibrateEvents__nomin*: htcondor_memory=5GB -; run3_2023__cf.CalibrateEvents: htcondor_memory=2GB +; cfg_run3_2023__task_cf.CalibrateEvents__shift_nomin*: htcondor_memory=5GB +; cfg_run3_2023__task_cf.CalibrateEvents: htcondor_memory=2GB [job] diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index de267eebf..40f85b104 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -23,7 +23,7 @@ import order as od from columnflow.columnar_util import mandatory_coffea_columns, Route, ColumnCollection -from columnflow.util import is_regex, prettify, DotDict +from columnflow.util import get_docs_url, is_regex, prettify, DotDict from columnflow.types import Sequence, Callable, Any, T @@ -354,10 +354,16 @@ def get_config_lookup_keys( else getattr(inst_or_params, "analysis", None) ) if analysis not in {law.NO_STR, None, ""}: - keys["analysis"] = analysis + prefix = "ana" + keys[prefix] = f"{prefix}_{analysis}" # add the task family - keys["task_family"] = cls.task_family + prefix = "task" + keys[prefix] = f"{prefix}_{cls.task_family}" + + # for backwards compatibility, add the task family again without the prefix + # (TODO: this should be removed in the future) + keys[f"{prefix}_compat"] = cls.task_family return keys @@ -375,7 +381,7 @@ def _dfs_key_lookup( return empty_value # the keys to use for the lookup are the flattened values of the keys dict - flat_keys = collections.deque(law.util.flatten(keys.values() if isinstance(keys, dict) else keys)) + flat_keys = law.util.flatten(keys.values() if isinstance(keys, dict) else keys) # start tree traversal using a queue lookup consisting of names and values of tree nodes, # as well as the remaining keys (as a deferred function) to compare for that particular path @@ -383,20 +389,33 @@ def _dfs_key_lookup( while lookup: pattern, obj, keys_func = lookup.popleft() - # when pattern starts with a "!", it is a negation - negate = pattern.startswith("!") - if negate: - pattern = pattern[1:] - # create the copy of comparison keys on demand # (the original sequence is living once on the previous stack until now) _keys = keys_func() # check if the pattern matches any key regex = is_regex(pattern) - while _keys: - key = _keys.popleft() - if law.util.multi_match(key, pattern, regex=regex) != negate: + for i, key in enumerate(_keys): + if law.util.multi_match(key, pattern, regex=regex): + # for a limited time, show a deprecation warning when the old task family key was matched + # (old = no "task_" prefix) + # TODO: remove once deprecated + if "task_compat" in keys and key == keys["task_compat"]: + docs_url = get_docs_url( + "user_guide", + "best_practices.html", + anchor="selecting-output-locations", + ) + logger.warning_once( + "dfs_lookup_old_task_key", + f"during the lookup of a pinned location, version or resource value of a '{cls.__name__}' " + f"task, an entry matched based on the task family '{key}' that misses the new 'task_' " + "prefix; please update the pinned entries in your law.cfg file by adding the 'task_' " + f"prefix to entries that contain the task family, e.g. 'task_{key}: VALUE'; support for " + f"missing prefixes will be removed in a future version; see {docs_url} for more info", + ) + # remove the matched key from remaining lookup keys + _keys.pop(i) # when obj is not a dict, we found the value if not isinstance(obj, dict): return obj @@ -1454,7 +1473,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "config", None) ) if config not in {law.NO_STR, None, ""}: - keys.insert_before("task_family", "config", config) + prefix = "cfg" + keys.insert_before("task", prefix, f"{prefix}_{config}") return keys @@ -1634,7 +1654,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "shift", None) ) if shift not in (law.NO_STR, None, ""): - keys["shift"] = shift + prefix = "shift" + keys[prefix] = f"{prefix}_{shift}" return keys @@ -1725,7 +1746,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "dataset", None) ) if dataset not in {law.NO_STR, None, ""}: - keys.insert_before("shift", "dataset", dataset) + prefix = "dataset" + keys.insert_before("shift", prefix, f"{prefix}_{dataset}") return keys diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index 37535b17a..d871bd7be 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -120,7 +120,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "calibrator", None) ) if calibrator not in (law.NO_STR, None, ""): - keys["calibrator"] = f"calib_{calibrator}" + prefix = "calib" + keys[prefix] = f"{prefix}_{calibrator}" return keys @@ -304,7 +305,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "calibrators", None) ) if calibrators not in {law.NO_STR, None, "", ()}: - keys["calibrators"] = [f"calib_{calibrator}" for calibrator in calibrators] + prefix = "calib" + keys[prefix] = [f"{prefix}_{calibrator}" for calibrator in calibrators] return keys @@ -510,7 +512,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "selector", None) ) if selector not in (law.NO_STR, None, ""): - keys["selector"] = f"sel_{selector}" + prefix = "sel" + keys[prefix] = f"{prefix}_{selector}" return keys @@ -702,7 +705,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "reducer", None) ) if reducer not in (law.NO_STR, None, ""): - keys["reducer"] = f"red_{reducer}" + prefix = "red" + keys[prefix] = f"{prefix}_{reducer}" return keys @@ -877,7 +881,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "producer", None) ) if producer not in (law.NO_STR, None, ""): - keys["producer"] = f"prod_{producer}" + prefix = "prod" + keys[prefix] = f"{prefix}_{producer}" return keys @@ -1061,7 +1066,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "producers", None) ) if producers not in {law.NO_STR, None, "", ()}: - keys["producers"] = [f"prod_{producer}" for producer in producers] + prefix = "prod" + keys[prefix] = [f"{prefix}_{producer}" for producer in producers] return keys @@ -1679,7 +1685,8 @@ def get_config_lookup_keys( else getattr(inst_or_params, "hist_producer", None) ) if producer not in (law.NO_STR, None, ""): - keys["hist_producer"] = f"hist_{producer}" + prefix = "hist" + keys[prefix] = f"{prefix}_{producer}" return keys diff --git a/docs/user_guide/best_practices.md b/docs/user_guide/best_practices.md index 4390319ab..0a72fbc0c 100644 --- a/docs/user_guide/best_practices.md +++ b/docs/user_guide/best_practices.md @@ -22,38 +22,40 @@ For convenience, if no file system with that name was defined, `LOCAL_FS_NAME` i - `wlcg, WLCG_FS_NAME` refers to a specific remote storage system named `WLCG_FS_NAME` that should be defined in the `law.cfg` file. `TASK_IDENTIFIER` identifies the task the location should apply to. -It can be a simple task family such as `cf.CalibrateEvents`, but for larger analyses a more fine grained selection is required. +It can be a simple task family such as `task_cf.CalibrateEvents` (see the format below), but for larger analyses a more fine grained selection is required. For this purpose, `TASK_IDENTIFIER` can be a `__`-separated sequence of so-called lookup keys, e.g. ```ini [outputs] -run3_23__cf.CalibrateEvents__nominal: wlcg, wlcg_fs_run3_23 +cfg_run3_23__task_cf.CalibrateEvents__shift_nominal: wlcg, wlcg_fs_run3_23 ``` Here, three keys are defined, making use of the config name, the task family, and the name of a systematic shift. The exact selection of possible keys and their resolution order is defined by the task itself in {py:meth}:`~columnflow.tasks.framework.base.AnalysisTask.get_config_lookup_keys` (and subclasses). Most tasks, however, define their lookup keys as: -1. analysis name -2. config name -3. task family -4. dataset name -5. shift name +1. analysis name, prefixed by `ana_` +2. config name, prefixed by `cfg_` +3. task family, prefixed by `task_` +4. dataset name, prefixed by `dataset_` +5. shift name, prefixed by `shift_` 6. calibrator name, prefixed by `calib_` 7. selector name, prefixed by `sel_` -8. producer name, prefixed by `prod_` +8. reducer name, prefixed by `red_` +9. producer name, prefixed by `prod_` +10. hist producer name, prefixed by `hist_` When defining `TASK_IDENTIFIER`'s, not all keys need to be specified, and patterns or regular expressions (`^EXPR$`) can be used. -The definition order is **important** as the first matching definition is used. +The definition order in the config file is **important** as the first matching definition is used. This way, output locations are highly customizable. ```ini [outputs] # store all run3 outputs on a specific fs, and all other outputs locally -run3_*__cf.CalibrateEvents: wlcg, wlcg_fs_run3 -cf.CalibrateEvents: local +cfg_run3_*__task_cf.CalibrateEvents: wlcg, wlcg_fs_run3 +task_cf.CalibrateEvents: local ``` ## Controlling versions of upstream tasks @@ -90,18 +92,18 @@ Consider the following two examples for defining versions, one via auxiliary con ```python cfg.x.versions = { - "run3_*": { - "cf.CalibrateEvents": "v2", + "cfg_run3_*": { + "task_cf.CalibrateEvents": "v2", }, - "cf.CalibrateEvents": "v1", + "task_cf.CalibrateEvents": "v1", } ``` ```ini [versions] -run3_*__cf.CalibrateEvents: v2 -cf.CalibrateEvents: v1 +cfg_run3_*__task_cf.CalibrateEvents: v2 +task_cf.CalibrateEvents: v1 ``` They are **equivalent** since the `__`-separated `TASK_IDENTIFIER`'s in the `law.cfg` are internallly converted to the same nested dictionary structure. diff --git a/docs/user_guide/examples.md b/docs/user_guide/examples.md index 7d9ea5022..09849eeab 100644 --- a/docs/user_guide/examples.md +++ b/docs/user_guide/examples.md @@ -267,11 +267,10 @@ lfn_sources: local_dcache # output locations per task family # for local targets : "local[, STORE_PATH]" # for remote targets: "wlcg[, WLCG_FS_NAME]" -cf.Task1: local -cf.Task2: local, /shared/path/to/store/output -cf.Task3: /shared/path/to/store/output +task_cf.Task1: local +task_cf.Task2: local, /shared/path/to/store/output +task_cf.Task3: /shared/path/to/store/output ... - ``` It is important to redirect the setup to the custom config file by setting the ```LAW_CONFIG_FILE``` environment variable in the `setup.sh` file to the path of the custom config file as follows: From 79c1bade5a962edb1bff4f7274d0f208f6b91b74 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 18 Jun 2025 10:51:25 +0200 Subject: [PATCH 015/123] Fix scope issue in seed producer. --- columnflow/production/cms/seeds.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/columnflow/production/cms/seeds.py b/columnflow/production/cms/seeds.py index 09c84d8a9..f8834b8db 100644 --- a/columnflow/production/cms/seeds.py +++ b/columnflow/production/cms/seeds.py @@ -30,10 +30,6 @@ def create_seed(val: int, n_hex: int = 16) -> int: return int(hashlib.sha256(bytes(str(val), "utf-8")).hexdigest()[:-(n_hex + 1):-1], base=16) -# store a vectorized version (only interface, not actually simd'ing) -create_seed_vec = np.vectorize(create_seed, otypes=[np.uint64]) - - @producer( uses={ # global columns for event seed @@ -74,7 +70,7 @@ def deterministic_event_seeds(self, events: ak.Array, **kwargs) -> ak.Array: before invoking this producer. """ # started from an already hashed seed based on event, run and lumi info multiplied with primes - seed = create_seed_vec( + seed = self.create_seed_vec( np.asarray( self.primes[7] * ak.values_astype(events.event, np.uint64) + self.primes[5] * ak.values_astype(events.run, np.uint64) + @@ -125,7 +121,7 @@ def deterministic_event_seeds(self, events: ak.Array, **kwargs) -> ak.Array: seed = seed + primes * ak.values_astype(hashed, np.uint64) # create and store them - seed = ak.Array(create_seed_vec(np.asarray(seed))) + seed = ak.Array(self.create_seed_vec(np.asarray(seed))) events = set_ak_column(events, "deterministic_seed", seed, value_type=np.uint64) # uniqueness test across the chunk for debugging @@ -178,6 +174,9 @@ def apply_route(ak_array: ak.Array, route: Route) -> ak.Array | None: self.apply_route = apply_route + # store a vectorized version of the create_seed function (only interface, not actually simd'ing) + self.create_seed_vec = np.vectorize(create_seed, otypes=[np.uint64]) + class deterministic_object_seeds(Producer): @@ -217,7 +216,7 @@ def call_func(self, events: ak.Array, **kwargs) -> ak.Array: ) ) np_object_seed = np.asarray(ak.flatten(object_seed)) - np_object_seed[:] = create_seed_vec(np_object_seed) + np_object_seed[:] = self.create_seed_vec(np_object_seed) # store them events = set_ak_column(events, f"{self.object_field}.deterministic_seed", object_seed, value_type=np.uint64) @@ -253,6 +252,9 @@ def setup_func( # store primes in array self.primes = np.array(primes, dtype=np.uint64) + # store a vectorized version of the create_seed function (only interface, not actually simd'ing) + self.create_seed_vec = np.vectorize(create_seed, otypes=[np.uint64]) + deterministic_jet_seeds = deterministic_object_seeds.derive( "deterministic_jet_seeds", From 00d595d45f500572f45afa980dfd23f7f45863ff Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Wed, 18 Jun 2025 13:01:05 +0200 Subject: [PATCH 016/123] typo. --- columnflow/production/cms/dy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py index 391cd4664..590605d74 100644 --- a/columnflow/production/cms/dy.py +++ b/columnflow/production/cms/dy.py @@ -332,7 +332,7 @@ def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array # Determine jet multiplicity for the event (jet selection as in original) if self.njet_column: - njet = np.asarry(events[self.njet_column], dtype=np.float32) + njet = np.asarray(events[self.njet_column], dtype=np.float32) else: jet_selection = ( ((events.Jet.pt > 30) & (np.abs(events.Jet.eta) < 2.5)) | From 97770d458c334871ddc1e6b25fbd1ad2a51ff7f5 Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Thu, 19 Jun 2025 09:32:32 +0200 Subject: [PATCH 017/123] add run to inputs and remove double underscore --- columnflow/calibration/cms/jets.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 24462732f..806b19c10 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -228,6 +228,7 @@ def get_jec_config_default(self: Calibrator) -> DotDict: @calibrator( uses={ + "run", optional("fixedGridRhoFastjetAll"), optional("Rho.fixedGridRhoFastjetAll"), attach_coffea_behavior, @@ -328,7 +329,7 @@ def jec( events = set_ak_column_f32(events, f"{jet_name}.pt_raw", events[jet_name].pt * (1 - events[jet_name].rawFactor)) events = set_ak_column_f32(events, f"{jet_name}.mass_raw", events[jet_name].mass * (1 - events[jet_name].rawFactor)) - def correct_jets(*, pt, eta, phi, area, rho, evaluator_key="jec"): + def correct_jets(*, pt, eta, phi, area, rho, run, evaluator_key="jec"): # variable naming convention variable_map = { "JetA": area, @@ -336,6 +337,7 @@ def correct_jets(*, pt, eta, phi, area, rho, evaluator_key="jec"): "JetPt": pt, "JetPhi": phi, "Rho": ak.values_astype(rho, np.float32), + "run": run, } # apply all correctors sequentially, updating the pt each time @@ -370,6 +372,7 @@ def correct_jets(*, pt, eta, phi, area, rho, evaluator_key="jec"): phi=events[jet_name].phi, area=events[jet_name].area, rho=rho, + run=events.run, evaluator_key="jec_subset_type1_met", ) @@ -392,6 +395,7 @@ def correct_jets(*, pt, eta, phi, area, rho, evaluator_key="jec"): phi=events[jet_name].phi, area=events[jet_name].area, rho=rho, + run=events.run, evaluator_key="jec", ) @@ -600,7 +604,7 @@ def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): jec_era = "Run" + self.dataset_inst.get_aux("era") return [ - f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{name}_{jec.jet_type}" + f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{name}_{jec.jet_type}".replace("__", "_") if is_data else f"{jec.campaign}_{jec.version}_MC_{name}_{jec.jet_type}" for name in names From 35bf540668e52081c67aff3580d7dc78f462cc84 Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Thu, 19 Jun 2025 11:04:21 +0200 Subject: [PATCH 018/123] improve readability of make_jme_keys --- columnflow/calibration/cms/jets.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 806b19c10..f0e8dbea1 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -603,12 +603,15 @@ def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): if jec_era is None: jec_era = "Run" + self.dataset_inst.get_aux("era") - return [ - f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{name}_{jec.jet_type}".replace("__", "_") - if is_data else - f"{jec.campaign}_{jec.version}_MC_{name}_{jec.jet_type}" - for name in names - ] + # if JEC era is specified as empty string, infer that the Run part is not included in the key + if jec_era == "": + jme_key = f"{jec.campaign}_{jec.version}_DATA_{{name}}_{jec.jet_type}" + else: + jme_key = f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{{name}}_{jec.jet_type}" + else: + jme_key = f"{jec.campaign}_{jec.version}_MC_{{name}}_{jec.jet_type}" + + return [jme_key.format(name=name) for name in names] jec_keys = make_jme_keys(jec_cfg.levels) jec_keys_subset_type1_met = make_jme_keys(jec_cfg.levels_for_type1_met) From 96d391967687b764e3b18c28ec0f7a4ec2f2e4fe Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Thu, 19 Jun 2025 14:57:58 +0200 Subject: [PATCH 019/123] implement data_per_era tag to jec config --- columnflow/calibration/cms/jets.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index f0e8dbea1..e42d0094a 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -581,6 +581,8 @@ def jec_setup( "CorrelationGroupFlavor", "CorrelationGroupUncorrelated", ], + # whether the JECs for data should be era-specific + "data_per_era": True, }, }) @@ -597,18 +599,27 @@ def jec_setup( jec_cfg = self.get_jec_config() def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): - if is_data: + if is_data and jec.get("data_per_era", True): + if "data_per_era" not in jec: + logger.warning_once( + f"{id(self)}_depr_jec_config_data_per_era", + "config aux 'jec' does not contain key 'data_per_era'. " + "This may be due to an outdated config. Continuing under the assumption that " + "JEC keys for data are era-specific. " + "This assumption will be removed in future versions of " + "columnflow, so please adapt the config according to the " + "documentation to remove this warning and ensure future " + "compatibility of the code.", + ) jec_era = self.dataset_inst.get_aux("jec_era", None) # if no special JEC era is specified, infer based on 'era' if jec_era is None: - jec_era = "Run" + self.dataset_inst.get_aux("era") + jec_era = "Run" + self.dataset_inst.get_aux("era", None) - # if JEC era is specified as empty string, infer that the Run part is not included in the key - if jec_era == "": - jme_key = f"{jec.campaign}_{jec.version}_DATA_{{name}}_{jec.jet_type}" - else: - jme_key = f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{{name}}_{jec.jet_type}" - else: + jme_key = f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{{name}}_{jec.jet_type}" + elif is_data: + jme_key = f"{jec.campaign}_{jec.version}_DATA_{{name}}_{jec.jet_type}" + else: # MC jme_key = f"{jec.campaign}_{jec.version}_MC_{{name}}_{jec.jet_type}" return [jme_key.format(name=name) for name in names] From c2e35faccdc739d607230e179054eaad84df7c36 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 20 Jun 2025 13:45:50 +0200 Subject: [PATCH 020/123] docs: add LuSchaller as a contributor for code (#701) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 12 +++++++++++- README.md | 3 ++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index d8c51a20d..09e5078cf 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -155,7 +155,8 @@ "contributions": [ "code" ] - }, { + }, + { "login": "philippgadow", "name": "philippgadow", "avatar_url": "https://avatars.githubusercontent.com/u/6804366?v=4", @@ -163,6 +164,15 @@ "contributions": [ "code" ] + }, + { + "login": "LuSchaller", + "name": "Lukas Schaller", + "avatar_url": "https://avatars.githubusercontent.com/u/30951523?v=4", + "profile": "https://github.com/LuSchaller", + "contributions": [ + "code" + ] } ], "commitType": "docs" diff --git a/README.md b/README.md index a7921d8e8..043f72609 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ For a better overview of the tasks that are triggered by the commands below, che
Marcel Rieger
Marcel Rieger

💻 👀 📖 ⚠️
Marcel Rieger
Marcel Rieger

💻 👀
Mathis Frahm
Mathis Frahm

💻 👀
Daniel Savoiu
Daniel Savoiu

💻 👀
pkausw
pkausw

💻 👀
- + @@ -156,6 +156,7 @@ For a better overview of the tasks that are triggered by the commands below, che +
Marcel Rieger
Marcel Rieger

💻 👀
Marcel Rieger
Marcel Rieger

💻 👀 📖 ⚠️
Mathis Frahm
Mathis Frahm

💻 👀
Daniel Savoiu
Daniel Savoiu

💻 👀
pkausw
pkausw

💻 👀
Ana Andrade
Ana Andrade

💻
philippgadow
philippgadow

💻
Lukas Schaller
Lukas Schaller

💻
From 86765d16e55f17efe9fd27e91f03734b810f1057 Mon Sep 17 00:00:00 2001 From: Lukas Schaller <30951523+LuSchaller@users.noreply.github.com> Date: Fri, 20 Jun 2025 13:56:09 +0200 Subject: [PATCH 021/123] added sourcing of cms setup to ensure scram is available (#699) Co-authored-by: Marcel Rieger --- sandboxes/_setup_cmssw.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sandboxes/_setup_cmssw.sh b/sandboxes/_setup_cmssw.sh index 8e874b60c..f872a5e34 100644 --- a/sandboxes/_setup_cmssw.sh +++ b/sandboxes/_setup_cmssw.sh @@ -234,12 +234,14 @@ setup_cmssw() { if command -v cf_cmssw_custom_install &> /dev/null; then echo -e "\nrunning cf_cmssw_custom_install" cf_cmssw_custom_install && - cd "${install_path}/src" && + source "/cvmfs/cms.cern.ch/cmsset_default.sh" "" && + cd "${install_path}/src" && scram b elif [ ! -z "${cf_cmssw_custom_install}" ] && [ -f "${cf_cmssw_custom_install}" ]; then echo -e "\nsourcing cf_cmssw_custom_install file" source "${cf_cmssw_custom_install}" "" && - cd "${install_path}/src" && + source "/cvmfs/cms.cern.ch/cmsset_default.sh" "" && + cd "${install_path}/src" && scram b fi ) From c86e46e053393f785ad482d4a9b81fd80267dc82 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 23 Jun 2025 08:44:26 +0200 Subject: [PATCH 022/123] Fix typo in cf_inspect. --- bin/cf_inspect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/cf_inspect.py b/bin/cf_inspect.py index fb080b8e9..5c2ecdc4a 100644 --- a/bin/cf_inspect.py +++ b/bin/cf_inspect.py @@ -101,7 +101,7 @@ def list_content(data: Any) -> None: print("file content loaded into variable 'objects'") # interpret data - intepreted = objects + interpreted = objects if args.events: # preload common packages import awkward as ak # noqa From e1ebc81cf8e8fbfc453f05a58c9570150621ae11 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 25 Jun 2025 10:37:27 +0200 Subject: [PATCH 023/123] Hotfix validation check in stitched normalization weight production. --- columnflow/production/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 273b156fc..4deea5e2c 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -305,7 +305,7 @@ def normalization_weights_setup( allowed_ids = set(map(int, merged_selection_stats["sum_mc_weight_per_process"])) # complain if there are processes seen/id'ed during selection that are not part of the datasets - unknown_process_ids = allowed_ids - {p.id for p in process_insts} + unknown_process_ids = {p.id for p in process_insts} - allowed_ids if unknown_process_ids: raise Exception( f"selection stats contain ids of processes that were not previously registered to the config " From 14a3444be9abc954ec867833793b97838b47c8ce Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Wed, 25 Jun 2025 13:01:37 +0200 Subject: [PATCH 024/123] Improve categorizer calls. (#702) --- columnflow/production/categories.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/columnflow/production/categories.py b/columnflow/production/categories.py index e6dee2e92..1d6d63dda 100644 --- a/columnflow/production/categories.py +++ b/columnflow/production/categories.py @@ -6,6 +6,9 @@ from __future__ import annotations +import functools +import operator + import law from columnflow.categorization import Categorizer @@ -34,16 +37,20 @@ def category_ids( """ Assigns each event an array of category ids. """ - category_ids = [] + # evaluate all unique categorizers, storing their returned masks + cat_masks = {} + for categorizer in self.unique_categorizers: + events, mask = self[categorizer](events, **kwargs) + cat_masks[categorizer] = mask + # loop through categories and construct mask over all categorizers + category_ids = [] for cat_inst, categorizers in self.categorizer_map.items(): - # start with a true mask - cat_mask = np.ones(len(events), dtype=bool) - - # loop through selectors - for categorizer in categorizers: - events, mask = self[categorizer](events, **kwargs) - cat_mask = cat_mask & mask + cat_mask = functools.reduce( + operator.and_, + (cat_masks[c] for c in categorizers), + np.ones(len(events), dtype=bool), + ) # covert to nullable array with the category ids or none, then apply ak.singletons ids = ak.where(cat_mask, np.float64(cat_inst.id), np.float64(np.nan)) @@ -95,3 +102,6 @@ def category_ids_init(self: Producer, **kwargs) -> None: self.produces.add(categorizer) self.categorizer_map.setdefault(cat_inst, []).append(categorizer) + + # store a list of unique categorizers + self.unique_categorizers = law.util.make_unique(sum(self.categorizer_map.values(), [])) From 2599015d5c97b8b9e3b0edda315f4a49d480005e Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Wed, 25 Jun 2025 14:28:00 +0200 Subject: [PATCH 025/123] Add ml model task pinning. (#703) --- columnflow/tasks/framework/mixins.py | 38 ++++++++++++++++++++++++++++ docs/user_guide/best_practices.md | 3 ++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index d871bd7be..0960b0e72 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -1283,6 +1283,25 @@ def events_used_in_training( not shift_inst.has_tag("disjoint_from_nominal") ) + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: MLModelMixinBase | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the ml model name + ml_model = ( + inst_or_params.get("ml_model") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "ml_model", None) + ) + if ml_model not in (law.NO_STR, None, ""): + prefix = "ml" + keys[prefix] = f"{prefix}_{ml_model}" + + return keys + class MLModelTrainingMixin( MLModelMixinBase, @@ -1610,6 +1629,25 @@ def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: return columns + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: MLModelsMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the ml model names + ml_models = ( + inst_or_params.get("ml_models") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "ml_models", None) + ) + if ml_models not in {law.NO_STR, None, "", ()}: + prefix = "ml" + keys[prefix] = [f"{prefix}_{ml_model}" for ml_model in ml_models] + + return keys + class HistProducerClassMixin(ArrayFunctionClassMixin): """ diff --git a/docs/user_guide/best_practices.md b/docs/user_guide/best_practices.md index 0a72fbc0c..150c09bc5 100644 --- a/docs/user_guide/best_practices.md +++ b/docs/user_guide/best_practices.md @@ -44,7 +44,8 @@ Most tasks, however, define their lookup keys as: 7. selector name, prefixed by `sel_` 8. reducer name, prefixed by `red_` 9. producer name, prefixed by `prod_` -10. hist producer name, prefixed by `hist_` +10. ml model name, prefixed by `ml_` +11. hist producer name, prefixed by `hist_` When defining `TASK_IDENTIFIER`'s, not all keys need to be specified, and patterns or regular expressions (`^EXPR$`) can be used. The definition order in the config file is **important** as the first matching definition is used. From 150e910d4b3fd9f81abf0132547cc303bf5f780f Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Wed, 25 Jun 2025 17:08:38 +0200 Subject: [PATCH 026/123] add exception when era aux is missing --- columnflow/calibration/cms/jets.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index e42d0094a..5ad98852f 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -614,7 +614,13 @@ def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): jec_era = self.dataset_inst.get_aux("jec_era", None) # if no special JEC era is specified, infer based on 'era' if jec_era is None: - jec_era = "Run" + self.dataset_inst.get_aux("era", None) + era = self.dataset_inst.get_aux("era", None) + if era is None: + raise ValueError( + "JEC data key is requested to be era dependent, but neither jec_era or era " + f"auxiliary is set for dataset {self.dataset_inst.name}.", + ) + jec_era = "Run" + era jme_key = f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{{name}}_{jec.jet_type}" elif is_data: From e9575f948e51d921acc293919904460da1a8b88b Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Thu, 26 Jun 2025 07:48:21 +0200 Subject: [PATCH 027/123] apply blinding threshold before process scaling --- columnflow/plotting/plot_functions_1d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 0608075fa..86d8ee1c5 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -60,12 +60,12 @@ def plot_variable_stack( hists, process_style_config = apply_process_settings(hists, process_settings) # variable-based settings (rebinning, slicing, flow handling) hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) - # process scaling - hists = apply_process_scaling(hists) # remove data in bins where sensitivity exceeds some threshold blinding_threshold = kwargs.get("blinding_threshold", None) if blinding_threshold: hists = blind_sensitive_bins(hists, config_inst, blinding_threshold) + # process scaling + hists = apply_process_scaling(hists) # density scaling per bin if density: hists = apply_density(hists, density) From 1700472aaa7e164dafa2f4a2725faea8f90e524a Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Thu, 26 Jun 2025 12:59:57 +0200 Subject: [PATCH 028/123] merge workflow reqs of different variables in CreateDatacards (#689) --- columnflow/tasks/framework/inference.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/columnflow/tasks/framework/inference.py b/columnflow/tasks/framework/inference.py index e742ad0d0..6c202aa31 100644 --- a/columnflow/tasks/framework/inference.py +++ b/columnflow/tasks/framework/inference.py @@ -109,12 +109,27 @@ def get_data_datasets(cls, config_inst: od.Config, cat_obj: DotDict) -> list[str def create_branch_map(self): return list(self.inference_model_inst.categories) - def _requires_cat_obj(self, cat_obj: DotDict, **req_kwargs): + def _requires_cat_obj(self, cat_obj: DotDict, merge_variables: bool = False, **req_kwargs): + """ + Helper to create the requirements for a single category object. + + :param cat_obj: category object from an InferenceModel + :param merge_variables: whether to merge the variables from all requested category objects + :return: requirements for the category object + """ reqs = {} for config_inst in self.config_insts: if not (config_data := cat_obj.config_data.get(config_inst.name)): continue + if merge_variables: + variables = tuple( + _cat_obj.config_data.get(config_inst.name).variable + for _cat_obj in self.branch_map.values() + ) + else: + variables = (config_data.variable,) + # add merged shifted histograms for mc reqs[config_inst.name] = { proc_obj.name: { @@ -130,7 +145,7 @@ def _requires_cat_obj(self, cat_obj: DotDict, **req_kwargs): self.inference_model_inst.require_shapes_for_parameter(param_obj) ) ), - variables=(config_data.variable,), + variables=variables, **req_kwargs, ) for dataset in self.get_mc_datasets(config_inst, proc_obj) @@ -150,7 +165,7 @@ def _requires_cat_obj(self, cat_obj: DotDict, **req_kwargs): self, config=config_inst.name, dataset=dataset, - variables=(config_data.variable,), + variables=variables, **req_kwargs, ) for dataset in data_datasets @@ -163,14 +178,13 @@ def workflow_requires(self): reqs["merged_hists"] = hist_reqs = {} for cat_obj in self.branch_map.values(): - cat_reqs = self._requires_cat_obj(cat_obj) + cat_reqs = self._requires_cat_obj(cat_obj, merge_variables=True) for config_name, proc_reqs in cat_reqs.items(): hist_reqs.setdefault(config_name, {}) for proc_name, dataset_reqs in proc_reqs.items(): hist_reqs[config_name].setdefault(proc_name, {}) for dataset_name, task in dataset_reqs.items(): hist_reqs[config_name][proc_name].setdefault(dataset_name, set()).add(task) - return reqs def requires(self): From 21c55a58cdd713a126a8fef10fc7e868459ecc91 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 26 Jun 2025 13:01:28 +0200 Subject: [PATCH 029/123] Revert process id check in normalization producer. --- columnflow/production/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 4deea5e2c..273b156fc 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -305,7 +305,7 @@ def normalization_weights_setup( allowed_ids = set(map(int, merged_selection_stats["sum_mc_weight_per_process"])) # complain if there are processes seen/id'ed during selection that are not part of the datasets - unknown_process_ids = {p.id for p in process_insts} - allowed_ids + unknown_process_ids = allowed_ids - {p.id for p in process_insts} if unknown_process_ids: raise Exception( f"selection stats contain ids of processes that were not previously registered to the config " From 3fe0f0c4154afc7e0c760fa2a9ad4fcfe7cfeb6e Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 26 Jun 2025 13:06:34 +0200 Subject: [PATCH 030/123] [cms] Add note on TEC-to-MET propagation. --- columnflow/calibration/cms/tau.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/columnflow/calibration/cms/tau.py b/columnflow/calibration/cms/tau.py index b9ada41ef..4cd4e7081 100644 --- a/columnflow/calibration/cms/tau.py +++ b/columnflow/calibration/cms/tau.py @@ -94,6 +94,11 @@ def tec( *get_tau_file* and *get_tec_config* can be adapted in a subclass in case they are stored differently in the config. + .. note:: + + In case you also perform the propagation from jet energy calibrations to MET, please check if the propagation of + tau energy calibrations to MET is required in your analysis! + Resources: https://twiki.cern.ch/twiki/bin/view/CMS/TauIDRecommendationForRun2?rev=113 https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/849c6a6efef907f4033715d52290d1a661b7e8f9/POG/TAU From 7a7359c383db575a9662bb8766cabe4d58f0c48a Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 26 Jun 2025 16:40:33 +0200 Subject: [PATCH 031/123] Control row group merging in MergeReducedEvents. --- columnflow/tasks/reduction.py | 9 ++++++++- modules/law | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 4e6c8d87a..d249c8c37 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -453,6 +453,9 @@ class MergeReducedEvents(_MergeReducedEvents): ReduceEvents=ReduceEvents, ) + # approximate number of events per row group in the merged file + target_row_group_size = 50_000 + @law.workflow_property(setter=True, cache=True, empty_value=0) def file_merging(self): # check if the merging stats are present @@ -499,7 +502,11 @@ def run(self): # merge law.pyarrow.merge_parquet_task( - self, inputs, output, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=inputs, + output=output, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.target_row_group_size, ) # optionally remove initial inputs diff --git a/modules/law b/modules/law index fb21a9c28..b76c17319 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit fb21a9c28337bbf24ad83d11d96a71b6f0a20d06 +Subproject commit b76c1731981b643f0c4a6e9a6e3040dfbd7bca12 From 054edd9fd414710e75b57a08449563aeaa377918 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 26 Jun 2025 17:13:37 +0200 Subject: [PATCH 032/123] Update law. --- modules/law | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/law b/modules/law index b76c17319..2e6133857 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit b76c1731981b643f0c4a6e9a6e3040dfbd7bca12 +Subproject commit 2e6133857d4f5923796ebb7bdde50e35c1ff51c0 From 32a041cb874e2db0fb0ba5581efab0e18e172df4 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 27 Jun 2025 07:13:15 +0200 Subject: [PATCH 033/123] Improve reduced events merging. --- columnflow/tasks/reduction.py | 1 + modules/law | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index d249c8c37..ad0dc867e 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -505,6 +505,7 @@ def run(self): task=self, inputs=inputs, output=output, + callback=self.create_progress_callback(len(inputs)), writer_opts=self.get_parquet_writer_opts(), target_row_group_size=self.target_row_group_size, ) diff --git a/modules/law b/modules/law index 2e6133857..1d667e43d 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 2e6133857d4f5923796ebb7bdde50e35c1ff51c0 +Subproject commit 1d667e43de3ef6f5f5f154b1c0d319220a3b52ab From 7240932fa8e25d05ec59380c34e98c228cb7cf59 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 27 Jun 2025 23:18:30 +0200 Subject: [PATCH 034/123] Updata law. --- modules/law | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/law b/modules/law index 1d667e43d..bbeaa4ad3 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 1d667e43de3ef6f5f5f154b1c0d319220a3b52ab +Subproject commit bbeaa4ad3213c9a01417061fd4c262947849e952 From 189d4d5fe1ec0c7cc2bb7cae9ebba463b7ac323f Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 30 Jun 2025 13:25:00 +0200 Subject: [PATCH 035/123] Store dataset_info_inst in params, cap files for reduction stats. --- columnflow/tasks/framework/base.py | 15 +++++++++++++++ columnflow/tasks/reduction.py | 17 +++++++++++------ modules/law | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index 40f85b104..d0c9dbfa4 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -1714,6 +1714,21 @@ def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any] return params + @classmethod + def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: + params = super().resolve_param_values(params) + + # also add a reference to the info instance when a global shift is defined + if "dataset_inst" in params and "global_shift_inst" in params: + shift_name = params["global_shift_inst"].name + params["dataset_info_inst"] = ( + params["dataset_inst"].get_info(shift_name) + if shift_name in params["dataset_inst"].info + else params["dataset_inst"].get_info("nominal") + ) + + return params + @classmethod def get_known_shifts( cls, diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index ad0dc867e..36b828c64 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -280,16 +280,15 @@ class MergeReductionStats(_MergeReductionStats): n_inputs = luigi.IntParameter( default=10, significant=True, - description="minimal number of input files for sufficient statistics to infer merging " - "factors; default: 10", + description="minimal number of input files to infer merging factors with sufficient statistics; default: 10", ) merged_size = law.BytesParameter( default=law.NO_FLOAT, unit="MB", significant=False, - description="the maximum file size of merged files; default unit is MB; when 0, the " - "merging factor is not actually calculated from input files, but it is assumed to be 1 " - "(= no merging); default: config value 'reduced_file_size' or 512MB'", + description="the maximum file size of merged files; default unit is MB; when 0, the merging factor is not " + "actually calculated from input files, but it is assumed to be 1 (= no merging); default: config value " + "'reduced_file_size' or 512MB", ) # upstream requirements @@ -302,6 +301,12 @@ class MergeReductionStats(_MergeReductionStats): def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().resolve_param_values(params) + # cap n_inputs + if "n_inputs" in params and (dataset_info_inst := params.get("dataset_info_inst")): + n_files = dataset_info_inst.n_files + if params["n_inputs"] < 0 or params["n_inputs"] > n_files: + params["n_inputs"] = n_files + # check for the default merged size if "merged_size" in params: if params["merged_size"] in {None, law.NO_FLOAT}: @@ -407,7 +412,7 @@ def get_avg_std(values): self.publish_message(f" stats of {n} input files ".center(40, "-")) self.publish_message(f"average size: {law.util.human_bytes(stats['avg_size'], fmt=True)}") deviation = stats["std_size"] / stats["avg_size"] - self.publish_message(f"deviation : {deviation * 100:.2f} % (std / avg)") + self.publish_message(f"deviation : {deviation * 100:.2f}% (std/avg)") self.publish_message(" merging info ".center(40, "-")) self.publish_message(f"target size : {self.merged_size} MB") self.publish_message(f"merging : {stats['merge_factor']} into 1") diff --git a/modules/law b/modules/law index bbeaa4ad3..d930140f5 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit bbeaa4ad3213c9a01417061fd4c262947849e952 +Subproject commit d930140f566d84ca92e1d63e45d69a923dc6457d From 7903bb97024ad9b711e443e244e263433bc3c1e3 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Thu, 3 Jul 2025 10:56:56 +0200 Subject: [PATCH 036/123] Make merging chunk size configurable. (#707) --- analysis_templates/cms_minimal/law.cfg | 3 +++ columnflow/columnar_util.py | 6 +++++- columnflow/tasks/calibration.py | 7 ++++++- columnflow/tasks/framework/mixins.py | 3 +++ columnflow/tasks/ml.py | 7 ++++++- columnflow/tasks/production.py | 7 ++++++- columnflow/tasks/reduction.py | 13 +++++++++---- columnflow/tasks/selection.py | 14 ++++++++++++-- law.cfg | 3 +++ 9 files changed, 53 insertions(+), 10 deletions(-) diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index 304c8eb65..aee2c061b 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -70,6 +70,9 @@ chunked_io_chunk_size: 100000 chunked_io_pool_size: 2 chunked_io_debug: False +# settings for merging parquet files in several locations +merging_row_group_size: 50000 + # csv list of task families that inherit from ChunkedReaderMixin and whose output arrays should be # checked (raising an exception) for non-finite values before saving them to disk check_finite_output: cf.CalibrateEvents, cf.SelectEvents, cf.ReduceEvents, cf.ProduceColumns diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index f9155b7be..2a4701268 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -3032,7 +3032,11 @@ def __init__( # case nested nodes separated by "*.list.element.*" (rather than "*.list.item.*") are found # (to be removed in the future) if open_options.get("split_row_groups"): - nodes = ak.ak_from_parquet.metadata(path)[0] + try: + nodes = ak.ak_from_parquet.metadata(path)[0] + except: + logger.error(f"unable to read {path}") + raise cre = re.compile(r"^.+\.list\.element(|\..+)$") if any(map(cre.match, nodes)): logger.warning( diff --git a/columnflow/tasks/calibration.py b/columnflow/tasks/calibration.py index cc097c6b2..2158487de 100644 --- a/columnflow/tasks/calibration.py +++ b/columnflow/tasks/calibration.py @@ -174,7 +174,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["columns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["columns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index 0960b0e72..c1229cfc7 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -2402,6 +2402,9 @@ class ChunkedIOMixin(ConfigTask): description="when True, checks whether columns if input arrays overlap in at least one field", ) + # number of events per row group in the merged file + merging_row_group_size = law.config.get_expanded_int("analysis", "merging_row_group_size", 50_000) + exclude_params_req = {"check_finite_output", "check_overlapping_inputs"} # define default chunk and pool sizes that can be adjusted per inheriting task diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 406746ff6..679a62575 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -806,7 +806,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["mlcolumns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["mlcolumns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) diff --git a/columnflow/tasks/production.py b/columnflow/tasks/production.py index cee2d1311..ceb5c736d 100644 --- a/columnflow/tasks/production.py +++ b/columnflow/tasks/production.py @@ -168,7 +168,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["columns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["columns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 36b828c64..88a101d38 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -239,7 +239,12 @@ def run(self): # merge output files sorted_chunks = [output_chunks[key] for key in sorted(output_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, output["events"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=output["events"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) @@ -458,8 +463,8 @@ class MergeReducedEvents(_MergeReducedEvents): ReduceEvents=ReduceEvents, ) - # approximate number of events per row group in the merged file - target_row_group_size = 50_000 + # number of events per row group in the merged file + merging_row_group_size = law.config.get_expanded_int("analysis", "merging_row_group_size", 50_000) @law.workflow_property(setter=True, cache=True, empty_value=0) def file_merging(self): @@ -512,7 +517,7 @@ def run(self): output=output, callback=self.create_progress_callback(len(inputs)), writer_opts=self.get_parquet_writer_opts(), - target_row_group_size=self.target_row_group_size, + target_row_group_size=self.merging_row_group_size, ) # optionally remove initial inputs diff --git a/columnflow/tasks/selection.py b/columnflow/tasks/selection.py index 314f1f3aa..0081fee36 100644 --- a/columnflow/tasks/selection.py +++ b/columnflow/tasks/selection.py @@ -265,14 +265,24 @@ def run(self): sorted_chunks = [result_chunks[key] for key in sorted(result_chunks)] writer_opts_masks = self.get_parquet_writer_opts(repeating_values=True) law.pyarrow.merge_parquet_task( - self, sorted_chunks, outputs["results"], local=True, writer_opts=writer_opts_masks, + task=self, + inputs=sorted_chunks, + output=outputs["results"], + local=True, + writer_opts=writer_opts_masks, + target_row_group_size=self.merging_row_group_size, ) # merge the column files if write_columns: sorted_chunks = [column_chunks[key] for key in sorted(column_chunks)] law.pyarrow.merge_parquet_task( - self, sorted_chunks, outputs["columns"], local=True, writer_opts=self.get_parquet_writer_opts(), + task=self, + inputs=sorted_chunks, + output=outputs["columns"], + local=True, + writer_opts=self.get_parquet_writer_opts(), + target_row_group_size=self.merging_row_group_size, ) # save stats diff --git a/law.cfg b/law.cfg index 183bad887..763eaf103 100644 --- a/law.cfg +++ b/law.cfg @@ -64,6 +64,9 @@ chunked_io_chunk_size: 100000 chunked_io_pool_size: 2 chunked_io_debug: False +# settings for merging parquet files in several locations +merging_row_group_size: 50000 + # csv list of task families that inherit from ChunkedReaderMixin and whose output arrays should be # checked (raising an exception) for non-finite values before saving them to disk # supported tasks are: cf.CalibrateEvents, cf.SelectEvents, cf.ReduceEvents, cf.ProduceColumns, From f6113d3734b0a3e58c3303387cd1f1ad3f475ed6 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 3 Jul 2025 17:12:52 +0200 Subject: [PATCH 037/123] Bump version in __version__ file. --- columnflow/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/__version__.py b/columnflow/__version__.py index d3a0772fb..e09464834 100644 --- a/columnflow/__version__.py +++ b/columnflow/__version__.py @@ -24,4 +24,4 @@ __contact__ = "https://github.com/columnflow/columnflow" __license__ = "BSD-3-Clause" __status__ = "Development" -__version__ = "0.2.4" +__version__ = "0.3.0" From 80d7c386e1d8f6e3f2eac3dc87eb465517de6f7a Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Fri, 4 Jul 2025 10:30:43 +0200 Subject: [PATCH 038/123] add confirm message to cf_remove_tmp --- bin/cf_remove_tmp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/bin/cf_remove_tmp b/bin/cf_remove_tmp index 0d9ba39f3..d14c7a617 100755 --- a/bin/cf_remove_tmp +++ b/bin/cf_remove_tmp @@ -37,10 +37,22 @@ cf_remove_tmp() { return "3" fi - # remove all files and directories in tmp_dir owned by the user local pattern="luigi-tmp-*" [ "${mode}" = "all" ] && pattern="*" - find "${tmp_dir}" -maxdepth 1 -name "${pattern}" -user "$( id -u )" -exec rm -r "{}" \; + prompt="Are you sure you want to delete all files in path \"${tmp_dir}\" matching \"${pattern}\"? (y/n) " + read -rp "$prompt" confirm + + case "$confirm" in + [Yy]) + # remove all files and directories in tmp_dir owned by the user + echo "Deleting files..." + find "${tmp_dir}" -maxdepth 1 -name "${pattern}" -user "$( id -u )" -print -exec rm -r "{}" \; + ;; + *) + echo "Canceled." + exit 1 + ;; + esac } cf_remove_tmp "$@" From 0100f25d6170b4fc985712823e702679b46e2fb7 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 4 Jul 2025 13:36:47 +0200 Subject: [PATCH 039/123] Use return code in tmp removal. --- bin/cf_remove_tmp | 8 ++++---- columnflow/production/cms/dy.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bin/cf_remove_tmp b/bin/cf_remove_tmp index d14c7a617..0d9f454f5 100755 --- a/bin/cf_remove_tmp +++ b/bin/cf_remove_tmp @@ -39,18 +39,18 @@ cf_remove_tmp() { local pattern="luigi-tmp-*" [ "${mode}" = "all" ] && pattern="*" + prompt="Are you sure you want to delete all files in path \"${tmp_dir}\" matching \"${pattern}\"? (y/n) " read -rp "$prompt" confirm - case "$confirm" in [Yy]) # remove all files and directories in tmp_dir owned by the user - echo "Deleting files..." + echo "deleting files..." find "${tmp_dir}" -maxdepth 1 -name "${pattern}" -user "$( id -u )" -print -exec rm -r "{}" \; ;; *) - echo "Canceled." - exit 1 + >&2 echo "canceled" + return "4" ;; esac } diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py index 590605d74..f38d57ba3 100644 --- a/columnflow/production/cms/dy.py +++ b/columnflow/production/cms/dy.py @@ -58,7 +58,7 @@ def gen_dilepton(self, events: ak.Array, **kwargs) -> ak.Array: (status == 1) & events.GenPart.hasFlags("fromHardProcess") ) - # taus need to have status == 2, + # taus need to have status == 2 tau_mask = ( (pdg_id == 15) & (status == 2) & events.GenPart.hasFlags("fromHardProcess") ) From 9510f2025922d99197a09842dc818cb79eed573f Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Fri, 11 Jul 2025 10:29:56 +0200 Subject: [PATCH 040/123] hotfix: constistent branches reqs between MergeReducedEvents and MergeSelectionStats --- columnflow/tasks/reduction.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 88a101d38..6e160c1d3 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -485,7 +485,10 @@ def create_branch_map(self): def workflow_requires(self): reqs = super().workflow_requires() reqs["stats"] = self.reqs.MergeReductionStats.req_different_branching(self) - reqs["events"] = self.reqs.ReduceEvents.req_different_branching(self, branches=((0, -1),)) + reqs["events"] = self.reqs.ReduceEvents.req_different_branching( + self, + branches=((0, self.dataset_info_inst.n_files),) + ) return reqs def requires(self): From bfa218092ac9020535f09d283bdb7de27f67f586 Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Fri, 11 Jul 2025 10:31:56 +0200 Subject: [PATCH 041/123] lint --- columnflow/tasks/reduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 6e160c1d3..f440d8244 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -487,7 +487,7 @@ def workflow_requires(self): reqs["stats"] = self.reqs.MergeReductionStats.req_different_branching(self) reqs["events"] = self.reqs.ReduceEvents.req_different_branching( self, - branches=((0, self.dataset_info_inst.n_files),) + branches=((0, self.dataset_info_inst.n_files),), ) return reqs From b45c2d82add84edb1edd39fa6809f5f8777dea0f Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:17:53 +0200 Subject: [PATCH 042/123] fix handling of non_zero_mask in murf_envelope (#704) Co-authored-by: Marcel Rieger --- columnflow/production/cms/scale.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/columnflow/production/cms/scale.py b/columnflow/production/cms/scale.py index caa683566..0a44fefa2 100644 --- a/columnflow/production/cms/scale.py +++ b/columnflow/production/cms/scale.py @@ -220,10 +220,14 @@ def murmuf_envelope_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Ar # take the max/min value of all considered variations murf_weights = (events.LHEScaleWeight[non_zero_mask] / murf_nominal)[:, envelope_indices] + weights_up = np.ones(len(events), dtype=np.float32) + weights_down = np.ones(len(events), dtype=np.float32) + weights_up[non_zero_mask] = ak.max(murf_weights, axis=1) + weights_down[non_zero_mask] = ak.min(murf_weights, axis=1) # store columns events = set_ak_column_f32(events, "murmuf_envelope_weight", ones) - events = set_ak_column_f32(events, "murmuf_envelope_weight_down", ak.min(murf_weights, axis=1)) - events = set_ak_column_f32(events, "murmuf_envelope_weight_up", ak.max(murf_weights, axis=1)) + events = set_ak_column_f32(events, "murmuf_envelope_weight_down", weights_down) + events = set_ak_column_f32(events, "murmuf_envelope_weight_up", weights_up) return events From 7e33b31912e2eb59f04b4bd06a1e3bafd2501e44 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 14 Jul 2025 14:46:00 +0200 Subject: [PATCH 043/123] Fix reduction chunk size control. --- columnflow/tasks/reduction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index f440d8244..08aadca45 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -195,6 +195,7 @@ def run(self): [inp.abspath for inp in inps], source_type=["coffea_root"] + (len(inps) - 1) * ["awkward_parquet"], read_columns=[read_columns, read_sel_columns] + (len(inps) - 2) * [read_columns], + chunk_size=self.reducer_inst.get_min_chunk_size(), ): # optional check for overlapping inputs within diffs if self.check_overlapping_inputs: From 35361c5eef5262fc745f44175eee616dc65c3172 Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Tue, 15 Jul 2025 16:32:31 +0200 Subject: [PATCH 044/123] Fix/asymmetric syst unc (#710) * consistent shift diffs * consistent cms label configs --------- Co-authored-by: Marcel Rieger --- columnflow/plotting/plot_all.py | 7 ++++--- columnflow/plotting/plot_functions_1d.py | 2 +- columnflow/plotting/plot_functions_2d.py | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index 522995977..14ab9ab55 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -80,13 +80,14 @@ def draw_syst_error_bands( # create pairs of shifts mapping from up -> down and vice versa shift_pairs = {} + shift_pairs[nominal_shift] = nominal_shift # nominal shift maps to itself for up_shift, down_shift in shift_groups.values(): shift_pairs[up_shift] = down_shift shift_pairs[down_shift] = up_shift # stack histograms separately per shift, falling back to the nominal one when missing shift_stacks: dict[od.Shift, hist.Hist] = {} - for shift_inst in sum(shift_groups.values(), []): + for shift_inst in sum(shift_groups.values(), [nominal_shift]): for _h in syst_hists: # when the shift is present, the flipped shift must exist as well shift_ax = _h.axes["shift"] @@ -119,8 +120,8 @@ def draw_syst_error_bands( down_diffs = [] for source, (up_shift, down_shift) in shift_groups.items(): # get actual differences resulting from this shift - shift_up_diff = shift_stacks[up_shift].values()[b] - h.values()[b] - shift_down_diff = shift_stacks[down_shift].values()[b] - h.values()[b] + shift_up_diff = shift_stacks[up_shift].values()[b] - shift_stacks[nominal_shift].values()[b] + shift_down_diff = shift_stacks[down_shift].values()[b] - shift_stacks[nominal_shift].values()[b] # store them depending on whether they really increase or decrease the yield up_diffs.append(max(shift_up_diff, shift_down_diff, 0)) down_diffs.append(min(shift_up_diff, shift_down_diff, 0)) diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 86d8ee1c5..13dc40014 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -350,7 +350,7 @@ def plot_cutflow( }, "annotate_cfg": {"text": cat_label or ""}, "cms_label_cfg": { - "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 2), # /pb -> /fb + "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 1), # /pb -> /fb "com": config_inst.campaign.ecm, }, } diff --git a/columnflow/plotting/plot_functions_2d.py b/columnflow/plotting/plot_functions_2d.py index c611f13f4..d8f58ae01 100644 --- a/columnflow/plotting/plot_functions_2d.py +++ b/columnflow/plotting/plot_functions_2d.py @@ -168,7 +168,8 @@ def plot_2d( "loc": "upper right", }, "cms_label_cfg": { - "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 2), # /pb -> /fb + "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 1), # /pb -> /fb + "com": config_inst.campaign.ecm, }, "plot2d_cfg": { "norm": cbar_norm, From 02af6b78919fdeca15b605616af11506a340034e Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 16 Jul 2025 09:32:15 +0200 Subject: [PATCH 045/123] Update law. --- modules/law | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/law b/modules/law index d930140f5..6a239437a 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit d930140f566d84ca92e1d63e45d69a923dc6457d +Subproject commit 6a239437a81921ce7a24bb838408e904a3dbf87f From ce87e4ff16019faf5375933ea4d858c02db24f5a Mon Sep 17 00:00:00 2001 From: Ana Andrade <99343616+aalvesan@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:53:13 +0200 Subject: [PATCH 046/123] check and remove overlaping processes (#712) --- columnflow/tasks/framework/mixins.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index c1229cfc7..abde3afb0 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -8,7 +8,7 @@ import time import itertools -from collections import Counter +from collections import Counter, defaultdict import luigi import law @@ -2180,7 +2180,24 @@ def resolve(config_inst: od.Config, processes: Any, datasets: Any) -> tuple[list deep=True, ) else: - processes = config_inst.processes.names() + processes = list(config_inst.processes.names()) + # protect against overlap between top-level processes + to_remove = defaultdict(set) + for process_name in processes: + process = config_inst.get_process(process_name) + # check any remaining process for overlap + for child_process_name in processes: + if child_process_name == process_name: + continue + if process.has_process(child_process_name, deep=True): + to_remove[child_process_name].add(process_name) + if to_remove: + processes = [process_name for process_name in processes if process_name not in to_remove] + for removed, reasons in to_remove.items(): + reasons = ", ".join(map("'{}'".format, reasons)) + logger.warning( + f"removed '{removed}' from selected processes due to overlap with {reasons}", + ) if not processes and not cls.allow_empty_processes: raise ValueError(f"no processes found matching {processes_orig}") if datasets != law.no_value: From 9d4ce32cd67f9f309dfbdc7d62acfd02708276dc Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 18 Jul 2025 13:29:31 +0200 Subject: [PATCH 047/123] sort configs by ids for the multi config representation --- columnflow/tasks/framework/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index d0c9dbfa4..04c84f2f5 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -1504,7 +1504,7 @@ def __init__(self, *args, **kwargs) -> None: @property def config_repr(self) -> str: - return "__".join(config_inst.name for config_inst in self.config_insts) + return "__".join(config_inst.name for config_inst in sorted(self.config_insts, key=lambda c: c.id)) def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() From 09603014ad2c5e52c416acb327ce8314756dfa74 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Mon, 21 Jul 2025 16:14:45 +0200 Subject: [PATCH 048/123] Optionally bypass branch-level plot requirements. (#716) --- columnflow/tasks/plotting.py | 98 ++++++++++++++++++++++-------------- modules/law | 2 +- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index 0c4cb18f9..73f4e36ca 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -50,10 +50,20 @@ class _PlotVariablesBase( class PlotVariablesBase(_PlotVariablesBase): + + bypass_branch_requirements = luigi.BoolParameter( + default=False, + description="whether to skip branch requirements and only use that of the workflow; default: False", + ) + single_config = False sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + exclude_params_repr = {"bypass_branch_requirements"} + exclude_params_index = {"bypass_branch_requirements"} + exclude_params_repr = {"bypass_branch_requirements"} + exclude_index = True def store_parts(self) -> law.util.InsertableDict: @@ -73,6 +83,13 @@ def workflow_requires(self): reqs["merged_hists"] = self.requires_from_branch() return reqs + def local_workflow_pre_run(self): + # when branches are cached, reinitiate the branch tasks with dropped branch level requirements since this + # method is called from a context where the identical workflow level requirements are already resolved + if self.cache_branch_map: + self._branch_tasks = None + self.get_branch_tasks(bypass_branch_requirements=True) + @abstractmethod def get_plot_shifts(self): return @@ -92,7 +109,7 @@ def get_config_process_map(self) -> tuple[dict[od.Config, dict[od.Process, dict[ dictionaries containing the dataset-process mapping and the shifts to be considered, and a dictionary mapping process names to the shifts to be considered. """ - reqs = self.requires() + reqs = self.requires() or self.as_workflow().requires().merged_hists config_process_map = {config_inst: {} for config_inst in self.config_insts} process_shift_map = defaultdict(set) @@ -164,7 +181,8 @@ def run(self): # histogram data per process copy hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} with self.publish_step(f"plotting {self.branch_data.variable} in {self.branch_data.category}"): - for i, (config, dataset_dict) in enumerate(self.input().items()): + inputs = self.input() or self.workflow_input().merged_hists + for i, (config, dataset_dict) in enumerate(inputs.items()): config_inst = self.config_insts[i] category_inst = config_inst.get_category(self.branch_data.category) leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] @@ -308,6 +326,7 @@ class PlotVariablesBaseSingleShift( ): # use the MergeHistograms task to trigger upstream TaskArrayFunction initialization resolution_task_cls = MergeHistograms + exclude_index = True reqs = Requirements( @@ -322,28 +341,27 @@ def create_branch_map(self): for cat_name in sorted(self.categories) ] - def workflow_requires(self): - reqs = super().workflow_requires() - return reqs - def requires(self): - req = {} + reqs = {} - for i, config_inst in enumerate(self.config_insts): - sub_datasets = self.datasets[i] - req[config_inst.name] = {} - for d in sub_datasets: - if d in config_inst.datasets.names(): - req[config_inst.name][d] = self.reqs.MergeHistograms.req( - self, - config=config_inst.name, - shift=self.global_shift_insts[config_inst].name, - dataset=d, - branch=-1, - _exclude={"branches"}, - _prefer_cli={"variables"}, - ) - return req + if self.is_branch() and self.bypass_branch_requirements: + return reqs + + for config_inst, datasets in zip(self.config_insts, self.datasets): + reqs[config_inst.name] = {} + for d in datasets: + if d not in config_inst.datasets: + continue + reqs[config_inst.name][d] = self.reqs.MergeHistograms.req_different_branching( + self, + config=config_inst.name, + shift=self.global_shift_insts[config_inst].name, + dataset=d, + branch=-1, + _prefer_cli={"variables"}, + ) + + return reqs def plot_parts(self) -> law.util.InsertableDict: parts = super().plot_parts() @@ -482,26 +500,32 @@ def create_branch_map(self) -> list[DotDict]: return [DotDict(zip(keys, vals)) for vals in itertools.product(*seqs)] def requires(self): + reqs = {} + + if self.is_branch() and self.bypass_branch_requirements: + return reqs + req_cls = lambda dataset_name: ( self.reqs.MergeShiftedHistograms if self.config_inst.get_dataset(dataset_name).is_mc else self.reqs.MergeHistograms ) - req = {} - for i, config_inst in enumerate(self.config_insts): - req[config_inst.name] = {} - for dataset_name in self.datasets[i]: - if dataset_name in config_inst.datasets: - req[config_inst.name][dataset_name] = req_cls(dataset_name).req( - self, - config=config_inst.name, - dataset=dataset_name, - branch=-1, - _exclude={"branches"}, - _prefer_cli={"variables"}, - ) - return req + for config_inst, datasets in zip(self.config_insts, self.datasets): + reqs[config_inst.name] = {} + for d in datasets: + if d not in config_inst.datasets: + continue + reqs[config_inst.name][d] = req_cls(d).req( + self, + config=config_inst.name, + dataset=d, + branch=-1, + _exclude={"branches"}, + _prefer_cli={"variables"}, + ) + + return reqs def plot_parts(self) -> law.util.InsertableDict: parts = super().plot_parts() @@ -573,8 +597,8 @@ class PlotShiftedVariablesPerShift1D( class PlotShiftedVariablesPerConfig1D( - law.WrapperTask, PlotShiftedVariables1D, + law.WrapperTask, ): # force this one to be a local workflow workflow = "local" diff --git a/modules/law b/modules/law index 6a239437a..9f6ccffdd 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 6a239437a81921ce7a24bb838408e904a3dbf87f +Subproject commit 9f6ccffddc219265aeb70fbc31735f26a600c154 From 5f5411e754e49fbffa173748e2a46b17e8066108 Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Mon, 21 Jul 2025 16:17:54 +0200 Subject: [PATCH 049/123] inference model caching (#714) * inference model caching * Generalize caching of derivables. --------- Co-authored-by: Marcel R. Co-authored-by: Marcel Rieger --- columnflow/columnar_util.py | 26 +++++--------------------- columnflow/inference/__init__.py | 22 +++++++++++++++------- columnflow/util.py | 27 ++++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 2a4701268..625209b00 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -24,9 +24,9 @@ import law import order as od -from columnflow.types import Sequence, Callable, Any, T, Generator +from columnflow.types import Sequence, Callable, Any, T, Generator, Hashable from columnflow.util import ( - UNSET, maybe_import, classproperty, DotDict, DerivableMeta, Derivable, pattern_matcher, + UNSET, maybe_import, classproperty, DotDict, DerivableMeta, CachedDerivableMeta, Derivable, pattern_matcher, get_source_code, real_path, freeze, get_docs_url, ) @@ -2260,26 +2260,10 @@ def skip_column( return tagged_column("skip", *routes) -class TaskArrayFunctionMeta(DerivableMeta): +class TaskArrayFunctionMeta(CachedDerivableMeta): - def __new__(metacls, cls_name: str, bases: tuple, cls_dict: dict) -> TaskArrayFunctionMeta: - # add an instance cache if not disabled - cls_dict.setdefault("cache_instances", True) - cls_dict["_instances"] = {} if cls_dict["cache_instances"] else None - - return super().__new__(metacls, cls_name, bases, cls_dict) - - def __call__(cls, *args, **kwargs) -> TaskArrayFunction: - # when not caching instances, return right away - if not cls.cache_instances: - return super().__call__(*args, **kwargs) - - # build the cache key from the inst_dict in kwargs - key = freeze((cls, kwargs.get("inst_dict", {}))) - if key not in cls._instances: - cls._instances[key] = super().__call__(*args, **kwargs) - - return cls._instances[key] + def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: + return freeze((cls, kwargs.get("inst_dict", {}))) class TaskArrayFunction(ArrayFunction, metaclass=TaskArrayFunctionMeta): diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index d5c3ab01e..fac1d6831 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -13,9 +13,11 @@ import order as od import yaml -from columnflow.types import Generator, Callable, TextIO, Sequence, Any -from columnflow.util import DerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url - +from columnflow.types import Generator, Callable, TextIO, Sequence, Any, Hashable, Type, T +from columnflow.util import ( + CachedDerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url, + freeze, +) logger = law.logger.get_logger(__name__) @@ -189,7 +191,13 @@ def __str__(self) -> str: return self.value -class InferenceModel(Derivable): +class InferenceModelMeta(CachedDerivableMeta): + + def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: + return freeze((cls, kwargs.get("inst_dict", {}))) + + +class InferenceModel(Derivable, metaclass=InferenceModelMeta): """ Interface to statistical inference models with connections to config objects (such as py:class:`order.Config` or :py:class:`order.Dataset`). @@ -322,11 +330,11 @@ def ignore_aliases(self, *args, **kwargs) -> bool: @classmethod def inference_model( - cls, + cls: T, func: Callable | None = None, bases: tuple[type] = (), **kwargs, - ) -> DerivableMeta | Callable: + ) -> Type[T] | Callable: """ Decorator for creating a new :py:class:`InferenceModel` subclass with additional, optional *bases* and attaching the decorated function to it as ``init_func``. All additional *kwargs* @@ -336,7 +344,7 @@ def inference_model( :param bases: Optional tuple of base classes for the new subclass. :returns: The new subclass or a decorator function. """ - def decorator(func: Callable) -> DerivableMeta: + def decorator(func: Callable) -> Type[T]: # create the class dict cls_dict = { **kwargs, diff --git a/columnflow/util.py b/columnflow/util.py index ef796eb1a..0d5348ebf 100644 --- a/columnflow/util.py +++ b/columnflow/util.py @@ -29,7 +29,7 @@ import luigi from columnflow import env_is_dev, env_is_remote, docs_url, github_url -from columnflow.types import Callable, Any, Sequence, Union, ModuleType +from columnflow.types import Callable, Any, Sequence, Union, ModuleType, Type, T, Hashable #: Placeholder for an unset value. @@ -932,6 +932,31 @@ def derived_by(cls, other: DerivableMeta) -> bool: return isinstance(other, DerivableMeta) and issubclass(other, cls) +class CachedDerivableMeta(DerivableMeta): + + def __new__(metacls, cls_name: str, bases: tuple, cls_dict: dict) -> CachedDerivableMeta: + # add an instance cache if not disabled + cls_dict.setdefault("cache_instances", True) + cls_dict["_instances"] = {} if cls_dict["cache_instances"] else None + + return super().__new__(metacls, cls_name, bases, cls_dict) + + def __call__(cls: Type[T], *args, **kwargs) -> T: + # when not caching instances, return right away + if not cls.cache_instances: + return super().__call__(*args, **kwargs) + + # build the cache key from the inst_dict in kwargs + key = cls._get_inst_cache_key(args, kwargs) + if key not in cls._instances: + cls._instances[key] = super().__call__(*args, **kwargs) + + return cls._instances[key] + + def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: + raise NotImplementedError("__get_inst_cache_key method must be implemented by the derived meta class") + + class Derivable(object, metaclass=DerivableMeta): """ Derivable base class with features provided by the meta :py:class:`DerivableMeta`. From 4396f8182af13cacc999769f8006a7d7d6e3ad02 Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Mon, 21 Jul 2025 16:26:23 +0200 Subject: [PATCH 050/123] pad with nominal if shift source missing in config (#715) * pad with nominal if shift source missing in config * Update columnflow/tasks/cms/inference.py Co-authored-by: Marcel Rieger * Update columnflow/tasks/cms/inference.py --------- Co-authored-by: Marcel Rieger --- columnflow/tasks/cms/inference.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index f0bfae242..493125cf7 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -117,10 +117,14 @@ def run(self): if not self.inference_model_inst.require_shapes_for_parameter(param_obj): continue # store the varied hists - shift_source = param_obj.config_data[config_inst.name].shift_source + shift_source = ( + param_obj.config_data[config_inst.name].shift_source + if config_inst.name in param_obj.config_data + else None + ) for d in ["up", "down"]: shift_hists[(param_obj.name, d)] = h_proc[{ - "shift": hist.loc(config_inst.get_shift(f"{shift_source}_{d}").name), + "shift": hist.loc(f"{shift_source}_{d}" if shift_source else "nominal"), }] # forward objects to the datacard writer From 32df6b739c225f19e5642b61a96d3bbc2dbf4c56 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 24 Jul 2025 11:35:06 +0200 Subject: [PATCH 051/123] Add option for hook to update dataset_selection_stats in norm weight prod. --- columnflow/production/normalization.py | 40 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 273b156fc..dbe5a7fd9 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -67,7 +67,7 @@ def get_stitching_datasets(self: Producer) -> list[od.Dataset]: def get_br_from_inclusive_dataset( self: Producer, inclusive_dataset: od.Dataset, - stats: dict, + dataset_selection_stats: dict[str, dict[str, float]], ) -> dict[int, float]: """ Helper function to compute the branching ratios from the inclusive sample. This is done with ratios of event weights @@ -77,7 +77,7 @@ def get_br_from_inclusive_dataset( proc_ds_map = { d.processes.get_first().id: d for d in self.config_inst.datasets - if d.name in stats.keys() + if d.name in dataset_selection_stats.keys() } inclusive_proc = inclusive_dataset.processes.get_first() N = lambda x: sn.Number(x, np.sqrt(x)) # alias for Number with counting error @@ -92,11 +92,11 @@ def get_br_from_inclusive_dataset( dataset_name = proc_ds_map[proc.id].name # get the mc weights for the "mother" dataset and add an entry for the process - sum_mc_weight: float = stats[dataset_name]["sum_mc_weight"] - sum_mc_weight_per_process: dict[str, float] = stats[dataset_name]["sum_mc_weight_per_process"] + sum_mc_weight: float = dataset_selection_stats[dataset_name]["sum_mc_weight"] + sum_mc_weight_per_process: dict[str, float] = dataset_selection_stats[dataset_name]["sum_mc_weight_per_process"] # use the number of events to compute the error on the branching ratio - num_events: int = stats[dataset_name]["num_events"] - num_events_per_process: dict[str, int] = stats[dataset_name]["num_events_per_process"] + num_events: int = dataset_selection_stats[dataset_name]["num_events"] + num_events_per_process: dict[str, int] = dataset_selection_stats[dataset_name]["num_events_per_process"] # loop over all child processes for child_proc in child_procs: @@ -166,6 +166,16 @@ def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: return branching_ratios +def update_dataset_selection_stats( + self: Producer, + dataset_selection_stats: dict[str, dict[str, float]], +) -> dict[str, dict[str, float]]: + """ + Hook to optionally update the per-dataset selection stats. + """ + return dataset_selection_stats + + @producer( uses={"process_id", "mc_weight"}, # name of the output column @@ -178,6 +188,7 @@ def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: get_stitching_datasets=get_stitching_datasets, get_inclusive_dataset=get_inclusive_dataset, get_br_from_inclusive_dataset=get_br_from_inclusive_dataset, + update_dataset_selection_stats=update_dataset_selection_stats, # only run on mc mc_only=True, ) @@ -278,21 +289,26 @@ def normalization_weights_setup( weights per process. """ # load the selection stats - selection_stats = { + dataset_selection_stats = { dataset: task.cached_value( key=f"selection_stats_{dataset}", func=lambda: inp["stats"].load(formatter="json"), ) for dataset, inp in inputs["selection_stats"].items() } + + # optional hook to amend the per-dataset selection stats + if callable(self.update_dataset_selection_stats): + dataset_selection_stats = self.update_dataset_selection_stats(dataset_selection_stats) + # if necessary, merge the selection stats across datasets - if len(selection_stats) > 1: + if len(dataset_selection_stats) > 1: from columnflow.tasks.selection import MergeSelectionStats merged_selection_stats = defaultdict(float) - for stats in selection_stats.values(): + for stats in dataset_selection_stats.values(): MergeSelectionStats.merge_counts(merged_selection_stats, stats) else: - merged_selection_stats = selection_stats[self.dataset_inst.name] + merged_selection_stats = dataset_selection_stats[self.dataset_inst.name] # determine all proceses at any depth in the stitching datasets process_insts = { @@ -333,7 +349,7 @@ def normalization_weights_setup( else: branching_ratios = self.get_br_from_inclusive_dataset( inclusive_dataset=inclusive_dataset, - stats=selection_stats, + dataset_selection_stats=dataset_selection_stats, ) if not branching_ratios: raise Exception( @@ -343,7 +359,7 @@ def normalization_weights_setup( # compute the weight the inclusive dataset would have on its own without stitching inclusive_xsec = inclusive_proc.get_xsec(self.config_inst.campaign.ecm).nominal self.inclusive_weight = ( - lumi * inclusive_xsec / selection_stats[inclusive_dataset.name]["sum_mc_weight"] + lumi * inclusive_xsec / dataset_selection_stats[inclusive_dataset.name]["sum_mc_weight"] if self.dataset_inst == inclusive_dataset else 0 ) From 0bdff9f0d8b2325398e7594c5b0ea41f9541a2af Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Tue, 29 Jul 2025 14:05:34 +0200 Subject: [PATCH 052/123] fix missing datasets in MultiConfig --- columnflow/tasks/plotting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index 73f4e36ca..12638ebe2 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -505,9 +505,9 @@ def requires(self): if self.is_branch() and self.bypass_branch_requirements: return reqs - req_cls = lambda dataset_name: ( + req_cls = lambda dataset_name, config_inst: ( self.reqs.MergeShiftedHistograms - if self.config_inst.get_dataset(dataset_name).is_mc + if config_inst.get_dataset(dataset_name).is_mc else self.reqs.MergeHistograms ) @@ -516,7 +516,7 @@ def requires(self): for d in datasets: if d not in config_inst.datasets: continue - reqs[config_inst.name][d] = req_cls(d).req( + reqs[config_inst.name][d] = req_cls(d, config_inst).req( self, config=config_inst.name, dataset=d, From f3af6f25300ceef4fa3d8aa31e501508490c7c63 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 30 Jul 2025 09:27:46 +0200 Subject: [PATCH 053/123] Update law. --- analysis_templates/cms_minimal/law.cfg | 6 ++++++ modules/law | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index aee2c061b..d5e69e9f1 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -162,6 +162,12 @@ remote_lcg_setup_el9: /cvmfs/grid.cern.ch/alma9-ui-test/etc/profile.d/setup-alma remote_lcg_setup_force: False +[target] + +# when removing target collections, use multi-threading +collection_remove_threads: 2 + + [local_fs] base: / diff --git a/modules/law b/modules/law index 9f6ccffdd..679b88941 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 9f6ccffddc219265aeb70fbc31735f26a600c154 +Subproject commit 679b88941ff32ff4a9724fe44848865daebce027 From b9cfaac0785b9adea29df6b41300ae489bc4d79d Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 30 Jul 2025 13:58:44 +0200 Subject: [PATCH 054/123] Add local directory check to cf_remove_tmp. --- bin/cf_remove_tmp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/bin/cf_remove_tmp b/bin/cf_remove_tmp index 0d9f454f5..8d6f4f0a5 100755 --- a/bin/cf_remove_tmp +++ b/bin/cf_remove_tmp @@ -24,6 +24,8 @@ cf_remove_tmp() { fi # get the directory + local prompt + local confirm local tmp_dir="$( law config target.tmp_dir )" local ret="$?" if [ "${ret}" != "0" ]; then @@ -35,14 +37,30 @@ cf_remove_tmp() { elif [ ! -d "${tmp_dir}" ]; then >&2 echo "'law config target.tmp_dir' is not a directory" return "3" + elif [ -z "${LAW_TARGET_TMP_DIR}" ] && [ "$( cd "${tmp_dir}" && pwd )" = "${PWD}" ]; then + prompt="'law config target.tmp_dir' reports that the tmp directory is set to the current working directory '${PWD}'. Continue? (y/n) " + read -rp "${prompt}" confirm + case "${confirm}" in + [Yy]) + ;; + *) + >&2 echo "canceled" + return "4" + ;; + esac fi + + local shell_is_zsh=$( [ -z "${ZSH_VERSION}" ] && echo "false" || echo "true" ) + local this_file="$( ${shell_is_zsh} && echo "${(%):-%x}" || echo "${BASH_SOURCE[0]}" )" + local this_dir="$( cd "$( dirname "${this_file}" )" && pwd )" + local pattern="luigi-tmp-*" [ "${mode}" = "all" ] && pattern="*" prompt="Are you sure you want to delete all files in path \"${tmp_dir}\" matching \"${pattern}\"? (y/n) " - read -rp "$prompt" confirm - case "$confirm" in + read -rp "${prompt}" confirm + case "${confirm}" in [Yy]) # remove all files and directories in tmp_dir owned by the user echo "deleting files..." From 6600f0ee6ccaa51e23fa9e047cf6bb398a303867 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 30 Jul 2025 14:00:08 +0200 Subject: [PATCH 055/123] Hotfix typo. --- bin/cf_remove_tmp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bin/cf_remove_tmp b/bin/cf_remove_tmp index 8d6f4f0a5..d5deb67d3 100755 --- a/bin/cf_remove_tmp +++ b/bin/cf_remove_tmp @@ -50,14 +50,11 @@ cf_remove_tmp() { esac fi - - local shell_is_zsh=$( [ -z "${ZSH_VERSION}" ] && echo "false" || echo "true" ) - local this_file="$( ${shell_is_zsh} && echo "${(%):-%x}" || echo "${BASH_SOURCE[0]}" )" - local this_dir="$( cd "$( dirname "${this_file}" )" && pwd )" - + # define the search pattern local pattern="luigi-tmp-*" [ "${mode}" = "all" ] && pattern="*" + # ask for confirmation prompt="Are you sure you want to delete all files in path \"${tmp_dir}\" matching \"${pattern}\"? (y/n) " read -rp "${prompt}" confirm case "${confirm}" in From 611dd8541daf25039b364836daee9f32a37af96c Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Thu, 31 Jul 2025 08:29:12 +0200 Subject: [PATCH 056/123] update met_phi Calibrator to new format (#719) * update met_phi Calibrator to new format * use npvsGood * add npvsGood to uses as well... * Minor adjustments, apply mask to all inputs. --------- Co-authored-by: Mathis Frahm Co-authored-by: Marcel R. --- columnflow/calibration/cms/met.py | 201 ++++++++++++++++++++---------- 1 file changed, 138 insertions(+), 63 deletions(-) diff --git a/columnflow/calibration/cms/met.py b/columnflow/calibration/cms/met.py index 229b4c9cb..942700a64 100644 --- a/columnflow/calibration/cms/met.py +++ b/columnflow/calibration/cms/met.py @@ -4,8 +4,12 @@ MET corrections. """ +from __future__ import annotations + import law +from dataclasses import dataclass + from columnflow.calibration import Calibrator, calibrator from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column @@ -15,69 +19,142 @@ ak = maybe_import("awkward") +@dataclass +class METPhiConfig: + variable_config: dict[str, tuple[str]] + correction_set: str = "met_xy_corrections" + met_name: str = "PuppiMET" + met_type: str = "MET" + keep_uncorrected: bool = False + + @classmethod + def new( + cls, + obj: METPhiConfig | tuple[str, list[str]] | tuple[str, list[str], str], + ) -> METPhiConfig: + # purely for backwards compatibility with the old string format + if isinstance(obj, cls): + return obj + if isinstance(obj, str): + return cls(correction_set=obj, variable_config={"pt": ("pt",), "phi": ("phi",)}) + if isinstance(obj, dict): + return cls(**obj) + raise ValueError(f"cannot convert {obj} to METPhiConfig") + + @calibrator( - uses={"run", "PV.npvs"}, - # name of the MET collection to calibrate - met_name="MET", + uses={"run", "PV.npvs", "PV.npvsGood"}, # function to determine the correction file get_met_file=(lambda self, external_files: external_files.met_phi_corr), # function to determine met correction config - get_met_config=(lambda self: self.config_inst.x.met_phi_correction_set), + get_met_config=(lambda self: METPhiConfig.new(self.config_inst.x.met_phi_correction)), ) def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ - Performs the MET phi (type II) correction using the - :external+correctionlib:doc:`index` for events there the - uncorrected MET pt is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``). - Requires an external file in the config under ``met_phi_corr``: + Performs the MET phi (type II) correction using the :external+correctionlib:doc:`index` for events there the + uncorrected MET pt is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``). Requires an + external file in the config under ``met_phi_corr``: .. code-block:: python cfg.x.external_files = DotDict.wrap({ - "met_phi_corr": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/met.json.gz", # noqa + "met_phi_corr": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-406118ec/POG/JME/2022_Summer22EE/met_xyCorrections_2022_2022EE.json.gz", # noqa }) - *get_met_file* can be adapted in a subclass in case it is stored differently in the external - files. + *get_met_file* can be adapted in a subclass in case it is stored differently in the external files. - The name of the correction set should be present as an auxiliary entry in the config: + The met_phi Calibrator should be configured with an auxiliary entry in the config that can contain: - .. code-block:: python + - the name of the correction set + - the name of the MET column + - the MET type that is passed as an input to the correction set + - a boolean flag to keep the uncorrected MET pt and phi values as additional output columns + - a dictionary that maps the input variable names ("pt", "phi") to a list of output variable names that should + be produced. - cfg.x.met_phi_correction_set = "{variable}_metphicorr_pfmet_{data_source}" + Exemplary config entry: - where "variable" and "data_source" are placeholders that are inserted in the + .. code-block:: python + + from columnflow.calibration.cms.met import METPhiConfig + cfg.x.met_phi_correction = METPhiConfig( + met_name="PuppiMET", + met_type="MET", + correction_set="met_xy_corrections", + keep_uncorrected=False, + variable_config={ + "pt": ( + "pt", + "pt_stat_yup", + "pt_stat_ydn", + "pt_stat_xup", + "pt_stat_xdn", + ), + "phi": ( + "phi", + "phi_stat_yup", + "phi_stat_ydn", + "phi_stat_xup", + "phi_stat_xdn", + ), + }, + ) + + The `correction_set` value can also contain the placeholders "variable" and "data_source" that are replaced in the calibrator setup :py:meth:`~.met_phi.setup_func`. - *get_met_correction_set* can be adapted in a subclass in case it is stored - differently in the config. + + *get_met_config* can be adapted in a subclass in case it is stored differently in the config. + + Resources: + - https://twiki.cern.ch/twiki/bin/view/CMS/MissingETRun2Corrections?rev=79#xy_Shift_Correction_MET_phi_modu :param events: awkward array containing events to process """ # get Met columns - met = events[self.met_name] - - # copy the intial pt and phi values - corr_pt = np.array(met.pt, dtype=np.float32) - corr_phi = np.array(met.phi, dtype=np.float32) - - # select only events where MET pt is below the expected beam energy - mask = met.pt < (0.5 * self.config_inst.campaign.ecm) - - # arguments for evaluation - args = ( - met.pt[mask], - met.phi[mask], - ak.values_astype(events.PV.npvs[mask], np.float32), - ak.values_astype(events.run[mask], np.float32), - ) - - # evaluate and insert - corr_pt[mask] = self.met_pt_corrector.evaluate(*args) - corr_phi[mask] = self.met_phi_corrector.evaluate(*args) - - # save the corrected values - events = set_ak_column(events, f"{self.met_name}.pt", corr_pt, value_type=np.float32) - events = set_ak_column(events, f"{self.met_name}.phi", corr_phi, value_type=np.float32) + met = events[self.met_config.met_name] + + # correct only events where MET pt is below the expected beam energy + mask = met.pt < (0.5 * self.config_inst.campaign.ecm * 1000) # convert TeV to GeV + + variable_map = { + "met_type": self.met_config.met_type, + "epoch": f"{self.config_inst.campaign.x.year}{self.config_inst.campaign.x.postfix}", + "dtmc": "DATA" if self.dataset_inst.is_data else "MC", + "variation": "nom", + "met_pt": ak.values_astype(met.pt[mask], np.float32), + "met_phi": ak.values_astype(met.phi[mask], np.float32), + "npvGood": ak.values_astype(events.PV.npvsGood[mask], np.float32), + "npvs": ak.values_astype(events.PV.npvs[mask], np.float32), # needed for old-style corrections + "run": ak.values_astype(events.run[mask], np.float32), + } + + for variable, outp_variables in self.met_config.variable_config.items(): + met_corrector = self.met_correctors[variable] + if self.met_config.keep_uncorrected: + events = set_ak_column( + events, + f"{self.met_config.met_name}.{variable}_xy_uncorrected", + met[variable], + value_type=np.float32, + ) + for out_var in outp_variables: + # copy initial value every time + # NOTE: this needs to be within the loop to ensure that the output values are not + # overwritten by the next iteration + corr_var = np.array(met[variable], dtype=np.float32) + + # get the input variables for the correction + variable_map_syst = { + **variable_map, + "pt_phi": out_var, + } + inputs = [variable_map_syst[inp.name] for inp in met_corrector.inputs] + + # insert the corrected values + corr_var[mask] = met_corrector(*inputs) + + # save the corrected values + events = set_ak_column(events, f"{self.met_config.met_name}.{out_var}", corr_var, value_type=np.float32) return events @@ -87,8 +164,16 @@ def met_phi_init(self: Calibrator, **kwargs) -> None: """ Initialize the :py:attr:`met_pt_corrector` and :py:attr:`met_phi_corrector` attributes. """ - self.uses.add(f"{self.met_name}.{{pt,phi}}") - self.produces.add(f"{self.met_name}.{{pt,phi}}") + self.met_config = self.get_met_config() + + self.uses.add(f"{self.met_config.met_name}.{{pt,phi}}") + + for variable in self.met_config.variable_config.keys(): + if self.met_config.keep_uncorrected: + self.produces.add(f"{self.met_config.met_name}.{variable}_xy_uncorrected") + for out_var in self.met_config.variable_config[variable]: + # add the produced columns to the uses set + self.produces.add(f"{self.met_config.met_name}.{out_var}") @met_phi.requires @@ -116,30 +201,20 @@ def met_phi_setup( ) -> None: """ Load the correct met files using the :py:func:`from_string` method of the - :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` - function and apply the corrections as needed. + :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` function and apply the corrections as + needed. - :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` - instance + :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` instance :param inputs: Additional inputs, currently not used. :param reader_targets: Additional targets, currently not used. """ # create the pt and phi correctors met_file = self.get_met_file(reqs["external_files"].files) correction_set = load_correction_set(met_file) - - name_tmpl = self.get_met_config() - self.met_pt_corrector = correction_set[name_tmpl.format( - variable="pt", - data_source=self.dataset_inst.data_source, - )] - self.met_phi_corrector = correction_set[name_tmpl.format( - variable="phi", - data_source=self.dataset_inst.data_source, - )] - - # check versions - if self.met_pt_corrector.version not in (1,): - raise Exception(f"unsuppprted met pt corrector version {self.met_pt_corrector.version}") - if self.met_phi_corrector.version not in (1,): - raise Exception(f"unsuppprted met phi corrector version {self.met_phi_corrector.version}") + name_tmpl = self.met_config.correction_set + self.met_correctors = { + variable: correction_set[name_tmpl.format( + variable=variable, + data_source=self.dataset_inst.data_source, + )] for variable in self.met_config.variable_config.keys() + } From d3689a3f58108a5b534a6894037e95d7df00ca47 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 31 Jul 2025 08:30:24 +0200 Subject: [PATCH 057/123] Update law. --- modules/law | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/law b/modules/law index 679b88941..c8c40094d 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 679b88941ff32ff4a9724fe44848865daebce027 +Subproject commit c8c40094d53a42849e408e77436319be7f9764c5 From 03efd8020ea534b4f5ecd73db7ad597c4edbdb6a Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Wed, 6 Aug 2025 15:00:11 +0200 Subject: [PATCH 058/123] Generalize normalization weight producer. (#718) * Generalize normalization weight producer. * Add pull warning. * Add per-dataset weight norm. * Update. * Optionally log brs. * Improve combinatoric treatment, fix single br calculation. * Helper to fill weight table. * Minor adjustments before review. --- columnflow/production/normalization.py | 600 +++++++++++++++---------- sandboxes/cf.txt | 2 +- 2 files changed, 363 insertions(+), 239 deletions(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index dbe5a7fd9..c70de84bf 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -6,7 +6,10 @@ from __future__ import annotations -from collections import defaultdict +import copy +import itertools +import dataclasses +import collections import law import order as od @@ -15,7 +18,7 @@ from columnflow.production import Producer, producer from columnflow.util import maybe_import, DotDict from columnflow.columnar_util import set_ak_column -from columnflow.types import Any +from columnflow.types import Any, Sequence np = maybe_import("numpy") sp = maybe_import("scipy") @@ -26,150 +29,238 @@ logger = law.logger.get_logger(__name__) -def get_inclusive_dataset(self: Producer) -> od.Dataset: +def get_stitching_datasets(self: Producer) -> tuple[od.Dataset, list[od.Dataset]]: """ - Helper function to obtain the inclusive dataset from a list of datasets that are required to stitch this - *dataset_inst*. + Helper function to obtain information about stitching datasets: + + - the inclusive dataset, which is the dataset that contains all processes + - all datasets that are required to stitch this *dataset_inst* """ - process_map = {d.processes.get_first(): d for d in self.stitching_datasets} + # first collect all datasets that are needed to stitch the current dataset + required_datasets = { + d + for d in self.config_inst.datasets + if ( + d.has_process(self.dataset_inst.processes.get_first(), deep=True) or + self.dataset_inst.has_process(d.processes.get_first(), deep=True) + ) + } + # determine the inclusive dataset + process_map = {d.processes.get_first(): d for d in required_datasets} process_inst = self.dataset_inst.processes.get_first() - incl_dataset = None + inclusive_dataset = None while process_inst: if process_inst in process_map: - incl_dataset = process_map[process_inst] + inclusive_dataset = process_map[process_inst] process_inst = process_inst.parent_processes.get_first(default=None) - - if not incl_dataset: + if not inclusive_dataset: raise Exception("inclusive dataset not found") - - unmatched_processes = {p for p in process_map if not incl_dataset.has_process(p, deep=True)} + # cross-check if there are processes in the required datasets that are not covered by the inclusive dataset + unmatched_processes = {p for p in process_map if not inclusive_dataset.has_process(p, deep=True)} if unmatched_processes: raise Exception(f"processes {unmatched_processes} not found in inclusive dataset") - return incl_dataset + return inclusive_dataset, list(required_datasets) -def get_stitching_datasets(self: Producer) -> list[od.Dataset]: - """ - Helper function to obtain all datasets that are required to stitch this *dataset_inst*. - """ - stitching_datasets = { - d for d in self.config_inst.datasets - if ( - d.has_process(self.dataset_inst.processes.get_first(), deep=True) or - self.dataset_inst.has_process(d.processes.get_first(), deep=True) - ) - } - return list(stitching_datasets) - - -def get_br_from_inclusive_dataset( +def get_br_from_inclusive_datasets( self: Producer, - inclusive_dataset: od.Dataset, - dataset_selection_stats: dict[str, dict[str, float]], -) -> dict[int, float]: + process_insts: Sequence[od.Process] | set[od.Process], + dataset_selection_stats: dict[str, dict[str, float | dict[str, float]]], + merged_selection_stats: dict[str, float | dict[str, float]], + log_brs: bool = False, +) -> dict[od.Process, float]: """ - Helper function to compute the branching ratios from the inclusive sample. This is done with ratios of event weights - isolated per dataset and thus independent of the overall mc weight normalization. + Helper function to compute the branching ratios from sum of weights of inclusive samples. """ - # define helper variables and mapping between process ids and dataset names - proc_ds_map = { - d.processes.get_first().id: d - for d in self.config_inst.datasets - if d.name in dataset_selection_stats.keys() - } - inclusive_proc = inclusive_dataset.processes.get_first() - N = lambda x: sn.Number(x, np.sqrt(x)) # alias for Number with counting error - - # create a dictionary "parent process id" -> {"child process id" -> "branching ratio", ...} - # each ratio is based on gen weight sums - child_brs: dict[int, dict[int, sn.Number]] = defaultdict(dict) - for proc, _, child_procs in inclusive_dataset.walk_processes(): - # the process must be covered by a dataset and should not be a leaf process - if proc.id not in proc_ds_map or proc.is_leaf_process: - continue - dataset_name = proc_ds_map[proc.id].name - - # get the mc weights for the "mother" dataset and add an entry for the process - sum_mc_weight: float = dataset_selection_stats[dataset_name]["sum_mc_weight"] - sum_mc_weight_per_process: dict[str, float] = dataset_selection_stats[dataset_name]["sum_mc_weight_per_process"] - # use the number of events to compute the error on the branching ratio - num_events: int = dataset_selection_stats[dataset_name]["num_events"] - num_events_per_process: dict[str, int] = dataset_selection_stats[dataset_name]["num_events_per_process"] - - # loop over all child processes - for child_proc in child_procs: - # skip processes that are not covered by any dataset or irrelevant for the used dataset - # (identified as leaf processes that have no occurrences in the stats - # or as non-leaf processes that are not in the stitching datasets) - is_leaf = child_proc.is_leaf_process - if ( - (is_leaf and str(child_proc.id) not in sum_mc_weight_per_process) or - (not is_leaf and child_proc.id not in proc_ds_map) - ): + # step 1: per desired process, collect datasets that contain them + process_datasets = collections.defaultdict(set) + for process_inst in process_insts: + for dataset_name, dstats in dataset_selection_stats.items(): + if str(process_inst.id) in dstats["sum_mc_weight_per_process"]: + process_datasets[process_inst].add(self.config_inst.get_dataset(dataset_name)) + + # step 2: per dataset, collect all "lowest level" processes that are contained in them + dataset_processes = collections.defaultdict(set) + for dataset_name in dataset_selection_stats: + dataset_inst = self.config_inst.get_dataset(dataset_name) + dataset_process_inst = dataset_inst.processes.get_first() + for process_inst in process_insts: + if process_inst == dataset_process_inst or dataset_process_inst.has_process(process_inst, deep=True): + dataset_processes[dataset_inst].add(process_inst) + + # step 3: per process, structure the assigned datasets and corresponding processes in DAGs, from more inclusive down + # to more exclusive phase spaces; usually each DAG can contain multiple paths to compute the BR of a single + # process; this is resolved in step 4 + @dataclasses.dataclass + class Node: + process_inst: od.Process + dataset_inst: od.Dataset | None = None + next: set[Node] = dataclasses.field(default_factory=set) + + def __hash__(self) -> int: + return hash((self.process_inst, self.dataset_inst)) + + def str_lines(self) -> list[str]: + lines = [ + f"{self.__class__.__name__}(", + f" process={self.process_inst.name}({self.process_inst.id})", + f" dataset={self.dataset_inst.name if self.dataset_inst else 'None'}", + ] + if self.next: + lines.append(" next={") + for n in self.next: + lines.extend(f" {line}" for line in n.str_lines()) + lines.append(" }") + else: + lines.append(r" next={}") + lines.append(")") + return lines + + def __str__(self) -> str: + return "\n".join(self.str_lines()) + + process_dags = {} + for process_inst, dataset_insts in process_datasets.items(): + # first, per dataset, remember all sub (more exclusive) datasets + # (the O(n^2) is not necessarily optimal, but we are dealing with very small numbers here, thus acceptable) + sub_datasets = {} + for d_incl, d_excl in itertools.permutations(dataset_insts, 2): + if d_incl.processes.get_first().has_process(d_excl.processes.get_first(), deep=True): + sub_datasets.setdefault(d_incl, set()).add(d_excl) + # then, expand to a DAG structure + nodes = {} + excl_nodes = set() + for d_incl, d_excls in sub_datasets.items(): + for d_excl in d_excls: + if d_incl not in nodes: + nodes[d_incl] = Node(d_incl.processes.get_first(), d_incl) + if d_excl not in nodes: + nodes[d_excl] = Node(d_excl.processes.get_first(), d_excl) + nodes[d_incl].next.add(nodes[d_excl]) + excl_nodes.add(nodes[d_excl]) + # mark the root node as the head of the DAG + dag = (set(nodes.values()) - excl_nodes).pop() + # add another node to leaves that only contains the process instance + for node in excl_nodes: + if node.next or node.process_inst == process_inst: continue - - # determine relevant leaf processes that will be summed over - # (since the all stats are only derived for those) - leaf_proc_ids = ( - [child_proc.id] - if is_leaf or str(child_proc.id) in sum_mc_weight_per_process - else [ - p.id for p, _, _ in child_proc.walk_processes() - if str(p.id) in sum_mc_weight_per_process - ] - ) - - # compute the br and its uncertainty using the bare number of events - # NOTE: we assume that the uncertainty is independent of the mc weights, so we can use - # the same relative uncertainty; this is a simplification, but should be fine for most - # cases; we can improve this by switching from jsons to hists when storing sum of weights - leaf_sum = lambda d: sum(d.get(str(proc_id), 0) for proc_id in leaf_proc_ids) - br_nom = leaf_sum(sum_mc_weight_per_process) / sum_mc_weight - br_unc = N(leaf_sum(num_events_per_process)) / N(num_events) - child_brs[proc.id][child_proc.id] = sn.Number( - br_nom, - br_unc(sn.UP, unc=True, factor=True) * 1j, # same relative uncertainty + if process_inst not in nodes: + nodes[process_inst] = Node(process_inst) + node.next.add(nodes[process_inst]) + process_dags[process_inst] = dag + + # step 4: per process, compute the branching ratio for each possible path in the DAG, while keeping track of the + # statistical precision of each combination, evaluated based on the raw number of events; then pick the + # most precise path; again, there should usually be just a single path, but multiple ones are possible when + # datasets have complex overlap + def get_single_br(dataset_inst: od.Dataset, process_inst: od.Process) -> sn.Number | None: + # process_inst might refer to a mid-layer process, so check which lowest-layer processes it is made of + lowest_process_ids = ( + [process_inst.id] + if process_inst in process_insts + else [ + int(process_id_str) + for process_id_str in dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"] + if process_inst.has_process(int(process_id_str), deep=True) + ] + ) + # extract stats + process_sum_weights = sum( + dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"].get(str(process_id), 0.0) + for process_id in lowest_process_ids + ) + dataset_sum_weights = sum(dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"].values()) + process_num_events = sum( + dataset_selection_stats[dataset_inst.name]["num_events_per_process"].get(str(process_id), 0.0) + for process_id in lowest_process_ids + ) + dataset_num_events = sum(dataset_selection_stats[dataset_inst.name]["num_events_per_process"].values()) + # when there are no events, return None + if process_num_events == 0: + logger.warning( + f"found no events for process '{process_inst.name}' ({process_inst.id}) with subprocess ids " + f"'{','.join(map(str, lowest_process_ids))}' in selection stats of dataset {dataset_inst.name}", ) + return None + # compute the ratio of events, assuming correlated poisson counting errors since numbers come from the same + # dataset, then compute the relative uncertainty + num_ratio = ( + sn.Number(process_num_events, process_num_events**0.5) / + sn.Number(dataset_num_events, dataset_num_events**0.5) + ) + rel_unc = num_ratio(sn.UP, unc=True, factor=True) + # compute the branching ratio, using the same relative uncertainty and store using the dataset name to mark its + # limited statistics as the source of uncertainty which is important for consistent error propagation + br = sn.Number(process_sum_weights / dataset_sum_weights, {f"{dataset_inst.name}_stats": rel_unc * 1j}) + return br + + def path_repr(br_path: tuple[sn.Number, ...], dag_path: tuple[Node, ...]) -> str: + return " X ".join( + f"{node.process_inst.name} (br = {br.combine_uncertainties().str(format=3)})" + for br, node in zip(br_path, dag_path) + ) - # define actual per-process branching ratios - branching_ratios: dict[int, float] = {} - - def multiply_branching_ratios(proc_id: int, proc_br: sn.Number) -> None: - """ - Recursively multiply the branching ratios from the nested dictionary. - """ - # when the br for proc_id can be created from sub processes, calculate it via product - if proc_id in child_brs: - for child_id, child_br in child_brs[proc_id].items(): - # multiply the branching ratios assuming no correlation - prod_br = child_br.mul(proc_br, rho=0, inplace=False) - multiply_branching_ratios(child_id, prod_br) - return - - # warn the user if the relative (statistical) error is large - rel_unc = proc_br(sn.UP, unc=True, factor=True) - if rel_unc > 0.05: + process_brs = {} + process_brs_log = {} + for process_inst, dag in process_dags.items(): + brs = [] + queue = collections.deque([(dag, (br := sn.Number(1.0, 0.0)), (br,), (dag,))]) + while queue: + node, br, br_path, dag_path = queue.popleft() + if not node.next: + brs.append((br, br_path, dag_path)) + continue + for sub_node in node.next: + sub_br = get_single_br(node.dataset_inst, sub_node.process_inst) + if sub_br is not None: + queue.append((sub_node, br * sub_br, br_path + (sub_br,), dag_path + (sub_node,))) + # combine all uncertainties + brs = [(br.combine_uncertainties(), *paths) for br, *paths in brs] + # select the most certain one + brs.sort(key=lambda tpl: tpl[0](sn.UP, unc=True, factor=True)) + best_br, best_br_path, best_dag_path = brs[0] + process_brs[process_inst] = best_br.nominal + process_brs_log[process_inst] = (best_br.nominal, best_br(sn.UP, unc=True, factor=True)) # value and % unc + # show a warning in case the relative uncertainty is large + if (rel_unc := best_br(sn.UP, unc=True, factor=True)) > 0.1: logger.warning( - f"large error on the branching ratio for process {inclusive_proc.get_process(proc_id).name} with " - f"process id {proc_id} ({rel_unc * 100:.2f}%)", + f"large error on the branching ratio of {rel_unc * 100:.2f}% for process '{process_inst.name}' " + f"({process_inst.id}), calculated along\n {path_repr(best_br_path, best_dag_path)}", ) + # in case there were multiple values, check their compatibility with the best one and warn if they diverge + for i, (br, br_path, dag_path) in enumerate(brs[1:], 2): + abs_diff = abs(best_br.n - br.n) + rel_diff = abs_diff / best_br.n + pull = abs(best_br.n - br.n) / (best_br.u(direction="up")**2 + br.u(direction="up")**2)**0.5 + if rel_diff > 0.1 and pull > 3: + logger.warning( + f"detected diverging branching ratios between the best and the one on position {i} for process " + f"'{process_inst.name}' (abs_diff={abs_diff:.4f}, rel_diff={rel_diff:.4f}, pull={pull:.2f} ):" + f"\nbest path: {best_br.str(format=3)} from {path_repr(best_br_path, best_dag_path)}" + f"\npath {i} : {br.str(format=3)} from {path_repr(br_path, dag_path)}", + ) - # just store the nominal value - branching_ratios[proc_id] = proc_br.nominal - - # fill all branching ratios - for proc_id, br in child_brs[inclusive_proc.id].items(): - multiply_branching_ratios(proc_id, br) + if log_brs: + from tabulate import tabulate + header = ["process name", "process id", "branching ratio", "uncertainty (%)"] + rows = [ + [ + process_inst.name, process_inst.id, process_brs_log[process_inst][0], + f"{process_brs_log[process_inst][1] * 100:.4f}", + ] + for process_inst in sorted(process_brs_log) + ] + logger.info(f"extracted branching ratios from process occurrence in datasets:\n{tabulate(rows, header)}") - return branching_ratios + return process_brs def update_dataset_selection_stats( self: Producer, - dataset_selection_stats: dict[str, dict[str, float]], -) -> dict[str, dict[str, float]]: + dataset_selection_stats: dict[str, dict[str, float | dict[str, float]]], +) -> dict[str, dict[str, float | dict[str, float]]]: """ Hook to optionally update the per-dataset selection stats. """ @@ -182,13 +273,16 @@ def update_dataset_selection_stats( weight_name="normalization_weight", # which luminosity to apply, uses the value stored in the config when None luminosity=None, + # whether to normalize weights per dataset to the mean weight first (to cancel out numeric differences) + normalize_weights_per_dataset=True, # whether to allow stitching datasets allow_stitching=False, - get_xsecs_from_inclusive_dataset=False, + get_xsecs_from_inclusive_datasets=False, get_stitching_datasets=get_stitching_datasets, - get_inclusive_dataset=get_inclusive_dataset, - get_br_from_inclusive_dataset=get_br_from_inclusive_dataset, + get_br_from_inclusive_datasets=get_br_from_inclusive_datasets, update_dataset_selection_stats=update_dataset_selection_stats, + update_dataset_selection_stats_br=None, + update_dataset_selection_stats_sum_weights=None, # only run on mc mc_only=True, ) @@ -217,33 +311,54 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra process_id = np.asarray(events.process_id) # ensure all ids were assigned a cross section - unique_process_ids = set(process_id) - invalid_ids = unique_process_ids - self.xs_process_ids + unique_process_ids = set(np.unique(process_id)) + invalid_ids = unique_process_ids - self.known_process_ids if invalid_ids: + invalid_names = [ + f"{self.config_inst.get_process(proc_id).name} ({proc_id})" + for proc_id in invalid_ids + ] raise Exception( - f"process_id field contains id(s) {invalid_ids} for which no cross sections were found; process ids with " - f"cross sections: {self.xs_process_ids}", + f"process_id field contains entries {', '.join(invalid_names)} for which no cross sections were found; " + f"process ids with cross sections: {self.known_process_ids}", ) # read the weight per process (defined as lumi * xsec / sum_weights) from the lookup table - process_weight = np.squeeze(np.asarray(self.process_weight_table[0, process_id].todense())) + process_weight = np.squeeze(np.asarray(self.process_weight_table[process_id, 0].todense())) # compute the weight and store it norm_weight = events.mc_weight * process_weight events = set_ak_column(events, self.weight_name, norm_weight, value_type=np.float32) - # if we are stitching, we also compute the inclusive weight for debugging purposes - if ( - self.allow_stitching and - self.get_xsecs_from_inclusive_dataset and - self.dataset_inst == self.inclusive_dataset - ): + # when stitching, also compute the inclusive-only weight + if self.allow_stitching and self.dataset_inst == self.inclusive_dataset: incl_norm_weight = events.mc_weight * self.inclusive_weight events = set_ak_column(events, self.weight_name_incl, incl_norm_weight, value_type=np.float32) return events +@normalization_weights.init +def normalization_weights_init(self: Producer, **kwargs) -> None: + """ + Initializes the normalization weights producer by setting up the normalization weight column. + """ + # declare the weight name to be a produced column + self.produces.add(self.weight_name) + + # when stitching is enabled, store specific information + if self.allow_stitching: + # remember the inclusive dataset and all datasets needed to determine the weights of processes in _this_ dataset + self.inclusive_dataset, self.required_datasets = self.get_stitching_datasets() + + # potentially also store the weight needed for only using the inclusive dataset + if self.dataset_inst == self.inclusive_dataset: + self.weight_name_incl = f"{self.weight_name}_inclusive" + self.produces.add(self.weight_name_incl) + else: + self.required_datasets = [self.dataset_inst] + + @normalization_weights.requires def normalization_weights_requires( self: Producer, @@ -255,7 +370,7 @@ def normalization_weights_requires( Adds the requirements needed by the underlying py:attr:`task` to access selection stats into *reqs*. """ # check that all datasets are known - for dataset in self.stitching_datasets: + for dataset in self.required_datasets: if not self.config_inst.has_dataset(dataset): raise Exception(f"unknown dataset '{dataset}' required for normalization weights computation") @@ -266,7 +381,7 @@ def normalization_weights_requires( dataset=dataset.name, branch=-1 if task.is_workflow() else 0, ) - for dataset in self.stitching_datasets + for dataset in self.required_datasets } return reqs @@ -287,144 +402,153 @@ def normalization_weights_setup( - py: attr: `process_weight_table`: A sparse array serving as a lookup table for the calculated process weights. This weight is defined as the product of the luminosity, the cross section, divided by the sum of event weights per process. + - py: attr: `known_process_ids`: A set of all process ids that are known by the lookup table. """ # load the selection stats dataset_selection_stats = { - dataset: task.cached_value( + dataset: copy.deepcopy(task.cached_value( key=f"selection_stats_{dataset}", func=lambda: inp["stats"].load(formatter="json"), - ) + )) for dataset, inp in inputs["selection_stats"].items() } - # optional hook to amend the per-dataset selection stats - if callable(self.update_dataset_selection_stats): - dataset_selection_stats = self.update_dataset_selection_stats(dataset_selection_stats) - - # if necessary, merge the selection stats across datasets - if len(dataset_selection_stats) > 1: - from columnflow.tasks.selection import MergeSelectionStats - merged_selection_stats = defaultdict(float) - for stats in dataset_selection_stats.values(): - MergeSelectionStats.merge_counts(merged_selection_stats, stats) - else: - merged_selection_stats = dataset_selection_stats[self.dataset_inst.name] + # optionally normalize weights per dataset to their mean, to potentially align different numeric domains + norm_factor = 1.0 + if self.normalize_weights_per_dataset: + for dataset, stats in dataset_selection_stats.items(): + dataset_mean_weight = ( + sum(stats["sum_mc_weight_per_process"].values()) / + sum(stats["num_events_per_process"].values()) + ) + for process_id_str in stats["sum_mc_weight_per_process"]: + stats["sum_mc_weight_per_process"][process_id_str] /= dataset_mean_weight + if dataset == self.dataset_inst.name: + norm_factor = 1.0 / dataset_mean_weight - # determine all proceses at any depth in the stitching datasets - process_insts = { - process_inst - for dataset_inst in self.stitching_datasets - for process_inst, _, _ in dataset_inst.walk_processes() + # drop unused stats + dataset_selection_stats = { + dataset: {field: stats[field] for field in ["num_events_per_process", "sum_mc_weight_per_process"]} + for dataset, stats in dataset_selection_stats.items() } - # determine ids of processes that were identified in the selection stats - allowed_ids = set(map(int, merged_selection_stats["sum_mc_weight_per_process"])) - - # complain if there are processes seen/id'ed during selection that are not part of the datasets - unknown_process_ids = allowed_ids - {p.id for p in process_insts} - if unknown_process_ids: + # separately treat stats for extracting BRs and sum of mc weights + def extract_stats(*update_funcs): + # create copy + stats = copy.deepcopy(dataset_selection_stats) + # update through one of the functions + for update_func in update_funcs: + if callable(update_func): + stats = update_func(stats) + break + # merge + if len(stats) > 1: + from columnflow.tasks.selection import MergeSelectionStats + merged_stats = collections.defaultdict(float) + for _stats in stats.values(): + MergeSelectionStats.merge_counts(merged_stats, _stats) + else: + merged_stats = stats[self.dataset_inst.name] + return stats, merged_stats + + dataset_selection_stats_br, merged_selection_stats_br = extract_stats( + self.update_dataset_selection_stats_br, + self.update_dataset_selection_stats, + ) + _, merged_selection_stats_sum_weights = extract_stats( + self.update_dataset_selection_stats_sum_weights, + self.update_dataset_selection_stats, + ) + + # get all process ids and instances seen and assigned during selection of this dataset + # (i.e., all possible processes that might be encountered during event processing) + process_ids = set(map(int, dataset_selection_stats_br[self.dataset_inst.name]["sum_mc_weight_per_process"])) + process_insts = set(map(self.config_inst.get_process, process_ids)) + + # consistency check: when the main process of the current dataset is part of these "lowest level" processes, + # there should only be this single process, otherwise the manual (sub) process assignment does not match the + # general dataset -> main process info + if self.dataset_inst.processes.get_first() in process_insts and len(process_insts) > 1: raise Exception( - f"selection stats contain ids of processes that were not previously registered to the config " - f"'{self.config_inst.name}': {', '.join(map(str, unknown_process_ids))}", + f"dataset '{self.dataset_inst.name}' has main process '{self.dataset_inst.processes.get_first().name}' " + "assigned to it (likely as per cmsdb), but the dataset selection stats for this dataset contain multiple " + "sub processes, which is likely a misconfiguration of the manual sub process assignment upstream; found " + f"sub processes: {', '.join(f'{process_inst.name} ({process_inst.id})' for process_inst in process_insts)}", ) - # likewise, drop processes that were not seen during selection - process_insts = {p for p in process_insts if p.id in allowed_ids} - max_id = max(process_inst.id for process_inst in process_insts) + # setup the event weight lookup table + process_weight_table = sp.sparse.lil_matrix((max(process_ids) + 1, 1), dtype=np.float32) - # get the luminosity - lumi = self.config_inst.x.luminosity if self.luminosity is None else self.luminosity - lumi = lumi.nominal if isinstance(lumi, sn.Number) else float(lumi) - - # create a event weight lookup table - process_weight_table = sp.sparse.lil_matrix((1, max_id + 1), dtype=np.float32) - if self.allow_stitching and self.get_xsecs_from_inclusive_dataset: - inclusive_dataset = self.inclusive_dataset - logger.debug(f"using inclusive dataset {inclusive_dataset.name} for cross section lookup") - - # extract branching ratios from inclusive dataset(s) - inclusive_proc = inclusive_dataset.processes.get_first() - if self.dataset_inst == inclusive_dataset and process_insts == {inclusive_proc}: - branching_ratios = {inclusive_proc.id: 1.0} - else: - branching_ratios = self.get_br_from_inclusive_dataset( - inclusive_dataset=inclusive_dataset, - dataset_selection_stats=dataset_selection_stats, + def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) -> None: + if sum_weights == 0: + logger.warning( + f"zero sum of weights found for computing normalization weight for process '{process_inst.name}' " + f"({process_inst.id}) in dataset '{self.dataset_inst.name}', going to use weight of 0.0", ) - if not branching_ratios: - raise Exception( - f"no branching ratios could be computed based on the inclusive dataset {inclusive_dataset}", - ) + weight = 0.0 + else: + weight = norm_factor * xsec * lumi / sum_weights + process_weight_table[process_inst.id, 0] = weight + + # get the luminosity + lumi = float(self.config_inst.x.luminosity if self.luminosity is None else self.luminosity) - # compute the weight the inclusive dataset would have on its own without stitching - inclusive_xsec = inclusive_proc.get_xsec(self.config_inst.campaign.ecm).nominal - self.inclusive_weight = ( - lumi * inclusive_xsec / dataset_selection_stats[inclusive_dataset.name]["sum_mc_weight"] - if self.dataset_inst == inclusive_dataset - else 0 + # prepare info for the inclusive dataset + inclusive_proc = self.inclusive_dataset.processes.get_first() + inclusive_xsec = inclusive_proc.get_xsec(self.config_inst.campaign.ecm).nominal + + # compute the weight the inclusive dataset would have on its own without stitching + if self.allow_stitching and self.dataset_inst == self.inclusive_dataset: + inclusive_sum_weights = sum( + dataset_selection_stats[self.inclusive_dataset.name]["sum_mc_weight_per_process"].values(), + ) + self.inclusive_weight = norm_factor * inclusive_xsec * lumi / inclusive_sum_weights + + # fill weights into the lut, depending on whether stitching is allowed / needed or not + do_stitch = ( + self.allow_stitching and + self.get_xsecs_from_inclusive_datasets and + len(self.required_datasets) > 1 + ) + if do_stitch: + logger.debug( + f"using inclusive dataset '{self.inclusive_dataset.name}' and process '{inclusive_proc.name}' for cross " + "section lookup", + ) + + # extract branching ratios + branching_ratios = self.get_br_from_inclusive_datasets( + process_insts, + dataset_selection_stats_br, + merged_selection_stats_br, ) # fill the process weight table - for proc_id, br in branching_ratios.items(): - sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(proc_id)] - process_weight_table[0, proc_id] = lumi * inclusive_xsec * br / sum_weights - - # fill in cross sections of missing leaf processes - missing_proc_ids = set(proc.id for proc in inclusive_proc.get_leaf_processes()) - set(branching_ratios.keys()) - for proc_id in missing_proc_ids: - process_inst = inclusive_proc.get_process(proc_id) - if ( - self.config_inst.campaign.ecm in process_inst.xsecs and - str(proc_id) in merged_selection_stats["sum_mc_weight_per_process"] - ): - xsec = process_inst.get_xsec(self.config_inst.campaign.ecm).nominal - sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(proc_id)] - process_weight_table[0, process_inst.id] = lumi * xsec / sum_weights - logger.warning( - f"added cross section for missing leaf process {process_inst.name} ({proc_id}) from xsec entry", - ) + for process_inst, br in branching_ratios.items(): + sum_weights = merged_selection_stats_sum_weights["sum_mc_weight_per_process"][str(process_inst.id)] + fill_weight_table(process_inst, br * inclusive_xsec, sum_weights) else: # fill the process weight table with per-process cross sections for process_inst in process_insts: - if self.config_inst.campaign.ecm not in process_inst.xsecs.keys(): + if self.config_inst.campaign.ecm not in process_inst.xsecs: raise KeyError( f"no cross section registered for process {process_inst} for center-of-mass energy of " f"{self.config_inst.campaign.ecm}", ) - sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(process_inst.id)] xsec = process_inst.get_xsec(self.config_inst.campaign.ecm).nominal - process_weight_table[0, process_inst.id] = lumi * xsec / sum_weights + sum_weights = merged_selection_stats_sum_weights["sum_mc_weight_per_process"][str(process_inst.id)] + fill_weight_table(process_inst, xsec, sum_weights) + # store lookup table and known process ids self.process_weight_table = process_weight_table - self.xs_process_ids = set(self.process_weight_table.rows[0]) - - -@normalization_weights.init -def normalization_weights_init(self: Producer, **kwargs) -> None: - """ - Initializes the normalization weights producer by setting up the normalization weight column. - """ - self.produces.add(self.weight_name) - if self.allow_stitching: - self.stitching_datasets = self.get_stitching_datasets() - self.inclusive_dataset = self.get_inclusive_dataset() - else: - self.stitching_datasets = [self.dataset_inst] - - if ( - self.allow_stitching and - self.get_xsecs_from_inclusive_dataset and - self.dataset_inst == self.inclusive_dataset - ): - self.weight_name_incl = f"{self.weight_name}_inclusive" - self.produces.add(self.weight_name_incl) + self.known_process_ids = process_ids stitched_normalization_weights = normalization_weights.derive( "stitched_normalization_weights", cls_dict={ "weight_name": "normalization_weight", - "get_xsecs_from_inclusive_dataset": True, + "get_xsecs_from_inclusive_datasets": True, "allow_stitching": True, }, ) @@ -432,6 +556,6 @@ def normalization_weights_init(self: Producer, **kwargs) -> None: stitched_normalization_weights_brs_from_processes = stitched_normalization_weights.derive( "stitched_normalization_weights_brs_from_processes", cls_dict={ - "get_xsecs_from_inclusive_dataset": False, + "get_xsecs_from_inclusive_datasets": False, }, ) diff --git a/sandboxes/cf.txt b/sandboxes/cf.txt index 46861d8c1..91e7306d0 100644 --- a/sandboxes/cf.txt +++ b/sandboxes/cf.txt @@ -1,7 +1,7 @@ # version 14 luigi~=3.6.0 -scinum~=2.2.0 +scinum~=2.2.1 six~=1.17.0 pyyaml~=6.0.2 typing_extensions~=4.13.0 From 0619ee33e560f000f742d3d8582bdf8c1e52333d Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 8 Aug 2025 08:15:47 +0200 Subject: [PATCH 059/123] Hotfix inclusive dataset lookup in norm producer. --- columnflow/production/normalization.py | 49 ++++++++++++++++---------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index c70de84bf..dbc529161 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -29,7 +29,7 @@ logger = law.logger.get_logger(__name__) -def get_stitching_datasets(self: Producer) -> tuple[od.Dataset, list[od.Dataset]]: +def get_stitching_datasets(self: Producer, debug: bool = False) -> tuple[od.Dataset, list[od.Dataset]]: """ Helper function to obtain information about stitching datasets: @@ -47,19 +47,27 @@ def get_stitching_datasets(self: Producer) -> tuple[od.Dataset, list[od.Dataset] } # determine the inclusive dataset - process_map = {d.processes.get_first(): d for d in required_datasets} - process_inst = self.dataset_inst.processes.get_first() inclusive_dataset = None - while process_inst: - if process_inst in process_map: - inclusive_dataset = process_map[process_inst] - process_inst = process_inst.parent_processes.get_first(default=None) + for dataset_inst in required_datasets: + for other_dataset_inst in required_datasets: + if dataset_inst == other_dataset_inst: + continue + # check if the other dataset is a sub-dataset of the current one by comparing their leading process + if not dataset_inst.has_process(other_dataset_inst.processes.get_first(), deep=True): + break + else: + # if we did not break, the dataset is the inclusive one + inclusive_dataset = dataset_inst + break if not inclusive_dataset: raise Exception("inclusive dataset not found") - # cross-check if there are processes in the required datasets that are not covered by the inclusive dataset - unmatched_processes = {p for p in process_map if not inclusive_dataset.has_process(p, deep=True)} - if unmatched_processes: - raise Exception(f"processes {unmatched_processes} not found in inclusive dataset") + + if debug: + logger.info( + f"determined info for stitching content of dataset '{self.dataset_inst.name}':\n" + f" - inclusive dataset: {inclusive_dataset.name}\n" + f" - required datasets: {', '.join(d.name for d in required_datasets)}", + ) return inclusive_dataset, list(required_datasets) @@ -69,7 +77,7 @@ def get_br_from_inclusive_datasets( process_insts: Sequence[od.Process] | set[od.Process], dataset_selection_stats: dict[str, dict[str, float | dict[str, float]]], merged_selection_stats: dict[str, float | dict[str, float]], - log_brs: bool = False, + debug: bool = False, ) -> dict[od.Process, float]: """ Helper function to compute the branching ratios from sum of weights of inclusive samples. @@ -203,7 +211,7 @@ def path_repr(br_path: tuple[sn.Number, ...], dag_path: tuple[Node, ...]) -> str ) process_brs = {} - process_brs_log = {} + process_brs_debug = {} for process_inst, dag in process_dags.items(): brs = [] queue = collections.deque([(dag, (br := sn.Number(1.0, 0.0)), (br,), (dag,))]) @@ -222,7 +230,7 @@ def path_repr(br_path: tuple[sn.Number, ...], dag_path: tuple[Node, ...]) -> str brs.sort(key=lambda tpl: tpl[0](sn.UP, unc=True, factor=True)) best_br, best_br_path, best_dag_path = brs[0] process_brs[process_inst] = best_br.nominal - process_brs_log[process_inst] = (best_br.nominal, best_br(sn.UP, unc=True, factor=True)) # value and % unc + process_brs_debug[process_inst] = (best_br.nominal, best_br(sn.UP, unc=True, factor=True)) # value and % unc # show a warning in case the relative uncertainty is large if (rel_unc := best_br(sn.UP, unc=True, factor=True)) > 0.1: logger.warning( @@ -242,15 +250,15 @@ def path_repr(br_path: tuple[sn.Number, ...], dag_path: tuple[Node, ...]) -> str f"\npath {i} : {br.str(format=3)} from {path_repr(br_path, dag_path)}", ) - if log_brs: + if debug: from tabulate import tabulate header = ["process name", "process id", "branching ratio", "uncertainty (%)"] rows = [ [ - process_inst.name, process_inst.id, process_brs_log[process_inst][0], - f"{process_brs_log[process_inst][1] * 100:.4f}", + process_inst.name, process_inst.id, process_brs_debug[process_inst][0], + f"{process_brs_debug[process_inst][1] * 100:.4f}", ] - for process_inst in sorted(process_brs_log) + for process_inst in sorted(process_brs_debug) ] logger.info(f"extracted branching ratios from process occurrence in datasets:\n{tabulate(rows, header)}") @@ -404,6 +412,10 @@ def normalization_weights_setup( weights per process. - py: attr: `known_process_ids`: A set of all process ids that are known by the lookup table. """ + # optionally run the dataset lookup again in debug mode + if getattr(task, "branch", None) == 0: + self.get_stitching_datasets(debug=True) + # load the selection stats dataset_selection_stats = { dataset: copy.deepcopy(task.cached_value( @@ -521,6 +533,7 @@ def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) process_insts, dataset_selection_stats_br, merged_selection_stats_br, + debug=getattr(task, "branch", None) == 0, ) # fill the process weight table From 13d65231ce9a408d7106822a3380d73052fafe8c Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 8 Aug 2025 11:17:55 +0200 Subject: [PATCH 060/123] Update boost-histogram version. --- sandboxes/columnar.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/sandboxes/columnar.txt b/sandboxes/columnar.txt index 8d6293ecc..e007dd1d7 100644 --- a/sandboxes/columnar.txt +++ b/sandboxes/columnar.txt @@ -7,6 +7,7 @@ pyarrow==19.0.1 dask-awkward==2025.3.0 correctionlib==2.6.4 coffea==2024.11.0 +boost-histogram==1.6.0 # minimum versions for general packages zstandard~=0.23.0 From c586eb973afede0f86f718f43c9f2fe3d4d2c8b7 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 11 Aug 2025 18:25:04 +0200 Subject: [PATCH 061/123] Hotfix norm weight logging. --- columnflow/production/normalization.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index dbc529161..c97e13a2c 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -412,8 +412,14 @@ def normalization_weights_setup( weights per process. - py: attr: `known_process_ids`: A set of all process ids that are known by the lookup table. """ - # optionally run the dataset lookup again in debug mode - if getattr(task, "branch", None) == 0: + # optionally run the dataset lookup again in debug mode when stitching + do_stitch = ( + self.allow_stitching and + self.get_xsecs_from_inclusive_datasets and + len(self.required_datasets) > 1 + ) + is_first_branch = getattr(task, "branch", None) == 0 + if do_stitch and is_first_branch: self.get_stitching_datasets(debug=True) # load the selection stats @@ -517,11 +523,6 @@ def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) self.inclusive_weight = norm_factor * inclusive_xsec * lumi / inclusive_sum_weights # fill weights into the lut, depending on whether stitching is allowed / needed or not - do_stitch = ( - self.allow_stitching and - self.get_xsecs_from_inclusive_datasets and - len(self.required_datasets) > 1 - ) if do_stitch: logger.debug( f"using inclusive dataset '{self.inclusive_dataset.name}' and process '{inclusive_proc.name}' for cross " @@ -533,7 +534,7 @@ def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) process_insts, dataset_selection_stats_br, merged_selection_stats_br, - debug=getattr(task, "branch", None) == 0, + debug=is_first_branch, ) # fill the process weight table From 2c77fb13a2c2cdb114096ed12346f00e1108e7ff Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Tue, 12 Aug 2025 13:54:57 +0200 Subject: [PATCH 062/123] Revert boost-histogram update. --- sandboxes/cf.txt | 2 +- sandboxes/columnar.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sandboxes/cf.txt b/sandboxes/cf.txt index 91e7306d0..9c1585164 100644 --- a/sandboxes/cf.txt +++ b/sandboxes/cf.txt @@ -1,4 +1,4 @@ -# version 14 +# version 15 luigi~=3.6.0 scinum~=2.2.1 diff --git a/sandboxes/columnar.txt b/sandboxes/columnar.txt index e007dd1d7..8d6293ecc 100644 --- a/sandboxes/columnar.txt +++ b/sandboxes/columnar.txt @@ -7,7 +7,6 @@ pyarrow==19.0.1 dask-awkward==2025.3.0 correctionlib==2.6.4 coffea==2024.11.0 -boost-histogram==1.6.0 # minimum versions for general packages zstandard~=0.23.0 From 650b4f6ec487e2e13aceabf0997af263c30ec6f5 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 13 Aug 2025 18:27:48 +0200 Subject: [PATCH 063/123] Hotfix combined jets calibrator. --- columnflow/calibration/cms/jets.py | 45 +++++++++++++++--------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 5ad98852f..964143c6c 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -12,7 +12,7 @@ from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random, propagate_met, sum_transverse from columnflow.production.util import attach_coffea_behavior -from columnflow.util import maybe_import, DotDict, load_correction_set +from columnflow.util import UNSET, maybe_import, DotDict, load_correction_set from columnflow.columnar_util import set_ak_column, layout_ak_array, optional_column as optional np = maybe_import("numpy") @@ -1129,8 +1129,6 @@ def deterministic_normal(loc, scale, seed): # @calibrator( - uses={jec, jer}, - produces={jec, jer}, # name of the jet collection to smear jet_name="Jet", # name of the associated gen jet collection (for JER smearing) @@ -1151,32 +1149,35 @@ def jets(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: :param events: awkward array containing events to process """ # apply jet energy corrections - events = self[jec](events, **kwargs) + events = self[self.jec_cls](events, **kwargs) # apply jer smearing on MC only if self.dataset_inst.is_mc: - events = self[jer](events, **kwargs) + events = self[self.jer_cls](events, **kwargs) return events -@jets.pre_init -def jets_pre_init(self: Calibrator, **kwargs) -> None: - # forward argument to the producers - self.deps_kwargs[jec]["jet_name"] = self.jet_name - self.deps_kwargs[jer]["jet_name"] = self.jet_name - self.deps_kwargs[jer]["gen_jet_name"] = self.gen_jet_name - if self.propagate_met is not None: - self.deps_kwargs[jec]["propagate_met"] = self.propagate_met - self.deps_kwargs[jer]["propagate_met"] = self.propagate_met - if self.get_jec_file is not None: - self.deps_kwargs[jec]["get_jec_file"] = self.get_jec_file - if self.get_jec_config is not None: - self.deps_kwargs[jec]["get_jec_config"] = self.get_jec_config - if self.get_jer_file is not None: - self.deps_kwargs[jer]["get_jer_file"] = self.get_jer_file - if self.get_jer_config is not None: - self.deps_kwargs[jer]["get_jer_config"] = self.get_jer_config +@jets.init +def jets_init(self: Calibrator, **kwargs) -> None: + # create custom jec and jer calibrators, using the jet name as the identifying value + def get_attrs(attrs): + cls_dict = {} + for attr in attrs: + if (value := getattr(self, attr, UNSET)) is not UNSET: + cls_dict[attr] = value + return cls_dict + + jec_attrs = ["jet_name", "gen_jet_name", "propagate_met", "get_jec_file", "get_jec_config"] + self.jec_cls = jec.derive(f"jec_{self.jet_name}", cls_dict=get_attrs(jec_attrs)) + self.uses.add(self.jec_cls) + self.produces.add(self.jec_cls) + + if self.dataset_inst.is_mc: + jer_attrs = ["jet_name", "gen_jet_name", "propagate_met", "get_jer_file", "get_jer_config"] + self.jer_cls = jer.derive(f"jer_{self.jet_name}", cls_dict=get_attrs(jer_attrs)) + self.uses.add(self.jer_cls) + self.produces.add(self.jer_cls) # explicit calibrators for standard jet collections From 433803d67c245445d264553ce59739ce85456406 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 14 Aug 2025 14:25:10 +0200 Subject: [PATCH 064/123] Hotfix inclusive dataset attribute in norm weight producer. --- columnflow/production/normalization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index c97e13a2c..9b9e4be5a 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -364,6 +364,7 @@ def normalization_weights_init(self: Producer, **kwargs) -> None: self.weight_name_incl = f"{self.weight_name}_inclusive" self.produces.add(self.weight_name_incl) else: + self.inclusive_dataset = [self.dataset_inst] self.required_datasets = [self.dataset_inst] From 9be20725c5d2bf3d2b8559a575c6b0517eddad15 Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Fri, 15 Aug 2025 11:13:23 +0200 Subject: [PATCH 065/123] make plotting faster using non-interactive backend --- columnflow/plotting/plot_all.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index 14ab9ab55..73407b736 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -345,6 +345,9 @@ def plot_all( # general mplhep style plt.style.use(mplhep.style.CMS) + # use non-interactive Agg backend for plotting + mpl.use("Agg") + # setup figure and axes rax = None grid_spec = {"left": 0.15, "right": 0.95, "top": 0.95, "bottom": 0.1} From 433ba76d58a435ffa9d2d61002105b9a0377404f Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 1 Sep 2025 15:57:51 +0200 Subject: [PATCH 066/123] [cms] Make datacard writer class configurable in task. --- columnflow/tasks/cms/inference.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index 493125cf7..6414a23ba 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -13,11 +13,15 @@ from columnflow.tasks.framework.inference import SerializeInferenceModelBase from columnflow.tasks.histograms import MergeHistograms +from columnflow.inference.cms.datacard import DatacardHists, ShiftHists, DatacardWriter + class CreateDatacards(SerializeInferenceModelBase): resolution_task_cls = MergeHistograms + datacard_writer_cls = DatacardWriter + def output(self): hooks_repr = self.hist_hooks_repr cat_obj = self.branch_data @@ -39,7 +43,6 @@ def basename(name: str, ext: str) -> str: @law.decorator.safe_output def run(self): import hist - from columnflow.inference.cms.datacard import DatacardHists, ShiftHists, DatacardWriter # prepare inputs inputs = self.input() @@ -129,7 +132,7 @@ def run(self): # forward objects to the datacard writer outputs = self.output() - writer = DatacardWriter(self.inference_model_inst, datacard_hists) + writer = self.datacard_writer_cls(self.inference_model_inst, datacard_hists) with outputs["card"].localize("w") as tmp_card, outputs["shapes"].localize("w") as tmp_shapes: writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outputs["shapes"].basename) From 36eb11ae0b8dc0fe7cbf598a979b9bd74feb3609 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 1 Sep 2025 15:58:31 +0200 Subject: [PATCH 067/123] Hotfix typo in norm weight producer. --- columnflow/production/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 9b9e4be5a..9fad0561a 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -364,7 +364,7 @@ def normalization_weights_init(self: Producer, **kwargs) -> None: self.weight_name_incl = f"{self.weight_name}_inclusive" self.produces.add(self.weight_name_incl) else: - self.inclusive_dataset = [self.dataset_inst] + self.inclusive_dataset = self.dataset_inst self.required_datasets = [self.dataset_inst] From e170f92a82a10d7b6e54a74b1c9df9b0b34e901a Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 1 Sep 2025 17:01:49 +0200 Subject: [PATCH 068/123] Adapt norm weight producer to more generic cases. --- columnflow/production/normalization.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 9fad0561a..4e4c3f70e 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -413,16 +413,6 @@ def normalization_weights_setup( weights per process. - py: attr: `known_process_ids`: A set of all process ids that are known by the lookup table. """ - # optionally run the dataset lookup again in debug mode when stitching - do_stitch = ( - self.allow_stitching and - self.get_xsecs_from_inclusive_datasets and - len(self.required_datasets) > 1 - ) - is_first_branch = getattr(task, "branch", None) == 0 - if do_stitch and is_first_branch: - self.get_stitching_datasets(debug=True) - # load the selection stats dataset_selection_stats = { dataset: copy.deepcopy(task.cached_value( @@ -524,12 +514,22 @@ def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) self.inclusive_weight = norm_factor * inclusive_xsec * lumi / inclusive_sum_weights # fill weights into the lut, depending on whether stitching is allowed / needed or not + do_stitch = ( + self.allow_stitching and + self.get_xsecs_from_inclusive_datasets and + (len(process_insts) > 1 or len(self.required_datasets) > 1) + ) if do_stitch: logger.debug( f"using inclusive dataset '{self.inclusive_dataset.name}' and process '{inclusive_proc.name}' for cross " "section lookup", ) + # optionally run the dataset lookup again in debug mode when stitching + is_first_branch = getattr(task, "branch", None) == 0 + if is_first_branch: + self.get_stitching_datasets(debug=True) + # extract branching ratios branching_ratios = self.get_br_from_inclusive_datasets( process_insts, From 23e992daaaa76878ca5a6d0915900b9eb127cd93 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 2 Sep 2025 16:03:22 +0200 Subject: [PATCH 069/123] Hotfix cf_inspect for root files. --- bin/cf_inspect.py | 21 ++++++++++++++++++--- columnflow/tasks/union.py | 3 +-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/bin/cf_inspect.py b/bin/cf_inspect.py index 5c2ecdc4a..66b47ba01 100644 --- a/bin/cf_inspect.py +++ b/bin/cf_inspect.py @@ -13,8 +13,6 @@ import pickle import awkward as ak -import coffea.nanoevents -import uproot from columnflow.util import ipython_shell from columnflow.types import Any @@ -34,10 +32,27 @@ def _load_parquet(fname: str) -> ak.Array: return ak.from_parquet(fname) -def _load_nano_root(fname: str) -> ak.Array: +def _load_nano_root(fname: str, treepath: str | None = None) -> ak.Array: + import uproot + import coffea.nanoevents + source = uproot.open(fname) + + # get the default treepath + if treepath is None: + for treepath in source.keys(): + treepath = treepath.split(";", 1)[0] + obj = source[treepath] + if isinstance(obj, uproot.TTree): + break + else: + print(f"no default treepath determined in {fname}") + treepath = None + return coffea.nanoevents.NanoEventsFactory.from_root( source, + treepath=treepath, + delayed=False, runtime_cache=None, persistent_cache=None, ).events() diff --git a/columnflow/tasks/union.py b/columnflow/tasks/union.py index 5b52d22b3..f987e7594 100644 --- a/columnflow/tasks/union.py +++ b/columnflow/tasks/union.py @@ -105,8 +105,7 @@ def output(self): @law.decorator.safe_output def run(self): from columnflow.columnar_util import ( - Route, RouteFilter, mandatory_coffea_columns, update_ak_array, sorted_ak_to_parquet, - sorted_ak_to_root, + Route, RouteFilter, mandatory_coffea_columns, update_ak_array, sorted_ak_to_parquet, sorted_ak_to_root, ) # prepare inputs and outputs From 4407e52263f3d9fc11d47a865727e14177e2986a Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 2 Sep 2025 16:12:11 +0200 Subject: [PATCH 070/123] Optimize UniteColumns compression for root files. --- columnflow/tasks/union.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/union.py b/columnflow/tasks/union.py index f987e7594..2797537c2 100644 --- a/columnflow/tasks/union.py +++ b/columnflow/tasks/union.py @@ -173,7 +173,7 @@ def run(self): self, sorted_chunks, output["events"], local=True, writer_opts=self.get_parquet_writer_opts(), ) else: # root - law.root.hadd_task(self, sorted_chunks, output["events"], local=True) + law.root.hadd_task(self, sorted_chunks, output["events"], local=True, hadd_args=["-O", "-f501"]) # overwrite class defaults From 17b3968586c4e50b630cedacc6dd020ea21a4ac5 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 4 Sep 2025 15:11:56 +0200 Subject: [PATCH 071/123] Improve treepath detection in cf_insepct. --- bin/cf_inspect.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/bin/cf_inspect.py b/bin/cf_inspect.py index 66b47ba01..9db4d8045 100644 --- a/bin/cf_inspect.py +++ b/bin/cf_inspect.py @@ -18,21 +18,21 @@ from columnflow.types import Any -def _load_json(fname: str) -> Any: +def _load_json(fname: str, **kwargs) -> Any: with open(fname, "r") as fobj: return json.load(fobj) -def _load_pickle(fname: str) -> Any: +def _load_pickle(fname: str, **kwargs) -> Any: with open(fname, "rb") as fobj: return pickle.load(fobj) -def _load_parquet(fname: str) -> ak.Array: +def _load_parquet(fname: str, **kwargs) -> ak.Array: return ak.from_parquet(fname) -def _load_nano_root(fname: str, treepath: str | None = None) -> ak.Array: +def _load_nano_root(fname: str, treepath: str | None = None, **kwargs) -> ak.Array: import uproot import coffea.nanoevents @@ -40,14 +40,13 @@ def _load_nano_root(fname: str, treepath: str | None = None) -> ak.Array: # get the default treepath if treepath is None: - for treepath in source.keys(): + for treepath in ["events", "Events"] + list(source.keys()): treepath = treepath.split(";", 1)[0] - obj = source[treepath] - if isinstance(obj, uproot.TTree): + if treepath in source and isinstance(source[treepath], uproot.TTree): + print(f"using treepath '{treepath}' in root file {fname}") break else: - print(f"no default treepath determined in {fname}") - treepath = None + raise ValueError(f"no default treepath determined in {fname}") return coffea.nanoevents.NanoEventsFactory.from_root( source, @@ -58,19 +57,19 @@ def _load_nano_root(fname: str, treepath: str | None = None) -> ak.Array: ).events() -def load(fname: str) -> Any: +def load(fname: str, **kwargs) -> Any: """ Load file contents based on file extension. """ basename, ext = os.path.splitext(fname) if ext == ".pickle": - return _load_pickle(fname) + return _load_pickle(fname, **kwargs) if ext == ".parquet": - return _load_parquet(fname) + return _load_parquet(fname, **kwargs) if ext == ".root": - return _load_nano_root(fname) + return _load_nano_root(fname, **kwargs) if ext == ".json": - return _load_json(fname) + return _load_json(fname, **kwargs) raise NotImplementedError(f"no loader implemented for extension '{ext}'") @@ -105,12 +104,16 @@ def list_content(data: Any) -> None: ap.add_argument("files", metavar="FILE", nargs="+", help="one or more supported files") ap.add_argument("--events", "-e", action="store_true", help="assume files to contain event info") ap.add_argument("--hists", "-h", action="store_true", help="assume files to contain histograms") + ap.add_argument("--treepath", "-t", type=str, help="name of the tree in ROOT files") ap.add_argument("--list", "-l", action="store_true", help="list contents of the loaded file") ap.add_argument("--help", action="help", help="show this help message and exit") args = ap.parse_args() - objects = [load(fname) for fname in args.files] + load_kwargs = { + "treepath": args.treepath, + } + objects = [load(fname, **load_kwargs) for fname in args.files] if len(objects) == 1: objects = objects[0] print("file content loaded into variable 'objects'") From 1c04755e1216e381b47639616b77c5fd2cdc7248 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 9 Sep 2025 15:38:32 +0200 Subject: [PATCH 072/123] Hotfix fowarding of known_shifts for instance caching. --- columnflow/tasks/framework/base.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index 04c84f2f5..1335134ae 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -1375,14 +1375,17 @@ def get_known_shifts( resolution_task_cls = None @classmethod - def req_params(cls, inst: law.Task, *args, **kwargs) -> dict[str, Any]: - params = super().req_params(inst, *args, **kwargs) - + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: # manually add known shifts between workflows and branches - if isinstance(inst, law.BaseWorkflow) and inst.__class__ == cls and getattr(inst, "known_shifts", None): - params["known_shifts"] = inst.known_shifts + if ( + "known_shifts" not in kwargs and + isinstance(inst, law.BaseWorkflow) and + inst.__class__ == cls and + getattr(inst, "known_shifts", None) + ): + kwargs["known_shifts"] = inst.known_shifts - return params + return super().req_params(inst, **kwargs) @classmethod def _multi_sequence_repr( From 87db4871f18214021a0275dda468a91d7fff3429 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 9 Sep 2025 15:38:50 +0200 Subject: [PATCH 073/123] Minor consistency change. --- columnflow/tasks/framework/base.py | 3 +++ columnflow/tasks/framework/mixins.py | 17 +---------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index 1335134ae..c0dee5b3b 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -80,6 +80,9 @@ class TaskShifts: local: set[str] = field(default_factory=set) upstream: set[str] = field(default_factory=set) + def __hash__(self) -> int: + return hash((frozenset(self.local), frozenset(self.upstream))) + class BaseTask(law.Task): diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index abde3afb0..ac4f93f86 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -1211,21 +1211,8 @@ def ml_model_repr(self) -> str: @classmethod def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: - """ - Get the required parameters for the task, preferring the ``--ml-model`` set on task-level - via CLI. - - This method first checks if the ``--ml-model`` parameter is set at the task-level via the command line. If it - is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary. The method then - calls the 'req_params' method of the superclass with the updated kwargs. - - :param inst: The current task instance. - :param kwargs: Additional keyword arguments that may contain parameters for the task. - :return: A dictionary of parameters required for the task. - """ # prefer --ml-model set on task-level via cli kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"ml_model"} - return super().req_params(inst, **kwargs) @classmethod @@ -1588,7 +1575,6 @@ def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any] def req_params(cls, inst: law.Task, **kwargs) -> dict: # prefer --ml-models set on task-level via cli kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"ml_models"} - return super().req_params(inst, **kwargs) @property @@ -1852,10 +1838,9 @@ def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any] return params @classmethod - def req_params(cls, inst: law.Task, **kwargs) -> dict: + def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: # prefer --inference-model set on task-level via cli kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"inference_model"} - return super().req_params(inst, **kwargs) @property From 9bbdbbecfe9dd98dcffbf53dcc788c1879f1613a Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 9 Sep 2025 17:13:34 +0200 Subject: [PATCH 074/123] Fix brace expansion in ProduceColumnsWrapper. --- columnflow/tasks/production.py | 1 + 1 file changed, 1 insertion(+) diff --git a/columnflow/tasks/production.py b/columnflow/tasks/production.py index ceb5c736d..842dfbdcd 100644 --- a/columnflow/tasks/production.py +++ b/columnflow/tasks/production.py @@ -204,6 +204,7 @@ class ProduceColumnsWrapper(_ProduceColumnsWrapperBase): producers = law.CSVParameter( default=(), description="names of producers to use; if empty, the default producer is used", + brace_expand=True, ) def __init__(self, *args, **kwargs): From d83e07b9406ebec7ca663640b30f63f76b10ef85 Mon Sep 17 00:00:00 2001 From: Ana Andrade <99343616+aalvesan@users.noreply.github.com> Date: Wed, 10 Sep 2025 18:03:14 +0200 Subject: [PATCH 075/123] adapting dy weight producer for custom weights (#724) --- columnflow/production/cms/dy.py | 98 ++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 44 deletions(-) diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py index f38d57ba3..2c97424aa 100644 --- a/columnflow/production/cms/dy.py +++ b/columnflow/production/cms/dy.py @@ -24,18 +24,17 @@ @dataclass class DrellYanConfig: era: str - order: str correction: str - unc_correction: str + unc_correction: str | None = None + order: str | None = None + njets: bool = False + systs: list[str] | None = None def __post_init__(self) -> None: - if ( - not self.era or - not self.order or - not self.correction or - not self.unc_correction - ): - raise ValueError("incomplete dy_weight_config: missing era, order, correction or unc_correction") + if not self.era or not self.correction: + raise ValueError(f"{self.__class__.__name__}: missing era or correction") + if self.unc_correction and not self.order: + raise ValueError(f"{self.__class__.__name__}: when unc_correction is defined, order must be set") @producer( @@ -149,49 +148,65 @@ def dy_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: *get_dy_weight_config* can be adapted in a subclass in case it is stored differently in the config. """ - # map the input variable names from the corrector to our columns variable_map = { "era": self.dy_config.era, - "order": self.dy_config.order, "ptll": events.gen_dilepton_pt, } - # initializing the list of weight variations - weights_list = [("dy_weight", "nom")] + # optionals + if self.dy_config.order: + variable_map["order"] = self.dy_config.order + if self.dy_config.njets: + variable_map["njets"] = ak.num(events.Jet, axis=1) + + # initializing the list of weight variations (called syst in the dy files) + systs = [("nom", "")] - # appending the respective number of uncertainties to the weight list - for i in range(self.n_unc): - for shift in ("up", "down"): - tmp_tuple = (f"dy_weight{i + 1}_{shift}", f"{shift}{i + 1}") - weights_list.append(tmp_tuple) + # add specific uncertainties or additional systs + if self.dy_config.unc_correction: + for i in range(self.n_unc): + for direction in ["up", "down"]: + systs.append((f"{direction}{i + 1}", f"_{direction}{i + 1}")) + elif self.dy_config.systs: + for syst in self.dy_config.systs: + systs.append((syst, f"_{syst}")) # preparing the input variables for the corrector - for column_name, syst in weights_list: - variable_map_syst = {**variable_map, "syst": syst} + for syst, postfix in systs: + _variable_map = {**variable_map, "syst": syst} # evaluating dy weights given a certain era, ptll array and sytematic shift - inputs = [variable_map_syst[inp.name] for inp in self.dy_corrector.inputs] + inputs = [_variable_map[inp.name] for inp in self.dy_corrector.inputs] dy_weight = self.dy_corrector.evaluate(*inputs) # save the weights in a new column - events = set_ak_column(events, column_name, dy_weight, value_type=np.float32) + events = set_ak_column(events, f"dy_weight{postfix}", dy_weight, value_type=np.float32) return events @dy_weights.init def dy_weights_init(self: Producer) -> None: - # the number of weights in partial run 3 is always 10 if self.config_inst.campaign.x.year not in {2022, 2023}: raise NotImplementedError( f"campaign year {self.config_inst.campaign.x.year} is not yet supported by {self.cls_name}", ) - self.n_unc = 10 - # register dynamically produced weight columns - for i in range(self.n_unc): - self.produces.add(f"dy_weight{i + 1}_{{up,down}}") + # declare additional used columns + self.dy_config: DrellYanConfig = self.get_dy_weight_config() + if self.dy_config.njets: + self.uses.add("Jet.pt") + + # declare additional produced columns + if self.dy_config.unc_correction: + # the number should always be 10 + self.n_unc = 10 + for i in range(self.n_unc): + self.produces.add(f"dy_weight{i + 1}_{{up,down}}") + elif self.dy_config.systs: + for syst in self.dy_config.systs: + self.produces.add(f"dy_weight_{syst}") @dy_weights.requires @@ -215,31 +230,26 @@ def dy_weights_setup( reader_targets: law.util.InsertableDict, ) -> None: """ - Loads the Drell-Yan weight calculator from the external files bundle and saves them in the - py:attr:`dy_corrector` attribute for simpler access in the actual callable. The number of uncertainties - is calculated, per era, by another correcter in the external file and is saved in the - py:attr:`dy_unc_corrector` attribute. + Loads the Drell-Yan weight calculator from the external files bundle and saves them in the py:attr:`dy_corrector` + attribute for simpler access in the actual callable. The number of uncertainties is calculated, per era, by another + correcter in the external file and is saved in the py:attr:`dy_unc_corrector` attribute. """ bundle = reqs["external_files"] # import all correctors from the external file correction_set = load_correction_set(self.get_dy_weight_file(bundle.files)) - # check number of fetched correctors - if len(correction_set.keys()) != 2: - raise Exception("Expected exactly two types of Drell-Yan correction") - - # create the weight and uncertainty correctors - self.dy_config: DrellYanConfig = self.get_dy_weight_config() + # create the weight corrector self.dy_corrector = correction_set[self.dy_config.correction] - self.dy_unc_corrector = correction_set[self.dy_config.unc_correction] - dy_n_unc = int(self.dy_unc_corrector.evaluate(self.dy_config.order)) - - if dy_n_unc != self.n_unc: - raise ValueError( - f"Expected {self.n_unc} uncertainties, got {dy_n_unc}", - ) + # create the uncertainty corrector + if self.dy_config.unc_correction: + self.dy_unc_corrector = correction_set[self.dy_config.unc_correction] + dy_n_unc = int(self.dy_unc_corrector.evaluate(self.dy_config.order)) + if dy_n_unc != self.n_unc: + raise ValueError( + f"Expected {self.n_unc} uncertainties, got {dy_n_unc}", + ) @producer( From fc9553db6098fbf57b2e2cacc72135de85ff2c2c Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 15 Sep 2025 16:45:08 +0200 Subject: [PATCH 076/123] Consistent handling of kwargs in teardown functions. --- columnflow/tasks/framework/mixins.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index ac4f93f86..84d8dbb5e 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -209,9 +209,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.calibrator_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_calibrator_inst(self) -> None: + def teardown_calibrator_inst(self, **kwargs) -> None: if self.calibrator_inst: - self.calibrator_inst.run_teardown(task=self) + self.calibrator_inst.run_teardown(task=self, **kwargs) @property def calibrator_repr(self) -> str: @@ -596,9 +596,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.selector_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_selector_inst(self) -> None: + def teardown_selector_inst(self, **kwargs) -> None: if self.selector_inst: - self.selector_inst.run_teardown(task=self) + self.selector_inst.run_teardown(task=self, **kwargs) @property def selector_repr(self) -> str: @@ -794,9 +794,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.reducer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_reducer_inst(self) -> None: + def teardown_reducer_inst(self, **kwargs) -> None: if self.reducer_inst: - self.reducer_inst.run_teardown(task=self) + self.reducer_inst.run_teardown(task=self, **kwargs) @property def reducer_repr(self) -> str: @@ -970,9 +970,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.producer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_producer_inst(self) -> None: + def teardown_producer_inst(self, **kwargs) -> None: if self.producer_inst: - self.producer_inst.run_teardown(task=self) + self.producer_inst.run_teardown(task=self, **kwargs) @property def producer_repr(self) -> str: @@ -1485,9 +1485,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.preparation_producer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_preparation_producer_inst(self) -> None: + def teardown_preparation_producer_inst(self, **kwargs) -> None: if self.preparation_producer_inst: - self.preparation_producer_inst.run_teardown(task=self) + self.preparation_producer_inst.run_teardown(task=self, **kwargs) @classmethod def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: @@ -1801,9 +1801,9 @@ def _array_function_post_init(self, **kwargs) -> None: self.hist_producer_inst.run_post_init(task=self, **kwargs) super()._array_function_post_init(**kwargs) - def teardown_hist_producer_inst(self) -> None: + def teardown_hist_producer_inst(self, **kwargs) -> None: if self.hist_producer_inst: - self.hist_producer_inst.run_teardown(task=self) + self.hist_producer_inst.run_teardown(task=self, **kwargs) @property def hist_producer_repr(self) -> str: From d35482450ca7f865fdce83a3b018b7414fc9057f Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 15 Sep 2025 16:45:42 +0200 Subject: [PATCH 077/123] Hotfix TAF instance method defaults. --- columnflow/columnar_util.py | 104 +++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 30 deletions(-) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 625209b00..1bcabda96 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -1635,11 +1635,6 @@ def my_other_func_init(self): """ # class-level attributes as defaults - call_func = None - pre_init_func = None - init_func = None - skip_func = None - uses = set() produces = set() check_used_columns = True @@ -1731,15 +1726,20 @@ def PRODUCES(cls) -> IOFlagged: return cls.IOFlagged(cls, cls.IOFlag.PRODUCES) @classmethod - def call(cls, func: Callable[[Any, ...], Any]) -> None: + def pre_init(cls, func: Callable[[], None]) -> None: """ - Decorator to wrap a function *func* that should be registered as :py:meth:`call_func` - which defines the main callable for processing chunks of data. The function should accept - arbitrary arguments and can return arbitrary objects. + Decorator to wrap a function *func* that should be registered as :py:meth:`pre_init_func` + which is invoked prior to any dependency creation. The function should not accept arguments. The decorator does not return the wrapped function. """ - cls.call_func = func + cls.pre_init_func = func + + def pre_init_func(self) -> None: + """ + Default pre-init function. + """ + return @classmethod def init(cls, func: Callable[[], None]) -> None: @@ -1752,16 +1752,11 @@ def init(cls, func: Callable[[], None]) -> None: """ cls.init_func = func - @classmethod - def pre_init(cls, func: Callable[[], None]) -> None: + def init_func(self) -> None: """ - Decorator to wrap a function *func* that should be registered as :py:meth:`pre_init_func` - which is invoked prior to any dependency creation. The function should not accept positional - arguments. - - The decorator does not return the wrapped function. + Default init function. """ - cls.pre_init_func = func + return @classmethod def skip(cls, func: Callable[[], bool]) -> None: @@ -1774,12 +1769,35 @@ def skip(cls, func: Callable[[], bool]) -> None: """ cls.skip_func = func + def skip_func(self) -> None: + """ + Default skip function. + """ + return + + @classmethod + def call(cls, func: Callable[[Any, ...], Any]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`call_func` + which defines the main callable for processing chunks of data. The function should accept + arbitrary arguments and can return arbitrary objects. + + The decorator does not return the wrapped function. + """ + cls.call_func = func + + def call_func(self, *args, **kwargs) -> Any: + """ + Default call function. + """ + return + def __init__( self, - call_func: Callable | None = law.no_value, - pre_init_func: Callable | None = law.no_value, - init_func: Callable | None = law.no_value, - skip_func: Callable | None = law.no_value, + pre_init_func: Callable | law.NoValue | None = law.no_value, + init_func: Callable | law.NoValue | None = law.no_value, + skip_func: Callable | law.NoValue | None = law.no_value, + call_func: Callable | law.NoValue | None = law.no_value, check_used_columns: bool | None = None, check_produced_columns: bool | None = None, instance_cache: dict | None = None, @@ -1790,14 +1808,14 @@ def __init__( super().__init__() # add class-level attributes as defaults for unset arguments (no_value) - if call_func == law.no_value: - call_func = self.__class__.call_func if pre_init_func == law.no_value: pre_init_func = self.__class__.pre_init_func if init_func == law.no_value: init_func = self.__class__.init_func if skip_func == law.no_value: skip_func = self.__class__.skip_func + if call_func == law.no_value: + call_func = self.__class__.call_func if check_used_columns is not None: self.check_used_columns = check_used_columns if check_produced_columns is not None: @@ -1806,14 +1824,14 @@ def __init__( self.log_runtime = log_runtime # when a custom funcs are passed, bind them to this instance - if call_func: - self.call_func = call_func.__get__(self, self.__class__) if pre_init_func: self.pre_init_func = pre_init_func.__get__(self, self.__class__) if init_func: self.init_func = init_func.__get__(self, self.__class__) if skip_func: self.skip_func = skip_func.__get__(self, self.__class__) + if call_func: + self.call_func = call_func.__get__(self, self.__class__) # create instance-level sets of dependent ArrayFunction classes, # optionally with priority to sets passed in keyword arguments @@ -2402,10 +2420,6 @@ class the normal way, or use a decorator to wrap the main callable first and by """ # class-level attributes as defaults - post_init_func = None - requires_func = None - setup_func = None - teardown_func = None sandbox = None call_force = None max_chunk_size = None @@ -2459,6 +2473,12 @@ def post_init(cls, func: Callable[[dict], None]) -> None: """ cls.post_init_func = func + def post_init_func(self, task: law.Task) -> None: + """ + Default post-init function. + """ + return + @classmethod def requires(cls, func: Callable[[dict], None]) -> None: """ @@ -2481,6 +2501,12 @@ def requires(cls, func: Callable[[dict], None]) -> None: """ cls.requires_func = func + def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]]) -> None: + """ + Default requires function. + """ + return + @classmethod def setup(cls, func: Callable[[dict], None]) -> None: """ @@ -2499,6 +2525,18 @@ def setup(cls, func: Callable[[dict], None]) -> None: """ cls.setup_func = func + def setup_func( + self, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + ) -> None: + """ + Default setup function. + """ + return + @classmethod def teardown(cls, func: Callable[[dict], None]) -> None: """ @@ -2512,6 +2550,12 @@ def teardown(cls, func: Callable[[dict], None]) -> None: """ cls.teardown_func = func + def teardown_func(self, task: law.Task) -> None: + """ + Default teardown function. + """ + return + def __init__( self, *args, From 17d2fafeb153ceb283253aabc3e9a5feeb97ce11 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Mon, 15 Sep 2025 16:46:40 +0200 Subject: [PATCH 078/123] Refactor and fix met phi calibration. --- columnflow/calibration/cms/met.py | 372 +++++++++++++++++++----------- 1 file changed, 231 insertions(+), 141 deletions(-) diff --git a/columnflow/calibration/cms/met.py b/columnflow/calibration/cms/met.py index 942700a64..9774c8006 100644 --- a/columnflow/calibration/cms/met.py +++ b/columnflow/calibration/cms/met.py @@ -6,11 +6,12 @@ from __future__ import annotations -import law +import functools +from dataclasses import dataclass, field -from dataclasses import dataclass +import law -from columnflow.calibration import Calibrator, calibrator +from columnflow.calibration import Calibrator from columnflow.util import maybe_import, load_correction_set, DotDict from columnflow.columnar_util import set_ak_column from columnflow.types import Any @@ -19,41 +20,178 @@ ak = maybe_import("awkward") +# helpers +set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) + + +class _met_phi_base(Calibrator): + """" + Common base class for MET phi calibrators. + """ + + exposed = False + + # function to determine the correction file + get_met_file = lambda self, external_files: external_files.met_phi_corr + + # function to determine met correction config + get_met_config = lambda self: self.config_inst.x.met_phi_correction + + def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +# +# Run 2 implementation +# + +@dataclass +class METPhiConfigRun2: + correction_set_template = r"{variable}_metphicorr_pfmet_{data_source}" + met_name: str = "MET" + keep_uncorrected: bool = False + + +@_met_phi_base.calibrator(exposed=True) +def met_phi_run2(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: + """ + Performs the MET phi (type II) correction using :external+correctionlib:doc:`index`. Events whose uncorrected MET pt + is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``) are skipped. Requires an external file + in the config under ``met_phi_corr``: + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "met_phi_corr": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-406118ec/POG/JME/2022_Summer22EE/met_xyCorrections_2022_2022EE.json.gz", # noqa + }) + + *get_met_file* can be adapted in a subclass in case it is stored differently in the external files. + + The calibrator should be configured with an :py:class:`METPhiConfigRun2` as an auxiliary entry in the config named + ``met_phi_correction``. *get_met_config* can be adapted in a subclass in case it is stored differently in the + config. Exemplary config entry: + + .. code-block:: python + + from columnflow.calibration.cms.met import METPhiConfigRun2 + cfg.x.met_phi_correction = METPhiConfigRun2( + met_name="MET", + correction_set_template="{variable}_metphicorr_pfmet_{data_source}", + keep_uncorrected=False, + ) + + "variable" and "data_source" are placeholders that will be replace with "pt" or "phi", and the data source of the + current dataset, respectively. + + Resources: + - https://twiki.cern.ch/twiki/bin/view/CMS/MissingETRun2Corrections?rev=79#xy_Shift_Correction_MET_phi_modu + """ + # get met columns + met_name = self.met_config.met_name + met = events[met_name] + + # store uncorrected values if requested + if self.met_config.keep_uncorrected: + events = set_ak_column_f32(events, f"{met_name}.pt_metphi_uncorrected", met.pt) + events = set_ak_column_f32(events, f"{met_name}.phi_metphi_uncorrected", met.phi) + + # copy the intial pt and phi values + corr_pt = np.array(met.pt, dtype=np.float32) + corr_phi = np.array(met.phi, dtype=np.float32) + + # select only events where MET pt is below the expected beam energy + mask = met.pt < (0.5 * self.config_inst.campaign.ecm) + + # arguments for evaluation + args = ( + met.pt[mask], + met.phi[mask], + ak.values_astype(events.PV.npvs[mask], np.float32), + ak.values_astype(events.run[mask], np.float32), + ) + + # evaluate and insert + corr_pt[mask] = self.met_pt_corrector.evaluate(*args) + corr_phi[mask] = self.met_phi_corrector.evaluate(*args) + + # save the corrected values + events = set_ak_column_f32(events, f"{met_name}.pt", corr_pt) + events = set_ak_column_f32(events, f"{met_name}.phi", corr_phi) + + return events + + +@met_phi_run2.init +def met_phi_run2_init(self: Calibrator, **kwargs) -> None: + self.met_config = self.get_met_config() + + # set used columns + self.uses.update({"run", "PV.npvs", f"{self.met_config.met_name}.{{pt,phi}}"}) + + # set produced columns + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}") + if self.met_config.keep_uncorrected: + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}_metphi_uncorrected") + + +@met_phi_run2.setup +def met_phi_run2_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: + # create the pt and phi correctors + met_file = self.get_met_file(reqs["external_files"].files) + correction_set = load_correction_set(met_file) + + name_tmpl = self.met_config.correction_set_template + self.met_pt_corrector = correction_set[name_tmpl.format( + variable="pt", + data_source=self.dataset_inst.data_source, + )] + self.met_phi_corrector = correction_set[name_tmpl.format( + variable="phi", + data_source=self.dataset_inst.data_source, + )] + + +# +# Run 3 implementation +# + @dataclass class METPhiConfig: - variable_config: dict[str, tuple[str]] correction_set: str = "met_xy_corrections" met_name: str = "PuppiMET" - met_type: str = "MET" + met_type: str = "PuppiMET" keep_uncorrected: bool = False - - @classmethod - def new( - cls, - obj: METPhiConfig | tuple[str, list[str]] | tuple[str, list[str], str], - ) -> METPhiConfig: - # purely for backwards compatibility with the old string format - if isinstance(obj, cls): - return obj - if isinstance(obj, str): - return cls(correction_set=obj, variable_config={"pt": ("pt",), "phi": ("phi",)}) - if isinstance(obj, dict): - return cls(**obj) - raise ValueError(f"cannot convert {obj} to METPhiConfig") - - -@calibrator( - uses={"run", "PV.npvs", "PV.npvsGood"}, - # function to determine the correction file - get_met_file=(lambda self, external_files: external_files.met_phi_corr), - # function to determine met correction config - get_met_config=(lambda self: METPhiConfig.new(self.config_inst.x.met_phi_correction)), -) + # variations (intrinsic method uncertainties) for pt and phi + pt_phi_variations: dict[str, str] | None = field(default_factory=lambda: { + "stat_xdn": "metphi_statx_down", + "stat_xup": "metphi_statx_up", + "stat_ydn": "metphi_staty_down", + "stat_yup": "metphi_staty_up", + }) + # other variations (external uncertainties) + variations: dict[str, str] | None = field(default_factory=lambda: { + "pu_dn": "minbias_xs_down", + "pu_up": "minbias_xs_up", + }) + + +@_met_phi_base.calibrator(exposed=True) def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ - Performs the MET phi (type II) correction using the :external+correctionlib:doc:`index` for events there the - uncorrected MET pt is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``). Requires an - external file in the config under ``met_phi_corr``: + Performs the MET phi (type II) correction using :external+correctionlib:doc:`index`. Events whose uncorrected MET pt + is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``) are skipped. Requires an external file + in the config under ``met_phi_corr``: .. code-block:: python @@ -63,131 +201,98 @@ def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: *get_met_file* can be adapted in a subclass in case it is stored differently in the external files. - The met_phi Calibrator should be configured with an auxiliary entry in the config that can contain: - - - the name of the correction set - - the name of the MET column - - the MET type that is passed as an input to the correction set - - a boolean flag to keep the uncorrected MET pt and phi values as additional output columns - - a dictionary that maps the input variable names ("pt", "phi") to a list of output variable names that should - be produced. - - Exemplary config entry: + The calibrator should be configured with an :py:class:`METPhiConfig` as an auxiliary entry in the config named + ``met_phi_correction``. *get_met_config* can be adapted in a subclass in case it is stored differently in the + config. Exemplary config entry: .. code-block:: python from columnflow.calibration.cms.met import METPhiConfig cfg.x.met_phi_correction = METPhiConfig( - met_name="PuppiMET", - met_type="MET", correction_set="met_xy_corrections", + met_name="PuppiMET", + met_type="PuppiMET", keep_uncorrected=False, - variable_config={ - "pt": ( - "pt", - "pt_stat_yup", - "pt_stat_ydn", - "pt_stat_xup", - "pt_stat_xdn", - ), - "phi": ( - "phi", - "phi_stat_yup", - "phi_stat_ydn", - "phi_stat_xup", - "phi_stat_xdn", - ), + # mappings of method variation to column (pt/phi) postfixes + pt_phi_variations={ + "stat_xdn": "metphi_statx_down", + "stat_xup": "metphi_statx_up", + "stat_ydn": "metphi_staty_down", + "stat_yup": "metphi_staty_up", + }, + variations={ + "pu_dn": "minbias_xs_down", + "pu_up": "minbias_xs_up", }, ) - - The `correction_set` value can also contain the placeholders "variable" and "data_source" that are replaced in the - calibrator setup :py:meth:`~.met_phi.setup_func`. - - *get_met_config* can be adapted in a subclass in case it is stored differently in the config. - - Resources: - - https://twiki.cern.ch/twiki/bin/view/CMS/MissingETRun2Corrections?rev=79#xy_Shift_Correction_MET_phi_modu - - :param events: awkward array containing events to process """ - # get Met columns - met = events[self.met_config.met_name] + # get met + met_name = self.met_config.met_name + met = events[met_name] + + # store uncorrected values if requested + if self.met_config.keep_uncorrected: + events = set_ak_column_f32(events, f"{met_name}.pt_metphi_uncorrected", met.pt) + events = set_ak_column_f32(events, f"{met_name}.phi_metphi_uncorrected", met.phi) # correct only events where MET pt is below the expected beam energy mask = met.pt < (0.5 * self.config_inst.campaign.ecm * 1000) # convert TeV to GeV + # gather variables variable_map = { "met_type": self.met_config.met_type, "epoch": f"{self.config_inst.campaign.x.year}{self.config_inst.campaign.x.postfix}", "dtmc": "DATA" if self.dataset_inst.is_data else "MC", - "variation": "nom", "met_pt": ak.values_astype(met.pt[mask], np.float32), "met_phi": ak.values_astype(met.phi[mask], np.float32), "npvGood": ak.values_astype(events.PV.npvsGood[mask], np.float32), - "npvs": ak.values_astype(events.PV.npvs[mask], np.float32), # needed for old-style corrections - "run": ak.values_astype(events.run[mask], np.float32), } - for variable, outp_variables in self.met_config.variable_config.items(): - met_corrector = self.met_correctors[variable] - if self.met_config.keep_uncorrected: - events = set_ak_column( - events, - f"{self.met_config.met_name}.{variable}_xy_uncorrected", - met[variable], - value_type=np.float32, - ) - for out_var in outp_variables: - # copy initial value every time - # NOTE: this needs to be within the loop to ensure that the output values are not - # overwritten by the next iteration - corr_var = np.array(met[variable], dtype=np.float32) - - # get the input variables for the correction - variable_map_syst = { - **variable_map, - "pt_phi": out_var, - } - inputs = [variable_map_syst[inp.name] for inp in met_corrector.inputs] - - # insert the corrected values - corr_var[mask] = met_corrector(*inputs) - - # save the corrected values - events = set_ak_column(events, f"{self.met_config.met_name}.{out_var}", corr_var, value_type=np.float32) + # evaluate pt and phi separately + for var in ["pt", "phi"]: + # remember initial values + vals_orig = np.array(met[var], dtype=np.float32) + # loop over general variations, then pt/phi variations + # (needed since the JME correction file is inconsistent in how intrinsic and external variations are treated) + general_vars = {"nom": ""} + if self.dataset_inst.is_mc: + general_vars.update(self.met_config.variations or {}) + for variation, postfix in general_vars.items(): + pt_phi_vars = {"": ""} + if variation == "nom" and self.dataset_inst.is_mc: + pt_phi_vars.update(self.met_config.pt_phi_variations or {}) + for pt_phi_variation, pt_phi_postfix in pt_phi_vars.items(): + _postfix = postfix or pt_phi_postfix + out_var = f"{var}{_postfix and '_' + _postfix}" + # prepare evaluator inputs + _variable_map = { + **variable_map, + "pt_phi": f"{var}{pt_phi_variation and '_' + pt_phi_variation}", + "variation": variation, + } + inputs = [_variable_map[inp.name] for inp in self.met_corrector.inputs] + # evaluate and create new column + corr_vals = np.array(vals_orig) + corr_vals[mask] = self.met_corrector(*inputs) + events = set_ak_column_f32(events, f"{met_name}.{out_var}", corr_vals) return events @met_phi.init def met_phi_init(self: Calibrator, **kwargs) -> None: - """ - Initialize the :py:attr:`met_pt_corrector` and :py:attr:`met_phi_corrector` attributes. - """ self.met_config = self.get_met_config() - self.uses.add(f"{self.met_config.met_name}.{{pt,phi}}") - - for variable in self.met_config.variable_config.keys(): - if self.met_config.keep_uncorrected: - self.produces.add(f"{self.met_config.met_name}.{variable}_xy_uncorrected") - for out_var in self.met_config.variable_config[variable]: - # add the produced columns to the uses set - self.produces.add(f"{self.met_config.met_name}.{out_var}") - + # set used columns + self.uses.update({"PV.npvsGood", f"{self.met_config.met_name}.{{pt,phi}}"}) -@met_phi.requires -def met_phi_requires( - self: Calibrator, - task: law.Task, - reqs: dict[str, DotDict[str, Any]], - **kwargs, -) -> None: - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(task) + # set produced columns + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}") + if self.dataset_inst.is_mc: + for postfix in {**(self.met_config.pt_phi_variations or {}), **(self.met_config.variations or {})}.values(): + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}_{postfix}") + if self.met_config.keep_uncorrected: + self.produces.add(f"{self.met_config.met_name}.{{pt,phi}}_metphi_uncorrected") @met_phi.setup @@ -199,22 +304,7 @@ def met_phi_setup( reader_targets: law.util.InsertableDict, **kwargs, ) -> None: - """ - Load the correct met files using the :py:func:`from_string` method of the - :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` function and apply the corrections as - needed. - - :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` instance - :param inputs: Additional inputs, currently not used. - :param reader_targets: Additional targets, currently not used. - """ - # create the pt and phi correctors + # load the corrector met_file = self.get_met_file(reqs["external_files"].files) correction_set = load_correction_set(met_file) - name_tmpl = self.met_config.correction_set - self.met_correctors = { - variable: correction_set[name_tmpl.format( - variable=variable, - data_source=self.dataset_inst.data_source, - )] for variable in self.met_config.variable_config.keys() - } + self.met_corrector = correction_set[self.met_config.correction_set] From 49c8a38b80857195bc4e0e4e1db01c03e26f5a67 Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:08:57 +0200 Subject: [PATCH 079/123] add plot function for efficeincy plots (#723) Co-authored-by: Mathis Frahm Co-authored-by: Marcel Rieger --- columnflow/plotting/plot_functions_1d.py | 53 +++++++++++++++++++++--- columnflow/plotting/plot_util.py | 19 ++++++--- 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 13dc40014..1899cbc4d 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -99,6 +99,9 @@ def plot_variable_stack( shape_norm, yscale, ) + # additional, plot function specific changes + if shape_norm: + default_style_config["ax_cfg"]["ylabel"] = "Normalized entries" style_config = law.util.merge_dicts( default_style_config, process_style_config, @@ -107,13 +110,51 @@ def plot_variable_stack( deep=True, ) - # additional, plot function specific changes - if shape_norm: - style_config["ax_cfg"]["ylabel"] = "Normalized entries" - return plot_all(plot_config, style_config, **kwargs) +def plot_variable_efficiency( + hists: OrderedDict, + config_inst: od.Config, + category_inst: od.Category, + variable_insts: list[od.Variable], + shift_insts: list[od.Shift] | None, + style_config: dict | None = None, + shape_norm: bool = True, + cumsum_reverse: bool = True, + **kwargs, +): + """ + This plot function allows users to plot the efficiency of a cut on a variable as a function of the cut value. + Per default, each bin shows the efficiency of requiring value >= bin edge (cumsum_reverse=True). + Setting cumsum_reverse=False will instead show the efficiency of requiring value <= bin edge. + """ + for proc_inst, proc_hist in hists.items(): + if cumsum_reverse: + proc_hist.values()[...] = np.cumsum(proc_hist.values()[..., ::-1], axis=-1)[..., ::-1] + shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: h.values()[0] if shape_norm else 1) + else: + proc_hist.values()[...] = np.cumsum(proc_hist.values(), axis=-1) + shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: h.values()[-1] if shape_norm else 1) + + default_style_config = { + "ax_cfg": {"ylabel": "Efficiency" if shape_norm else "Cumulative entries"}, + } + style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) + + return plot_variable_stack( + hists, + config_inst, + category_inst, + variable_insts, + shift_insts, + shape_norm=shape_norm, + shape_norm_func=shape_norm_func, + style_config=style_config, + **kwargs, + ) + + def plot_variable_variants( hists: OrderedDict, config_inst: od.Config, @@ -274,6 +315,8 @@ def plot_shifted_variable( default_style_config["rax_cfg"]["ylabel"] = "Ratio" if legend_title: default_style_config["legend_cfg"]["title"] = legend_title + if shape_norm: + style_config["ax_cfg"]["ylabel"] = "Normalized entries" style_config = law.util.merge_dicts( default_style_config, process_style_config, @@ -281,8 +324,6 @@ def plot_shifted_variable( style_config, deep=True, ) - if shape_norm: - style_config["ax_cfg"]["ylabel"] = "Normalized entries" return plot_all(plot_config, style_config, **kwargs) diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index 8aa5a0302..c2088414f 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -541,9 +541,13 @@ def prepare_stack_plot_config( # setup plotting configs plot_config = OrderedDict() + # take first (non-underflow) bin + # shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: h.values()[0] if shape_norm else 1) + shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: sum(h.values()) if shape_norm else 1) + # draw stack if h_mc_stack is not None: - mc_norm = sum(h_mc.values()) if shape_norm else 1 + mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_stack"] = { "method": "draw_stack", "hist": h_mc_stack, @@ -558,7 +562,7 @@ def prepare_stack_plot_config( # draw lines for i, h in enumerate(line_hists): - line_norm = sum(h.values()) if shape_norm else 1 + line_norm = shape_norm_func(h, shape_norm) plot_config[f"line_{i}"] = plot_cfg = { "method": "draw_hist", "hist": h, @@ -582,7 +586,7 @@ def prepare_stack_plot_config( # draw statistical error for stack if h_mc_stack is not None and not hide_stat_errors: - mc_norm = sum(h_mc.values()) if shape_norm else 1 + mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_stat_unc"] = { "method": "draw_stat_error_bands", "hist": h_mc, @@ -592,7 +596,7 @@ def prepare_stack_plot_config( # draw systematic error for stack if h_mc_stack is not None and mc_syst_hists: - mc_norm = sum(h_mc.values()) if shape_norm else 1 + mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_syst_unc"] = { "method": "draw_syst_error_bands", "hist": h_mc, @@ -611,7 +615,7 @@ def prepare_stack_plot_config( # draw data if data_hists: - data_norm = sum(h_data.values()) if shape_norm else 1 + data_norm = shape_norm_func(h_data, shape_norm) plot_config["data"] = plot_cfg = { "method": "draw_errorbars", "hist": h_data, @@ -908,7 +912,7 @@ def blind_sensitive_bins( # set data points in masked region to zero for proc, h in data.items(): - h.values()[..., mask] = 0 + h.values()[..., mask] = -999 h.variances()[..., mask] = 0 # merge all histograms @@ -1048,6 +1052,9 @@ def calculate_stat_error( values = hist.view().value confidence_interval = poisson_interval(values, variances) + # negative values are considerd as blinded bins -> set confidence interval to 0 + confidence_interval[:, values < 0] = 0 + if error_type == "poisson_weighted": # might happen if some bins are empty, see https://github.com/scikit-hep/hist/blob/5edbc25503f2cb8193cc5ff1eb71e1d8fa877e3e/src/hist/intervals.py#L74 # noqa: E501 confidence_interval[np.isnan(confidence_interval)] = 0 From f0cc020b777489e0b5ea367731b04b7eb2369299 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 16 Sep 2025 16:37:13 +0200 Subject: [PATCH 080/123] Hotfix parameter group cleaning in inference model. --- columnflow/inference/__init__.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index fac1d6831..a54aa67ee 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -1799,23 +1799,17 @@ def remove_empty_categories(self) -> None: def remove_dangling_parameters_from_groups( self, keep_parameters: str | Sequence[str] | None = None, - match_mode: Callable = any, ) -> None: """ - Removes names of parameters from parameter groups that are not assigned to any process in - any category. + Removes names of parameters from parameter groups that are not assigned to any process in any category. :param keep_parameters: A string, pattern, or sequence of them to specify parameters to keep. - :param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see - :py:func:`pattern_matcher`). """ # get a list of all parameters parameter_names = self.get_parameters("*", flat=True) # get set of parameters to keep - _keep_parameters = set() - if keep_parameters: - _keep_parameters = set(self.get_parameters(keep_parameters, match_mode=match_mode, flat=True)) + _keep_parameters = law.util.make_set(keep_parameters) if keep_parameters else set() # go through groups and remove dangling parameters for group in self.parameter_groups: @@ -1824,7 +1818,7 @@ def remove_dangling_parameters_from_groups( for parameter_name in group.parameter_names if ( parameter_name in parameter_names or - (_keep_parameters and parameter_name in _keep_parameters) + law.util.multi_match(parameter_name, _keep_parameters, mode=any) ) ] From bbd86c1a57369894468067d256321e8c313535ac Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 17 Sep 2025 09:00:11 +0200 Subject: [PATCH 081/123] Hotfix: allow brace patterns in TAF shifts. --- columnflow/columnar_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 1bcabda96..6b5fa3461 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -2694,7 +2694,7 @@ def _get_all_shifts(self, _cache: set | None = None) -> set[str]: if isinstance(shift, od.Shift): shifts.add(shift.name) elif isinstance(shift, str): - shifts.add(shift) + shifts.update(law.util.brace_expand(shift)) _cache.add(self) # add shifts of all dependent objects From e0953a948a2bbc3cc1b5142fa715339259c5490c Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Wed, 17 Sep 2025 12:20:35 +0200 Subject: [PATCH 082/123] Remove year from intrinsic btag weight names. (#726) --- columnflow/production/cms/btag.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/columnflow/production/cms/btag.py b/columnflow/production/cms/btag.py index 65fd9c290..0577c3d59 100644 --- a/columnflow/production/cms/btag.py +++ b/columnflow/production/cms/btag.py @@ -248,18 +248,17 @@ def add_weight(syst_name, syst_direction, column_name): events = add_weight("central", None, self.weight_name) for syst_name, col_name in self.btag_uncs.items(): for direction in ["up", "down"]: - name = col_name.format(year=self.config_inst.campaign.x.year) events = add_weight( syst_name, direction, - f"{self.weight_name}_{name}_{direction}", + f"{self.weight_name}_{col_name}_{direction}", ) if syst_name in ["cferr1", "cferr2"]: # for c flavor uncertainties, multiply the uncertainty with the nominal btag weight events = set_ak_column( events, - f"{self.weight_name}_{name}_{direction}", - events[self.weight_name] * events[f"{self.weight_name}_{name}_{direction}"], + f"{self.weight_name}_{col_name}_{direction}", + events[self.weight_name] * events[f"{self.weight_name}_{col_name}_{direction}"], value_type=np.float32, ) elif self.shift_is_known_jec_source: @@ -287,7 +286,7 @@ def btag_weights_post_init(self: Producer, task: law.Task, **kwargs) -> None: # NOTE: we currently setup the produced columns only during the post_init. This means # that the `produces` of this Producer will be empty during task initialization, meaning - # that this Producer would be skipped if one would directly request it on command line + # that this Producer would be skipped if one would directly request it on the command line # gather info self.btag_config = self.get_btag_config() @@ -303,14 +302,14 @@ def btag_weights_post_init(self: Producer, task: law.Task, **kwargs) -> None: self.jec_source and btag_sf_jec_source in self.btag_config.jec_sources ) - # save names of method-intrinsic uncertainties + # names of method-intrinsic uncertainties, mapped to how they are namend in produced columns self.btag_uncs = { "hf": "hf", "lf": "lf", - "hfstats1": "hfstats1_{year}", - "hfstats2": "hfstats2_{year}", - "lfstats1": "lfstats1_{year}", - "lfstats2": "lfstats2_{year}", + "hfstats1": "hfstats1", + "hfstats2": "hfstats2", + "lfstats1": "lfstats1", + "lfstats2": "lfstats2", "cferr1": "cferr1", "cferr2": "cferr2", } @@ -321,9 +320,7 @@ def btag_weights_post_init(self: Producer, task: law.Task, **kwargs) -> None: self.produces.add(self.weight_name) # all varied columns for col_name in self.btag_uncs.values(): - name = col_name.format(year=self.config_inst.campaign.x.year) - for direction in ["up", "down"]: - self.produces.add(f"{self.weight_name}_{name}_{direction}") + self.produces.add(f"{self.weight_name}_{col_name}_{{up,down}}") elif self.shift_is_known_jec_source: # jec varied column self.produces.add(f"{self.weight_name}_jec_{self.jec_source}_{shift_inst.direction}") From fccc8532b0785d909b7a26731b668e6975497640 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 23 Sep 2025 11:14:52 +0200 Subject: [PATCH 083/123] Forward remote claw sandbox. --- columnflow/tasks/framework/base.py | 2 +- columnflow/tasks/framework/remote.py | 19 +++++++++++++++++++ modules/law | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index c0dee5b3b..680724024 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -179,7 +179,7 @@ def req_params(cls, inst: AnalysisTask, **kwargs) -> dict[str, Any]: _prefer_cli = law.util.make_set(kwargs.get("_prefer_cli", [])) | { "version", "workflow", "job_workers", "poll_interval", "walltime", "max_runtime", "retries", "acceptance", "tolerance", "parallel_jobs", "shuffle_jobs", "htcondor_cpus", - "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_pool", "pilot", + "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_pool", "pilot", "remote_claw_sandbox", } kwargs["_prefer_cli"] = _prefer_cli diff --git a/columnflow/tasks/framework/remote.py b/columnflow/tasks/framework/remote.py index bac3affdb..6ba3bb72d 100644 --- a/columnflow/tasks/framework/remote.py +++ b/columnflow/tasks/framework/remote.py @@ -347,8 +347,17 @@ class RemoteWorkflowMixin(AnalysisTask): Mixin class for custom remote workflows adding common functionality. """ + remote_claw_sandbox = luigi.Parameter( + default=law.NO_STR, + significant=False, + description="the name of a non-dev sandbox to use in remote jobs for the 'claw' executable rather than using " + "using 'law' directly; not used when empty; default: empty", + ) + skip_destination_info: bool = False + exclude_params_req = {"remote_claw_sandbox"} + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -576,6 +585,14 @@ def add_common_configs( render=False, ) + # claw sandbox + if self.remote_claw_sandbox not in {None, "", law.NO_STR}: + if self.remote_claw_sandbox.endswith("_dev"): + raise ValueError( + f"remote_claw_sandbox must not refer to a dev sandbox, got '{self.remote_claw_sandbox}'", + ) + config.render_variables["law_exe"] = f"CLAW_SANDBOX='{self.remote_claw_sandbox}' claw" + def common_destination_info(self, info: dict[str, str]) -> dict[str, str]: """ Hook to modify the additional info printed along logs of the workflow. @@ -800,6 +817,8 @@ def htcondor_job_config(self, config, job_num, branches): batch_name += f"_{info['config']}" if "dataset" in info: batch_name += f"_{info['dataset']}" + if "shift" in info: + batch_name += f"_{info['shift']}" config.custom_content.append(("batch_name", batch_name)) # CERN settings, https://batchdocs.web.cern.ch/local/submit.html#os-selection-via-containers diff --git a/modules/law b/modules/law index c8c40094d..0578967c0 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit c8c40094d53a42849e408e77436319be7f9764c5 +Subproject commit 0578967c014323324f25006e0ba05a03de429cc6 From 9ea6c7253a8de3c355cae8e264e8e9f2e6ba1cc6 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 24 Sep 2025 11:14:07 +0200 Subject: [PATCH 084/123] Add pilot option to MergeShiftedHistograms. --- columnflow/inference/cms/datacard.py | 3 +++ columnflow/tasks/histograms.py | 36 +++++++++++++++------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 032a988ce..e94916eca 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -260,6 +260,9 @@ def write( type_str = "lnU" blocks.tabular_parameters.append([param_name, type_str, effects]) + # alphabetical, case-insensitive order by name + blocks.tabular_parameters.sort(key=lambda line: line[0].lower()) + if blocks.tabular_parameters: empty_lines.add("tabular_parameters") diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index 883661e05..013b21680 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -417,10 +417,12 @@ def requires(self): ) def output(self): - return {"hists": law.SiblingFileCollection({ - variable_name: self.target(f"hist__var_{variable_name}.pickle") - for variable_name in self.variables - })} + return { + "hists": law.SiblingFileCollection({ + variable_name: self.target(f"hist__var_{variable_name}.pickle") + for variable_name in self.variables + }), + } @law.decorator.notify @law.decorator.log @@ -503,9 +505,10 @@ def create_branch_map(self): def workflow_requires(self): reqs = super().workflow_requires() - # add nominal and both directions per shift source - for shift in ["nominal"] + self.shifts: - reqs[shift] = self.reqs.MergeHistograms.req(self, shift=shift, _prefer_cli={"variables"}) + if not self.pilot: + # add nominal and both directions per shift source + for shift in ["nominal"] + self.shifts: + reqs[shift] = self.reqs.MergeHistograms.req(self, shift=shift, _prefer_cli={"variables"}) return reqs @@ -531,17 +534,16 @@ def run(self): outputs = self.output()["hists"].targets for variable_name, outp in self.iter_progress(outputs.items(), len(outputs)): - self.publish_message(f"merging histograms for '{variable_name}'") - - # load hists - variable_hists = [ - coll["hists"].targets[variable_name].load(formatter="pickle") - for coll in inputs.values() - ] + with self.publish_step(f"merging histograms for '{variable_name}' ..."): + # load hists + variable_hists = [ + coll["hists"].targets[variable_name].load(formatter="pickle") + for coll in inputs.values() + ] - # merge and write the output - merged = sum(variable_hists[1:], variable_hists[0].copy()) - outp.dump(merged, formatter="pickle") + # merge and write the output + merged = sum(variable_hists[1:], variable_hists[0].copy()) + outp.dump(merged, formatter="pickle") MergeShiftedHistogramsWrapper = wrapper_factory( From b1028a3471a131cc7f720c1b6161a08e302a6b89 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 24 Sep 2025 11:17:18 +0200 Subject: [PATCH 085/123] Forward known values to hist hooks. --- columnflow/tasks/framework/mixins.py | 3 ++- columnflow/tasks/plotting.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index 84d8dbb5e..3bbebce9c 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -2531,6 +2531,7 @@ def _get_hist_hook(self, name: str) -> Callable: def invoke_hist_hooks( self, hists: dict[od.Config, dict[od.Process, Any]], + hook_kwargs: dict | None = None, ) -> dict[od.Config, dict[od.Process, Any]]: """ Invoke hooks to modify histograms before further processing such as plotting. @@ -2552,7 +2553,7 @@ def invoke_hist_hooks( # invoke it self.publish_message(f"invoking hist hook '{hook}'") - hists = func(self, hists) + hists = func(self, hists, **(hook_kwargs or {})) return hists diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index 12638ebe2..6ec08235c 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -236,7 +236,10 @@ def run(self): ) # update histograms using custom hooks - hists = self.invoke_hist_hooks(hists) + hists = self.invoke_hist_hooks( + hists, + hook_kwargs={"category_name": self.branch_data.category, "variable_name": self.branch_data.variable}, + ) # merge configs if len(self.config_insts) != 1: From 1de8924d4e20236328d7dabe82a3bda25957d945 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 26 Sep 2025 15:35:13 +0200 Subject: [PATCH 086/123] Hook column union, update law. --- columnflow/tasks/union.py | 19 ++++++++++++++++++- modules/law | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/union.py b/columnflow/tasks/union.py index 2797537c2..933c7dd53 100644 --- a/columnflow/tasks/union.py +++ b/columnflow/tasks/union.py @@ -4,6 +4,8 @@ Task to unite columns horizontally into a single file for further, possibly external processing. """ +from __future__ import annotations + import luigi import law @@ -13,7 +15,9 @@ from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.tasks.ml import MLEvaluation +from columnflow.columnar_util import Route from columnflow.util import dev_sandbox +from columnflow.types import Callable class _UniteColumns( @@ -47,6 +51,9 @@ class UniteColumns(_UniteColumns): MLEvaluation=MLEvaluation, ) + # a column that is evaluated to decide whether to keep or drop an event before writing + filter_events: str | Route | Callable | None = None + def workflow_requires(self): reqs = super().workflow_requires() @@ -105,7 +112,7 @@ def output(self): @law.decorator.safe_output def run(self): from columnflow.columnar_util import ( - Route, RouteFilter, mandatory_coffea_columns, update_ak_array, sorted_ak_to_parquet, sorted_ak_to_root, + RouteFilter, mandatory_coffea_columns, update_ak_array, sorted_ak_to_parquet, sorted_ak_to_root, ) # prepare inputs and outputs @@ -151,6 +158,16 @@ def run(self): # add additional columns events = update_ak_array(events, *columns) + # optionally filter events + if self.filter_events: + if callable(self.filter_events): + filter_func = self.filter_events + else: + r = Route(self.filter_events) + filter_func = r.apply + mask = filter_func(events) + events = events[mask] + # remove columns events = route_filter(events) diff --git a/modules/law b/modules/law index 0578967c0..b881450a1 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 0578967c014323324f25006e0ba05a03de429cc6 +Subproject commit b881450a1927bf30c6e504da6ed6f394e7e49b93 From 7caa0fca5b4c883ae1cdc0063a35384b4daedd03 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Fri, 26 Sep 2025 17:45:10 +0200 Subject: [PATCH 087/123] Hotfix default version injection into tasks with same family. --- columnflow/tasks/framework/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index 680724024..e0e791fc7 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -191,8 +191,7 @@ def req_params(cls, inst: AnalysisTask, **kwargs) -> dict[str, Any]: if ( isinstance(getattr(cls, "version", None), luigi.Parameter) and "version" not in kwargs and - not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") and - cls.task_family != law.parser.root_task_cls().task_family + not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") ): default_version = cls.get_default_version(inst, params) if default_version and default_version != law.NO_STR: From 72fde041994c645c5fee1d05bafa20c7b338aec8 Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Mon, 29 Sep 2025 10:22:23 +0200 Subject: [PATCH 088/123] add ParameterTransformation for ratifying + envelope if one-sided --- columnflow/inference/__init__.py | 2 ++ columnflow/inference/cms/datacard.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index a54aa67ee..43277a729 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -95,6 +95,8 @@ class ParameterTransformation(enum.Enum): normalize = "normalize" effect_from_shape = "effect_from_shape" effect_from_rate = "effect_from_rate" + ratify = "ratify" + envelope_if_one_sided = "envelope_if_one_sided" def __str__(self: ParameterTransformation) -> str: """ diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index e94916eca..2dd390f78 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -526,7 +526,28 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: # apply optional transformations integral = lambda h: h.sum().value for trafo in param_obj.transformations: - if trafo == ParameterTransformation.centralize: + if trafo == ParameterTransformation.ratify: + n, d, u = integral(h_nom), integral(h_down), integral(h_up) + ratio_up = safe_div(u, n) + ratio_down = safe_div(d, n) + h_down = h_nom.copy() * ratio_down + h_up = h_nom.copy() * ratio_up + + elif trafo == ParameterTransformation.envelope_if_one_sided: + n, d, u = integral(h_nom), integral(h_down), integral(h_up) + if (n - d) * (n - u) > 0: + # one-sided effect, use the larger variation + if abs(n - d) > abs(n - u): + # use the down variation with effect flipped + h_up = 2 * h_nom.copy() - h_down.view() + # TODO: better estimate of the variance + h_up.view().variance = h_down.variances() + else: + # use the up variation with effect flipped + h_down = 2 * h_nom.copy() - h_up.view() + h_down.view().variance = h_up.variances() + + elif trafo == ParameterTransformation.centralize: # get the absolute spread based on integrals n, d, u = integral(h_nom), integral(h_down), integral(h_up) if not (min(d, n) <= n <= max(d, n)): From b0342460d9331429aa5638b94f35c6a95b35c771 Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:15:53 +0200 Subject: [PATCH 089/123] allow removing negative contributions per process (#730) * allow removing negative contributions per process * Update columnflow/plotting/plot_util.py Co-authored-by: Marcel Rieger --------- Co-authored-by: Mathis Frahm Co-authored-by: Marcel Rieger --- columnflow/plotting/plot_functions_1d.py | 12 ++++++++++++ columnflow/plotting/plot_util.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 1899cbc4d..4c1bf4b60 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -27,6 +27,7 @@ get_position, get_profile_variations, blind_sensitive_bins, + remove_negative_contributions, join_labels, ) from columnflow.hist_util import add_missing_shifts @@ -64,6 +65,11 @@ def plot_variable_stack( blinding_threshold = kwargs.get("blinding_threshold", None) if blinding_threshold: hists = blind_sensitive_bins(hists, config_inst, blinding_threshold) + + # remove negative contributions per process if requested + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) + # process scaling hists = apply_process_scaling(hists) # density scaling per bin @@ -175,6 +181,8 @@ def plot_variable_variants( variable_inst = variable_insts[0] hists = apply_variable_settings(hists, variable_insts, variable_settings) + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) if density: hists = apply_density(hists, density) @@ -245,6 +253,8 @@ def plot_shifted_variable( hists, process_style_config = apply_process_settings(hists, process_settings) hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) hists = apply_process_scaling(hists) if density: hists = apply_density(hists, density) @@ -449,6 +459,8 @@ def plot_profile( hists, process_style_config = apply_process_settings(hists, process_settings) hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) + if kwargs.get("remove_negative", None): + hists = remove_negative_contributions(hists) hists = apply_process_scaling(hists) if density: hists = apply_density(hists, density) diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index c2088414f..3f09b724e 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -314,6 +314,15 @@ def apply_variable_settings( return hists, variable_style_config +def remove_negative_contributions(hists: dict[Hashable, hist.Hist]) -> dict[Hashable, hist.Hist]: + _hists = hists.copy() + for proc_inst, h in hists.items(): + h = h.copy() + h.view().value[h.view().value < 0] = 0 + _hists[proc_inst] = h + return _hists + + def use_flow_bins( h_in: hist.Hist, axis_name: str | int, From b0b1aa76856c38531dad7893442b0c9d83a447ac Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Mon, 29 Sep 2025 15:20:17 +0200 Subject: [PATCH 090/123] Apply suggestion from @riga --- columnflow/inference/cms/datacard.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 2dd390f78..30d313e0d 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -526,14 +526,7 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: # apply optional transformations integral = lambda h: h.sum().value for trafo in param_obj.transformations: - if trafo == ParameterTransformation.ratify: - n, d, u = integral(h_nom), integral(h_down), integral(h_up) - ratio_up = safe_div(u, n) - ratio_down = safe_div(d, n) - h_down = h_nom.copy() * ratio_down - h_up = h_nom.copy() * ratio_up - - elif trafo == ParameterTransformation.envelope_if_one_sided: + if trafo == ParameterTransformation.envelope_if_one_sided: n, d, u = integral(h_nom), integral(h_down), integral(h_up) if (n - d) * (n - u) > 0: # one-sided effect, use the larger variation From b66573ec3932bde3b39448cee9c53da72bde2afd Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Mon, 29 Sep 2025 15:20:24 +0200 Subject: [PATCH 091/123] Apply suggestion from @riga --- columnflow/inference/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 43277a729..070f4a159 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -95,7 +95,6 @@ class ParameterTransformation(enum.Enum): normalize = "normalize" effect_from_shape = "effect_from_shape" effect_from_rate = "effect_from_rate" - ratify = "ratify" envelope_if_one_sided = "envelope_if_one_sided" def __str__(self: ParameterTransformation) -> str: From b4a3a6cddb3c6661fafbcc9d8b10ddabfe22b9fe Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Mon, 29 Sep 2025 15:52:34 +0200 Subject: [PATCH 092/123] Sequentialize and optimize datacard writing. (#728) * Sequentialize and optimize datacard writing. * Adjust tests. * Update law. * Minor nuisance group fix. * Loop over variables, eagerly clear memory. --- columnflow/inference/__init__.py | 33 +-- columnflow/tasks/cms/inference.py | 267 ++++++++++++++-------- columnflow/tasks/framework/inference.py | 292 +++++++++++++----------- tests/test_inference.py | 31 --- 4 files changed, 336 insertions(+), 287 deletions(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 070f4a159..bbd268a2c 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -573,35 +573,6 @@ def parameter_config_spec( ("shift_source", str(shift_source) if shift_source else None), ]) - @classmethod - def require_shapes_for_parameter(self, param_obj: dict) -> bool: - """ - Function to check if for a certain parameter object *param_obj* varied - shapes are needed. - - :param param_obj: The parameter object to check. - :returns: *True* if varied shapes are needed, *False* otherwise. - """ - if param_obj.type.is_shape: - # the shape might be build from a rate, in which case input shapes are not required - if param_obj.transformations.any_from_rate: - return False - # in any other case, shapes are required - return True - - if param_obj.type.is_rate: - # when the rate effect is extracted from shapes, they are required - if param_obj.transformations.any_from_shape: - return True - # in any other case, shapes are not required - return False - - # other cases are not supported - raise Exception( - f"shape requirement cannot be evaluated for parameter '{param_obj.name}' with type " + - f"'{param_obj.type}' and transformations {param_obj.transformations}", - ) - def __init__(self, config_insts: list[od.Config]) -> None: super().__init__() @@ -1336,8 +1307,8 @@ def add_parameter( for process in _processes: process.parameters.append(_copy.deepcopy(parameter)) - # add to groups - if group: + # add to groups if it was added to at least one process + if group and processes and any(_processes for _processes in processes.values()): self.add_parameter_to_group(parameter.name, group) return parameter diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index 6414a23ba..ca8bb633b 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -6,6 +6,8 @@ from __future__ import annotations +import collections + import law import order as od @@ -19,122 +21,197 @@ class CreateDatacards(SerializeInferenceModelBase): resolution_task_cls = MergeHistograms - datacard_writer_cls = DatacardWriter def output(self): - hooks_repr = self.hist_hooks_repr - cat_obj = self.branch_data - - def basename(name: str, ext: str) -> str: + def basename(cat_obj, name, ext): parts = [name, cat_obj.name] - if hooks_repr: + if (hooks_repr := self.hist_hooks_repr): parts.append(f"hooks_{hooks_repr}") if cat_obj.postfix is not None: parts.append(cat_obj.postfix) return f"{'__'.join(map(str, parts))}.{ext}" - return { - "card": self.target(basename("datacard", "txt")), - "shapes": self.target(basename("shapes", "root")), - } + return law.SiblingFileCollection({ + cat_obj.name: { + "card": self.target(basename(cat_obj, "datacard", "txt")), + "shapes": self.target(basename(cat_obj, "shapes", "root")), + } + for cat_obj in self.inference_model_inst.categories + }) @law.decorator.log @law.decorator.safe_output def run(self): import hist - # prepare inputs + # prepare inputs and outputs inputs = self.input() + outputs = self.output() - # loop over all configs required by the datacard category and gather histograms - cat_obj = self.branch_data - datacard_hists: DatacardHists = {cat_obj.name: {}} - - # step 1: gather histograms per process for each config - input_hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} - for config_inst in self.config_insts: - # skip configs that are not required - if not cat_obj.config_data.get(config_inst.name): - continue - # load them - input_hists[config_inst] = self.load_process_hists(inputs, cat_obj, config_inst) - - # step 2: apply hist hooks - input_hists = self.invoke_hist_hooks(input_hists) - - # step 3: transform to nested histogram as expected by the datacard writer - for config_inst in input_hists.keys(): - config_data = cat_obj.config_data.get(config_inst.name) - - # determine leaf categories to gather - category_inst = config_inst.get_category(config_data.category) - leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] - - # start the transformation - proc_objs = list(cat_obj.processes) - if config_data.data_datasets and not cat_obj.data_from_processes: - proc_objs.append(self.inference_model_inst.process_spec(name="data")) - for proc_obj in proc_objs: - # get the corresponding process instance - if proc_obj.name == "data": - process_inst = config_inst.get_process("data") - elif config_inst.name in proc_obj.config_data: - process_inst = config_inst.get_process(proc_obj.config_data[config_inst.name].process) - else: - # skip process objects that rely on data from a different config - continue - - # extract the histogram for the process - if not (h_proc := input_hists[config_inst].get(process_inst, None)): - self.logger.warning( - f"found no histogram to model datacard process '{proc_obj.name}', please check your " - f"inference model '{self.inference_model}'", - ) + # overall strategy to load data efficiently and to write datacards: + # 1) determine which variables have to be loaded for which config (stored in a map), then loop over variables + # 2) load all histograms per config + # 3) start datacard writing by looping over datacard categories that use the specific variable + # 4) apply hist hooks + # 5) prepare histogram in the nested format expected by the datacard writer and write the card + + # step 1: gather variable info, then loop + variable_data = collections.defaultdict(set) + for config_inst, data in self.combined_config_data.items(): + for variable in data["variables"]: + variable_data[variable].add(config_inst) + + for variable, variable_config_insts in variable_data.items(): + # step 2 + input_hists: dict[od.Config, dict[od.Process, hist.Hist]] = {} + for config_inst in variable_config_insts: + data = self.combined_config_data[config_inst] + input_hists[config_inst] = self.load_process_hists( + config_inst, + list(data["mc_datasets"]) + list(data["data_datasets"]), + variable, + inputs[config_inst.name], + ) + + # step 3 + for cat_obj in self.inference_model_inst.categories: + # skip if the variable is not used in this category + if not any(d.variable == variable for d in cat_obj.config_data.values()): continue - - # select relevant categories - h_proc = h_proc[{ - "category": [ - hist.loc(c.name) - for c in leaf_category_insts - if c.name in h_proc.axes["category"] - ], - }][{"category": sum}] - - # create the nominal hist - datacard_hists[cat_obj.name].setdefault(proc_obj.name, {}).setdefault(config_inst.name, {}) - shift_hists: ShiftHists = datacard_hists[cat_obj.name][proc_obj.name][config_inst.name] - shift_hists["nominal"] = h_proc[{ - - "shift": hist.loc(config_inst.get_shift("nominal").name), - }] - - # no additional shifts need to be created for data - if proc_obj.name == "data": + # cross check that all configs use the same variable (should already be guarded by the model validation) + assert all(d.variable == variable for d in cat_obj.config_data.values()) + + # check which configs contribute to this category + config_insts = [ + config_inst for config_inst in self.config_insts + if config_inst.name in cat_obj.config_data + ] + if not config_insts: continue - - # create histograms per shift - for param_obj in proc_obj.parameters: - # skip the parameter when varied hists are not needed - if not self.inference_model_inst.require_shapes_for_parameter(param_obj): - continue - # store the varied hists - shift_source = ( - param_obj.config_data[config_inst.name].shift_source - if config_inst.name in param_obj.config_data - else None - ) - for d in ["up", "down"]: - shift_hists[(param_obj.name, d)] = h_proc[{ - "shift": hist.loc(f"{shift_source}_{d}" if shift_source else "nominal"), + self.publish_message(f"processing inputs for category '{cat_obj.name}' with variable '{variable}'") + + # get config-based category name + category = cat_obj.config_data[config_insts[0].name].category + + # step 4: hist hooks + _input_hists = self.invoke_hist_hooks( + {config_inst: input_hists[config_inst].copy() for config_inst in config_insts}, + hook_kwargs={"variable_name": variable, "category_name": category}, + ) + + # step 5: transform to datacard format + datacard_hists: DatacardHists = {cat_obj.name: {}} + for config_inst in _input_hists.keys(): + config_data = cat_obj.config_data.get(config_inst.name) + + # determine leaf categories to gather + category_inst = config_inst.get_category(category) + leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] + + # eagerly remove data histograms in case data is supposed to be faked from mc processes + if cat_obj.data_from_processes: + for process_inst in list(_input_hists[config_inst]): + if process_inst.is_data: + del _input_hists[config_inst][process_inst] + + # start the transformation + proc_objs = list(cat_obj.processes) + if config_data.data_datasets and not cat_obj.data_from_processes: + proc_objs.append(self.inference_model_inst.process_spec(name="data")) + for proc_obj in proc_objs: + # get all process instances (keys in _input_hists) to be combined + if proc_obj.is_dynamic: + if not (process_name := proc_obj.config_data[config_inst.name].get("process", None)): + raise ValueError( + f"dynamic datacard process object misses 'process' entry in config data for " + f"'{config_inst.name}': {proc_obj}", + ) + process_insts = [config_inst.get_process(process_name)] + else: + process_insts = [ + config_inst.get_dataset(dataset_name).processes.get_first() + for dataset_name in proc_obj.config_data[config_inst.name].mc_datasets + ] + + # collect per-process histograms + h_procs = [] + for process_inst in process_insts: + # extract the histogram for the process + # (removed from hists to eagerly cleanup memory) + h_proc = _input_hists[config_inst].pop(process_inst, None) + if h_proc is None: + self.logger.error( + f"found no histogram to model datacard process '{proc_obj.name}', please check your " + f"inference model '{self.inference_model}'", + ) + continue + + # select relevant categories + h_proc = h_proc[{ + "category": [ + hist.loc(c.name) + for c in leaf_category_insts + if c.name in h_proc.axes["category"] + ], + }] + h_proc = h_proc[{"category": sum}] + + h_procs.append(h_proc) + + if h_procs is None: + continue + + # combine them + h_proc = sum(h_procs[1:], h_procs[0].copy()) + + # create the nominal hist + datacard_hists[cat_obj.name].setdefault(proc_obj.name, {}).setdefault(config_inst.name, {}) + shift_hists: ShiftHists = datacard_hists[cat_obj.name][proc_obj.name][config_inst.name] + shift_hists["nominal"] = h_proc[{ + "shift": hist.loc(config_inst.get_shift("nominal").name), }] - # forward objects to the datacard writer - outputs = self.output() - writer = self.datacard_writer_cls(self.inference_model_inst, datacard_hists) - with outputs["card"].localize("w") as tmp_card, outputs["shapes"].localize("w") as tmp_shapes: - writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outputs["shapes"].basename) + # no additional shifts need to be created for data + if proc_obj.name == "data": + continue + + # create histograms per shape shift + for param_obj in proc_obj.parameters: + # skip the parameter when varied hists are not needed + if ( + not param_obj.type.is_shape and + not any(trafo.from_shape for trafo in param_obj.transformations) + ): + continue + # store the varied hists + shift_source = ( + param_obj.config_data[config_inst.name].shift_source + if config_inst.name in param_obj.config_data + else None + ) + for d in ["up", "down"]: + if shift_source and f"{shift_source}_{d}" not in h_proc.axes["shift"]: + raise ValueError( + f"cannot find '{shift_source}_{d}' in shift axis of histogram for process " + f"'{proc_obj.name}' in config '{config_inst.name}' while handling parameter " + f"'{param_obj.name}' in datacard category '{cat_obj.name}', available shifts " + f"are: {list(h_proc.axes['shift'])}", + ) + shift_hists[(param_obj.name, d)] = h_proc[{ + "shift": hist.loc(f"{shift_source}_{d}" if shift_source else "nominal"), + }] + + # forward objects to the datacard writer + outp = outputs[cat_obj.name] + writer = self.datacard_writer_cls(self.inference_model_inst, datacard_hists) + with outp["card"].localize("w") as tmp_card, outp["shapes"].localize("w") as tmp_shapes: + writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outp["shapes"].basename) + self.publish_message(f"datacard written to {outp['card'].abspath}") + + # eager cleanup + del _input_hists + del input_hists CreateDatacardsWrapper = wrapper_factory( diff --git a/columnflow/tasks/framework/inference.py b/columnflow/tasks/framework/inference.py index 6c202aa31..18b3679bc 100644 --- a/columnflow/tasks/framework/inference.py +++ b/columnflow/tasks/framework/inference.py @@ -6,6 +6,8 @@ from __future__ import annotations +import pickle + import law import order as od @@ -15,7 +17,7 @@ InferenceModelMixin, HistHookMixin, MLModelsMixin, ) from columnflow.tasks.framework.remote import RemoteWorkflow -from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms +from columnflow.tasks.histograms import MergeShiftedHistograms from columnflow.util import dev_sandbox, DotDict, maybe_import from columnflow.config_util import get_datasets_from_process @@ -42,7 +44,6 @@ class SerializeInferenceModelBase( # upstream requirements reqs = Requirements( RemoteWorkflow.reqs, - MergeHistograms=MergeHistograms, MergeShiftedHistograms=MergeShiftedHistograms, ) @@ -106,153 +107,184 @@ def get_data_datasets(cls, config_inst: od.Config, cat_obj: DotDict) -> list[str ) ] - def create_branch_map(self): - return list(self.inference_model_inst.categories) + @law.workflow_property(cache=True) + def combined_config_data(self) -> dict[od.ConfigInst, dict[str, dict | set]]: + # prepare data extracted from the inference model + config_data = { + config_inst: { + # all variables used in this config in any datacard category + "variables": set(), + # plain set of names of real data datasets + "data_datasets": set(), + # per name of mc dataset, the set of shift sources and the name of the datacard process object + "mc_datasets": {}, + } + for config_inst in self.config_insts + } + + # iterate over all model categories + for cat_obj in self.inference_model_inst.categories: + # keep track of per-category information for consistency checks + variables = set() + categories = set() + + # iterate over configs relevant for this category + config_insts = [config_inst for config_inst in self.config_insts if config_inst.name in cat_obj.config_data] + for config_inst in config_insts: + data = config_data[config_inst] + + # variables + data["variables"].add(cat_obj.config_data[config_inst.name].variable) + + # data datasets, but only if + # - data in that category is not faked from mc processes, or + # - at least one process object is dynamic (that usually means data-driven) + if not cat_obj.data_from_processes or any(proc_obj.is_dynamic for proc_obj in cat_obj.processes): + data["data_datasets"].update(self.get_data_datasets(config_inst, cat_obj)) + + # mc datasets over all process objects + # - the process is not dynamic + for proc_obj in cat_obj.processes: + mc_datasets = self.get_mc_datasets(config_inst, proc_obj) + for dataset_name in mc_datasets: + if dataset_name not in data["mc_datasets"]: + data["mc_datasets"][dataset_name] = { + "proc_name": proc_obj.name, + "shift_sources": set(), + } + elif data["mc_datasets"][dataset_name]["proc_name"] != proc_obj.name: + raise ValueError( + f"dataset '{dataset_name}' was already assigned to datacard process " + f"'{data['mc_datasets'][dataset_name]['proc_name']}' and cannot be re-assigned to " + f"'{proc_obj.name}' in config '{config_inst.name}'", + ) - def _requires_cat_obj(self, cat_obj: DotDict, merge_variables: bool = False, **req_kwargs): - """ - Helper to create the requirements for a single category object. + # shift sources + for param_obj in proc_obj.parameters: + if config_inst.name not in param_obj.config_data: + continue + # only add if a shift is required for this parameter + if param_obj.type.is_shape or any(trafo.from_shape for trafo in param_obj.transformations): + shift_source = param_obj.config_data[config_inst.name].shift_source + for mc_dataset in mc_datasets: + data["mc_datasets"][mc_dataset]["shift_sources"].add(shift_source) + + # for consistency checks later + variables.add(cat_obj.config_data[config_inst.name].variable) + categories.add(cat_obj.config_data[config_inst.name].category) + + # consistency checks: the config-based variable and category names must be identical + if len(variables) != 1: + raise ValueError( + f"found diverging variables to be used in datacard category '{cat_obj.name}' across configs " + f"{', '.join(c.name for c in config_insts)}: {variables}", + ) + if len(categories) != 1: + raise ValueError( + f"found diverging categories to be used in datacard category '{cat_obj.name}' across configs " + f"{', '.join(c.name for c in config_insts)}: {categories}", + ) - :param cat_obj: category object from an InferenceModel - :param merge_variables: whether to merge the variables from all requested category objects - :return: requirements for the category object - """ + return config_data + + def create_branch_map(self): + # dummy branch map + return {0: None} + + def _hist_requirements(self, **kwargs): + # gather data from inference model to define requirements in the structure + # config_name -> dataset_name -> MergeHistogramsTask reqs = {} - for config_inst in self.config_insts: - if not (config_data := cat_obj.config_data.get(config_inst.name)): - continue - - if merge_variables: - variables = tuple( - _cat_obj.config_data.get(config_inst.name).variable - for _cat_obj in self.branch_map.values() + for config_inst, data in self.combined_config_data.items(): + reqs[config_inst.name] = {} + # mc datasets + for dataset_name in sorted(data["mc_datasets"]): + reqs[config_inst.name][dataset_name] = self.reqs.MergeShiftedHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=dataset_name, + shift_sources=tuple(sorted(data["mc_datasets"][dataset_name]["shift_sources"])), + variables=tuple(sorted(data["variables"])), + **kwargs, + ) + # data datasets + for dataset_name in sorted(data["data_datasets"]): + reqs[config_inst.name][dataset_name] = self.reqs.MergeShiftedHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=dataset_name, + shift_sources=(), + variables=tuple(sorted(data["variables"])), + **kwargs, ) - else: - variables = (config_data.variable,) - - # add merged shifted histograms for mc - reqs[config_inst.name] = { - proc_obj.name: { - dataset: self.reqs.MergeShiftedHistograms.req_different_branching( - self, - config=config_inst.name, - dataset=dataset, - shift_sources=tuple( - param_obj.config_data[config_inst.name].shift_source - for param_obj in proc_obj.parameters - if ( - config_inst.name in param_obj.config_data and - self.inference_model_inst.require_shapes_for_parameter(param_obj) - ) - ), - variables=variables, - **req_kwargs, - ) - for dataset in self.get_mc_datasets(config_inst, proc_obj) - } - for proc_obj in cat_obj.processes - if config_inst.name in proc_obj.config_data and not proc_obj.is_dynamic - } - # add merged histograms for data, but only if - # - data in that category is not faked from mc, or - # - at least one process object is dynamic (that usually means data-driven) - if ( - (not cat_obj.data_from_processes or any(proc_obj.is_dynamic for proc_obj in cat_obj.processes)) and - (data_datasets := self.get_data_datasets(config_inst, cat_obj)) - ): - reqs[config_inst.name]["data"] = { - dataset: self.reqs.MergeHistograms.req_different_branching( - self, - config=config_inst.name, - dataset=dataset, - variables=variables, - **req_kwargs, - ) - for dataset in data_datasets - } return reqs def workflow_requires(self): reqs = super().workflow_requires() - - reqs["merged_hists"] = hist_reqs = {} - for cat_obj in self.branch_map.values(): - cat_reqs = self._requires_cat_obj(cat_obj, merge_variables=True) - for config_name, proc_reqs in cat_reqs.items(): - hist_reqs.setdefault(config_name, {}) - for proc_name, dataset_reqs in proc_reqs.items(): - hist_reqs[config_name].setdefault(proc_name, {}) - for dataset_name, task in dataset_reqs.items(): - hist_reqs[config_name][proc_name].setdefault(dataset_name, set()).add(task) + reqs["merged_hists"] = self._hist_requirements() return reqs def requires(self): - cat_obj = self.branch_data - return self._requires_cat_obj(cat_obj, branch=-1, workflow="local") + return self._hist_requirements(branch=-1, workflow="local") def load_process_hists( self, - inputs: dict, - cat_obj: DotDict, config_inst: od.Config, - ) -> dict[od.Process, hist.Hist]: - # loop over all configs required by the datacard category and gather histograms - config_data = cat_obj.config_data.get(config_inst.name) - - # collect histograms per config process + dataset_names: list[str], + variable: str, + inputs: dict, + ) -> dict[str, dict[od.Process, hist.Hist]]: + # collect histograms per variable and process hists: dict[od.Process, hist.Hist] = {} - with self.publish_step( - f"extracting {config_data.variable} in {config_data.category} for config {config_inst.name}...", - ): - for proc_obj_name, inp in inputs[config_inst.name].items(): - if proc_obj_name == "data": + + with self.publish_step(f"extracting '{variable}' for config {config_inst.name} ..."): + for dataset_name in dataset_names: + dataset_inst = config_inst.get_dataset(dataset_name) + process_inst = dataset_inst.processes.get_first() + + # for real data, fallback to the main data process + if process_inst.is_data: process_inst = config_inst.get_process("data") - else: - proc_obj = self.inference_model_inst.get_process(proc_obj_name, category=cat_obj.name) - process_inst = config_inst.get_process(proc_obj.config_data[config_inst.name].process) + + # gather all subprocesses for a full query later sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] - # loop over per-dataset inputs and extract histograms containing the process - h_proc = None - for dataset_name, _inp in inp.items(): - dataset_inst = config_inst.get_dataset(dataset_name) - - # skip when the dataset is already known to not contain any sub process - if not any(map(dataset_inst.has_process, sub_process_insts)): - self.logger.warning( - f"dataset '{dataset_name}' does not contain process '{process_inst.name}' or any of " - "its subprocesses which indicates a misconfiguration in the inference model " - f"'{self.inference_model}'", - ) - continue - - # open the histogram and work on a copy - h = _inp["collection"][0]["hists"][config_data.variable].load(formatter="pickle").copy() - - # axis selections - h = h[{ - "process": [ - hist.loc(p.name) - for p in sub_process_insts - if p.name in h.axes["process"] - ], - }] - - # axis reductions - h = h[{"process": sum}] - - # add the histogram for this dataset - if h_proc is None: - h_proc = h - else: - h_proc += h - - # there must be a histogram - if h_proc is None: - raise Exception(f"no histograms found for process '{process_inst.name}'") - - # save histograms mapped to processes - hists[process_inst] = h_proc + # open the histogram and work on a copy + inp = inputs[dataset_name]["collection"][0]["hists"][variable] + try: + h = inp.load(formatter="pickle").copy() + except pickle.UnpicklingError as e: + raise Exception( + f"failed to load '{variable}' histogram for dataset '{dataset_name}' in config " + f"'{config_inst.name}' from {inp.abspath}", + ) from e + + # there must be at least one matching sub process + if not any(p.name in h.axes["process"] for p in sub_process_insts): + raise Exception(f"no '{variable}' histograms found for process '{process_inst.name}'") + + # select and reduce over relevant processes + h = h[{"process": [hist.loc(p.name) for p in sub_process_insts if p.name in h.axes["process"]]}] + h = h[{"process": sum}] + + # additional custom reductions + h = self.modify_process_hist(process_inst, h) + + # store it + if process_inst in hists: + hists[process_inst] += h + else: + hists[process_inst] = h return hists + + def modify_process_hist(self, process_inst: od.Process, h: hist.Hist) -> hist.Hist: + """ + Hook to modify a process histogram after it has been loaded. This can be helpful to reduce memory early on. + + :param process_inst: The process instance the histogram belongs to. + :param histo: The histogram to modify. + :return: The modified histogram. + """ + return h diff --git a/tests/test_inference.py b/tests/test_inference.py index 5f1bde4b4..ad14549b8 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -177,34 +177,3 @@ def test_parameter_group_spec_with_no_parameter_names(self): ) self.assertDictEqual(result, expected_result) - - def test_require_shapes_for_parameter_shape(self): - # No shape is required if the parameter type is a rate - types = [ParameterType.rate_gauss, ParameterType.rate_uniform, ParameterType.rate_unconstrained] - for t in types: - with self.subTest(t=t): - param_obj = DotDict( - type=t, - transformations=ParameterTransformations([ParameterTransformation.effect_from_rate]), - name="test_param", - ) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertFalse(result) - - # if the transformation is shape-based expect True - param_obj.transformations = ParameterTransformations([ParameterTransformation.effect_from_shape]) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertTrue(result) - - # No shape is required if the transformation is from a rate - param_obj = DotDict( - type=ParameterType.shape, - transformations=ParameterTransformations([ParameterTransformation.effect_from_rate]), - name="test_param", - ) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertFalse(result) - - param_obj.transformations = ParameterTransformations([ParameterTransformation.effect_from_shape]) - result = InferenceModel.require_shapes_for_parameter(param_obj) - self.assertTrue(result) From 7216e56fb4a94a9de5591dd5d4c8be1157508222 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 30 Sep 2025 17:41:09 +0200 Subject: [PATCH 093/123] Refactor datacard parameter transformations. --- columnflow/inference/__init__.py | 54 ++-- columnflow/inference/cms/datacard.py | 355 +++++++++++++++++------- columnflow/tasks/cms/inference.py | 9 +- columnflow/tasks/framework/inference.py | 5 +- 4 files changed, 301 insertions(+), 122 deletions(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index bbd268a2c..e6f4d3315 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -40,7 +40,7 @@ class ParameterType(enum.Enum): rate_unconstrained = "rate_unconstrained" shape = "shape" - def __str__(self: ParameterType) -> str: + def __str__(self) -> str: """ Returns the string representation of the parameter type. @@ -49,7 +49,7 @@ def __str__(self: ParameterType) -> str: return self.value @property - def is_rate(self: ParameterType) -> bool: + def is_rate(self) -> bool: """ Checks if the parameter type is a rate type. @@ -62,7 +62,7 @@ def is_rate(self: ParameterType) -> bool: } @property - def is_shape(self: ParameterType) -> bool: + def is_shape(self) -> bool: """ Checks if the parameter type is a shape type. @@ -77,27 +77,48 @@ class ParameterTransformation(enum.Enum): """ Flags denoting transformations to be applied on parameters. + Implementation details depend on the routines that apply these transformations, usually as part for a serialization + processes (such as so-called "datacards" in the CMS context). As such, the exact implementation may also differ + depending on the type of the parameter that a transformation is applied to (e.g. shape vs rate). + + The general purpose of each transformation is described below. + :cvar none: No transformation. - :cvar centralize: Centralize the parameter. - :cvar symmetrize: Symmetrize the parameter. - :cvar asymmetrize: Asymmetrize the parameter. - :cvar asymmetrize_if_large: Asymmetrize the parameter if it is large. - :cvar normalize: Normalize the parameter. - :cvar effect_from_shape: Derive effect from shape. - :cvar effect_from_rate: Derive effect from rate. + :cvar effect_from_rate: Creates shape variations for a shape-type parameter using the single- or two-valued effect + usually attributed to rate-type parameters. Only applies to shape-type parameters. + :cvar effect_from_shape: Derive the effect of a rate-type parameter using the overall, integral effect of shape + variations. Only applies to rate-type parameters. + :cvar effect_from_shape_if_small: Same as :py:attr:`effect_from_shape`, but depending on a threshold on the size of + the effect which can be subject to the serialization routine. Only applies to rate-type parameters. + :cvar symmetrize: The overall (integral) effect of up and down variations is measured and centralized, updating the + variations such that they are equidistant to the nominal one. Can apply to both rate- and shape-type parameters. + :cvar asymmetrize: The symmetric effect on a rate-type parameter (usually given as a single value) is converted into + an asymmetric representation (using two values). Only applies to rate-type parameters. + :cvar asymmetrize_if_large: Same as :py:attr:`asymmetrize`, but depending on a threshold on the size of the + symmetric effect which can be subject to the serialization routine. Only applies to rate-type parameters. + :cvar normalize: Variations of shape-type parameters are changed such that their integral effect identical to the + nominal one. Should only apply to shape-type parameters. + :cvar envelope: Builds an evelope of the up and down variations of a shape-type parameter, potentially on a + bin-by-bin basis. Only applies to shape-type parameters. + :cvar envelope_if_one_sided: Same as :py:attr:`envelope`, but only if the shape variations are one-sided following + a definition that can be subject to the serialization routine. Only applies to shape-type parameters. + :cvar envelope_enforce_two_sided: Same as :py:attr:`envelope`, but it enforces that the up (down) variation of the + constructed envelope is always above (below) the nominal one. Only applies to shape-type parameters. """ none = "none" - centralize = "centralize" + effect_from_rate = "effect_from_rate" + effect_from_shape = "effect_from_shape" + effect_from_shape_if_small = "effect_from_shape_if_small" symmetrize = "symmetrize" asymmetrize = "asymmetrize" asymmetrize_if_large = "asymmetrize_if_large" normalize = "normalize" - effect_from_shape = "effect_from_shape" - effect_from_rate = "effect_from_rate" + envelope = "envelope" envelope_if_one_sided = "envelope_if_one_sided" + envelope_enforce_two_sided = "envelope_enforce_two_sided" - def __str__(self: ParameterTransformation) -> str: + def __str__(self) -> str: """ Returns the string representation of the parameter transformation. @@ -106,7 +127,7 @@ def __str__(self: ParameterTransformation) -> str: return self.value @property - def from_shape(self: ParameterTransformation) -> bool: + def from_shape(self) -> bool: """ Checks if the transformation is derived from shape. @@ -114,10 +135,11 @@ def from_shape(self: ParameterTransformation) -> bool: """ return self in { self.effect_from_shape, + self.effect_from_shape_if_small, } @property - def from_rate(self: ParameterTransformation) -> bool: + def from_rate(self) -> bool: """ Checks if the transformation is derived from rate. diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 30d313e0d..d22412c0a 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -41,17 +41,110 @@ class DatacardWriter(object): At the moment, all shapes are written into the same root file and a shape line with wildcards for both bin and process resolution is created. + + As per the definition in :py:class:`ParameterTransformation`, the following parameter effect transormations are + implemented with the following details. + + - :py:attr:`ParameterTransformation.effect_from_rate`: Creates shape variations from a rate-style effect. + Shape-type parameters only. + - :py:attr:`ParameterTransformation.effect_from_shape`: Converts the integral effect of shape variations to an + asymmetric rate-style effect. Rate-type parameters only. + - :py:attr:`ParameterTransformation.effect_from_shape_if_small`: Same as above with a default threshold of 2%. + Configurable via *effect_from_shape_if_small_threshold*. The parameter should initially be of rate-type, but + in case the threshold is not met, the effect is interpreted as shape-type. + - :py:attr:`ParameterTransformation.symmetrize`: Changes up and down variations of either rate effects and + shapes to symmetrize them around the nominal value. For rate-type parameters, this has no effect if the + effect strength was provided by a single value. There is no conversion into a single value and consequently, + the result is always a two-valued effect. + - :py:attr:`ParameterTransformation.asymmetrize`: Converts single-valued to two-valued effects for rate-style + parameters. + - :py:attr:`ParameterTransformation.asymmetrize_if_large`: Same as above, with a default threshold of 20%. + Configurable via *asymmetrize_if_large_threshold*. + - :py:attr:`ParameterTransformation.normalize`: Normalizes shape variations such that their integrals match that + of the nominal shape. + - :py:attr:`ParameterTransformation.envelope`: Takes the bin-wise maximum in each direction of the up and down + variations of shape-type parameters and constructs new shapes. + - :py:attr:`ParameterTransformation.envelope_if_one_sided`: Same as above, but only in bins where up and down + contributions are one-sided. + - :py:attr:`ParameterTransformation.envelope_enforce_two_sided`: Same as :py:attr:`envelope`, but it enforces + that the up (down) variation of the constructed envelope is always above (below) the nominal one. + + .. note:: + + If used, the transformations :py:attr:`ParameterTransformation.effect_from_rate`, + :py:attr:`ParameterTransformation.effect_from_shape`, and + :py:attr:`ParameterTransformation.effect_from_shape_if_small` must be the first element in the sequence of + transformations to be applied. The remaining transformations are applied in order based on the outcome of the + effect conversion. """ # minimum separator between columns col_sep = " " + # specific sets of transformations + first_index_trafos = { + ParameterTransformation.effect_from_rate, + ParameterTransformation.effect_from_shape, + ParameterTransformation.effect_from_shape_if_small, + } + shape_only_trafos = { + ParameterTransformation.effect_from_rate, + ParameterTransformation.normalize, + ParameterTransformation.envelope, + ParameterTransformation.envelope_if_one_sided, + ParameterTransformation.envelope_enforce_two_sided, + } + rate_only_trafos = { + ParameterTransformation.effect_from_shape, + ParameterTransformation.effect_from_shape_if_small, + ParameterTransformation.asymmetrize, + ParameterTransformation.asymmetrize_if_large, + } + + @classmethod + def validate_model(cls, inference_model_inst: InferenceModel, silent: bool = False) -> bool: + # perform parameter checks one after another, collect errors along the way + errors: list[str] = [] + for cat_name, proc_name, param_obj in inference_model_inst.iter_parameters(): + # check the transformations + _errors: list[str] = [] + for i, trafo in enumerate(param_obj.transformations): + if i != 0 and trafo in cls.first_index_trafos: + _errors.append( + f"parameter transformation '{trafo}' must be the first one to apply, but found at index {i}", + ) + if not param_obj.type.is_shape and trafo in cls.shape_only_trafos: + _errors.append( + f"parameter transformation '{trafo}' only applies to shape-type parameters, but found type " + f"'{param_obj.type}'", + ) + if not param_obj.type.is_rate and trafo in cls.rate_only_trafos: + _errors.append( + f"parameter transformation '{trafo}' only applies to rate-type parameters, but found type " + f"'{param_obj.type}'", + ) + errors.extend( + f"for parameter '{param_obj}' in process '{proc_name}' in category '{cat_name}': {err}" + for err in _errors + ) + + # handle errors + if errors: + if silent: + return False + errors_repr = "\n - ".join(errors) + raise ValueError(f"inference model invalid, reasons:\n - {errors_repr}") + + return True + def __init__( self, inference_model_inst: InferenceModel, histograms: DatacardHists, rate_precision: int = 4, effect_precision: int = 4, + effect_from_shape_if_small_threshold: float = 0.02, + asymmetrize_if_large_threshold: float = 0.2, ) -> None: super().__init__() @@ -60,6 +153,11 @@ def __init__( self.histograms = histograms self.rate_precision = rate_precision self.effect_precision = effect_precision + self.effect_from_shape_if_small_threshold = effect_from_shape_if_small_threshold + self.asymmetrize_if_large_threshold = asymmetrize_if_large_threshold + + # validate the inference model + self.validate_model(self.inference_model_inst) def write( self, @@ -152,10 +250,10 @@ def write( # tabular-style parameters blocks.tabular_parameters = [] for param_name in self.inference_model_inst.get_parameters(flat=True): - param_obj = None + types = set() effects = [] for cat_name, proc_name in flat_rates: - _param_obj = self.inference_model_inst.get_parameter( + param_obj = self.inference_model_inst.get_parameter( param_name, category=cat_name, process=proc_name, @@ -163,52 +261,41 @@ def write( ) # skip line-style parameters as they are handled separately below - if _param_obj and _param_obj.type == ParameterType.rate_unconstrained: + if param_obj and param_obj.type == ParameterType.rate_unconstrained: continue # empty effect - if _param_obj is None: + if param_obj is None: effects.append("-") continue - # compare with previous param_obj - if param_obj is None: - param_obj = _param_obj - elif _param_obj.type != param_obj.type: + # compare with previously seen types as combine cannot mix arbitrary parameter types acting differently + # on different processes + types.add(param_obj.type) + if len(types) > 1 and types != {ParameterType.rate_gauss, ParameterType.shape}: raise ValueError( - f"misconfigured parameter '{param_name}' with type '{_param_obj.type}' that was previously " - f"seen with incompatible type '{param_obj.type}'", + f"misconfigured parameter '{param_name}' with type '{param_obj.type}' that was previously " + f"seen with incompatible type(s) '{types - {param_obj.type}}'", ) # get the effect - effect = _param_obj.effect + effect = param_obj.effect # rounding helper depending on the effect precision effect_precision = ( self.effect_precision - if _param_obj.effect_precision <= 0 - else _param_obj.effect_precision + if param_obj.effect_precision <= 0 + else param_obj.effect_precision ) rnd = lambda f: round(f, effect_precision) # update and transform effects - if _param_obj.type.is_rate: - # obtain from shape effects when requested - if _param_obj.transformations.any_from_shape: - effect = shape_effects[cat_name][proc_name][param_name] - + if param_obj.type.is_rate: # apply transformations one by one - for trafo in _param_obj.transformations: - if trafo == ParameterTransformation.centralize: - # skip symmetric effects - if not isinstance(effect, tuple) and len(effect) != 2: - continue - # skip one sided effects - if not (min(effect) <= 1 <= max(effect)): - continue - d, u = effect - diff = 0.5 * (d + u) - 1.0 - effect = (effect[0] - diff, effect[1] - diff) + for trafo in param_obj.transformations: + if trafo.from_shape: + # take effect from shape variations + effect = shape_effects[cat_name][proc_name][param_name] elif trafo == ParameterTransformation.symmetrize: # skip symmetric effects @@ -218,28 +305,34 @@ def write( if not (min(effect) <= 1 <= max(effect)): continue d, u = effect - effect = 0.5 * (u - d) + 1.0 + diff = 0.5 * (d + u) - 1.0 + effect = (effect[0] - diff, effect[1] - diff) - elif trafo == ParameterTransformation.asymmetrize or ( - trafo == ParameterTransformation.asymmetrize_if_large and - isinstance(effect, float) and - abs(effect - 1.0) >= 0.2 + elif ( + trafo == ParameterTransformation.asymmetrize or + ( + trafo == ParameterTransformation.asymmetrize_if_large and + isinstance(effect, float) and + abs(effect - 1.0) >= self.asymmetrize_if_large_threshold + ) ): # skip asymmetric effects if not isinstance(effect, float): continue effect = (2.0 - effect, effect) - elif _param_obj.type.is_shape: - # when the shape was constructed from a rate, reset the effect to 1 - if _param_obj.transformations.any_from_rate: - effect = 1.0 + elif param_obj.type.is_shape: + # apply transformations one by one + for trafo in param_obj.transformations: + if trafo.from_rate: + # when the shape was constructed from a rate, reset the effect to 1 + effect = 1.0 # encode the effect if isinstance(effect, (int, float)): if effect == 0.0: effects.append("-") - elif effect == 1.0 and _param_obj.type.is_shape: + elif effect == 1.0 and param_obj.type.is_shape: effects.append("1") else: effects.append(str(rnd(effect))) @@ -252,12 +345,22 @@ def write( ) # add the tabular line - if param_obj and effects: - type_str = "shape" - if param_obj.type == ParameterType.rate_gauss: - type_str = "lnN" - elif param_obj.type == ParameterType.rate_uniform: - type_str = "lnU" + if types and effects: + type_str = None + if len(types) == 1: + _type = list(types)[0] + if _type == ParameterType.rate_gauss: + type_str = "lnN" + elif _type == ParameterType.rate_uniform: + type_str = "lnU" + elif _type == ParameterType.shape: + type_str = "shape" + elif types == {ParameterType.rate_gauss, ParameterType.shape}: + # when mixing lnN and shape effects, combine expects the "?" type and makes the actual decision + # dependend on the presence of shape variations in the accompanying shape files + type_str = "?" + if not type_str: + raise ValueError(f"misconfigured parameter '{param_name}' with incompatible type(s) '{types}'") blocks.tabular_parameters.append([param_name, type_str, effects]) # alphabetical, case-insensitive order by name @@ -477,27 +580,52 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: ) return sum(map(get, hists[1:]), get(hists[0]).copy()) + # helper to extract sum of hists, apply scale, handle flow and fill empty bins + def load( + hist_name: str, + hist_key: Hashable, + fallback_key: Hashable | None = None, + scale: float = 1.0, + ) -> hist.Hist: + h = sum_hists(hist_key, fallback_key) * scale + handle_flow(cat_obj, h, hist_name) + fill_empty(cat_obj, h) + return h + # get the process scale (usually 1) proc_obj = self.inference_model_inst.get_process(proc_name, category=cat_name) scale = proc_obj.scale # nominal shape - h_nom = sum_hists("nominal") * scale nom_name = nom_pattern.format(category=cat_name, process=proc_name) - fill_empty(cat_obj, h_nom) - handle_flow(cat_obj, h_nom, nom_name) + h_nom = load(nom_name, "nominal", scale=scale) out_file[nom_name] = h_nom _rates[proc_name] = h_nom.sum().value + integral = lambda h: h.sum().value # prepare effects __effects = _effects[proc_name] = OrderedDict() - # go through all parameters and check if varied shapes need to be processed + # go through all parameters and potentially handle varied shapes for _, _, param_obj in self.inference_model_inst.iter_parameters(category=cat_name, process=proc_name): + down_name = syst_pattern.format( + category=cat_name, + process=proc_name, + parameter=param_obj.name, + direction="Down", + ) + up_name = syst_pattern.format( + category=cat_name, + process=proc_name, + parameter=param_obj.name, + direction="Up", + ) + # read or create the varied histograms, or skip the parameter if param_obj.type.is_shape: # the source of the shape depends on the transformation if param_obj.transformations.any_from_rate: + # create the shape from the nominal one and an integral rate effect if isinstance(param_obj.effect, float): f_down, f_up = 2.0 - param_obj.effect, param_obj.effect elif isinstance(param_obj.effect, tuple) and len(param_obj.effect) == 2: @@ -510,43 +638,46 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: h_down = h_nom.copy() * f_down h_up = h_nom.copy() * f_up else: - # just extract the shapes - h_down = sum_hists((param_obj.name, "down"), "nominal") * scale - h_up = sum_hists((param_obj.name, "up"), "nominal") * scale + # just extract the shapes from the inputs + h_down = load(down_name, (param_obj.name, "down"), "nominal", scale=scale) + h_up = load(up_name, (param_obj.name, "up"), "nominal", scale=scale) elif param_obj.type.is_rate: if param_obj.transformations.any_from_shape: # just extract the shapes - h_down = sum_hists((param_obj.name, "down"), "nominal") * scale - h_up = sum_hists((param_obj.name, "up"), "nominal") * scale + h_down = load(down_name, (param_obj.name, "down"), "nominal", scale=scale) + h_up = load(up_name, (param_obj.name, "up"), "nominal", scale=scale) + + # in case the transformation is effect_from_shape_if_small, and any of the two relative + # integral effects are above the required "small" threshold, convert the parameter to + # shape-type and drop all transformations that do not apply to shapes + if param_obj.transformations[0] == ParameterTransformation.effect_from_shape_if_small: + n, d, u = integral(h_nom), integral(h_down), integral(h_up) + rel_diff_d = safe_div(abs(n - d), n) + rel_diff_u = safe_div(abs(u - n), n) + if min(rel_diff_d, rel_diff_u) > self.effect_from_shape_if_small_threshold: + param_obj.type = ParameterType.shape + param_obj.transformations = type(param_obj.transformations)( + trafo for trafo in param_obj.transformations[1:] + if trafo not in self.rate_only_trafos + ) else: - # skip the parameter continue - # apply optional transformations - integral = lambda h: h.sum().value + else: + # other effect type that is not handled yet + logger.warning(f"datacard parameter '{param_obj.name}' has unsupported type '{param_obj.type}'") + continue + + # apply optional transformations one by one for trafo in param_obj.transformations: - if trafo == ParameterTransformation.envelope_if_one_sided: - n, d, u = integral(h_nom), integral(h_down), integral(h_up) - if (n - d) * (n - u) > 0: - # one-sided effect, use the larger variation - if abs(n - d) > abs(n - u): - # use the down variation with effect flipped - h_up = 2 * h_nom.copy() - h_down.view() - # TODO: better estimate of the variance - h_up.view().variance = h_down.variances() - else: - # use the up variation with effect flipped - h_down = 2 * h_nom.copy() - h_up.view() - h_down.view().variance = h_up.variances() - - elif trafo == ParameterTransformation.centralize: + if trafo == ParameterTransformation.symmetrize: # get the absolute spread based on integrals n, d, u = integral(h_nom), integral(h_down), integral(h_up) + # skip one sided effects if not (min(d, n) <= n <= max(d, n)): - # skip one sided effects logger.info( - f"skipping shape centralization of parameter '{param_obj.name}' for process " + f"skipping shape symmetrization of parameter '{param_obj.name}' for process " f"'{proc_name}' in category '{cat_name}' as effect is one-sided", ) continue @@ -557,42 +688,64 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: elif trafo == ParameterTransformation.normalize: # normale varied hists to the nominal integral - h_down *= safe_div(integral(h_nom), integral(h_down)) - h_up *= safe_div(integral(h_nom), integral(h_up)) - - else: - # no other transormation is applied at this point - continue - - # empty bins are always filled + n, d, u = integral(h_nom), integral(h_down), integral(h_up) + h_down *= safe_div(n, d) + h_up *= safe_div(n, u) + + elif trafo in {ParameterTransformation.envelope, ParameterTransformation.envelope_if_one_sided}: + d, u = integral(h_down), integral(h_up) + v_nom = h_nom.view() + v_down = h_down.view() + v_up = h_up.view() + # compute masks denoting at which locations a variation is abs larger than the other + diffs_up = v_up.value - v_nom.value + diffs_down = v_down.value - v_nom.value + up_mask = abs(diffs_up) > abs(diffs_down) + down_mask = abs(diffs_down) > abs(diffs_up) + # when only checking one-sided, remove True's from the masks where variations are two-sided + if trafo == ParameterTransformation.envelope_if_one_sided: + one_sided = (diffs_up * diffs_down) > 0 + up_mask &= one_sided + down_mask &= one_sided + # fill values from the larger variation + v_up.value[down_mask] = v_nom.value[down_mask] - diffs_down[down_mask] + v_up.variance[down_mask] = v_down.variance[down_mask] + v_down.value[up_mask] = v_nom.value[up_mask] - diffs_up[up_mask] + v_down.variance[up_mask] = v_up.variance[up_mask] + + elif trafo == ParameterTransformation.envelope_enforce_two_sided: + # envelope creation with enforced two-sidedness + v_nom = h_nom.view() + v_down = h_down.view() + v_up = h_up.view() + # compute masks denoting at which locations a variation is abs larger than the other + abs_diffs_up = abs(v_up.value - v_nom.value) + abs_diffs_down = abs(v_down.value - v_nom.value) + up_mask = abs_diffs_up >= abs_diffs_down + down_mask = ~up_mask + # fill values from the absolute larger variation + v_up.value[up_mask] = v_nom.value[up_mask] + abs_diffs_up[up_mask] + v_up.value[down_mask] = v_nom.value[down_mask] + abs_diffs_down[down_mask] + v_up.variance[down_mask] = v_down.variance[down_mask] + v_down.value[down_mask] = v_nom.value[down_mask] - abs_diffs_down[down_mask] + v_down.value[up_mask] = v_nom.value[up_mask] - abs_diffs_up[up_mask] + v_down.variance[up_mask] = v_up.variance[up_mask] + + # fill empty bins again after all transformations fill_empty(cat_obj, h_down) fill_empty(cat_obj, h_up) - # save them when they represent real shapes - if param_obj.type.is_shape: - down_name = syst_pattern.format( - category=cat_name, - process=proc_name, - parameter=param_obj.name, - direction="Down", - ) - up_name = syst_pattern.format( - category=cat_name, - process=proc_name, - parameter=param_obj.name, - direction="Up", - ) - handle_flow(cat_obj, h_down, down_name) - handle_flow(cat_obj, h_up, up_name) - out_file[down_name] = h_down - out_file[up_name] = h_up - # save the effect __effects[param_obj.name] = ( safe_div(integral(h_down), integral(h_nom)), safe_div(integral(h_up), integral(h_nom)), ) + # save them to file if they have shape-type + if param_obj.type.is_shape: + out_file[down_name] = h_down + out_file[up_name] = h_up + # data handling, first checking if data should be faked, then if real data exists if cat_obj.data_from_processes: # fake data from processes diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index ca8bb633b..43a2f24e4 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -179,10 +179,11 @@ def run(self): # create histograms per shape shift for param_obj in proc_obj.parameters: # skip the parameter when varied hists are not needed - if ( - not param_obj.type.is_shape and - not any(trafo.from_shape for trafo in param_obj.transformations) - ): + need_shapes = ( + (param_obj.type.is_shape and not param_obj.transformations.any_from_rate) or + (param_obj.type.is_rate and param_obj.transformations.any_from_shape) + ) + if not need_shapes: continue # store the varied hists shift_source = ( diff --git a/columnflow/tasks/framework/inference.py b/columnflow/tasks/framework/inference.py index 18b3679bc..79d95f7c9 100644 --- a/columnflow/tasks/framework/inference.py +++ b/columnflow/tasks/framework/inference.py @@ -164,7 +164,10 @@ def combined_config_data(self) -> dict[od.ConfigInst, dict[str, dict | set]]: if config_inst.name not in param_obj.config_data: continue # only add if a shift is required for this parameter - if param_obj.type.is_shape or any(trafo.from_shape for trafo in param_obj.transformations): + if ( + (param_obj.type.is_shape and not param_obj.transformations.any_from_rate) or + (param_obj.type.is_rate and param_obj.transformations.any_from_shape) + ): shift_source = param_obj.config_data[config_inst.name].shift_source for mc_dataset in mc_datasets: data["mc_datasets"][mc_dataset]["shift_sources"].add(shift_source) From 4c18d30119be386f072d30f6493d9d76d421fee2 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 30 Sep 2025 17:47:04 +0200 Subject: [PATCH 094/123] Update tests. --- tests/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index ad14549b8..16e4161f9 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -106,7 +106,7 @@ def test_parameter_spec(self): # Test data name = "test_parameter" type = ParameterType.rate_gauss - transformations = [ParameterTransformation.centralize, ParameterTransformation.symmetrize] + transformations = [ParameterTransformation.symmetrize] config_name = "test_config" config_shift_source = "test_shift_source" effect = 1.5 From b02ee6b5f4fa04048da141f613ea48ad6fe4aaa6 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 1 Oct 2025 13:12:57 +0200 Subject: [PATCH 095/123] Add flip_(smaller|larger)_if_one_sided transormations. --- columnflow/inference/__init__.py | 9 ++- columnflow/inference/cms/datacard.py | 92 +++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 3 deletions(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index e6f4d3315..2214237cf 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -104,6 +104,11 @@ class ParameterTransformation(enum.Enum): a definition that can be subject to the serialization routine. Only applies to shape-type parameters. :cvar envelope_enforce_two_sided: Same as :py:attr:`envelope`, but it enforces that the up (down) variation of the constructed envelope is always above (below) the nominal one. Only applies to shape-type parameters. + :cvar flip_smaller_if_one_sided: For asymmetric rate effects (usually given by two values) that are found to be + one-sided (e.g. after applying :py:attr:`effect_from_shape`), flips the smaller effect to the other side of the + nominal value. Only applies to rate-type parameters. + :cvar flip_larger_if_one_sided: Same as :py:attr:`flip_smaller_if_one_sided`, but flips the larger effect. Only + applies to rate-type parameters. """ none = "none" @@ -117,6 +122,8 @@ class ParameterTransformation(enum.Enum): envelope = "envelope" envelope_if_one_sided = "envelope_if_one_sided" envelope_enforce_two_sided = "envelope_enforce_two_sided" + flip_smaller_if_one_sided = "flip_smaller_if_one_sided" + flip_larger_if_one_sided = "flip_larger_if_one_sided" def __str__(self) -> str: """ @@ -395,7 +402,7 @@ def model_spec(cls) -> DotDict: Returns a dictionary representing the top-level structure of the model. - *categories*: List of :py:meth:`category_spec` objects. - - *parameter_groups*: List of :py:meth:`paramter_group_spec` objects. + - *parameter_groups*: List of :py:meth:`parameter_group_spec` objects. """ return DotDict([ ("categories", []), diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index d22412c0a..0c22651eb 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -68,6 +68,12 @@ class DatacardWriter(object): contributions are one-sided. - :py:attr:`ParameterTransformation.envelope_enforce_two_sided`: Same as :py:attr:`envelope`, but it enforces that the up (down) variation of the constructed envelope is always above (below) the nominal one. + - :py:attr:`ParameterTransformation.flip_smaller_if_one_sided`: For asymmetric (two-valued) rate effects that + are found to be one-sided (e.g. after :py:attr:`ParameterTransformation.effect_from_shape`), flips the + smaller effect to the other side. Rate-type parameters only. + - :py:attr:`ParameterTransformation.flip_larger_if_one_sided`: Same as + :py:attr:`ParameterTransformation.flip_smaller_if_one_sided`, but flips the larger effect. Rate-type + parameters only. .. note:: @@ -99,6 +105,8 @@ class DatacardWriter(object): ParameterTransformation.effect_from_shape_if_small, ParameterTransformation.asymmetrize, ParameterTransformation.asymmetrize_if_large, + ParameterTransformation.flip_smaller_if_one_sided, + ParameterTransformation.flip_larger_if_one_sided, } @classmethod @@ -299,7 +307,7 @@ def write( elif trafo == ParameterTransformation.symmetrize: # skip symmetric effects - if not isinstance(effect, tuple) and len(effect) != 2: + if not isinstance(effect, tuple) or len(effect) != 2: continue # skip one sided effects if not (min(effect) <= 1 <= max(effect)): @@ -321,6 +329,31 @@ def write( continue effect = (2.0 - effect, effect) + elif trafo in { + ParameterTransformation.flip_smaller_if_one_sided, + ParameterTransformation.flip_larger_if_one_sided, + }: + # skip symmetric effects + if not isinstance(effect, tuple) or len(effect) != 2: + continue + # check sidedness and determine which of the two effect values to flip, identified by index + if max(effect) < 1.0: + # both below nominal + flip_index = int( + (effect[1] > effect[0] and ParameterTransformation.flip_smaller_if_one_sided) or + (effect[1] < effect[0] and ParameterTransformation.flip_larger_if_one_sided), + ) + elif min(effect) > 1.0: + # both above nominal + flip_index = int( + (effect[1] > effect[0] and ParameterTransformation.flip_larger_if_one_sided) or + (effect[1] < effect[0] and ParameterTransformation.flip_smaller_if_one_sided), + ) + else: + # skip onde-sided effects + continue + effect = tuple(((2.0 - e) if i == flip_index else e) for i, e in enumerate(effect)) + elif param_obj.type.is_shape: # apply transformations one by one for trafo in param_obj.transformations: @@ -328,6 +361,9 @@ def write( # when the shape was constructed from a rate, reset the effect to 1 effect = 1.0 + # custom hook to adjust effect + effect = self.modify_parameter_effect(cat_name, proc_name, param_obj, effect) + # encode the effect if isinstance(effect, (int, float)): if effect == 0.0: @@ -357,7 +393,7 @@ def write( type_str = "shape" elif types == {ParameterType.rate_gauss, ParameterType.shape}: # when mixing lnN and shape effects, combine expects the "?" type and makes the actual decision - # dependend on the presence of shape variations in the accompanying shape files + # dependend on the presence of shape variations in the accompaying shape files type_str = "?" if not type_str: raise ValueError(f"misconfigured parameter '{param_name}' with incompatible type(s) '{types}'") @@ -731,6 +767,16 @@ def load( v_down.value[up_mask] = v_nom.value[up_mask] - abs_diffs_up[up_mask] v_down.variance[up_mask] = v_up.variance[up_mask] + # custom hook to adjust shapes + h_nom, h_down, h_up = self.modify_parameter_shape( + cat_name, + proc_name, + param_obj, + h_nom, + h_down, + h_up, + ) + # fill empty bins again after all transformations fill_empty(cat_obj, h_down) fill_empty(cat_obj, h_up) @@ -845,3 +891,45 @@ def align_rates_and_parameters( lines = cls.align_lines(rates + parameters) return lines[:n_rate_lines], lines[n_rate_lines:] + + def modify_parameter_effect( + self, + category: str, + process: str, + param_obj: DotDict, + effect: float | tuple[float, float], + ) -> float | tuple[float, float]: + """ + Custom hook to modify the effect of a parameter on a given category and process before it is encoded into the + datacard. By default, this does nothing and simply returns the given effect. + + :param category: The category name. + :param process: The process name. + :param param_obj: The parameter object, following :py:meth:`columnflow.inference.InferenceModel.parameter_spec`. + :param effect: The effect value(s) to be modified. + :returns: The modified effect value(s). + """ + return effect + + def modify_parameter_shape( + self, + category: str, + process: str, + param_obj: DotDict, + h_nom: hist.Hist, + h_down: hist.Hist, + h_up: hist.Hist, + ) -> tuple[hist.Hist, hist.Hist, hist.Hist]: + """ + Custom hook to modify the nominal and varied (down, up) shapes of a parameter on a given category and process + before they are saved to the shapes file. By default, this does nothing and simply returns the given histograms. + + :param category: The category name. + :param process: The process name. + :param param_obj: The parameter object, following :py:meth:`columnflow.inference.InferenceModel.parameter_spec`. + :param h_nom: The nominal histogram. + :param h_down: The down-varied histogram. + :param h_up: The up-varied histogram. + :returns: The modified nominal and varied (down, up) histograms. + """ + return h_nom, h_down, h_up From 7edba46d051aa4b3155262625a7c31d824222dba Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 1 Oct 2025 13:19:11 +0200 Subject: [PATCH 096/123] Typos. --- columnflow/inference/cms/datacard.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 0c22651eb..5a14e9fdc 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -350,7 +350,7 @@ def write( (effect[1] < effect[0] and ParameterTransformation.flip_smaller_if_one_sided), ) else: - # skip onde-sided effects + # skip one-sided effects continue effect = tuple(((2.0 - e) if i == flip_index else e) for i, e in enumerate(effect)) @@ -361,7 +361,7 @@ def write( # when the shape was constructed from a rate, reset the effect to 1 effect = 1.0 - # custom hook to adjust effect + # custom hook to modify the effect effect = self.modify_parameter_effect(cat_name, proc_name, param_obj, effect) # encode the effect @@ -767,7 +767,7 @@ def load( v_down.value[up_mask] = v_nom.value[up_mask] - abs_diffs_up[up_mask] v_down.variance[up_mask] = v_up.variance[up_mask] - # custom hook to adjust shapes + # custom hook to modify the shapes h_nom, h_down, h_up = self.modify_parameter_shape( cat_name, proc_name, From 14c0a10882c2a4ced0b1dc4a141be138a24267a8 Mon Sep 17 00:00:00 2001 From: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:43:32 +0200 Subject: [PATCH 097/123] allow multiple processes per dataset in datacard writer (#733) * allow multiple processes per dataset in datacard writer * cleanup * add arguments to modify_process_hist hook * Disentangle process-to-dataset mapping. * Update inference model instance cache key. * hotfix flip_if_one_sided * Drop effect_from_shape_if_small in favor of effect_from_shape_if_flat. * Typo. --------- Co-authored-by: Mathis Frahm Co-authored-by: Marcel Rieger Co-authored-by: Marcel R. --- columnflow/inference/__init__.py | 13 ++-- columnflow/inference/cms/datacard.py | 64 ++++++++++++------- columnflow/tasks/cms/inference.py | 64 +++++++++---------- columnflow/tasks/framework/inference.py | 85 ++++++++++++++----------- 4 files changed, 125 insertions(+), 101 deletions(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 2214237cf..46351ed81 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -88,8 +88,9 @@ class ParameterTransformation(enum.Enum): usually attributed to rate-type parameters. Only applies to shape-type parameters. :cvar effect_from_shape: Derive the effect of a rate-type parameter using the overall, integral effect of shape variations. Only applies to rate-type parameters. - :cvar effect_from_shape_if_small: Same as :py:attr:`effect_from_shape`, but depending on a threshold on the size of - the effect which can be subject to the serialization routine. Only applies to rate-type parameters. + :cvar effect_from_shape_if_flat: Same as :py:attr:`effect_from_shape`, but applies only if both shape variations are + reasonably flat. The definition of "reasonably flat" can be subject to the serialization routine. Only applies + to rate-type parameters. :cvar symmetrize: The overall (integral) effect of up and down variations is measured and centralized, updating the variations such that they are equidistant to the nominal one. Can apply to both rate- and shape-type parameters. :cvar asymmetrize: The symmetric effect on a rate-type parameter (usually given as a single value) is converted into @@ -114,7 +115,7 @@ class ParameterTransformation(enum.Enum): none = "none" effect_from_rate = "effect_from_rate" effect_from_shape = "effect_from_shape" - effect_from_shape_if_small = "effect_from_shape_if_small" + effect_from_shape_if_flat = "effect_from_shape_if_flat" symmetrize = "symmetrize" asymmetrize = "asymmetrize" asymmetrize_if_large = "asymmetrize_if_large" @@ -142,7 +143,7 @@ def from_shape(self) -> bool: """ return self in { self.effect_from_shape, - self.effect_from_shape_if_small, + self.effect_from_shape_if_flat, } @property @@ -224,7 +225,9 @@ def __str__(self) -> str: class InferenceModelMeta(CachedDerivableMeta): def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: - return freeze((cls, kwargs.get("inst_dict", {}))) + config_insts = args[0] + config_names = tuple(sorted(config_inst.name for config_inst in config_insts)) + return freeze((cls, config_names, kwargs.get("inst_dict", {}))) class InferenceModel(Derivable, metaclass=InferenceModelMeta): diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 5a14e9fdc..d4652a7fe 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -49,9 +49,13 @@ class DatacardWriter(object): Shape-type parameters only. - :py:attr:`ParameterTransformation.effect_from_shape`: Converts the integral effect of shape variations to an asymmetric rate-style effect. Rate-type parameters only. - - :py:attr:`ParameterTransformation.effect_from_shape_if_small`: Same as above with a default threshold of 2%. - Configurable via *effect_from_shape_if_small_threshold*. The parameter should initially be of rate-type, but - in case the threshold is not met, the effect is interpreted as shape-type. + - :py:attr:`ParameterTransformation.effect_from_shape_if_flat`: Same as above but only applies to cases where + both shape variations are reasonably flat. The flatness per varied shape is determined by two criteria that + both must be met: 1. the maximum relative outlier of bin contents with respect to their mean (defaults to + 20%, configurable via *effect_from_shape_if_flat_max_outlier*), 2. the deviation / dispersion of bin + contents, i.e., the square root of the variance of bin contents, relative to their mean (defaults to 10%, + configurable via *effect_from_shape_if_flat_max_deviation*). The parameter should initially be of rate-type, + but in case the criteria are not met, the effect is interpreted as shape-type. - :py:attr:`ParameterTransformation.symmetrize`: Changes up and down variations of either rate effects and shapes to symmetrize them around the nominal value. For rate-type parameters, this has no effect if the effect strength was provided by a single value. There is no conversion into a single value and consequently, @@ -79,7 +83,7 @@ class DatacardWriter(object): If used, the transformations :py:attr:`ParameterTransformation.effect_from_rate`, :py:attr:`ParameterTransformation.effect_from_shape`, and - :py:attr:`ParameterTransformation.effect_from_shape_if_small` must be the first element in the sequence of + :py:attr:`ParameterTransformation.effect_from_shape_if_flat` must be the first element in the sequence of transformations to be applied. The remaining transformations are applied in order based on the outcome of the effect conversion. """ @@ -91,7 +95,7 @@ class DatacardWriter(object): first_index_trafos = { ParameterTransformation.effect_from_rate, ParameterTransformation.effect_from_shape, - ParameterTransformation.effect_from_shape_if_small, + ParameterTransformation.effect_from_shape_if_flat, } shape_only_trafos = { ParameterTransformation.effect_from_rate, @@ -102,7 +106,7 @@ class DatacardWriter(object): } rate_only_trafos = { ParameterTransformation.effect_from_shape, - ParameterTransformation.effect_from_shape_if_small, + ParameterTransformation.effect_from_shape_if_flat, ParameterTransformation.asymmetrize, ParameterTransformation.asymmetrize_if_large, ParameterTransformation.flip_smaller_if_one_sided, @@ -151,7 +155,8 @@ def __init__( histograms: DatacardHists, rate_precision: int = 4, effect_precision: int = 4, - effect_from_shape_if_small_threshold: float = 0.02, + effect_from_shape_if_flat_max_outlier: float = 0.2, + effect_from_shape_if_flat_max_deviation: float = 0.1, asymmetrize_if_large_threshold: float = 0.2, ) -> None: super().__init__() @@ -161,7 +166,8 @@ def __init__( self.histograms = histograms self.rate_precision = rate_precision self.effect_precision = effect_precision - self.effect_from_shape_if_small_threshold = effect_from_shape_if_small_threshold + self.effect_from_shape_if_flat_max_outlier = effect_from_shape_if_flat_max_outlier + self.effect_from_shape_if_flat_max_deviation = effect_from_shape_if_flat_max_deviation self.asymmetrize_if_large_threshold = asymmetrize_if_large_threshold # validate the inference model @@ -336,18 +342,20 @@ def write( # skip symmetric effects if not isinstance(effect, tuple) or len(effect) != 2: continue + flip_larger = trafo == ParameterTransformation.flip_larger_if_one_sided + flip_smaller = trafo == ParameterTransformation.flip_smaller_if_one_sided # check sidedness and determine which of the two effect values to flip, identified by index if max(effect) < 1.0: # both below nominal flip_index = int( - (effect[1] > effect[0] and ParameterTransformation.flip_smaller_if_one_sided) or - (effect[1] < effect[0] and ParameterTransformation.flip_larger_if_one_sided), + (effect[1] > effect[0] and flip_larger) or + (effect[1] < effect[0] and flip_smaller), ) elif min(effect) > 1.0: # both above nominal flip_index = int( - (effect[1] > effect[0] and ParameterTransformation.flip_larger_if_one_sided) or - (effect[1] < effect[0] and ParameterTransformation.flip_smaller_if_one_sided), + (effect[1] > effect[0] and flip_smaller) or + (effect[1] < effect[0] and flip_larger), ) else: # skip one-sided effects @@ -684,19 +692,27 @@ def load( h_down = load(down_name, (param_obj.name, "down"), "nominal", scale=scale) h_up = load(up_name, (param_obj.name, "up"), "nominal", scale=scale) - # in case the transformation is effect_from_shape_if_small, and any of the two relative - # integral effects are above the required "small" threshold, convert the parameter to - # shape-type and drop all transformations that do not apply to shapes - if param_obj.transformations[0] == ParameterTransformation.effect_from_shape_if_small: - n, d, u = integral(h_nom), integral(h_down), integral(h_up) - rel_diff_d = safe_div(abs(n - d), n) - rel_diff_u = safe_div(abs(u - n), n) - if min(rel_diff_d, rel_diff_u) > self.effect_from_shape_if_small_threshold: - param_obj.type = ParameterType.shape - param_obj.transformations = type(param_obj.transformations)( - trafo for trafo in param_obj.transformations[1:] - if trafo not in self.rate_only_trafos + # in case the transformation is effect_from_shape_if_flat, and any of the two variations + # do not qualify as "flat", convert the parameter to shape-type and drop all transformations + # that do not apply to shapes + if param_obj.transformations[0] == ParameterTransformation.effect_from_shape_if_flat: + # check if flatness criteria are met + for h in [h_down, h_up]: + values = h.view().value + mean, std = values.mean(), values.std() + rel_deviation = safe_div(std, mean) + max_rel_outlier = safe_div(max(abs(values - mean)), mean) + is_flat = ( + rel_deviation <= self.effect_from_shape_if_flat_max_deviation and + max_rel_outlier <= self.effect_from_shape_if_flat_max_outlier ) + if not is_flat: + param_obj.type = ParameterType.shape + param_obj.transformations = type(param_obj.transformations)( + trafo for trafo in param_obj.transformations[1:] + if trafo not in self.rate_only_trafos + ) + break else: continue diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index 43a2f24e4..f01812a16 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -69,7 +69,13 @@ def run(self): data = self.combined_config_data[config_inst] input_hists[config_inst] = self.load_process_hists( config_inst, - list(data["mc_datasets"]) + list(data["data_datasets"]), + { + dataset_name: list(data["mc_datasets"][dataset_name]["proc_names"]) + for dataset_name in data["mc_datasets"] + } | { + dataset_name: ["data"] + for dataset_name in data["data_datasets"] + }, variable, inputs[config_inst.name], ) @@ -127,43 +133,33 @@ def run(self): f"dynamic datacard process object misses 'process' entry in config data for " f"'{config_inst.name}': {proc_obj}", ) - process_insts = [config_inst.get_process(process_name)] + process_inst = config_inst.get_process(process_name) else: - process_insts = [ - config_inst.get_dataset(dataset_name).processes.get_first() - for dataset_name in proc_obj.config_data[config_inst.name].mc_datasets - ] - - # collect per-process histograms - h_procs = [] - for process_inst in process_insts: - # extract the histogram for the process - # (removed from hists to eagerly cleanup memory) - h_proc = _input_hists[config_inst].pop(process_inst, None) - if h_proc is None: - self.logger.error( - f"found no histogram to model datacard process '{proc_obj.name}', please check your " - f"inference model '{self.inference_model}'", - ) - continue - - # select relevant categories - h_proc = h_proc[{ - "category": [ - hist.loc(c.name) - for c in leaf_category_insts - if c.name in h_proc.axes["category"] - ], - }] - h_proc = h_proc[{"category": sum}] - - h_procs.append(h_proc) + process_inst = config_inst.get_process( + proc_obj.name + if proc_obj.name == "data" + else proc_obj.config_data[config_inst.name].process, + ) - if h_procs is None: + # extract the histogram for the process + # (removed from hists to eagerly cleanup memory) + h_proc = _input_hists[config_inst].get(process_inst, None) + if h_proc is None: + self.logger.error( + f"found no histogram to model datacard process '{proc_obj.name}', please check your " + f"inference model '{self.inference_model}'", + ) continue - # combine them - h_proc = sum(h_procs[1:], h_procs[0].copy()) + # select relevant categories + h_proc = h_proc[{ + "category": [ + hist.loc(c.name) + for c in leaf_category_insts + if c.name in h_proc.axes["category"] + ], + }] + h_proc = h_proc[{"category": sum}] # create the nominal hist datacard_hists[cat_obj.name].setdefault(proc_obj.name, {}).setdefault(config_inst.name, {}) diff --git a/columnflow/tasks/framework/inference.py b/columnflow/tasks/framework/inference.py index 79d95f7c9..7cff1ae88 100644 --- a/columnflow/tasks/framework/inference.py +++ b/columnflow/tasks/framework/inference.py @@ -116,7 +116,7 @@ def combined_config_data(self) -> dict[od.ConfigInst, dict[str, dict | set]]: "variables": set(), # plain set of names of real data datasets "data_datasets": set(), - # per name of mc dataset, the set of shift sources and the name of the datacard process object + # per mc dataset name, the set of shift sources and the names processes to be extracted from them "mc_datasets": {}, } for config_inst in self.config_insts @@ -149,15 +149,12 @@ def combined_config_data(self) -> dict[od.ConfigInst, dict[str, dict | set]]: for dataset_name in mc_datasets: if dataset_name not in data["mc_datasets"]: data["mc_datasets"][dataset_name] = { - "proc_name": proc_obj.name, "shift_sources": set(), + "proc_names": set(), } - elif data["mc_datasets"][dataset_name]["proc_name"] != proc_obj.name: - raise ValueError( - f"dataset '{dataset_name}' was already assigned to datacard process " - f"'{data['mc_datasets'][dataset_name]['proc_name']}' and cannot be re-assigned to " - f"'{proc_obj.name}' in config '{config_inst.name}'", - ) + data["mc_datasets"][dataset_name]["proc_names"].add( + proc_obj.config_data[config_inst.name].process, + ) # shift sources for param_obj in proc_obj.parameters: @@ -234,7 +231,7 @@ def requires(self): def load_process_hists( self, config_inst: od.Config, - dataset_names: list[str], + dataset_processes: dict[str, list[str]], variable: str, inputs: dict, ) -> dict[str, dict[od.Process, hist.Hist]]: @@ -242,17 +239,7 @@ def load_process_hists( hists: dict[od.Process, hist.Hist] = {} with self.publish_step(f"extracting '{variable}' for config {config_inst.name} ..."): - for dataset_name in dataset_names: - dataset_inst = config_inst.get_dataset(dataset_name) - process_inst = dataset_inst.processes.get_first() - - # for real data, fallback to the main data process - if process_inst.is_data: - process_inst = config_inst.get_process("data") - - # gather all subprocesses for a full query later - sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] - + for dataset_name, process_names in dataset_processes.items(): # open the histogram and work on a copy inp = inputs[dataset_name]["collection"][0]["hists"][variable] try: @@ -263,31 +250,53 @@ def load_process_hists( f"'{config_inst.name}' from {inp.abspath}", ) from e - # there must be at least one matching sub process - if not any(p.name in h.axes["process"] for p in sub_process_insts): - raise Exception(f"no '{variable}' histograms found for process '{process_inst.name}'") - - # select and reduce over relevant processes - h = h[{"process": [hist.loc(p.name) for p in sub_process_insts if p.name in h.axes["process"]]}] - h = h[{"process": sum}] - - # additional custom reductions - h = self.modify_process_hist(process_inst, h) - - # store it - if process_inst in hists: - hists[process_inst] += h - else: - hists[process_inst] = h + # determine processes to extract + process_insts = [config_inst.get_process(name) for name in process_names] + + # loop over all proceses assigned to this dataset + for process_inst in process_insts: + # gather all subprocesses for a full query later + sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] + + # there must be at least one matching sub process + if not any(p.name in h.axes["process"] for p in sub_process_insts): + raise Exception(f"no '{variable}' histograms found for process '{process_inst.name}'") + + # select and reduce over relevant processes + h_proc = h[{ + "process": [hist.loc(p.name) for p in sub_process_insts if p.name in h.axes["process"]], + }] + h_proc = h_proc[{"process": sum}] + + # additional custom reductions + h_proc = self.modify_process_hist( + config_inst=config_inst, + process_inst=process_inst, + variable=variable, + h=h_proc, + ) + + # store it + if process_inst in hists: + hists[process_inst] += h_proc + else: + hists[process_inst] = h_proc return hists - def modify_process_hist(self, process_inst: od.Process, h: hist.Hist) -> hist.Hist: + def modify_process_hist( + self, + config_inst: od.Config, + process_inst: od.Process, + variable: str, + h: hist.Hist, + ) -> hist.Hist: """ Hook to modify a process histogram after it has been loaded. This can be helpful to reduce memory early on. + :param config_inst: The config instance the histogram belongs to. :param process_inst: The process instance the histogram belongs to. - :param histo: The histogram to modify. + :param h: The histogram to modify. :return: The modified histogram. """ return h From af56133dfbe1088e82576d28376eede42d5c292c Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 2 Oct 2025 11:11:12 +0200 Subject: [PATCH 098/123] Hotfix process object selection for multi-config datacards. --- columnflow/inference/__init__.py | 16 ++++++---------- columnflow/tasks/cms/inference.py | 4 ++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 46351ed81..43870d29e 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -40,12 +40,10 @@ class ParameterType(enum.Enum): rate_unconstrained = "rate_unconstrained" shape = "shape" - def __str__(self) -> str: - """ - Returns the string representation of the parameter type. + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.value}>" - :returns: The string representation of the parameter type. - """ + def __str__(self) -> str: return self.value @property @@ -126,12 +124,10 @@ class ParameterTransformation(enum.Enum): flip_smaller_if_one_sided = "flip_smaller_if_one_sided" flip_larger_if_one_sided = "flip_larger_if_one_sided" - def __str__(self) -> str: - """ - Returns the string representation of the parameter transformation. + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.value}>" - :returns: The string representation of the parameter transformation. - """ + def __str__(self) -> str: return self.value @property diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index f01812a16..dc37f82ec 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -126,6 +126,10 @@ def run(self): if config_data.data_datasets and not cat_obj.data_from_processes: proc_objs.append(self.inference_model_inst.process_spec(name="data")) for proc_obj in proc_objs: + # skip the process objects if it does not contribute to this config_inst + if config_inst.name not in proc_obj.config_data: + continue + # get all process instances (keys in _input_hists) to be combined if proc_obj.is_dynamic: if not (process_name := proc_obj.config_data[config_inst.name].get("process", None)): From 76c3f62ffc02b0edaf071cea376b16ff0363f43e Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 2 Oct 2025 11:32:53 +0200 Subject: [PATCH 099/123] Hotfix variable shape? type in combine datacard writer. --- columnflow/inference/cms/datacard.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index d4652a7fe..951e7e086 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -400,9 +400,10 @@ def write( elif _type == ParameterType.shape: type_str = "shape" elif types == {ParameterType.rate_gauss, ParameterType.shape}: - # when mixing lnN and shape effects, combine expects the "?" type and makes the actual decision - # dependend on the presence of shape variations in the accompaying shape files - type_str = "?" + # when mixing lnN and shape effects, combine expects the "shape?" type and makes the actual decision + # dependend on the presence of shape variations in the accompaying shape files, see + # https://cms-analysis.github.io/HiggsAnalysis-CombinedLimit/v10.2.X/part2/settinguptheanalysis/?h=shape%3F#template-shape-uncertainties # noqa + type_str = "shape?" if not type_str: raise ValueError(f"misconfigured parameter '{param_name}' with incompatible type(s) '{types}'") blocks.tabular_parameters.append([param_name, type_str, effects]) From 77c36dad7f38ddc1a1a51f4eb449e1655bb6cd8d Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Thu, 2 Oct 2025 16:39:35 +0200 Subject: [PATCH 100/123] Hotfix abs eta in cms muon weight producer. --- columnflow/production/cms/muon.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/columnflow/production/cms/muon.py b/columnflow/production/cms/muon.py index 071b3122f..ee790a7d5 100644 --- a/columnflow/production/cms/muon.py +++ b/columnflow/production/cms/muon.py @@ -84,14 +84,14 @@ def muon_weights( Optionally, a *muon_mask* can be supplied to compute the scale factor weight based only on a subset of muons. """ - # flat absolute eta and pt views - abs_eta = flat_np_view(abs(events.Muon["eta"][muon_mask]), axis=1) + # flat eta and pt views + eta = flat_np_view(events.Muon["eta"][muon_mask], axis=1) pt = flat_np_view(events.Muon["pt"][muon_mask], axis=1) variable_map = { "year": self.muon_config.campaign, - "abseta": abs_eta, - "eta": abs_eta, + "eta": eta, + "abseta": abs(eta), "pt": pt, } From 80bff98157e4bd942fef9af77360ed1a190d08d4 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 4 Oct 2025 15:05:27 +0200 Subject: [PATCH 101/123] Make default_remote_claw_sandbox configurable via law.cfg. --- analysis_templates/cms_minimal/law.cfg | 3 +++ columnflow/tasks/framework/remote.py | 7 +++++-- law.cfg | 3 +++ modules/law | 2 +- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index d5e69e9f1..fe0ed4bb1 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -56,6 +56,9 @@ default_create_selection_hists: False # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False +# the name of a sandbox to use for tasks in remote jobs initially (invoked with claw when set) +default_remote_claw_sandbox: None + # some remote workflow parameter defaults # (resources like memory and disk can also be set in [resources] with more granularity) htcondor_flavor: $CF_HTCONDOR_FLAVOR diff --git a/columnflow/tasks/framework/remote.py b/columnflow/tasks/framework/remote.py index 6ba3bb72d..ef1cb807c 100644 --- a/columnflow/tasks/framework/remote.py +++ b/columnflow/tasks/framework/remote.py @@ -342,16 +342,19 @@ def __post_init__(self): ) +_default_remove_claw_sandbox = law.config.get_expanded("analysis", "default_remote_claw_sandbox", None) or law.NO_STR + + class RemoteWorkflowMixin(AnalysisTask): """ Mixin class for custom remote workflows adding common functionality. """ remote_claw_sandbox = luigi.Parameter( - default=law.NO_STR, + default=_default_remove_claw_sandbox, significant=False, description="the name of a non-dev sandbox to use in remote jobs for the 'claw' executable rather than using " - "using 'law' directly; not used when empty; default: empty", + f"using 'law' directly; not used when empty; default: {_default_remove_claw_sandbox}", ) skip_destination_info: bool = False diff --git a/law.cfg b/law.cfg index 763eaf103..2fa6c1065 100644 --- a/law.cfg +++ b/law.cfg @@ -50,6 +50,9 @@ default_create_selection_hists: True # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False +# the name of a sandbox to use for tasks in remote jobs initially (invoked with claw when set) +default_remote_claw_sandbox: None + # some remote workflow parameter defaults # (resources like memory and disk can also be set in [resources] with more granularity) htcondor_flavor: $CF_HTCONDOR_FLAVOR diff --git a/modules/law b/modules/law index b881450a1..44b98b7dc 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit b881450a1927bf30c6e504da6ed6f394e7e49b93 +Subproject commit 44b98b7dcd434badd003fd498eaf399e14c3ee53 From f247eeb1d72d0d90e5f6e4dc50b1cb8c9909a146 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sat, 4 Oct 2025 16:42:30 +0200 Subject: [PATCH 102/123] Hotfix version lookup. --- columnflow/tasks/framework/base.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index e0e791fc7..177ed8e84 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -23,7 +23,7 @@ import order as od from columnflow.columnar_util import mandatory_coffea_columns, Route, ColumnCollection -from columnflow.util import get_docs_url, is_regex, prettify, DotDict +from columnflow.util import get_docs_url, is_regex, prettify, DotDict, freeze from columnflow.types import Sequence, Callable, Any, T @@ -186,12 +186,19 @@ def req_params(cls, inst: AnalysisTask, **kwargs) -> dict[str, Any]: # build the params params = super().req_params(inst, **kwargs) - # when not explicitly set in kwargs and no global value was defined on the cli for the task - # family, evaluate and use the default value + # evaluate and use the default version in case + # - "version" is an actual parameter object of cls, and + # - "version" is not explicitly set in kwargs, and + # - no global value was defined on the cli for the task family, and + # - if cls and inst belong to the same family, they differ in the keys used for the config lookup if ( isinstance(getattr(cls, "version", None), luigi.Parameter) and "version" not in kwargs and - not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") + not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") and + ( + cls.task_family != inst.task_family or + freeze(cls.get_config_lookup_keys(params)) != freeze(inst.get_config_lookup_keys(params)) + ) ): default_version = cls.get_default_version(inst, params) if default_version and default_version != law.NO_STR: From 781cb8ffec372fe45783fe0dbeac597d69930b07 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Sun, 5 Oct 2025 10:40:24 +0200 Subject: [PATCH 103/123] Raise explicit error in reduction on option type masks. --- columnflow/reduction/util.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/columnflow/reduction/util.py b/columnflow/reduction/util.py index e9c4ab826..6dd5c7637 100644 --- a/columnflow/reduction/util.py +++ b/columnflow/reduction/util.py @@ -89,8 +89,21 @@ def create_collections_from_masks( # add collections for dst_name in dst_names: - object_mask = object_masks[src_name, dst_name] - dst_collection = events[src_name][object_mask] + object_mask = ak.drop_none(object_masks[src_name, dst_name]) + try: + dst_collection = events[src_name][object_mask] + except ValueError as e: + # check f the object mask refers to an option type + mask_type = getattr(getattr(ak.type(object_mask), "content", None), "cotent", None) + if isinstance(mask_type, ak.types.OptionType): + msg = ( + f"object mask to create dst collection '{dst_name}' from src collection '{src_name}' uses an " + f"option type '{object_mask.typestr}' which is not supported; please adjust your mask to not " + "contain missing values (most likely by using ak.drop_none() in your event selection)" + ) + raise ValueError(msg) from e + # no further custom handling, re-raise + raise e events = set_ak_column(events, dst_name, dst_collection) return events From 0757a60063dc5d88ff226ed9c3e663567361555f Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Mon, 6 Oct 2025 14:13:27 +0200 Subject: [PATCH 104/123] Add req helpers to mixins. (#738) --- columnflow/tasks/framework/mixins.py | 243 ++++++++++++++++++--------- 1 file changed, 164 insertions(+), 79 deletions(-) diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index 3bbebce9c..659058019 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -91,21 +91,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrator"} return super().req_params(inst, **kwargs) - @property - def calibrator_repr(self) -> str: - """ - Return a string representation of the calibrator class. - """ - return self.build_repr(self.array_function_cls_repr(self.calibrator)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "calibrator", f"calib__{self.calibrator_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -125,6 +110,21 @@ def get_config_lookup_keys( return keys + @property + def calibrator_repr(self) -> str: + """ + Return a string representation of the calibrator class. + """ + return self.build_repr(self.array_function_cls_repr(self.calibrator)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "calibrator", f"calib__{self.calibrator_repr}") + return parts + class CalibratorMixin(ArrayFunctionInstanceMixin, CalibratorClassMixin): """ @@ -198,6 +198,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_calibrator(cls, inst: CalibratorMixin, **kwargs) -> CalibratorMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + calibrator instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # calibrator_inst and known_shifts must be set to None to by-pass calibrator instance cache lookup and thus, + # also full parameter resolution + kwargs.setdefault("calibrator_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -477,6 +494,25 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: } return super().req_params(inst, **kwargs) + @classmethod + def get_config_lookup_keys( + cls, + inst_or_params: SelectorClassMixin | dict[str, Any], + ) -> law.util.InsertiableDict: + keys = super().get_config_lookup_keys(inst_or_params) + + # add the selector name + selector = ( + inst_or_params.get("selector") + if isinstance(inst_or_params, dict) + else getattr(inst_or_params, "selector", None) + ) + if selector not in (law.NO_STR, None, ""): + prefix = "sel" + keys[prefix] = f"{prefix}_{selector}" + + return keys + @property def selector_repr(self) -> str: """ @@ -498,25 +534,6 @@ def store_parts(self) -> law.util.InsertableDict: parts.insert_after(self.config_store_anchor, "selector", f"sel__{self.selector_repr}") return parts - @classmethod - def get_config_lookup_keys( - cls, - inst_or_params: SelectorClassMixin | dict[str, Any], - ) -> law.util.InsertiableDict: - keys = super().get_config_lookup_keys(inst_or_params) - - # add the selector name - selector = ( - inst_or_params.get("selector") - if isinstance(inst_or_params, dict) - else getattr(inst_or_params, "selector", None) - ) - if selector not in (law.NO_STR, None, ""): - prefix = "sel" - keys[prefix] = f"{prefix}_{selector}" - - return keys - class SelectorMixin(ArrayFunctionInstanceMixin, SelectorClassMixin): """ @@ -585,6 +602,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_selector(cls, inst: SelectorMixin, **kwargs) -> SelectorMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + selector instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # selector_inst and known_shifts must be set to None to by-pass selector instance cache lookup and thus, also + # full parameter resolution + kwargs.setdefault("selector_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -676,21 +710,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"reducer"} return super().req_params(inst, **kwargs) - @property - def reducer_repr(self) -> str: - """ - Return a string representation of the reducer class. - """ - return self.build_repr(self.array_function_cls_repr(self.reducer)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "reducer", f"red__{self.reducer_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -710,6 +729,21 @@ def get_config_lookup_keys( return keys + @property + def reducer_repr(self) -> str: + """ + Return a string representation of the reducer class. + """ + return self.build_repr(self.array_function_cls_repr(self.reducer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "reducer", f"red__{self.reducer_repr}") + return parts + class ReducerMixin(ArrayFunctionInstanceMixin, ReducerClassMixin): """ @@ -783,6 +817,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_reducer(cls, inst: ReducerMixin, **kwargs) -> ReducerMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + reducer instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # reducer_inst and known_shifts must be set to None to by-pass reducer instance cache lookup and thus, also full + # parameter resolution + kwargs.setdefault("reducer_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -852,21 +903,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producer"} return super().req_params(inst, **kwargs) - @property - def producer_repr(self) -> str: - """ - Return a string representation of the producer class. - """ - return self.build_repr(self.array_function_cls_repr(self.producer)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "producer", f"prod__{self.producer_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -886,6 +922,21 @@ def get_config_lookup_keys( return keys + @property + def producer_repr(self) -> str: + """ + Return a string representation of the producer class. + """ + return self.build_repr(self.array_function_cls_repr(self.producer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "producer", f"prod__{self.producer_repr}") + return parts + class ProducerMixin(ArrayFunctionInstanceMixin, ProducerClassMixin): """ @@ -959,6 +1010,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_producer(cls, inst: ProducerMixin, **kwargs) -> ProducerMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + producer instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # producer_inst and known_shifts must be set to None to by-pass producer instance cache lookup and thus, also + # full parameter resolution + kwargs.setdefault("producer_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -1680,21 +1748,6 @@ def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]: kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"hist_producer"} return super().req_params(inst, **kwargs) - @property - def hist_producer_repr(self) -> str: - """ - Return a string representation of the hist producer class. - """ - return self.build_repr(self.array_function_cls_repr(self.hist_producer)) - - def store_parts(self) -> law.util.InsertableDict: - """ - :return: Dictionary with parts that will be translated into an output directory path. - """ - parts = super().store_parts() - parts.insert_after(self.config_store_anchor, "hist_producer", f"hist__{self.hist_producer_repr}") - return parts - @classmethod def get_config_lookup_keys( cls, @@ -1714,6 +1767,21 @@ def get_config_lookup_keys( return keys + @property + def hist_producer_repr(self) -> str: + """ + Return a string representation of the hist producer class. + """ + return self.build_repr(self.array_function_cls_repr(self.hist_producer)) + + def store_parts(self) -> law.util.InsertableDict: + """ + :return: Dictionary with parts that will be translated into an output directory path. + """ + parts = super().store_parts() + parts.insert_after(self.config_store_anchor, "hist_producer", f"hist__{self.hist_producer_repr}") + return parts + class HistProducerMixin(ArrayFunctionInstanceMixin, HistProducerClassMixin): """ @@ -1790,6 +1858,23 @@ def get_known_shifts( super().get_known_shifts(params, shifts) + @classmethod + def req_other_hist_producer(cls, inst: HistProducerMixin, **kwargs) -> HistProducerMixin: + """ + Same as :py:meth:`req` but overwrites specific arguments for instantiation that simplify requesting a different + hist producer instance. + + :param inst: The reference instance to request parameters from. + :param kwargs: Additional arguments forwarded to :py:meth:`req`. + :return: A new instance of *this* class. + """ + # hist_producer_inst and known_shifts must be set to None to by-pass hist producer instance cache lookup and + # thus, also full parameter resolution + kwargs.setdefault("hist_producer_inst", None) + kwargs.setdefault("known_shifts", None) + + return cls.req(inst, **kwargs) + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) From b5081ad7fd271712c81f64aef7f41dddad17a09f Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Mon, 6 Oct 2025 15:09:40 +0200 Subject: [PATCH 105/123] Accelerate loading time and imports (#737) * Accelerate imports, allowing only 3rd party np and ak. * Fix missing sub imports. * Add missing plt import. * Adjust imports in tests. * Add missing coffea sub import. --- columnflow/__init__.py | 38 ++++++++-------- columnflow/__version__.py | 1 + columnflow/calibration/cms/jets.py | 8 +++- columnflow/columnar_util.py | 20 +++++---- columnflow/hist_util.py | 15 ++++++- columnflow/histogramming/__init__.py | 11 +++-- columnflow/histogramming/default.py | 15 ++++--- columnflow/inference/__init__.py | 2 +- columnflow/inference/cms/datacard.py | 17 +++---- columnflow/ml/__init__.py | 7 +-- columnflow/plotting/plot_all.py | 23 +++++++--- columnflow/plotting/plot_functions_1d.py | 16 ++++--- columnflow/plotting/plot_functions_2d.py | 20 +++++---- columnflow/plotting/plot_ml_evaluation.py | 54 ++++++++++++++--------- columnflow/plotting/plot_util.py | 25 +++++++---- columnflow/production/cms/btag.py | 1 + columnflow/production/cms/dy.py | 4 +- columnflow/production/normalization.py | 6 +-- columnflow/production/util.py | 18 +++++--- columnflow/selection/cms/json_filter.py | 6 +-- columnflow/tasks/cms/inference.py | 5 ++- columnflow/tasks/framework/histograms.py | 4 +- columnflow/tasks/framework/inference.py | 8 +++- columnflow/tasks/ml.py | 1 - columnflow/tasks/selection.py | 1 - columnflow/types.py | 4 +- setup.sh | 7 ++- tests/test_columnar_util.py | 1 + 28 files changed, 209 insertions(+), 129 deletions(-) diff --git a/columnflow/__init__.py b/columnflow/__init__.py index dda0895c7..19177c644 100644 --- a/columnflow/__init__.py +++ b/columnflow/__init__.py @@ -6,6 +6,7 @@ import os import re +import time import logging import law @@ -17,7 +18,7 @@ ) -logger = logging.getLogger(__name__) +logger = logging.getLogger(f"{__name__}_module_loader") # version info m = re.match(r"^(\d+)\.(\d+)\.(\d+)(-.+)?$", __version__) @@ -79,62 +80,59 @@ for fs in law.config.get_expanded("outputs", "wlcg_file_systems", [], split_csv=True) ] - # initialize producers, calibrators, selectors, categorizers, ml models and stat models + # initialize producers, calibrators, selectors, reducers, categorizers, ml models, hist producers and stat models from columnflow.util import maybe_import + def load(module, group): + t0 = time.perf_counter() + maybe_import(module) + duration = law.util.human_duration(seconds=time.perf_counter() - t0) + logger.debug(f"loaded {group} module '{module}', took {duration}") + import columnflow.calibration # noqa if law.config.has_option("analysis", "calibration_modules"): for m in law.config.get_expanded("analysis", "calibration_modules", [], split_csv=True): - logger.debug(f"loading calibration module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "calibration") import columnflow.selection # noqa if law.config.has_option("analysis", "selection_modules"): for m in law.config.get_expanded("analysis", "selection_modules", [], split_csv=True): - logger.debug(f"loading selection module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "selection") import columnflow.reduction # noqa if law.config.has_option("analysis", "reduction_modules"): for m in law.config.get_expanded("analysis", "reduction_modules", [], split_csv=True): - logger.debug(f"loading reduction module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "reduction") import columnflow.production # noqa if law.config.has_option("analysis", "production_modules"): for m in law.config.get_expanded("analysis", "production_modules", [], split_csv=True): - logger.debug(f"loading production module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "production") import columnflow.histogramming # noqa if law.config.has_option("analysis", "hist_production_modules"): for m in law.config.get_expanded("analysis", "hist_production_modules", [], split_csv=True): - logger.debug(f"loading hist production module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "hist production") import columnflow.categorization # noqa if law.config.has_option("analysis", "categorization_modules"): for m in law.config.get_expanded("analysis", "categorization_modules", [], split_csv=True): - logger.debug(f"loading categorization module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "categorization") import columnflow.ml # noqa if law.config.has_option("analysis", "ml_modules"): for m in law.config.get_expanded("analysis", "ml_modules", [], split_csv=True): - logger.debug(f"loading ml module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "ml") import columnflow.inference # noqa if law.config.has_option("analysis", "inference_modules"): for m in law.config.get_expanded("analysis", "inference_modules", [], split_csv=True): - logger.debug(f"loading inference module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "inference") # preload all task modules so that task parameters are globally known and accepted if law.config.has_section("modules"): for m in law.config.options("modules"): - logger.debug(f"loading task module '{m}'") - maybe_import(m.strip()) + load(m.strip(), "task") # cleanup del m diff --git a/columnflow/__version__.py b/columnflow/__version__.py index e09464834..bf1b3f34b 100644 --- a/columnflow/__version__.py +++ b/columnflow/__version__.py @@ -20,6 +20,7 @@ "Tobias Kramer", "Matthias Schroeder", "Johannes Lange", + "Ana Andrade", ] __contact__ = "https://github.com/columnflow/columnflow" __license__ = "BSD-3-Clause" diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 964143c6c..7a770204d 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -4,20 +4,24 @@ Jet energy corrections and jet resolution smearing. """ +from __future__ import annotations + import functools import law -from columnflow.types import Any from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random, propagate_met, sum_transverse from columnflow.production.util import attach_coffea_behavior from columnflow.util import UNSET, maybe_import, DotDict, load_correction_set from columnflow.columnar_util import set_ak_column, layout_ak_array, optional_column as optional +from columnflow.types import TYPE_CHECKING, Any np = maybe_import("numpy") ak = maybe_import("awkward") -correctionlib = maybe_import("correctionlib") +if TYPE_CHECKING: + correctionlib = maybe_import("correctionlib") + logger = law.logger.get_logger(__name__) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 6b5fa3461..8657bec82 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -32,13 +32,7 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -dak = maybe_import("dask_awkward") uproot = maybe_import("uproot") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents") -maybe_import("coffea.nanoevents.methods.base") -maybe_import("coffea.nanoevents.methods.nanoaod") -pq = maybe_import("pyarrow.parquet") # loggers @@ -1237,6 +1231,9 @@ def attach_behavior( (*skip_fields*) can contain names or name patterns of fields that are kept (filtered). *keep_fields* has priority, i.e., when it is set, *skip_fields* is not considered. """ + import coffea.nanoevents + import coffea.nanoevents.methods.nanoaod + if behavior is None: behavior = getattr(ak_array, "behavior", None) or coffea.nanoevents.methods.nanoaod.behavior if behavior is None: @@ -3076,6 +3073,7 @@ def __init__( open_options["split_row_groups"] = False # open the file + import dask_awkward as dak self.dak_array = dak.from_parquet(path, **open_options) self.path = path @@ -3614,7 +3612,7 @@ def read_coffea_root( chunk_pos: ChunkPosition, read_options: dict | None = None, read_columns: set[str | Route] | None = None, - ) -> coffea.nanoevents.methods.base.NanoEventsArray: + ) -> ak.Array: """ Given a file location or opened uproot file, and a tree name in a 2-tuple *source_object*, returns an awkward array chunk referred to by *chunk_pos*, assuming nanoAOD structure. @@ -3622,6 +3620,8 @@ def read_coffea_root( *read_columns* are converted to strings and, if not already present, added as nested fields ``iteritems_options.filter_name`` to *read_options*. """ + import coffea.nanoevents + # default read options read_options = read_options or {} read_options["delayed"] = False @@ -3669,6 +3669,8 @@ def open_coffea_parquet( Given a parquet file located at *source*, returns a 2-tuple *(source, entries)*. Passing *open_options* or *read_columns* has no effect. """ + import pyarrow.parquet as pq + return (source, pq.ParquetFile(source).metadata.num_rows) @classmethod @@ -3688,7 +3690,7 @@ def read_coffea_parquet( chunk_pos: ChunkPosition, read_options: dict | None = None, read_columns: set[str | Route] | None = None, - ) -> coffea.nanoevents.methods.base.NanoEventsArray: + ) -> ak.Array: """ Given a the location of a parquet file *source_object*, returns an awkward array chunk referred to by *chunk_pos*, assuming nanoAOD structure. *read_options* are passed to @@ -3696,6 +3698,8 @@ def read_coffea_parquet( strings and, if not already present, added as nested field ``parquet_options.read_dictionary`` to *read_options*. """ + import coffea.nanoevents + # default read options read_options = read_options or {} read_options["runtime_cache"] = None diff --git a/columnflow/hist_util.py b/columnflow/hist_util.py index 7f16da17a..f579c0af5 100644 --- a/columnflow/hist_util.py +++ b/columnflow/hist_util.py @@ -14,11 +14,12 @@ from columnflow.columnar_util import flat_np_view from columnflow.util import maybe_import -from columnflow.types import Any +from columnflow.types import TYPE_CHECKING, Any -hist = maybe_import("hist") np = maybe_import("numpy") ak = maybe_import("awkward") +if TYPE_CHECKING: + hist = maybe_import("hist") logger = law.logger.get_logger(__name__) @@ -38,6 +39,8 @@ def fill_hist( determined automatically and depends on the variable axis type. In this case, shifting is applied to all continuous, non-circular axes. """ + import hist + if fill_kwargs is None: fill_kwargs = {} @@ -163,6 +166,8 @@ def get_axis_kwargs(axis: hist.axis.AxesMixin) -> dict[str, Any]: :param axis: The axis instance to extract information from. :return: The extracted information in a dict. """ + import hist + axis_attrs = ["name", "label"] traits_attrs = [] kwargs = {} @@ -213,6 +218,8 @@ def create_hist_from_variables( weight: bool = True, storage: str | None = None, ) -> hist.Hist: + import hist + histogram = hist.Hist.new # additional category axes @@ -259,6 +266,8 @@ def translate_hist_intcat_to_strcat( axis_name: str, id_map: dict[int, str], ) -> hist.Hist: + import hist + out_axes = [ ax if ax.name != axis_name else hist.axis.StrCategory( [id_map[v] for v in list(ax)], @@ -280,6 +289,8 @@ def add_missing_shifts( """ Adds missing shift bins to a histogram *h*. """ + import hist + # get the set of bins that are missing in the histogram shift_bins = set(h.axes[str_axis]) missing_shifts = set(expected_shifts_bins) - shift_bins diff --git a/columnflow/histogramming/__init__.py b/columnflow/histogramming/__init__.py index 2282f94fb..41a9438c7 100644 --- a/columnflow/histogramming/__init__.py +++ b/columnflow/histogramming/__init__.py @@ -11,13 +11,12 @@ import law import order as od -from columnflow.types import Callable -from columnflow.util import DerivableMeta, maybe_import from columnflow.columnar_util import TaskArrayFunction -from columnflow.types import Any - +from columnflow.util import DerivableMeta, maybe_import +from columnflow.types import TYPE_CHECKING, Any, Callable -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") class HistProducer(TaskArrayFunction): @@ -247,7 +246,7 @@ def run_post_process_hist(self, h: Any, task: law.Task) -> Any: return h return self.post_process_hist_func(h, task=task) - def run_post_process_merged_hist(self, h: Any, task: law.Task) -> hist.Histogram: + def run_post_process_merged_hist(self, h: Any, task: law.Task) -> hist.Hist: """ Invokes the :py:meth:`post_process_merged_hist_func` of this instance and returns its result, forwarding all arguments. diff --git a/columnflow/histogramming/default.py b/columnflow/histogramming/default.py index 8171031ef..c702b0d87 100644 --- a/columnflow/histogramming/default.py +++ b/columnflow/histogramming/default.py @@ -10,14 +10,15 @@ import order as od from columnflow.histogramming import HistProducer, hist_producer -from columnflow.util import maybe_import -from columnflow.hist_util import create_hist_from_variables, fill_hist, translate_hist_intcat_to_strcat from columnflow.columnar_util import has_ak_column, Route -from columnflow.types import Any +from columnflow.hist_util import create_hist_from_variables, fill_hist, translate_hist_intcat_to_strcat +from columnflow.util import maybe_import +from columnflow.types import TYPE_CHECKING, Any np = maybe_import("numpy") ak = maybe_import("awkward") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") @hist_producer() @@ -39,7 +40,7 @@ def cf_default_create_hist( variables: list[od.Variable], task: law.Task, **kwargs, -) -> hist.Histogram: +) -> hist.Hist: """ Define the histogram structure for the default histogram producer. """ @@ -55,7 +56,7 @@ def cf_default_create_hist( @cf_default.fill_hist -def cf_default_fill_hist(self: HistProducer, h: hist.Histogram, data: dict[str, Any], task: law.Task) -> None: +def cf_default_fill_hist(self: HistProducer, h: hist.Hist, data: dict[str, Any], task: law.Task) -> None: """ Fill the histogram with the data. """ @@ -63,7 +64,7 @@ def cf_default_fill_hist(self: HistProducer, h: hist.Histogram, data: dict[str, @cf_default.post_process_hist -def cf_default_post_process_hist(self: HistProducer, h: hist.Histogram, task: law.Task) -> hist.Histogram: +def cf_default_post_process_hist(self: HistProducer, h: hist.Hist, task: law.Task) -> hist.Hist: """ Post-process the histogram, converting integer to string axis for consistent lookup across configs where ids might be different. diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 43870d29e..41046bf1d 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -19,8 +19,8 @@ freeze, ) -logger = law.logger.get_logger(__name__) +logger = law.logger.get_logger(__name__) default_dataset = law.config.get_expanded("analysis", "default_dataset") diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 951e7e086..647cf9258 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -14,18 +14,19 @@ from columnflow import __version__ as cf_version from columnflow.inference import InferenceModel, ParameterType, ParameterTransformation, FlowStrategy from columnflow.util import DotDict, maybe_import, real_path, ensure_dir, safe_div, maybe_int -from columnflow.types import Sequence, Any, Union, Hashable +from columnflow.types import TYPE_CHECKING, Sequence, Any, Union, Hashable -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") + # type aliases for nested histogram structs + ShiftHists = dict[Union[str, tuple[str, str]], hist.Hist] # "nominal" or (param_name, "up|down") -> hists + ConfigHists = dict[str, ShiftHists] # config name -> hists + ProcHists = dict[str, ConfigHists] # process name -> hists + DatacardHists = dict[str, ProcHists] # category name -> hists -logger = law.logger.get_logger(__name__) -# type aliases for nested histogram structs -ShiftHists = dict[Union[str, tuple[str, str]], hist.Hist] # "nominal" or (param_name, "up|down") -> hists -ConfigHists = dict[str, ShiftHists] # config name -> hists -ProcHists = dict[str, ConfigHists] # process name -> hists -DatacardHists = dict[str, ProcHists] # category name -> hists +logger = law.logger.get_logger(__name__) class DatacardWriter(object): diff --git a/columnflow/ml/__init__.py b/columnflow/ml/__init__.py index 43419ea86..e50b22bf9 100644 --- a/columnflow/ml/__init__.py +++ b/columnflow/ml/__init__.py @@ -12,11 +12,12 @@ import law import order as od -from columnflow.types import Any, Sequence -from columnflow.util import maybe_import, Derivable, DotDict, KeyValueMessage from columnflow.columnar_util import Route +from columnflow.util import maybe_import, Derivable, DotDict, KeyValueMessage +from columnflow.types import TYPE_CHECKING, Any, Sequence -ak = maybe_import("awkward") +if TYPE_CHECKING: + ak = maybe_import("awkward") class MLModel(Derivable): diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index 73407b736..dbfc369b6 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -10,7 +10,6 @@ import order as od -from columnflow.types import Sequence from columnflow.util import maybe_import, try_float from columnflow.config_util import group_shifts from columnflow.plotting.plot_util import ( @@ -21,12 +20,12 @@ apply_label_placeholders, calculate_stat_error, ) +from columnflow.types import TYPE_CHECKING, Sequence -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def draw_stat_error_bands( @@ -71,6 +70,8 @@ def draw_syst_error_bands( method: str = "quadratic_sum", **kwargs, ) -> None: + import hist + assert len(h.axes) == 1 assert method in ("quadratic_sum", "envelope") @@ -169,6 +170,8 @@ def draw_stack( norm: float | Sequence | np.ndarray = 1.0, **kwargs, ) -> None: + import hist + # check if norm is a number if try_float(norm): h = hist.Stack(*[i / norm for i in h]) @@ -202,6 +205,8 @@ def draw_hist( error_type: str = "variance", **kwargs, ) -> None: + import hist + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} if kwargs.get("color", "") is None: @@ -243,6 +248,8 @@ def draw_profile( """ Profiled histograms contains the storage type "Mean" and can therefore not be normalized """ + import hist + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} if kwargs.get("color", "") is None: @@ -272,6 +279,8 @@ def draw_errorbars( error_type: str = "poisson_unweighted", **kwargs, ) -> None: + import hist + assert error_type in {"variance", "poisson_unweighted", "poisson_weighted"} values = h.values() / norm @@ -342,6 +351,10 @@ def plot_all( :param magnitudes: Optional float parameter that defines the displayed ymin when plotting with a logarithmic scale. :return: tuple of plot figure and axes """ + import matplotlib as mpl + import matplotlib.pyplot as plt + import mplhep + # general mplhep style plt.style.use(mplhep.style.CMS) diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 4c1bf4b60..69e26562e 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -11,8 +11,8 @@ from collections import OrderedDict import law +import order as od -from columnflow.types import Iterable from columnflow.util import maybe_import from columnflow.plotting.plot_all import plot_all from columnflow.plotting.plot_util import ( @@ -31,14 +31,12 @@ join_labels, ) from columnflow.hist_util import add_missing_shifts +from columnflow.types import TYPE_CHECKING, Iterable - -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") -od = maybe_import("order") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def plot_variable_stack( @@ -249,6 +247,8 @@ def plot_shifted_variable( """ TODO. """ + import hist + variable_inst = variable_insts[0] hists, process_style_config = apply_process_settings(hists, process_settings) @@ -451,6 +451,8 @@ def plot_profile( :param base_distribution_yscale: yscale of the base distributions :param skip_variations: whether to skip adding the up and down variation of the profile plot """ + import matplotlib.pyplot as plt + if len(variable_insts) != 2: raise Exception("The plot_profile function can only be used for 2-dimensional input histograms.") diff --git a/columnflow/plotting/plot_functions_2d.py b/columnflow/plotting/plot_functions_2d.py index d8f58ae01..c731c4822 100644 --- a/columnflow/plotting/plot_functions_2d.py +++ b/columnflow/plotting/plot_functions_2d.py @@ -6,11 +6,14 @@ from __future__ import annotations +__all__ = [] + from collections import OrderedDict from functools import partial from unittest.mock import patch import law +import order as od from columnflow.util import maybe_import from columnflow.plotting.plot_util import ( @@ -22,14 +25,11 @@ get_position, reduce_with, ) +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") -od = maybe_import("order") -mticker = maybe_import("matplotlib.ticker") +if TYPE_CHECKING: + plt = maybe_import("matplotlib.pyplot") def plot_2d( @@ -55,6 +55,10 @@ def plot_2d( variable_settings: dict | None = None, **kwargs, ) -> plt.Figure: + import matplotlib as mpl + import matplotlib.pyplot as plt + import mplhep + # remove shift axis from histograms hists = remove_residual_axis(hists, "shift") @@ -273,10 +277,10 @@ def plot_2d( _scale = cbar.ax.yaxis._scale _scale.subs = [2, 3, 4, 5, 6, 7, 8, 9] cbar.ax.yaxis.set_minor_locator( - mticker.SymmetricalLogLocator(_scale.get_transform(), subs=_scale.subs), + mpl.ticker.SymmetricalLogLocator(_scale.get_transform(), subs=_scale.subs), ) cbar.ax.yaxis.set_minor_formatter( - mticker.LogFormatterSciNotation(_scale.base), + mpl.ticker.LogFormatterSciNotation(_scale.base), ) plt.tight_layout() diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 26c4fe6cc..ec5e951c4 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -6,42 +6,45 @@ from __future__ import annotations +__all__ = [] + import re -from columnflow.types import Sequence +import order as od +import scinum + from columnflow.util import maybe_import from columnflow.plotting.plot_util import get_cms_label +from columnflow.types import TYPE_CHECKING, Sequence -ak = maybe_import("awkward") -od = maybe_import("order") np = maybe_import("numpy") -sci = maybe_import("scinum") -plt = maybe_import("matplotlib.pyplot") -hep = maybe_import("mplhep") -colors = maybe_import("matplotlib.colors") +ak = maybe_import("awkward") +if TYPE_CHECKING: + plt = maybe_import("matplotlib.pyplot") + # define a CF custom color maps cf_colors = { - "cf_green_cmap": colors.ListedColormap([ + "cf_green_cmap": [ "#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927", "#305126", "#325A25", "#356224", "#386B22", "#3B7520", "#3F7F1E", "#43891B", "#479418", "#4C9F14", "#52AA10", "#58B60C", "#5FC207", "#67cf02", - ]), - "cf_ygb_cmap": colors.ListedColormap([ + ], + "cf_ygb_cmap": [ "#003675", "#005B83", "#008490", "#009A83", "#00A368", "#00AC49", "#00B428", "#00BC06", "#0CC300", "#39C900", "#67cf02", "#72DB02", "#7EE605", "#8DF207", "#9CFD09", "#AEFF0B", "#C1FF0E", "#D5FF10", "#EBFF12", "#FFFF14", - ]), - "cf_cmap": colors.ListedColormap([ + ], + "cf_cmap": [ "#002C9C", "#00419F", "#0056A2", "#006BA4", "#0081A7", "#0098AA", "#00ADAB", "#00B099", "#00B287", "#00B574", "#00B860", "#00BB4C", "#00BD38", "#00C023", "#00C20D", "#06C500", "#1EC800", "#36CA00", "#4ECD01", "#67cf02", - ]), - "viridis": colors.ListedColormap([ + ], + "viridis": [ "#263DA8", "#1652CC", "#1063DB", "#1171D8", "#1380D5", "#0E8ED0", "#089DCC", "#0DA7C2", "#1DAFB3", "#2DB7A3", "#52BA91", "#73BD80", "#94BE71", "#B2BC65", "#D0BA59", "#E1BF4A", "#F4C53A", "#FCD12B", "#FAE61C", "#F9F90E", - ]), + ], } @@ -111,6 +114,10 @@ def plot_cm( is not *None* and its shape doesn't match *predictions*. :raises AssertionError: If *normalization* is not one of *None*, "row", "column". """ + import matplotlib as mpl + import matplotlib.pyplot as plt + import mplhep + # defining some useful properties and output shapes true_labels = list(events.keys()) pred_labels = [s.removeprefix("score_") for s in list(events.values())[0].fields] @@ -136,7 +143,7 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: counts[ind, index] += count if not skip_uncertainties: - vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count) if count else 0))) + vecNumber = np.vectorize(lambda n, count: scinum.Number(n, float(n / np.sqrt(count) if count else 0))) result = vecNumber(result, counts) # normalize Matrix if needed @@ -203,7 +210,7 @@ def get_errors(matrix): Useful for seperating the error from the data """ if matrix.dtype.name == "object": - get_errors_vec = np.vectorize(lambda x: x.get(sci.UP, unc=True)) + get_errors_vec = np.vectorize(lambda x: x.get(scinum.UP, unc=True)) return get_errors_vec(matrix) return np.zeros_like(matrix) @@ -219,13 +226,13 @@ def fmt(v): return "{}\n\u00B1{}".format(fmt(values[i][j]), fmt(np.nan_to_num(uncs[i][j]))) # create the plot - plt.style.use(hep.style.CMS) + plt.style.use(mplhep.style.CMS) fig, ax = plt.subplots(dpi=300) # some useful variables and functions n_processes = cm.shape[0] n_classes = cm.shape[1] - cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) + cmap = mpl.colors.ListedColormap(cf_colors.get(colormap, cf_colors["cf_cmap"])) x_labels = x_labels if x_labels else [f"out{i}" for i in range(n_classes)] y_labels = y_labels if y_labels else true_labels font_ax = 20 @@ -292,7 +299,7 @@ def fmt(v): if cms_llabel != "skip": cms_label_kwargs = get_cms_label(ax=ax, llabel=cms_llabel) cms_label_kwargs["rlabel"] = cms_rlabel - hep.cms.label(**cms_label_kwargs) + mplhep.cms.label(**cms_label_kwargs) plt.tight_layout() return fig @@ -349,6 +356,9 @@ def plot_roc( is not *None* and its shape doesn't match *predictions*. :raises ValueError: If *normalization* is not one of *None*, 'row', 'column'. """ + import matplotlib.pyplot as plt + import mplhep + # defining some useful properties and output shapes thresholds = np.linspace(0, 1, n_thresholds) weights = create_sample_weights(sample_weights, events, list(events.keys())) @@ -478,7 +488,7 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64: fpr = roc_data["fpr"] tpr = roc_data["tpr"] - plt.style.use(hep.style.CMS) + plt.style.use(mplhep.style.CMS) fig, ax = plt.subplots(dpi=300) ax.set_xlabel("FPR", loc="right", labelpad=10, fontsize=25) ax.set_ylabel("TPR", loc="top", labelpad=15, fontsize=25) @@ -499,7 +509,7 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64: if cms_llabel != "skip": cms_label_kwargs = get_cms_label(ax=ax, llabel=cms_llabel) cms_label_kwargs["rlabel"] = cms_rlabel - hep.cms.label(**cms_label_kwargs) + mplhep.cms.label(**cms_label_kwargs) plt.tight_layout() return fig diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index 3f09b724e..e664e4b15 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -9,6 +9,7 @@ __all__ = [] import re +import math import operator import functools from collections import OrderedDict @@ -19,13 +20,12 @@ from columnflow.util import maybe_import, try_int, try_complex, UNSET from columnflow.hist_util import copy_axis -from columnflow.types import Iterable, Any, Callable, Sequence, Hashable +from columnflow.types import TYPE_CHECKING, Iterable, Any, Callable, Sequence, Hashable -math = maybe_import("math") -hist = maybe_import("hist") np = maybe_import("numpy") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") logger = law.logger.get_logger(__name__) @@ -255,6 +255,8 @@ def apply_variable_settings( applies settings from *variable_settings* dictionary to the *variable_insts*; the *rebin*, *overflow*, *underflow*, and *slice* settings are directly applied to the histograms """ + import hist + # store info gathered along application of variable settings that can be inserted to the style config variable_style_config = {} @@ -382,12 +384,12 @@ def apply_density(hists: dict, density: bool = True) -> dict: if not density: return hists - for key, hist in hists.items(): + for key, h in hists.items(): # bin area safe for multi-dimensional histograms - area = functools.reduce(operator.mul, hist.axes.widths) + area = functools.reduce(operator.mul, h.axes.widths) # scale hist by bin area - hists[key] = hist / area + hists[key] = h / area return hists @@ -398,6 +400,8 @@ def remove_residual_axis_single( max_bins: int = 1, select_value: Any = None, ) -> hist.Hist: + import hist + # force always returning a copy h = h.copy() @@ -510,6 +514,8 @@ def prepare_stack_plot_config( backgrounds with uncertainty bands, unstacked processes as lines and data entrys with errorbars. """ + import hist + # separate histograms into stack, lines and data hists mc_hists, mc_colors, mc_edgecolors, mc_labels = [], [], [], [] mc_syst_hists = [] @@ -943,6 +949,8 @@ def rebin_equal_width( :param axis_name: Name of the axis to rebin. :return: Tuple of the rebinned histograms and the new bin edges. """ + import hist + # get the variable axis from the first histogram assert hists for var_index, var_axis in enumerate(list(hists.values())[0].axes): @@ -1049,6 +1057,7 @@ def calculate_stat_error( - 'poisson_unweighted': the plotted error is the poisson error for each bin - 'poisson_weighted': the plotted error is the poisson error for each bin, weighted by the variance """ + import hist # determine the error type if error_type == "variance": diff --git a/columnflow/production/cms/btag.py b/columnflow/production/cms/btag.py index 0577c3d59..87e7654b8 100644 --- a/columnflow/production/cms/btag.py +++ b/columnflow/production/cms/btag.py @@ -18,6 +18,7 @@ np = maybe_import("numpy") ak = maybe_import("awkward") + logger = law.logger.get_logger(__name__) diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py index 2c97424aa..46201d28d 100644 --- a/columnflow/production/cms/dy.py +++ b/columnflow/production/cms/dy.py @@ -16,7 +16,7 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -vector = maybe_import("vector") + logger = law.logger.get_logger(__name__) @@ -301,6 +301,8 @@ def recoil_corrected_met(self: Producer, events: ak.Array, **kwargs) -> ak.Array *get_dy_recoil_config* can be adapted in a subclass in case it is stored differently in the config. """ + import vector + # steps: # 1) Build transverse vectors for MET and the generator-level boson (full and visible). # 2) Compute the recoil vector U = MET + vis - full in the transverse plane. diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 4e4c3f70e..e85a0a37d 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -21,8 +21,6 @@ from columnflow.types import Any, Sequence np = maybe_import("numpy") -sp = maybe_import("scipy") -maybe_import("scipy.sparse") ak = maybe_import("awkward") @@ -413,6 +411,8 @@ def normalization_weights_setup( weights per process. - py: attr: `known_process_ids`: A set of all process ids that are known by the lookup table. """ + import scipy.sparse + # load the selection stats dataset_selection_stats = { dataset: copy.deepcopy(task.cached_value( @@ -486,7 +486,7 @@ def extract_stats(*update_funcs): ) # setup the event weight lookup table - process_weight_table = sp.sparse.lil_matrix((max(process_ids) + 1, 1), dtype=np.float32) + process_weight_table = scipy.sparse.lil_matrix((max(process_ids) + 1, 1), dtype=np.float32) def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) -> None: if sum_weights == 0: diff --git a/columnflow/production/util.py b/columnflow/production/util.py index 5c3df8fa0..1df6d49f9 100644 --- a/columnflow/production/util.py +++ b/columnflow/production/util.py @@ -3,9 +3,10 @@ """ General producers that might be utilized in various places. """ + from __future__ import annotations -from functools import partial +import functools from columnflow.types import Iterable, Sequence, Union from columnflow.production import Producer, producer @@ -13,7 +14,6 @@ from columnflow.columnar_util import attach_coffea_behavior as attach_coffea_behavior_fn ak = maybe_import("awkward") -coffea = maybe_import("coffea") @producer(call_force=True) @@ -69,15 +69,21 @@ def ak_extract_fields(arr: ak.Array, fields: list[str], **kwargs): # functions for operating on lorentz vectors # -_lv_base = partial(ak_extract_fields, behavior=coffea.nanoevents.methods.nanoaod.behavior) +def _lv_base(*args, **kwargs): + # scoped partial to defer coffea import + import coffea.nanoevents + import coffea.nanoevents.methods.nanoaod + kwargs["behavior"] = coffea.nanoevents.methods.nanoaod.behavior + return ak_extract_fields(*args, **kwargs) + -lv_xyzt = partial(_lv_base, fields=["x", "y", "z", "t"], with_name="LorentzVector") +lv_xyzt = functools.partial(_lv_base, fields=["x", "y", "z", "t"], with_name="LorentzVector") lv_xyzt.__doc__ = """Construct a `LorentzVectorArray` from an input array.""" -lv_mass = partial(_lv_base, fields=["pt", "eta", "phi", "mass"], with_name="PtEtaPhiMLorentzVector") +lv_mass = functools.partial(_lv_base, fields=["pt", "eta", "phi", "mass"], with_name="PtEtaPhiMLorentzVector") lv_mass.__doc__ = """Construct a `PtEtaPhiMLorentzVectorArray` from an input array.""" -lv_energy = partial(_lv_base, fields=["pt", "eta", "phi", "energy"], with_name="PtEtaPhiELorentzVector") +lv_energy = functools.partial(_lv_base, fields=["pt", "eta", "phi", "energy"], with_name="PtEtaPhiELorentzVector") lv_energy.__doc__ = """Construct a `PtEtaPhiELorentzVectorArray` from an input array.""" diff --git a/columnflow/selection/cms/json_filter.py b/columnflow/selection/cms/json_filter.py index 2b750a563..6eddb84d1 100644 --- a/columnflow/selection/cms/json_filter.py +++ b/columnflow/selection/cms/json_filter.py @@ -14,8 +14,6 @@ ak = maybe_import("awkward") np = maybe_import("numpy") -sp = maybe_import("scipy") -maybe_import("scipy.sparse") def get_lumi_file_default(self, external_files: DotDict) -> str: @@ -124,6 +122,8 @@ def json_filter_setup( :param inputs: Additional inputs, currently not used :param reader_targets: Additional targets, currently not used """ + import scipy.sparse + bundle = reqs["external_files"] # import the correction sets from the external file @@ -134,7 +134,7 @@ def json_filter_setup( max_run = max(map(int, json.keys())) # build lookup table - self.run_ls_lookup = sp.sparse.lil_matrix((max_run + 1, max_ls + 1), dtype=bool) + self.run_ls_lookup = scipy.sparse.lil_matrix((max_run + 1, max_ls + 1), dtype=bool) for run, ls_ranges in json.items(): run = int(run) for ls_range in ls_ranges: diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index dc37f82ec..e88c41975 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -14,8 +14,11 @@ from columnflow.tasks.framework.base import AnalysisTask, wrapper_factory from columnflow.tasks.framework.inference import SerializeInferenceModelBase from columnflow.tasks.histograms import MergeHistograms +from columnflow.inference.cms.datacard import DatacardWriter +from columnflow.types import TYPE_CHECKING -from columnflow.inference.cms.datacard import DatacardHists, ShiftHists, DatacardWriter +if TYPE_CHECKING: + from columnflow.inference.cms.datacard import DatacardHists, ShiftHists class CreateDatacards(SerializeInferenceModelBase): diff --git a/columnflow/tasks/framework/histograms.py b/columnflow/tasks/framework/histograms.py index a0ecdaa2e..03df5e0b7 100644 --- a/columnflow/tasks/framework/histograms.py +++ b/columnflow/tasks/framework/histograms.py @@ -16,8 +16,10 @@ ) from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms from columnflow.util import dev_sandbox, maybe_import +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") class HistogramsUserBase( diff --git a/columnflow/tasks/framework/inference.py b/columnflow/tasks/framework/inference.py index 7cff1ae88..508b67e17 100644 --- a/columnflow/tasks/framework/inference.py +++ b/columnflow/tasks/framework/inference.py @@ -18,10 +18,12 @@ ) from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.histograms import MergeShiftedHistograms -from columnflow.util import dev_sandbox, DotDict, maybe_import from columnflow.config_util import get_datasets_from_process +from columnflow.util import dev_sandbox, DotDict, maybe_import +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") class SerializeInferenceModelBase( @@ -235,6 +237,8 @@ def load_process_hists( variable: str, inputs: dict, ) -> dict[str, dict[od.Process, hist.Hist]]: + import hist + # collect histograms per variable and process hists: dict[od.Process, hist.Hist] = {} diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 679a62575..806782727 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -35,7 +35,6 @@ from columnflow.util import dev_sandbox, safe_div, DotDict, maybe_import from columnflow.columnar_util import set_ak_column - ak = maybe_import("awkward") diff --git a/columnflow/tasks/selection.py b/columnflow/tasks/selection.py index 0081fee36..3c8fca7fc 100644 --- a/columnflow/tasks/selection.py +++ b/columnflow/tasks/selection.py @@ -21,7 +21,6 @@ from columnflow.util import maybe_import, ensure_proxy, dev_sandbox, safe_div, DotDict from columnflow.tasks.framework.parameters import DerivableInstParameter - np = maybe_import("numpy") ak = maybe_import("awkward") diff --git a/columnflow/types.py b/columnflow/types.py index cfe437207..05d783daa 100644 --- a/columnflow/types.py +++ b/columnflow/types.py @@ -22,8 +22,8 @@ from collections.abc import KeysView, ValuesView # noqa from types import ModuleType, GeneratorType, GenericAlias # noqa from typing import ( # noqa - Any, Union, TypeVar, ClassVar, Sequence, Callable, Generator, TextIO, Iterable, Hashable, - Type, + TYPE_CHECKING, Any, Union, TypeVar, ClassVar, Sequence, Callable, Generator, TextIO, Iterable, Hashable, Type, + Literal, ) from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType, TypeAlias # noqa diff --git a/setup.sh b/setup.sh index 4bc024729..b45d7197e 100644 --- a/setup.sh +++ b/setup.sh @@ -760,6 +760,8 @@ cf_setup_post_install() { # Optional environment variables: # CF_SKIP_SETUP_GIT_HOOKS # When set to true, the setup of git hooks is skipped. + # CF_SKIP_LAW_INDEX + # When set to true, the initial indexing of law tasks is skipped. # CF_SKIP_CHECK_TMP_DIR # When set to true, the check of the size of the target tmp directory is skipped. @@ -788,7 +790,9 @@ cf_setup_post_install() { complete -o bashdefault -o default -F _law_complete claw # silently index - law index -q + if ! ${CF_SKIP_LAW_INDEX}; then + law index -q + fi fi fi @@ -1113,6 +1117,7 @@ for flag_name in \ CF_REINSTALL_HOOKS \ CF_SKIP_BANNER \ CF_SKIP_SETUP_GIT_HOOKS \ + CF_SKIP_LAW_INDEX \ CF_SKIP_CHECK_TMP_DIR \ CF_ON_HTCONDOR \ CF_ON_SLURM \ diff --git a/tests/test_columnar_util.py b/tests/test_columnar_util.py index d06ac9882..952e27c3a 100644 --- a/tests/test_columnar_util.py +++ b/tests/test_columnar_util.py @@ -16,6 +16,7 @@ ak = maybe_import("awkward") dak = maybe_import("dask_awkward") coffea = maybe_import("coffea") +maybe_import("coffea.nanoevents") class RouteTest(unittest.TestCase): From ade427b528ff443c7713ecd4cf1c89047ab47d65 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Tue, 7 Oct 2025 14:14:31 +0200 Subject: [PATCH 106/123] Add parent_mode flag to create_category_combinations, add change tracking. (#736) * Add skip_parents flag to create_category_combinations. * Refactor category combinations, add category tracker. * Update columnflow/config_util.py Co-authored-by: Ana Andrade <99343616+aalvesan@users.noreply.github.com> * Update columnflow/config_util.py Co-authored-by: Ana Andrade <99343616+aalvesan@users.noreply.github.com> * Typo. --------- Co-authored-by: Ana Andrade <99343616+aalvesan@users.noreply.github.com> --- columnflow/config_util.py | 146 ++++++++++++++++++++++++++++++++++---- 1 file changed, 134 insertions(+), 12 deletions(-) diff --git a/columnflow/config_util.py b/columnflow/config_util.py index 3a3da34f3..5d74ee482 100644 --- a/columnflow/config_util.py +++ b/columnflow/config_util.py @@ -18,7 +18,7 @@ from columnflow.util import maybe_import, get_docs_url from columnflow.columnar_util import flat_np_view, layout_ak_array -from columnflow.types import Callable, Any, Sequence +from columnflow.types import Callable, Any, Sequence, Literal ak = maybe_import("awkward") np = maybe_import("numpy") @@ -467,6 +467,10 @@ class CategoryGroup: Container to store information about a group of categories, mostly used for creating combinations in :py:func:`create_category_combinations`. + .. note:: + + A group is considered a full partition of the phase space if it is both complete and non-overlapping. + :param categories: List of :py:class:`order.Category` objects or names that refer to the desired category. :param is_complete: Should be *True* if the union of category selections covers the full phase space (no gaps). :param has_overlap: Should be *False* if all categories are pairwise disjoint (no overlap). @@ -490,6 +494,7 @@ def create_category_combinations( config: od.Config, categories: dict[str, CategoryGroup | list[od.Category]], name_fn: Callable[[Any], str], + parent_mode: Literal["all", "none", "safe"] = "safe", kwargs_fn: Callable[[Any], dict] | None = None, skip_existing: bool = True, skip_fn: Callable[[dict[str, od.Category], str], bool] | None = None, @@ -500,9 +505,9 @@ def create_category_combinations( returns the number of newly created categories. *categories* should be a dictionary that maps string names to :py:class:`CategoryGroup` objects which are thin - wrappers around sequences of categories (objects or names). Group names (dictionary keys) are used as keyword - arguments in a callable *name_fn* that is supposed to return the name of newly created categories (see example - below). + wrappers around sequences of categories (objects or names) and provide additional information about the group as a + whole. Group names (dictionary keys) are used as keyword arguments in a callable *name_fn* that is supposed to + return the name of newly created categories (see example below). .. note:: @@ -511,6 +516,18 @@ def create_category_combinations( over-counting when combining leaf categories. These checks may be performed by other functions and tools based on information derived from groups and stored in auxiliary fields of the newly created categories. + All intermediate layers of categories can be built and connected automatically to one another by parent - child + category relations. The exact behavior is controlled by *parent_mode*: + + - ``"all"``: All intermediate parent category layers are created and connected. Please note that this choice + omits information about group completeness and overlaps (see :py:attr:`CategoryGroup.is_partition`) of child + categories which - in cases such as child category summation - can lead to unintended results. + - ``"none"``: No intermediate parent category layers but only leaf categories are created and connected to their + root categories. + - ``"safe"``: Intermediate parent category layers are created and connected only if the group of child + categories is both complete and non-overlapping (see :py:attr:`CategoryGroup.is_partition`). This is the + recommended choice (and the default) as it avoids unintended results as mentioned in ``"all"``. + Each newly created category is instantiated with this name as well as arbitrary keyword arguments as returned by *kwargs_fn*. This function is called with the categories (in a dictionary, mapped to the sequence names as given in *categories*) that contribute to the newly created category and should return a dictionary. If the fields ``"id"`` @@ -547,6 +564,8 @@ def kwargs_fn(categories): :param categories: Dictionary that maps group names to :py:class:`CategoryGroup` containers. :param name_fn: Callable that receives a dictionary mapping group names to categories and returns the name of the newly created category. + :param parent_mode: Controls how intermediate parent categories are created and connected. Either of ``"all"``, + ``"none"``, or ``"safe"``. :param kwargs_fn: Callable that receives a dictionary mapping group names to categories and returns a dictionary of keyword arguments that are forwarded to the category constructor. :param skip_existing: If *True*, skip the creation of a category when it already exists in *config*. @@ -557,6 +576,12 @@ def kwargs_fn(categories): :raises ValueError: If a non-unique category id is detected. :return: Number of newly created categories. """ + # check parent mode + parent_mode = parent_mode.lower() + known_parent_modes = ["all", "none", "safe"] + if parent_mode not in known_parent_modes: + raise ValueError(f"unknown parent_mode {parent_mode}, known values are {', '.join(known_parent_modes)}") + # cast categories for name, _categories in categories.items(): # ensure CategoryGroup is used @@ -567,6 +592,7 @@ def kwargs_fn(categories): f"using a list to define a sequence of categories for create_category_combinations() is depcreated " f"and will be removed in a future version, please use a CategoryGroup instance instead: {docs_url}", ) + # create a group assuming (!) it describes a full, valid phasespace partition _categories = CategoryGroup( categories=law.util.make_list(_categories), is_complete=True, @@ -582,6 +608,8 @@ def kwargs_fn(categories): unique_ids_cache = {cat.id for cat, _, _ in config.walk_categories()} n_groups = len(categories) group_names = list(categories.keys()) + safe_groups = {name for name, group in categories.items() if group.is_partition} + unsafe_groups = set(group_names) - safe_groups # nothing to do when there are less than 2 groups if n_groups < 2: @@ -593,11 +621,19 @@ def kwargs_fn(categories): if kwargs_fn and not callable(kwargs_fn): raise TypeError(f"when set, kwargs_fn must be a function, but got {kwargs_fn}") - # start combining, considering one additional groups for combinatorics at a time - for _n_groups in range(2, n_groups + 1): + # lookup table with created categories for faster access when connecting parents + created_categories: dict[str, od.Category] = {} + # start combining, considering one additional group for combinatorics at a time + # if skipping parents entirely, only consider the iteration that contains all groups + for _n_groups in ([n_groups] if parent_mode == "none" else range(2, n_groups + 1)): # build all group combinations for _group_names in itertools.combinations(group_names, _n_groups): + # when creating parents in "safe" mode, skip combinations that miss unsafe groups + # (i.e. they must be part of _group_names to be used later) + if parent_mode == "safe": + if (set(group_names) - set(_group_names)) & unsafe_groups: + continue # build the product of all categories for the given groups _categories = [categories[group_name].categories for group_name in _group_names] @@ -623,7 +659,7 @@ def kwargs_fn(categories): # create the new category cat = od.Category(name=cat_name, **kwargs) - n_created_categories += 1 + created_categories[cat_name] = cat # ID uniqueness check: raise an error when a non-unique id is detected for a new category if isinstance(kwargs["id"], int): @@ -636,19 +672,105 @@ def kwargs_fn(categories): ) unique_ids_cache.add(kwargs["id"]) - # find direct parents and connect them - for _parent_group_names in itertools.combinations(_group_names, _n_groups - 1): + # find combinations of parents and connect them, depending on parent_mode + if parent_mode == "all": + # all direct parents, obtained by combinations with one missing group + parent_gen = itertools.combinations(_group_names, _n_groups - 1) + elif parent_mode == "none": + # only connect to root categories + parent_gen = ((name,) for name in _group_names) + else: # safe + # same as "all", but unsafe groups must be part of the combinations + def _parent_gen(): + seen = set() + # choose 1 group to sum over from _n_groups available + for names in itertools.combinations(_group_names, _n_groups - 1): + # as above, if there is at least one unsafe group missing, the parent was not created + if (set(_group_names) - set(names)) & unsafe_groups: + continue + if names and names not in seen: + seen.add(names) + yield names + # in case no parent combination was yielded, yield all root categories separately + if not seen: + yield from ((name,) for name in _group_names) + parent_gen = _parent_gen() + + # actual connections + for _parent_group_names in parent_gen: + # find the parent if len(_parent_group_names) == 1: - parent_cat_name = root_cats[_parent_group_names[0]].name + parent_cat = root_cats[_parent_group_names[0]] else: parent_cat_name = name_fn({ group_name: root_cats[group_name] for group_name in _parent_group_names }) - parent_cat = config.get_category(parent_cat_name, deep=True) + if parent_cat_name in created_categories: + parent_cat = created_categories[parent_cat_name] + else: + parent_cat = config.get_category(parent_cat_name, deep=True) + # connect parent_cat.add_category(cat) - return n_created_categories + return len(created_categories) + + +def track_category_changes(config: od.Config, summary_path: str | None = None) -> None: + """ + Scans the categories in *config* and saves a summary in a file located at *summary_path*. If the file exists, + the summary from a previous run is loaded first and compare to the current categories. If changes are found, a + warning is shown with details about these changes. + + :param config: :py:class:`~order.config.Config` instance to scan for categories. + :param summary_path: Path to the summary file. Defaults to "$LAW_HOME/category_summary_{config.name}.json". + """ + # build summary file as law target + if not summary_path: + summary_path = law.config.law_home_path(f"category_summary_{config.name}.json") + summary_file = law.LocalFileTarget(summary_path) + + # gather category info + cat_pairs = sorted((cat.name, cat.id) for cat, *_ in config.walk_categories(include_self=True)) + cat_summary = { + "hash": law.util.create_hash(cat_pairs), + "categories": dict(cat_pairs), + } + + save_summary = True + if summary_file.exists(): + previous_summary = summary_file.load(formatter="json") + if previous_summary["hash"] == cat_summary["hash"]: + save_summary = False + else: + msgs = [ + f"the category definitions in config '{config.name}' seem to have changed based on a hash comparison, " + "ignore this message in case you knowingly adjusted categories fully aware of the changes:", + f"old hash: {previous_summary['hash']}, new hash: {cat_summary['hash']}", + ] + curr = cat_summary["categories"] + prev = previous_summary["categories"] + # track added and removed names + curr_names = set(curr) + prev_names = set(prev) + if (added_names := curr_names - prev_names): + msgs.append(f"added categories : {', '.join(sorted(added_names))}") + if (removed_names := prev_names - curr_names): + msgs.append(f"removed categories : {', '.join(sorted(removed_names))}") + # track id changes for names present in both + changed_ids = { + name: (prev[name], curr[name]) + for name in curr_names & prev_names + if prev[name] != curr[name] + } + if changed_ids: + pair_repr = lambda pair: f"{pair[0]}: {pair[1][0]} -> {pair[1][1]}" + msgs.append("changed category ids:\n - " + "\n - ".join(map(pair_repr, changed_ids.items()))) + + logger.warning_once(f"categories_changed_{config.name}", "\n".join(msgs)) + + if save_summary: + summary_file.dump(cat_summary, formatter="json", indent=4) def verify_config_processes(config: od.Config, warn: bool = False) -> None: From b06111be549b28d7c87e3900522c895c08965c99 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Tue, 7 Oct 2025 15:17:33 +0200 Subject: [PATCH 107/123] Update sandboxes, avoid dask_awkward in IO (#735) * Update sandboxes, add ChunkedParquetReader. * Update docstring. * Remove dask-awkward from columnar requirements. * Adjustable sandbox in cf_inspect. * Avoid using np.bool. * Cleanup venv setup. * Update columnflow/columnar_util.py Co-authored-by: Nathan Prouvost <49162277+nprouvost@users.noreply.github.com> * Update columnflow/columnar_util.py Co-authored-by: Nathan Prouvost <49162277+nprouvost@users.noreply.github.com> --------- Co-authored-by: Nathan Prouvost <49162277+nprouvost@users.noreply.github.com> --- bin/cf_inspect | 12 +- columnflow/calibration/cms/jets.py | 2 +- columnflow/columnar_util.py | 226 +++++++++++++++++++++++++++-- columnflow/selection/empty.py | 2 +- sandboxes/_setup_venv.sh | 4 +- sandboxes/columnar.txt | 17 +-- sandboxes/dev.txt | 15 +- sandboxes/ml_tf.txt | 2 +- setup.sh | 80 +++++++--- 9 files changed, 304 insertions(+), 56 deletions(-) diff --git a/bin/cf_inspect b/bin/cf_inspect index 68ecf4a52..fdefffe38 100755 --- a/bin/cf_inspect +++ b/bin/cf_inspect @@ -1,15 +1,25 @@ #!/bin/sh action () { + # local variables local shell_is_zsh="$( [ -z "${ZSH_VERSION}" ] && echo "false" || echo "true" )" local this_file="$( ${shell_is_zsh} && echo "${(%):-%x}" || echo "${BASH_SOURCE[0]}" )" local this_dir="$( cd "$( dirname "${this_file}" )" && pwd )" + # check arguments [ "$#" -eq 0 ] && { echo "ERROR: at least one file must be provided" return 1 } - cf_sandbox venv_columnar_dev python "${this_dir}/cf_inspect.py" "$@" + # determine the sandbox to use + local cf_inspect_sandbox="${CF_INSPECT_SANDBOX:-venv_columnar_dev}" + + # run the inspection script, potentially switching to the sandbox if not already in it + if [ "${CF_VENV_NAME}" = "${cf_inspect_sandbox}" ]; then + python "${this_dir}/cf_inspect.py" "$@" + else + cf_sandbox "${cf_inspect_sandbox}" python "${this_dir}/cf_inspect.py" "$@" + fi } action "$@" diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 7a770204d..ab1e4361f 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -737,7 +737,7 @@ def get_jer_config_default(self: Calibrator) -> DotDict: # whether gen jet matching should be performed relative to the nominal jet pt, or the jec varied values gen_jet_matching_nominal=False, # regions where stochastic smearing is applied - stochastic_smearing_mask=lambda self, jets: ak.ones_like(jets.pt, dtype=np.bool), + stochastic_smearing_mask=lambda self, jets: ak.ones_like(jets.pt, dtype=bool), ) def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 8657bec82..3e4cef879 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -3225,6 +3225,140 @@ def _materialize_via_partitions( return arr +class ChunkedParquetReader(object): + """ + Class that wraps a parquet file containing an awkward array and handles chunked reading via splitting and merging of + row groups. To allow memory efficient caching in case of overlaps between groups on disk and chunks to be read + (possibly with different sizes) this process is implemented as a one-time-only read operation. Hence, in situations + where particular chunks need to be read more than once, another instance of this class should be used. + """ + + def __init__(self, path: str, open_options: dict | None = None) -> None: + super().__init__() + if not open_options: + open_options = {} + + # store attributes + self.path = path + self.open_options = open_options.copy() + + # open and store meta data with updated open options + # (when closing the reader, this attribute is set to None) + meta_options = open_options.copy() + meta_options.pop("row_groups", None) + meta_options.pop("ignore_metadata", None) + self.metadata = ak.metadata_from_parquet(path, **meta_options) + + # extract row group sizes for chunked reading + if "col_counts" not in self.metadata: + raise Exception( + f"{self.__class__.__name__}: entry 'col_counts' is missing in meta data of file '{path}', but it is " + "strictly required for chunked reading; please debug", + ) + self.group_sizes = list(self.metadata["col_counts"]) + + # compute cumulative division boundaries + divs = [0] + for s in self.group_sizes: + divs.append(divs[-1] + s) + self.group_divisions = tuple(divs) + + # fixed mapping of chunk indices to group indices, created in materialize + self.chunk_to_groups = {} + + # mapping of group indices to cache information (chunks still to be handled and a cached array) that changes + # during the read process in materialize + self.group_cache = {g: DotDict(chunks=set(), array=None) for g in range(len(self.group_sizes))} + + # locks to protect against RCs during read operations by different threads + self.chunk_to_groups_lock = threading.Lock() + self.group_locks = {g: threading.Lock() for g in self.group_cache} + + def __del__(self) -> None: + self.close() + + def __len__(self) -> int: + return self.group_divisions[-1] + + @property + def closed(self) -> bool: + return self.metadata is None + + def close(self) -> None: + self.metadata = None + if getattr(self, "group_cache", None): + for g in self.group_cache: + self.group_cache[g] = None + + def materialize( + self, + *, + chunk_index: int, + entry_start: int, + entry_stop: int, + max_chunk_size: int, + ) -> ak.Array: + # strategy: read from disk with granularity given by row group sizes + # - use chunk info to determine which groups need to be read + # - guard each read operation of a group by locks + # - add materialized groups that might overlap with another chunk in a temporary cache + # - remove cached groups eagerly once it becomes clear that no chunk will need it + + # fill the chunk -> groups mapping once + with self.chunk_to_groups_lock: + if not self.chunk_to_groups: + # note: a hare-and-tortoise algorithm could be possible to get the mapping with less + # than n^2 complexity, but for our case with ~30 chunks this should be ok (for now) + n_chunks = int(math.ceil(len(self) / max_chunk_size)) + # in case there are no entries, ensure that at least one empty chunk is created + for _chunk_index in range(max(n_chunks, 1)): + _entry_start = _chunk_index * max_chunk_size + _entry_stop = min(_entry_start + max_chunk_size, len(self)) + groups = [] + for g, (g_start, g_stop) in enumerate(zip(self.group_divisions[:-1], self.group_divisions[1:])): + # note: check strict increase of chunk size to accommodate zero-length size + if g_stop <= _entry_start < _entry_stop: + continue + if g_start >= _entry_stop > _entry_start: + break + groups.append(g) + self.group_cache[g].chunks.add(_chunk_index) + self.chunk_to_groups[_chunk_index] = groups + + # read groups one at a time and store parts that make up the chunk for concatenation + parts = [] + for g in self.chunk_to_groups[chunk_index]: + # obtain the array + with self.group_locks[g]: + # remove this chunk from the list of chunks to be handled + self.group_cache[g].chunks.remove(chunk_index) + + if self.group_cache[g].array is None: + arr = ak.from_parquet(self.path, row_groups=[g], **self.open_options) + # add to cache when there is a chunk left that will need it + if self.group_cache[g].chunks: + self.group_cache[g].array = arr + else: + arr = self.group_cache[g].array + # remove from cache when there is no chunk left that would need it + if not self.group_cache[g].chunks: + self.group_cache[g].array = None + + # add part for concatenation using entry info + div_start, div_stop = self.group_divisions[g:g + 2] + part_start = max(entry_start - div_start, 0) + part_stop = min(entry_stop - div_start, div_stop - div_start) + parts.append(arr[part_start:part_stop]) + + # construct the full array + arr = parts[0] if len(parts) == 1 else ak.concatenate(parts, axis=0) + + # cleanup + del parts + + return arr + + class ChunkedIOHandler(object): """ Allows reading one or multiple files and iterating through chunks of their content with @@ -3427,6 +3561,7 @@ def get_source_handler( - "coffea_root" - "coffea_parquet" - "awkward_parquet" + - "dask_awkward_parquet" """ if source_type is None: if isinstance(source, uproot.ReadOnlyDirectory): @@ -3438,7 +3573,7 @@ def get_source_handler( # priotize coffea nano events source_type = "coffea_root" elif source.endswith(".parquet"): - # priotize awkward nano events + # prioritize non-dask awkward reader source_type = "awkward_parquet" if not source_type: @@ -3472,6 +3607,13 @@ def get_source_handler( cls.close_awkward_parquet, cls.read_awkward_parquet, ) + if source_type == "dask_awkward_parquet": + return cls.SourceHandler( + source_type, + cls.open_dask_awkward_parquet, + cls.close_dask_awkward_parquet, + cls.read_dask_awkward_parquet, + ) raise NotImplementedError(f"unknown source_type '{source_type}'") @@ -3624,7 +3766,7 @@ def read_coffea_root( # default read options read_options = read_options or {} - read_options["delayed"] = False + read_options["mode"] = "eager" read_options["runtime_cache"] = None read_options["persistent_cache"] = None @@ -3702,6 +3844,7 @@ def read_coffea_parquet( # default read options read_options = read_options or {} + read_options["mode"] = "eager" read_options["runtime_cache"] = None read_options["persistent_cache"] = None @@ -3738,20 +3881,19 @@ def open_awkward_parquet( source: str, open_options: dict | None = None, read_columns: set[str | Route] | None = None, - ) -> tuple[ak.Array, int]: + ) -> tuple[ChunkedParquetReader, int]: """ - Opens a parquet file saved at *source*, loads the content as an dask awkward array, - wrapped by a :py:class:`DaskArrayReader`, and returns a 2-tuple *(array, length)*. - *open_options* and *chunk_size* are forwarded to :py:class:`DaskArrayReader`. *read_columns* - are converted to strings and, if not already present, added as field ``columns`` to - *open_options*. + Opens a parquet file saved at *source*, loads the content as chunks of an awkward array wrapped by a + :py:class:`ChunkedParquetReader`, and returns a 2-tuple *(reader, length)*. + + *open_options* and *chunk_size* are forwarded accordingly. *read_columns* are converted to strings and, if not + already present, added as field ``columns`` to *open_options*. """ if not isinstance(source, str): raise Exception(f"'{source}' cannot be opened as awkward_parquet") # default open options open_options = open_options or {} - open_options.setdefault("split_row_groups", True) # preserve input file partitions # inject read_columns if read_columns and "columns" not in open_options: @@ -3759,12 +3901,72 @@ def open_awkward_parquet( open_options["columns"] = filter_name # load the array wrapper - arr = DaskArrayReader(source, open_options) + reader = ChunkedParquetReader(source, open_options) - return (arr, len(arr)) + return (reader, len(reader)) @classmethod def close_awkward_parquet( + cls, + source_object: ChunkedParquetReader, + ) -> None: + """ + Closes the chunked parquet reader referred to by *source_object*. + """ + source_object.close() + + @classmethod + def read_awkward_parquet( + cls, + source_object: ChunkedParquetReader, + chunk_pos: ChunkedIOHandler.ChunkPosition, + read_options: dict | None = None, + read_columns: set[str | Route] | None = None, + ) -> ak.Array: + """ + Given a :py:class:`ChunkedParquetReader` *source_object*, returns the chunk referred to by *chunk_pos* as a + full copy loaded into memory. Passing neither *read_options* nor *read_columns* has an effect. + """ + # get the materialized ak array for that chunk + return source_object.materialize( + chunk_index=chunk_pos.index, + entry_start=chunk_pos.entry_start, + entry_stop=chunk_pos.entry_stop, + max_chunk_size=chunk_pos.max_chunk_size, + ) + + @classmethod + def open_dask_awkward_parquet( + cls, + source: str, + open_options: dict | None = None, + read_columns: set[str | Route] | None = None, + ) -> tuple[DaskArrayReader, int]: + """ + Opens a parquet file saved at *source*, loads the content as an dask awkward array, wrapped by a + :py:class:`DaskArrayReader`, and returns a 2-tuple *(reader, length)*. + + *open_options* and *chunk_size* are forwarded to :py:class:`DaskArrayReader`. *read_columns* are converted to + strings and, if not already present, added as field ``columns`` to *open_options*. + """ + if not isinstance(source, str): + raise Exception(f"'{source}' cannot be opened as awkward_parquet") + + # default open options + open_options = open_options or {} + + # inject read_columns + if read_columns and "columns" not in open_options: + filter_name = [Route(s).string_column for s in read_columns] + open_options["columns"] = filter_name + + # load the array wrapper + reader = DaskArrayReader(source, open_options) + + return (reader, len(reader)) + + @classmethod + def close_dask_awkward_parquet( cls, source_object: DaskArrayReader, ) -> None: @@ -3774,7 +3976,7 @@ def close_awkward_parquet( source_object.close() @classmethod - def read_awkward_parquet( + def read_dask_awkward_parquet( cls, source_object: DaskArrayReader, chunk_pos: ChunkedIOHandler.ChunkPosition, diff --git a/columnflow/selection/empty.py b/columnflow/selection/empty.py index 0be227402..563846e47 100644 --- a/columnflow/selection/empty.py +++ b/columnflow/selection/empty.py @@ -61,7 +61,7 @@ def empty( events = set_ak_column(events, "category_ids", category_ids) # empty selection result with a trivial event mask - results = SelectionResult(event=ak.Array(np.ones(len(events), dtype=np.bool_))) + results = SelectionResult(event=ak.Array(np.ones(len(events), dtype=bool))) # increment stats weight_map = { diff --git a/sandboxes/_setup_venv.sh b/sandboxes/_setup_venv.sh index 31e8e8fd4..fc45460e1 100644 --- a/sandboxes/_setup_venv.sh +++ b/sandboxes/_setup_venv.sh @@ -248,7 +248,9 @@ setup_venv() { # install if not existing if [ ! -f "${CF_SANDBOX_FLAG_FILE}" ]; then - cf_color cyan "installing venv ${CF_VENV_NAME} from ${sandbox_file} at ${install_path}" + echo -n "$( cf_color cyan "installing venv" )" + echo -n " $( cf_color cyan_bright "${CF_VENV_NAME}" )" + echo " $( cf_color cyan "from ${sandbox_file} at ${install_path}" )" rm -rf "${install_path}" cf_create_venv "${venv_name_hashed}" diff --git a/sandboxes/columnar.txt b/sandboxes/columnar.txt index 8d6293ecc..36cbda25c 100644 --- a/sandboxes/columnar.txt +++ b/sandboxes/columnar.txt @@ -1,14 +1,13 @@ -# version 17 +# version 18 # exact versions for core array packages -awkward==2.8.1 -uproot==5.6.0 -pyarrow==19.0.1 -dask-awkward==2025.3.0 -correctionlib==2.6.4 -coffea==2024.11.0 +awkward==2.8.9 +uproot==5.6.6 +pyarrow==21.0.0 +correctionlib==2.7.0 +coffea==2025.9.0 # minimum versions for general packages -zstandard~=0.23.0 -lz4~=4.4.3 +zstandard~=0.25.0 +lz4~=4.4.4 xxhash~=3.5.0 diff --git a/sandboxes/dev.txt b/sandboxes/dev.txt index 73ab68008..16fd687de 100644 --- a/sandboxes/dev.txt +++ b/sandboxes/dev.txt @@ -1,12 +1,13 @@ -# version 11 +# version 12 # last version to support python 3.9 ipython~=8.18.1 -pytest~=8.3.5 -pytest-cov~=6.0.0 -flake8~=7.1.2 +pytest~=8.4.2 +pytest-cov~=7.0.0 +flake8~=7.3.0 flake8-commas~=4.0.0 flake8-quotes~=3.4.0 -pipdeptree~=2.26.0 -pymarkdownlnt~=0.9.29 -uniplot~=0.17.1 +pymarkdownlnt~=0.9.32 +uniplot~=0.21.4 +pipdeptree~=2.28.0 +mermaidmro~=0.2.1 diff --git a/sandboxes/ml_tf.txt b/sandboxes/ml_tf.txt index 382f89151..ce18e6828 100644 --- a/sandboxes/ml_tf.txt +++ b/sandboxes/ml_tf.txt @@ -1,4 +1,4 @@ -# version 11 +# version 12 # use packages from columnar sandbox as baseline -r columnar.txt diff --git a/setup.sh b/setup.sh index b45d7197e..932a96357 100644 --- a/setup.sh +++ b/setup.sh @@ -376,6 +376,7 @@ cf_setup_interactive_common_variables() { query CF_VENV_SETUP_MODE_UPDATE "Automatically update virtual envs if needed" "false" [ "${CF_VENV_SETUP_MODE_UPDATE}" != "true" ] && export_and_save CF_VENV_SETUP_MODE "update" unset CF_VENV_SETUP_MODE_UPDATE + query CF_INTERACTIVE_VENV_FILE "Custom venv setup fill to use for interactive work instead of 'cf_dev'" "" "''" query CF_LOCAL_SCHEDULER "Use a local scheduler for law tasks" "true" if [ "${CF_LOCAL_SCHEDULER}" != "true" ]; then @@ -530,8 +531,12 @@ cf_setup_software_stack() { # Optional environments variables: # CF_REMOTE_ENV # When true-ish, the software stack is sourced but not built. - # CF_CI_ENV - # When true-ish, the "cf" venv is skipped and only the "cf_dev" env is built. + # CF_LOCAL_ENV + # When not true-ish, the context is not meant for local development and only the "cf_dev" venv is built. + # CF_INTERACTIVE_VENV_FILE is ignored in this case. + # CF_INTERACTIVE_VENV_FILE + # IF CF_LOCAL_ENV is true-ish, the venv setup of this file is sourced to start the interactive shell. When + # empty, defaults to ${CF_BASE}/sandboxes/cf_dev.sh. # CF_REINSTALL_SOFTWARE # When true-ish, any existing software stack is removed and freshly installed. # CF_CONDA_ARCH @@ -620,8 +625,8 @@ cf_setup_software_stack() { # conda / micromamba setup # - # not needed in CI or RTD jobs - if ! ${CF_CI_ENV} && ! ${CF_RTD_ENV}; then + # only needed in local envs + if ${CF_LOCAL_ENV}; then # base environment local conda_missing="$( [ -d "${CF_CONDA_BASE}" ] && echo "false" || echo "true" )" if ${conda_missing}; then @@ -690,32 +695,61 @@ EOF # - "cf" : contains the minimal stack to run tasks and is sent alongside jobs # - "cf_dev" : "cf" + additional python tools for local development (e.g. ipython) + # - custom : when CF_INTERACTIVE_VENV_FILE is set, source the venv setup from there + + source_venv() { + # all parameters must be given + local venv_file="$1" + local venv_name="$2" + # must be true or false + local use_subshell="$3" + + # source the file and catch the return code + local ret="0" + if ${use_subshell}; then + ( source "${venv_file}" "" "silent" ) + ret="$?" + else + source "${venv_file}" "" "silent" + ret="$?" + fi - show_version_warning() { - >&2 echo - >&2 echo "WARNING: your venv '$1' is not up to date, please consider updating it in a new shell with" - >&2 echo "WARNING: > CF_REINSTALL_SOFTWARE=1 source setup.sh $( ${setup_is_default} || echo "${setup_name}" )" - >&2 echo - } - - # source the production sandbox, potentially skipped in CI and RTD jobs - if ! ${CF_CI_ENV} && ! ${CF_RTD_ENV}; then - ( source "${CF_BASE}/sandboxes/cf.sh" "" "silent" ) - ret="$?" + # code 21 means "version outdated", all others are as usual if [ "${ret}" = "21" ]; then - show_version_warning "cf" + >&2 echo + >&2 echo "WARNING: your venv '${venv_name}' is not up to date, please consider updating it in a new shell with" + >&2 echo "WARNING: > CF_REINSTALL_SOFTWARE=1 source setup.sh $( ${setup_is_default} || echo "${setup_name}" )" + >&2 echo elif [ "${ret}" != "0" ]; then return "${ret}" fi + + return "0" + } + + # build the production sandbox in a subshell, only in local envs + if ${CF_LOCAL_ENV}; then + source_venv "${CF_BASE}/sandboxes/cf.sh" "cf" true || return "$?" fi - # source the dev sandbox - source "${CF_BASE}/sandboxes/cf_dev.sh" "" "silent" - ret="$?" - if [ "${ret}" = "21" ]; then - show_version_warning "cf_dev" - elif [ "${ret}" != "0" ]; then - return "${ret}" + # check if a custom interactive venv should be used, check the file, but only in local envs + local use_custom_interactive_venv="false" + if ${CF_LOCAL_ENV} && [ ! -z "${CF_INTERACTIVE_VENV_FILE}" ]; then + # check existence + if [ ! -f "${CF_INTERACTIVE_VENV_FILE}" ]; then + >&2 echo "the interactive venv setup file ${CF_INTERACTIVE_VENV_FILE} does not exist" + return "2" + fi + use_custom_interactive_venv="true" + fi + + # build the dev sandbox, using a subshell if a custom venv is given that should be sourced afterwards + source_venv "${CF_BASE}/sandboxes/cf_dev.sh" "cf_dev" "${use_custom_interactive_venv}" || return "$?" + + # source the custom interactive venv setup file if given + if ${use_custom_interactive_venv}; then + echo "activating custom interactive venv from $( cf_color magenta "${CF_INTERACTIVE_VENV_FILE}" )" + source_venv "${CF_INTERACTIVE_VENV_FILE}" "$( basename "${CF_INTERACTIVE_VENV_FILE%.*}" )" false || return "$?" fi # initialze submodules From fe2d28ab351968e5fd35fd4e40c3806b5e1ea12d Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 8 Oct 2025 10:09:42 +0200 Subject: [PATCH 108/123] Hotfix ChunkedParquetReader. --- columnflow/columnar_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 3e4cef879..1e61e8ce1 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -3247,6 +3247,7 @@ def __init__(self, path: str, open_options: dict | None = None) -> None: meta_options = open_options.copy() meta_options.pop("row_groups", None) meta_options.pop("ignore_metadata", None) + meta_options.pop("columns", None) self.metadata = ak.metadata_from_parquet(path, **meta_options) # extract row group sizes for chunked reading From 8004828cd7f1f61600b027cff335476de95b9647 Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Wed, 8 Oct 2025 13:50:42 +0200 Subject: [PATCH 109/123] Hotfix bad import in plot utils. --- columnflow/plotting/plot_util.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index e664e4b15..70755ae67 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -1046,28 +1046,25 @@ def remove_label_placeholders( return re.sub(f"__{sel}__", "", label) -def calculate_stat_error( - hist: hist.Hist, - error_type: str, -) -> dict: +def calculate_stat_error(h: hist.Hist, error_type: str) -> np.ndarray: """ - Calculate the error to be plotted for the given histogram *hist*. + Calculate the error to be plotted for the given histogram *h*. Supported error types are: - - 'variance': the plotted error is the square root of the variance for each bin - - 'poisson_unweighted': the plotted error is the poisson error for each bin - - 'poisson_weighted': the plotted error is the poisson error for each bin, weighted by the variance - """ - import hist + - "variance": the plotted error is the square root of the variance for each bin + - "poisson_unweighted": the plotted error is the poisson error for each bin + - "poisson_weighted": the plotted error is the poisson error for each bin, weighted by the variance + """ # determine the error type if error_type == "variance": - yerr = hist.view().variance ** 0.5 + yerr = h.view().variance ** 0.5 + elif error_type in {"poisson_unweighted", "poisson_weighted"}: # compute asymmetric poisson confidence interval from hist.intervals import poisson_interval - variances = hist.view().variance if error_type == "poisson_weighted" else None - values = hist.view().value + variances = h.view().variance if error_type == "poisson_weighted" else None + values = h.view().value confidence_interval = poisson_interval(values, variances) # negative values are considerd as blinded bins -> set confidence interval to 0 @@ -1080,19 +1077,14 @@ def calculate_stat_error( raise ValueError("Unweighted Poisson interval calculation returned NaN values, check Hist package") # calculate the error - # yerr_lower is the lower error yerr_lower = values - confidence_interval[0] - # yerr_upper is the upper error yerr_upper = confidence_interval[1] - values - # yerr is the size of the errorbars to be plotted yerr = np.array([yerr_lower, yerr_upper]) if np.any(yerr < 0): - logger.warning( - "yerr < 0, setting to 0. " - "This should not happen, please check your histogram.", - ) + logger.warning("found yerr < 0, forcing to 0; this should not happen, please check your histogram") yerr[yerr < 0] = 0 + else: raise ValueError(f"unknown error type '{error_type}'") From 81c2621b4d8f3c3da671d5a7992608aa1871908c Mon Sep 17 00:00:00 2001 From: maadcoen Date: Tue, 14 Oct 2025 12:21:16 +0200 Subject: [PATCH 110/123] fix deprecated use of masked_sorted_indices --- .../__cf_module_name__/selection/objects.py | 8 ++++---- columnflow/selection/cmsGhent/lepton_mva_cuts.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py b/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py index 568350b58..86a3952b5 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py +++ b/analysis_templates/ghent_template/__cf_module_name__/selection/objects.py @@ -10,7 +10,7 @@ from columnflow.util import maybe_import, four_vec from columnflow.columnar_util import set_ak_column from columnflow.selection import Selector, SelectionResult, selector -from columnflow.reduction.util import masked_sorted_indices +from columnflow.columnar_util import sorted_indices_from_mask ak = maybe_import("awkward") @@ -53,7 +53,7 @@ def muon_object( steps={}, objects={ "Muon": { - "Muon": masked_sorted_indices(mu_mask, muon.pt) + "Muon": sorted_indices_from_mask(mu_mask, muon.pt) } }, ) @@ -108,7 +108,7 @@ def electron_object( steps={}, objects={ "Electron": { - "Electron": masked_sorted_indices(e_mask, electron.pt) + "Electron": sorted_indices_from_mask(e_mask, electron.pt) } }, ) @@ -142,7 +142,7 @@ def jet_object( (dR_mask) ) - jet_indices = masked_sorted_indices(jet_mask, events.Jet.pt) + jet_indices = sorted_indices_from_mask(jet_mask, events.Jet.pt) n_jets = ak.sum(jet_mask, axis=-1) return events, SelectionResult( diff --git a/columnflow/selection/cmsGhent/lepton_mva_cuts.py b/columnflow/selection/cmsGhent/lepton_mva_cuts.py index a1e238530..df3772cec 100644 --- a/columnflow/selection/cmsGhent/lepton_mva_cuts.py +++ b/columnflow/selection/cmsGhent/lepton_mva_cuts.py @@ -13,7 +13,7 @@ from columnflow.columnar_util import set_ak_column, optional_column # from columnflow.production.util import attach_coffea_behavior from columnflow.selection import Selector, SelectionResult, selector -from columnflow.reduction.util import masked_sorted_indices +from columnflow.columnar_util import sorted_indices_from_mask ak = maybe_import("awkward") @@ -83,7 +83,7 @@ def lepton_mva_object( steps={}, objects={ lep: - {lep: masked_sorted_indices(events[lep][working_point[lep]], events[lep].pt)} + {lep: sorted_indices_from_mask(events[lep][working_point[lep]], events[lep].pt)} for lep in ["Muon", "Electron"] }, ) From 30cef66ffd20eaaf1f52225a39965dbce61efe7f Mon Sep 17 00:00:00 2001 From: maadcoen Date: Tue, 14 Oct 2025 13:40:17 +0200 Subject: [PATCH 111/123] fix missing make_plot_2d --- .../plotting/cmsGhent/plot_functions_2d.py | 37 ++++++++++++++++--- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/columnflow/plotting/cmsGhent/plot_functions_2d.py b/columnflow/plotting/cmsGhent/plot_functions_2d.py index ed58a89c9..1b7f68eb4 100644 --- a/columnflow/plotting/cmsGhent/plot_functions_2d.py +++ b/columnflow/plotting/cmsGhent/plot_functions_2d.py @@ -9,7 +9,6 @@ mplhep = maybe_import("mplhep") hist = maybe_import("hist") -from columnflow.plotting.plot_all import make_plot_2d from columnflow.plotting.plot_util import ( apply_variable_settings, remove_residual_axis, @@ -153,17 +152,45 @@ def plot_migration_matrices( default_style_config["annotate_cfg"]["bbox"] = dict(alpha=0.5, facecolor="white") style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) - + + # # make main central migration plot - make_plot_2d(plot_config, style_config, figaxes=(fig, axes[0, 1])) + # + + central_ax = axes[0, 1] + + # apply style_config + if ax_cfg := style_config.get("ax_cfg", {}): + for tickname in ["xticks", "yticks"]: + ticks = ax_cfg.pop(tickname) + for ticksize in ["major", "minor"]: + if subticks := ticks.get(ticksize, {}): + getattr(central_ax, "set_" + tickname)(**subticks, minor=ticksize == "minor") + central_ax.set(**ax_cfg) + + if "legend_cfg" in style_config: + central_ax.legend(**style_config["legend_cfg"]) + + # annotation of category label + if annotate_kwargs := style_config.get("annotate_cfg", {}): + central_ax.annotate(**annotate_kwargs) + + if cms_label_kwargs := style_config.get("cms_label_cfg", {}): + mplhep.cms.label(ax=central_ax, **cms_label_kwargs) + + # call plot method, patching the colorbar function + # called internally by mplhep to draw the extension symbols + with patch.object(plt, "colorbar", partial(plt.colorbar, **plot_config.get("cbar_kwargs", {}))): + plot_config["hist"].plot2d(ax=central_ax, **plot_config.get("kwargs", {})) + if label_numbers: for i, x in enumerate(migrations_eq_ax.axes[0].centers): for j, y in enumerate(migrations_eq_ax.axes[1].centers): if abs(i - j) <= 1: lbl = f"{migrations_eq_ax.values()[i, j] * 100:.0f}" - axes[0, 1].text(x, y, lbl, ha="center", va="center", size="large") + central_ax.text(x, y, lbl, ha="center", va="center", size="large") - cbar = plt.colorbar(axes[0, 1].collections[0], **plot_config["cbar_kwargs"]) + cbar = plt.colorbar(central_ax.collections[0], **plot_config["cbar_kwargs"]) fix_cbar_minor_ticks(cbar) # set cbar range From 5a4761cdaf7d794ec278585110fe964ba533c97e Mon Sep 17 00:00:00 2001 From: maadcoen Date: Tue, 14 Oct 2025 14:16:49 +0200 Subject: [PATCH 112/123] fix missing make_plot_2d --- columnflow/plotting/cmsGhent/plot_functions_2d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/columnflow/plotting/cmsGhent/plot_functions_2d.py b/columnflow/plotting/cmsGhent/plot_functions_2d.py index 1b7f68eb4..620876c72 100644 --- a/columnflow/plotting/cmsGhent/plot_functions_2d.py +++ b/columnflow/plotting/cmsGhent/plot_functions_2d.py @@ -1,6 +1,8 @@ import law from collections import OrderedDict from columnflow.util import maybe_import +from unittest.mock import patch +from functools import partial plt = maybe_import("matplotlib.pyplot") np = maybe_import("numpy") From 3fb4b872ca1441c78cf805c2f337ef740b7f7175 Mon Sep 17 00:00:00 2001 From: maadcoen Date: Tue, 14 Oct 2025 14:26:53 +0200 Subject: [PATCH 113/123] refactor import of matplotlib, hist, coffea, correctionlib --- .../__cf_module_name__/plotting/example.py | 13 ++++++++---- .../__cf_module_name__/plotting/example.py | 13 ++++++++---- .../__cf_module_name__/production/default.py | 6 ++++-- .../__cf_module_name__/selection/default.py | 10 ++++++++-- columnflow/calibration/cmsGhent/lepton_mva.py | 2 -- columnflow/columnar_util_Ghent.py | 3 ++- .../plotting/cmsGhent/plot_functions_1d.py | 13 +++++++----- .../plotting/cmsGhent/plot_functions_2d.py | 13 +++++++----- columnflow/plotting/cmsGhent/plot_util.py | 6 +++++- columnflow/plotting/cmsGhent/unrolled.py | 20 +++++++++++-------- .../production/cmsGhent/btag_weights.py | 7 +++++-- .../production/cmsGhent/gen_features.py | 1 - columnflow/production/cmsGhent/lepton.py | 2 -- .../cmsGhent/trigger/hist_producer.py | 6 +++++- .../cmsGhent/trigger/sf_producer.py | 2 -- .../cmsGhent/trigger/uncertainties.py | 10 +++++----- .../production/cmsGhent/trigger/util.py | 11 ++++++++-- columnflow/tasks/cmsGhent/btagefficiency.py | 5 ++++- columnflow/tasks/cmsGhent/selection_hists.py | 7 ++++--- .../tasks/cmsGhent/trigger_scale_factors.py | 6 +++--- 20 files changed, 100 insertions(+), 56 deletions(-) diff --git a/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py b/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py index da7e34817..f160e5e6f 100644 --- a/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py +++ b/analysis_templates/cms_minimal/__cf_module_name__/plotting/example.py @@ -14,14 +14,16 @@ apply_variable_settings, apply_process_settings, ) +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") od = maybe_import("order") +# import hist, matplotlib... for type checking only like this! import them then also locallu. +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") + def my_plot1d_func( hists: OrderedDict[od.Process, hist.Hist], @@ -45,6 +47,9 @@ def my_plot1d_func( --plot-function __cf_module_name__.plotting.example.my_plot1d_func \ --general-settings example_param=some_text """ + import mplhep + import matplotlib.pyplot as plt + # we can add arbitrary parameters via the `general_settings` parameter to access them in the # plotting function. They are automatically parsed either to a bool, float, or string print(f"the example_param has been set to '{example_param}' (type: {type(example_param)})") diff --git a/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py b/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py index 943d3ce33..2166f1f22 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py +++ b/analysis_templates/ghent_template/__cf_module_name__/plotting/example.py @@ -14,14 +14,16 @@ apply_variable_settings, apply_process_settings, ) +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") od = maybe_import("order") +# import hist, matplotlib... for type checking only like this! import them then also locallu. +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") + def my_plot1d_func( hists: OrderedDict[od.Process, hist.Hist], @@ -45,6 +47,9 @@ def my_plot1d_func( --plot-function __cf_module_name__.plotting.example.my_plot1d_func \ --general-settings example_param=some_text """ + import mplhep + import matplotlib.pyplot as plt + # we can add arbitrary parameters via the `general_settings` parameter to access them in the # plotting function. They are automatically parsed either to a bool, float, or string print(f"The example_param has been set to '{example_param}' (type: {type(example_param)})") diff --git a/analysis_templates/ghent_template/__cf_module_name__/production/default.py b/analysis_templates/ghent_template/__cf_module_name__/production/default.py index a4617ef4b..fb546987c 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/production/default.py +++ b/analysis_templates/ghent_template/__cf_module_name__/production/default.py @@ -16,8 +16,10 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents.methods.nanoaod") + +# do not import coffea globally! Do this inside the function +# coffea = maybe_import("coffea") +# maybe_import("coffea.nanoevents.methods.nanoaod") @producer( diff --git a/analysis_templates/ghent_template/__cf_module_name__/selection/default.py b/analysis_templates/ghent_template/__cf_module_name__/selection/default.py index 1e237ff5b..e2ea6669b 100644 --- a/analysis_templates/ghent_template/__cf_module_name__/selection/default.py +++ b/analysis_templates/ghent_template/__cf_module_name__/selection/default.py @@ -28,15 +28,21 @@ from __cf_short_name_lc__.selection.stats import __cf_short_name_lc___increment_stats from __cf_short_name_lc__.selection.trigger import trigger_selection +# only numpy and awkward are okay to import globally np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents.methods.nanoaod") + +# do not import coffea globally! Do this inside the function +# coffea = maybe_import("coffea") +# maybe_import("coffea.nanoevents.methods.nanoaod") logger = law.logger.get_logger(__name__) def TetraVec(arr: ak.Array) -> ak.Array: + import coffea + import coffea.nanoevents.methods.nanoaod + TetraVec = ak.zip({"pt": arr.pt, "eta": arr.eta, "phi": arr.phi, "mass": arr.mass}, with_name="PtEtaPhiMLorentzVector", behavior=coffea.nanoevents.methods.vector.behavior) diff --git a/columnflow/calibration/cmsGhent/lepton_mva.py b/columnflow/calibration/cmsGhent/lepton_mva.py index 95c4199ad..739f22d81 100644 --- a/columnflow/calibration/cmsGhent/lepton_mva.py +++ b/columnflow/calibration/cmsGhent/lepton_mva.py @@ -13,8 +13,6 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") -maybe_import("coffea.nanoevents.methods.nanoaod") @producer( diff --git a/columnflow/columnar_util_Ghent.py b/columnflow/columnar_util_Ghent.py index 0e985b106..6cd73cffb 100644 --- a/columnflow/columnar_util_Ghent.py +++ b/columnflow/columnar_util_Ghent.py @@ -16,13 +16,14 @@ from columnflow.columnar_util import remove_ak_column, has_ak_column ak = maybe_import("awkward") -coffea = maybe_import("coffea") def TetraVec(arr: ak.Array, keep: Sequence | str | Literal[-1] = -1) -> ak.Array: """ create a Lorentz for fector from an awkward array with pt, eta, phi, and mass fields """ + import coffea + mandatory_fields = ("pt", "eta", "phi", "mass") exclude_fields = ("x", "y", "z", "t") for field in mandatory_fields: diff --git a/columnflow/plotting/cmsGhent/plot_functions_1d.py b/columnflow/plotting/cmsGhent/plot_functions_1d.py index 589643ca1..e6e9d4b2d 100644 --- a/columnflow/plotting/cmsGhent/plot_functions_1d.py +++ b/columnflow/plotting/cmsGhent/plot_functions_1d.py @@ -1,5 +1,7 @@ from __future__ import annotations +import math + import order as od import law from collections import OrderedDict @@ -14,13 +16,12 @@ from columnflow.plotting.plot_all import plot_all from columnflow.plotting.cmsGhent.plot_util import cumulate +from columnflow.types import TYPE_CHECKING -plt = maybe_import("matplotlib.pyplot") np = maybe_import("numpy") -mtrans = maybe_import("matplotlib.transforms") -mplhep = maybe_import("mplhep") -math = maybe_import("math") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def plot_multi_variables( @@ -252,6 +253,8 @@ def plot_1d_line( """ TODO. """ + import hist + n_bins = math.prod([v.n_bins for v in variable_insts]) def flatten_data(data: hist.Hist | np.ndarray): diff --git a/columnflow/plotting/cmsGhent/plot_functions_2d.py b/columnflow/plotting/cmsGhent/plot_functions_2d.py index 620876c72..426dc9d83 100644 --- a/columnflow/plotting/cmsGhent/plot_functions_2d.py +++ b/columnflow/plotting/cmsGhent/plot_functions_2d.py @@ -1,15 +1,11 @@ +import order as od import law from collections import OrderedDict from columnflow.util import maybe_import from unittest.mock import patch from functools import partial -plt = maybe_import("matplotlib.pyplot") np = maybe_import("numpy") -od = maybe_import("order") -mtrans = maybe_import("matplotlib.transforms") -mplhep = maybe_import("mplhep") -hist = maybe_import("hist") from columnflow.plotting.plot_util import ( apply_variable_settings, @@ -24,6 +20,7 @@ def merge_migration_bins(h): """ binning both axes in equal bins """ + import hist x_edges = h.axes[0].edges y_edges = h.axes[1].edges @@ -98,6 +95,12 @@ def plot_migration_matrices( keep_bins_in_bkg: bool = False, **kwargs, ): + import mplhep + import matplotlib.transforms as mtrans + import matplotlib.pyplot as plt + + + plt.style.use(mplhep.style.CMS) fig, axes = plt.subplots( 2, 3, diff --git a/columnflow/plotting/cmsGhent/plot_util.py b/columnflow/plotting/cmsGhent/plot_util.py index 08b11643c..74cbdbb46 100644 --- a/columnflow/plotting/cmsGhent/plot_util.py +++ b/columnflow/plotting/cmsGhent/plot_util.py @@ -1,12 +1,16 @@ from __future__ import annotations import order as od from columnflow.util import maybe_import +from columnflow.types import TYPE_CHECKING, Sequence -hist = maybe_import("hist") np = maybe_import("numpy") +if TYPE_CHECKING: + hist = maybe_import("hist") def cumulate(h: np.ndarray | hist.Hist, direction="below", axis: str | int | od.Variable = 0): + import hist + idx_slice = np.s_[::-1] if direction == "above" else np.s_[:] arr = h if isinstance(h, np.ndarray) else h.view(flow=False) if isinstance(axis, od.Variable): diff --git a/columnflow/plotting/cmsGhent/unrolled.py b/columnflow/plotting/cmsGhent/unrolled.py index 746f2bd0c..10b08e58e 100644 --- a/columnflow/plotting/cmsGhent/unrolled.py +++ b/columnflow/plotting/cmsGhent/unrolled.py @@ -33,6 +33,7 @@ from collections import OrderedDict import law +import order as od from columnflow.util import maybe_import from columnflow.plotting.plot_all import ( @@ -48,25 +49,25 @@ get_cms_label, get_position, ) +from columnflow.types import TYPE_CHECKING - -hist = maybe_import("hist") np = maybe_import("numpy") -mpl = maybe_import("matplotlib") -plt = maybe_import("matplotlib.pyplot") -mplhep = maybe_import("mplhep") -od = maybe_import("order") -mticker = maybe_import("matplotlib.ticker") -colorsys = maybe_import("colorsys") +if TYPE_CHECKING: + hist = maybe_import("hist") + plt = maybe_import("matplotlib.pyplot") def change_saturation(hls, saturation_factor): + import colorsys + # Convert back to RGB new_rgb = colorsys.hls_to_rgb(hls[0], hls[1], saturation_factor) return new_rgb def get_new_colors(original_color, n_new_colors=2): + import colorsys + # Convert RGB to HLS hls = colorsys.rgb_to_hls(*original_color) @@ -178,6 +179,9 @@ def plot_unrolled( variable_settings: dict | None = None, **kwargs, ) -> plt.Figure: + import mplhep + import matplotlib as mpl + import matplotlib.pyplot as plt # remove shift axis from histograms if len(shift_insts) == 1: diff --git a/columnflow/production/cmsGhent/btag_weights.py b/columnflow/production/cmsGhent/btag_weights.py index 179ac438c..eaf1879a7 100644 --- a/columnflow/production/cmsGhent/btag_weights.py +++ b/columnflow/production/cmsGhent/btag_weights.py @@ -20,8 +20,6 @@ ak = maybe_import("awkward") np = maybe_import("numpy") -hist = maybe_import("hist") -correctionlib = maybe_import("correctionlib") logger = law.logger.get_logger(__name__) @@ -64,6 +62,8 @@ def init_btag(self: Producer, add_eff_vars=True): def setup_btag(self: Producer, task: law.Task, reqs: dict): + import correctionlib + bundle = reqs["external_files"] correction_set_btag_wp_corr = correctionlib.CorrectionSet.from_string( self.get_btag_sf(bundle.files).load(formatter="gzip").decode("utf-8"), @@ -299,6 +299,8 @@ def fixed_wp_btag_weights_setup( inputs: dict, reader_targets: law.util.InsertableDict, ) -> None: + import correctionlib + correction_set_btag_wp_corr = setup_btag(self, task, reqs) # fix for change in nomenclature of deepJet scale factors for light hadronFlavour jets @@ -366,6 +368,7 @@ def btag_efficiency_hists( hists: DotDict | dict = None, **kwargs, ) -> ak.Array: + import hist if hists is None: return events diff --git a/columnflow/production/cmsGhent/gen_features.py b/columnflow/production/cmsGhent/gen_features.py index 2f170aef4..6667e6697 100644 --- a/columnflow/production/cmsGhent/gen_features.py +++ b/columnflow/production/cmsGhent/gen_features.py @@ -10,7 +10,6 @@ np = maybe_import("numpy") ak = maybe_import("awkward") -coffea = maybe_import("coffea") def _geometric_matching(particles1: ak.Array, particles2: ak.Array) -> (ak.Array, ak.Array): diff --git a/columnflow/production/cmsGhent/lepton.py b/columnflow/production/cmsGhent/lepton.py index 051e4358e..31e080139 100644 --- a/columnflow/production/cmsGhent/lepton.py +++ b/columnflow/production/cmsGhent/lepton.py @@ -12,8 +12,6 @@ ak = maybe_import("awkward") np = maybe_import("numpy") -hist = maybe_import("hist") -correctionlib = maybe_import("correctionlib") logger = law.logger.get_logger(__name__) diff --git a/columnflow/production/cmsGhent/trigger/hist_producer.py b/columnflow/production/cmsGhent/trigger/hist_producer.py index 0f3abaa0b..7b0840689 100644 --- a/columnflow/production/cmsGhent/trigger/hist_producer.py +++ b/columnflow/production/cmsGhent/trigger/hist_producer.py @@ -15,10 +15,12 @@ import columnflow.production.cmsGhent.trigger.util as util from columnflow.selection import SelectionResult import order as od +from columnflow.types import TYPE_CHECKING np = maybe_import("numpy") ak = maybe_import("awkward") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") logger = law.logger.get_logger(__name__) @@ -34,6 +36,8 @@ def trigger_efficiency_hists( object_mask: dict = None, **kwargs, ) -> ak.Array: + import hist + if hists is None: logger.warning_once(self.cls_name + " did not get any histograms") return events diff --git a/columnflow/production/cmsGhent/trigger/sf_producer.py b/columnflow/production/cmsGhent/trigger/sf_producer.py index 82e320214..fbe5c03b1 100644 --- a/columnflow/production/cmsGhent/trigger/sf_producer.py +++ b/columnflow/production/cmsGhent/trigger/sf_producer.py @@ -12,10 +12,8 @@ from columnflow.columnar_util import set_ak_column, has_ak_column, Route import columnflow.production.cmsGhent.trigger.util as util - np = maybe_import("numpy") ak = maybe_import("awkward") -hist = maybe_import("hist") logger = law.logger.get_logger(__name__) diff --git a/columnflow/production/cmsGhent/trigger/uncertainties.py b/columnflow/production/cmsGhent/trigger/uncertainties.py index 6e0cec903..4ec341862 100644 --- a/columnflow/production/cmsGhent/trigger/uncertainties.py +++ b/columnflow/production/cmsGhent/trigger/uncertainties.py @@ -4,12 +4,12 @@ from columnflow.util import maybe_import from columnflow.production.cmsGhent.trigger.Koopman_test import koopman_confint import columnflow.production.cmsGhent.trigger.util as util +from columnflow.types import TYPE_CHECKING -import numpy as np - -hist = maybe_import("hist") - -Hist = hist.Hist +np = maybe_import("numpy") +if TYPE_CHECKING: + hist = maybe_import("hist") + Hist = hist.Hist def calc_stat( diff --git a/columnflow/production/cmsGhent/trigger/util.py b/columnflow/production/cmsGhent/trigger/util.py index 6054ecb1a..fc0963ec1 100644 --- a/columnflow/production/cmsGhent/trigger/util.py +++ b/columnflow/production/cmsGhent/trigger/util.py @@ -9,10 +9,13 @@ from columnflow.production import Producer from columnflow.util import maybe_import from columnflow.plotting.plot_util import use_flow_bins +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") -Hist = hist.Hist np = maybe_import("numpy") +ak = maybe_import("awkward") +if TYPE_CHECKING: + hist = maybe_import("hist") + Hist = hist.Hist logger = law.logger.get_logger(__name__) @@ -23,6 +26,8 @@ def reduce_hist( exclude: str | Collection[str] = tuple(), keepdims=True, ): + import hist + exclude = law.util.make_list(exclude) if reduce is Ellipsis: return histogram.project(*exclude) @@ -74,6 +79,8 @@ def syst_hist( syst_name: str = "", arrays: np.ndarray | tuple[np.ndarray, np.ndarray] = None, ) -> Hist: + import hist + if syst_name == "central": variations = [syst_name] else: diff --git a/columnflow/tasks/cmsGhent/btagefficiency.py b/columnflow/tasks/cmsGhent/btagefficiency.py index 32c07cc8b..333ff05ea 100644 --- a/columnflow/tasks/cmsGhent/btagefficiency.py +++ b/columnflow/tasks/cmsGhent/btagefficiency.py @@ -17,8 +17,11 @@ from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.util import dev_sandbox, dict_add_strict, DotDict, maybe_import +from columnflow.types import TYPE_CHECKING -hist = maybe_import("hist") + +if TYPE_CHECKING: + hist = maybe_import("hist") class BTagEfficiencyBase: diff --git a/columnflow/tasks/cmsGhent/selection_hists.py b/columnflow/tasks/cmsGhent/selection_hists.py index 13db8a6ed..4176a7fa7 100644 --- a/columnflow/tasks/cmsGhent/selection_hists.py +++ b/columnflow/tasks/cmsGhent/selection_hists.py @@ -14,10 +14,10 @@ PlotBase, PlotBase1D, VariablePlotSettingMixin, ProcessPlotSettingMixin, ) -from columnflow.types import Any +from columnflow.types import TYPE_CHECKING, Any - -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") class CustomDefaultVariablesMixin( @@ -189,6 +189,7 @@ def efficiency(cls, selected_counts: hist.Hist, incl: hist.Hist, **kwargs) -> hi @param incl: histogram with event counts before selection @param kwargs: keyword arguments passed to **proportion_confint** """ + import hist from statsmodels.stats.proportion import proportion_confint efficiency = selected_counts / incl.values() eff_sample_size_corr = incl.values() / incl.variances() diff --git a/columnflow/tasks/cmsGhent/trigger_scale_factors.py b/columnflow/tasks/cmsGhent/trigger_scale_factors.py index 412760b0f..006f4c044 100644 --- a/columnflow/tasks/cmsGhent/trigger_scale_factors.py +++ b/columnflow/tasks/cmsGhent/trigger_scale_factors.py @@ -6,7 +6,6 @@ from itertools import product import luigi -from columnflow.types import Any from columnflow.tasks.framework.base import Requirements, ConfigTask from columnflow.tasks.framework.mixins import ( CalibratorClassesMixin, SelectorClassMixin, DatasetsMixin, @@ -19,10 +18,11 @@ import columnflow.production.cmsGhent.trigger.util as util from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.util import dev_sandbox, dict_add_strict, maybe_import - +from columnflow.types import TYPE_CHECKING, Any np = maybe_import("numpy") -hist = maybe_import("hist") +if TYPE_CHECKING: + hist = maybe_import("hist") logger = law.logger.get_logger(__name__) From f0198c94a2e03c2d5e72e649e0d08cea9c5e5b64 Mon Sep 17 00:00:00 2001 From: juvanden Date: Mon, 27 Oct 2025 12:24:21 +0100 Subject: [PATCH 114/123] linting fixes --- bin/cf_inspect.py | 3 +- columnflow/inference/__init__.py | 54 ------------------- .../plotting/cmsGhent/plot_functions_2d.py | 6 +-- columnflow/plotting/cmsGhent/plot_util.py | 2 +- columnflow/tasks/plotting.py | 15 +++--- 5 files changed, 12 insertions(+), 68 deletions(-) diff --git a/bin/cf_inspect.py b/bin/cf_inspect.py index 3a4db1c39..f4b13742a 100644 --- a/bin/cf_inspect.py +++ b/bin/cf_inspect.py @@ -13,8 +13,6 @@ import pickle import awkward as ak -import coffea.nanoevents -import uproot import numpy as np # noqa from columnflow.util import ipython_shell @@ -61,6 +59,7 @@ def _load_nano_root(fname: str, treepath: str | None = None, **kwargs) -> ak.Arr except: return uproot.open(fname) + def _load_h5(fname: str, **kwargs): import h5py return h5py.File(fname, "r") diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 0e0b07c07..70e1f7785 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -601,60 +601,6 @@ def parameter_config_spec( ("shift_source", str(shift_source) if shift_source else None), ]) - @classmethod - def category_config_spec( - cls, - category: str | None = None, - variable: str | None = None, - data_datasets: Sequence[str] | None = None, - ) -> DotDict: - """ - Returns a dictionary representing configuration specific data, forwarding all arguments. - - :param category: The name of the source category in the config to use. - :param variable: The name of the variable in the config to use. - :param data_datasets: List of names or patterns of datasets in the config to use for real data. - :returns: A dictionary representing category specific settings. - """ - return DotDict([ - ("category", str(category) if category else None), - ("variable", str(variable) if variable else None), - ("data_datasets", list(map(str, data_datasets or []))), - ]) - - @classmethod - def process_config_spec( - cls, - process: str | None = None, - mc_datasets: Sequence[str] | None = None, - ) -> DotDict: - """ - Returns a dictionary representing configuration specific data, forwarding all arguments. - - :param process: The name of the process in the config to use. - :param mc_datasets: List of names or patterns of datasets in the config to use for mc. - :returns: A dictionary representing process specific settings. - """ - return DotDict([ - ("process", str(process) if process else None), - ("mc_datasets", list(map(str, mc_datasets or []))), - ]) - - @classmethod - def parameter_config_spec( - cls, - shift_source: str | None = None, - ) -> DotDict: - """ - Returns a dictionary representing configuration specific data, forwarding all arguments. - - :param shift_source: The name of a systematic shift source in the config. - :returns: A dictionary representing parameter specific settings. - """ - return DotDict([ - ("shift_source", str(shift_source) if shift_source else None), - ]) - def __init__(self, config_insts: list[od.Config]) -> None: super().__init__() diff --git a/columnflow/plotting/cmsGhent/plot_functions_2d.py b/columnflow/plotting/cmsGhent/plot_functions_2d.py index 426dc9d83..98ac2551a 100644 --- a/columnflow/plotting/cmsGhent/plot_functions_2d.py +++ b/columnflow/plotting/cmsGhent/plot_functions_2d.py @@ -99,8 +99,6 @@ def plot_migration_matrices( import matplotlib.transforms as mtrans import matplotlib.pyplot as plt - - plt.style.use(mplhep.style.CMS) fig, axes = plt.subplots( 2, 3, @@ -157,7 +155,7 @@ def plot_migration_matrices( default_style_config["annotate_cfg"]["bbox"] = dict(alpha=0.5, facecolor="white") style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) - + # # make main central migration plot # @@ -187,7 +185,7 @@ def plot_migration_matrices( # called internally by mplhep to draw the extension symbols with patch.object(plt, "colorbar", partial(plt.colorbar, **plot_config.get("cbar_kwargs", {}))): plot_config["hist"].plot2d(ax=central_ax, **plot_config.get("kwargs", {})) - + if label_numbers: for i, x in enumerate(migrations_eq_ax.axes[0].centers): for j, y in enumerate(migrations_eq_ax.axes[1].centers): diff --git a/columnflow/plotting/cmsGhent/plot_util.py b/columnflow/plotting/cmsGhent/plot_util.py index 74cbdbb46..f3e83d8ce 100644 --- a/columnflow/plotting/cmsGhent/plot_util.py +++ b/columnflow/plotting/cmsGhent/plot_util.py @@ -1,7 +1,7 @@ from __future__ import annotations import order as od from columnflow.util import maybe_import -from columnflow.types import TYPE_CHECKING, Sequence +from columnflow.types import TYPE_CHECKING np = maybe_import("numpy") if TYPE_CHECKING: diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index 67eca251d..6dd2acf3b 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -212,7 +212,7 @@ def run(self): # create expected shift bins and fill them with the nominal histogram # change Ghent: replace all expected shifts with nominal. # not preffered by columnflow: https://github.com/columnflow/columnflow/pull/692 - expected_shifts = plot_shift_names # & process_shift_map[process_inst.name] + expected_shifts = plot_shift_names # & process_shift_map[process_inst.name] add_missing_shifts(h, expected_shifts, str_axis="shift", nominal_bin="nominal") # add the histogram @@ -266,14 +266,15 @@ def run(self): for process_inst in hists.keys(): h = hists[process_inst] # determine expected shifts from the intersection of requested shifts and those known for the process - process_shifts = ( - process_shift_map[process_inst.name] - if process_inst.name in process_shift_map - else {"nominal"} - ) + # process_shifts = ( + # process_shift_map[process_inst.name] + # if process_inst.name in process_shift_map + # else {"nominal"} + # ) + # change Ghent: replace all expected shifts with nominal. # not preffered by columnflow: https://github.com/columnflow/columnflow/pull/692 - expected_shifts = plot_shift_names # & process_shifts + expected_shifts = plot_shift_names # & process_shifts if not expected_shifts: raise Exception(f"no shifts to plot found for process {process_inst.name}") # selections From 0ab442a38a6b196e46fc2fd9582637328a5c782c Mon Sep 17 00:00:00 2001 From: juvanden Date: Wed, 29 Oct 2025 11:29:59 +0100 Subject: [PATCH 115/123] addition of met_uncertainty_sources parameter for jer calibrator that propagates the jet energy resolution smearing to the met uncertainty variations given to jer. --- columnflow/calibration/cms/jets.py | 38 +++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index e3e6fa2ca..30b5a9262 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -738,6 +738,8 @@ def get_jer_config_default(self: Calibrator) -> DotDict: get_jec_config=get_jec_config_default, # jec uncertainty sources to propagate jer to, defaults to config when empty jec_uncertainty_sources=None, + # MET uncertainty sources to propagate jer to, defaults to None when empty + met_uncertainty_sources=None, # whether gen jet matching should be performed relative to the nominal jet pt, or the jec varied values gen_jet_matching_nominal=False, # regions where stochastic smearing is applied @@ -946,7 +948,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: if self.propagate_met: jetsum_pt_before = {} jetsum_phi_before = {} - for postfix in self.postfixes: + for postfix in self.jet_postfixes: jetsum_pt_before[postfix], jetsum_phi_before[postfix] = sum_transverse( events[jet_name][f"pt{postfix}"], events[jet_name].phi, @@ -954,7 +956,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: # apply the smearing # (note: this requires that postfixes and smear_factors have the same order, but this should be the case) - for i, postfix in enumerate(self.postfixes): + for i, postfix in enumerate(self.jet_postfixes): pt_name = f"pt{postfix}" m_name = f"mass{postfix}" events = set_ak_column_f32(events, f"{jet_name}.{pt_name}", events[jet_name][pt_name] * smear_factors[..., i]) @@ -970,27 +972,33 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: events = set_ak_column_f32(events, f"{met_name}.phi_unsmeared", events[met_name].phi) # propagate per variation - for postfix in self.postfixes: + for postfix in self.met_postfixes: # get pt and phi of all jets after correcting + + if hasattr(events[jet_name], f"pt{postfix}"): + jet_postfix = postfix + else: + jet_postfix = "" + + # jet variation exists, use it jetsum_pt_after, jetsum_phi_after = sum_transverse( - events[jet_name][f"pt{postfix}"], + events[jet_name][f"pt{jet_postfix}"], events[jet_name].phi, ) # propagate changes to MET met_pt, met_phi = propagate_met( - jetsum_pt_before[postfix], - jetsum_phi_before[postfix], + jetsum_pt_before[jet_postfix], + jetsum_phi_before[jet_postfix], jetsum_pt_after, jetsum_phi_after, events[met_name][f"pt{postfix}"], events[met_name][f"phi{postfix}"], ) + events = set_ak_column_f32(events, f"{met_name}.pt{postfix}", met_pt) events = set_ak_column_f32(events, f"{met_name}.phi{postfix}", met_phi) - return events - jer_horn_handling = jer.derive("jer_horn_handling", cls_dict={ # source: https://cms-jerc.web.cern.ch/Recommendations/#note-25eta30 @@ -1019,7 +1027,7 @@ def jer_init(self: Calibrator, **kwargs) -> None: # prepare jer variations and postfixes self.jer_variations = ["nom", "up", "down"] - self.postfixes = ["", "_jer_up", "_jer_down"] + [f"_{jec_var}" for jec_var in self.jec_variations] + self.jet_postfixes = ["", "_jer_up", "_jer_down"] + [f"_{jec_var}" for jec_var in self.jec_variations] # register used jet columns self.uses.add(f"{self.jet_name}.{{pt,eta,phi,mass,{self.gen_jet_idx_column}}}") @@ -1039,11 +1047,23 @@ def jer_init(self: Calibrator, **kwargs) -> None: if jec_sources: self.uses |= met_jec_columns + met_sources = self.met_uncertainty_sources or [] + self.met_variations = sum(([f"{unc}_up", f"{unc}_down"] for unc in met_sources), []) + self.met_postfixes = ["", "_jer_up", "_jer_down"] + \ + [f"_{jec_var}" for jec_var in self.jec_variations] + \ + [f"_{met_source}" for met_source in self.met_variations] + + if met_sources: + self.uses |= {f"{self.met_name}.{{pt,phi}}_{met_source}" for met_source in self.met_variations} + # register produced MET columns self.produces.add(f"{self.met_name}.{{pt,phi}}{{,_jer_up,_jer_down,_unsmeared}}") if jec_sources: self.produces |= met_jec_columns + if met_sources: + self.produces |= {f"{self.met_name}.{{pt,phi}}_{met_source}" for met_source in self.met_variations} + @jer.requires def jer_requires( From 34a1f335f57fdc224bfad664c7cce169d8efb6bb Mon Sep 17 00:00:00 2001 From: juvanden Date: Wed, 29 Oct 2025 11:32:37 +0100 Subject: [PATCH 116/123] cleaning of jets.py --- columnflow/calibration/cms/jets.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 30b5a9262..0a8872098 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -975,10 +975,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: for postfix in self.met_postfixes: # get pt and phi of all jets after correcting - if hasattr(events[jet_name], f"pt{postfix}"): - jet_postfix = postfix - else: - jet_postfix = "" + jet_postfix = postfix if hasattr(events[jet_name], f"pt{postfix}") else "" # jet variation exists, use it jetsum_pt_after, jetsum_phi_after = sum_transverse( From d958c40e82199ebe8b195f7f3cc2a11dc4cdace3 Mon Sep 17 00:00:00 2001 From: Maarten De Coen <52047931+maadcoen@users.noreply.github.com> Date: Thu, 30 Oct 2025 12:29:14 +0100 Subject: [PATCH 117/123] add jer return --- columnflow/calibration/cms/jets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index 0a8872098..b0d36dc3c 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -995,6 +995,8 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: events = set_ak_column_f32(events, f"{met_name}.pt{postfix}", met_pt) events = set_ak_column_f32(events, f"{met_name}.phi{postfix}", met_phi) + + return events jer_horn_handling = jer.derive("jer_horn_handling", cls_dict={ From 5b363ee517b7517486f8305581daf5a5fb556804 Mon Sep 17 00:00:00 2001 From: Maarten De Coen <52047931+maadcoen@users.noreply.github.com> Date: Thu, 30 Oct 2025 15:39:44 +0100 Subject: [PATCH 118/123] fix linting --- columnflow/calibration/cms/jets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index b0d36dc3c..09f381770 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -995,7 +995,7 @@ def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: events = set_ak_column_f32(events, f"{met_name}.pt{postfix}", met_pt) events = set_ak_column_f32(events, f"{met_name}.phi{postfix}", met_phi) - + return events From e764265536ac531ff058a8f67ea0d93764ee65ab Mon Sep 17 00:00:00 2001 From: JulesVandenbroeck <93740577+JulesVandenbroeck@users.noreply.github.com> Date: Fri, 31 Oct 2025 11:26:35 +0100 Subject: [PATCH 119/123] Upstream changing to upstream merge request (#113) * Extend dy weight application to use btag multiplicity. (#739) * Extend dy weight application to use btag multiplicity. * Update docstring. * Hotfix nbtags variable in dy weight producer. * fix skipping data in CreateDatacards * Add objects for interacting with CMS CAT meta data. (#740) * Add objects for interacting with CAT meta data. * Remove namespace for now. * Cleanup. * Update fixed law. * Use cf.cms task namespace. * Add CMSDatasetInfo. * Allow pathlib input. * Add dc pog to CATSnapshot. * More flexible POG overrides. * Typo. * Simplify. * Hotfix CAT metadata update check for missing POG dirs. * add subplots_cfg in plot_all (#742) Co-authored-by: Mathis Frahm * Update law. * Refactor generator-level top and top decay product lookup (#741) * Refactor gen top lookup. * Add theory-based top pt weight method. * Comments. * Comments. * Rename field wDecay -> wChildren. * Update kept fields in gen_particles.py Removed 'status' and 'statusFlags' from kept generator particle fields. * Fix gen part field transformations. * Add suggestion by @jolange * Add gen_higgs_lookup. * Hotfix saving of columns in gen_particle lookups. * Hotfix depth limit of gen particles. * Add gen_dy_lookup. * Hotfix multi-config lookup via patterns. * Hotfix reduction to skip empty chunks. * Hotfix higgs gen lookup, considering effective gluon/photon decays. * Hotfix single shift selection in plotting. * Allow patterns in get_shifts_from_sources. * Hotfix save_div in plot scale factor. * [cms] Update log in CheckCATUpdates task. * Skip string columns in finiteness checks, fixes #743. * Hotfix repo bunlding, add missing user config. * [cms] Refactor egamma calibrators. (#745) * docs: add Bogdan-Wiederspan as a contributor for review (#746) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> * docs: add aalvesan as a contributor for review (#747) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> * Add t->w->tau children in gen_top_lookup. * Hotfix typo in gen_top lookup. * Add and use sum_hists helper. * Extend tes versions. --------- Co-authored-by: Marcel Rieger Co-authored-by: Marcel R. Co-authored-by: Mathis Frahm Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 6 +- README.md | 4 +- analysis_templates/cms_minimal/law.cfg | 2 +- columnflow/calibration/cms/egamma.py | 808 ++++++--------------- columnflow/calibration/cms/tau.py | 2 +- columnflow/cms_util.py | 201 +++++ columnflow/config_util.py | 29 +- columnflow/hist_util.py | 36 +- columnflow/inference/cms/datacard.py | 11 +- columnflow/plotting/plot_all.py | 8 +- columnflow/plotting/plot_functions_1d.py | 6 +- columnflow/plotting/plot_functions_2d.py | 3 +- columnflow/plotting/plot_util.py | 10 +- columnflow/production/cms/dy.py | 36 +- columnflow/production/cms/gen_particles.py | 359 +++++++++ columnflow/production/cms/gen_top_decay.py | 90 --- columnflow/production/cms/top_pt_weight.py | 190 ++--- columnflow/production/util.py | 9 +- columnflow/tasks/cms/external.py | 77 ++ columnflow/tasks/cms/inference.py | 2 +- columnflow/tasks/external.py | 7 +- columnflow/tasks/framework/base.py | 11 + columnflow/tasks/framework/mixins.py | 16 +- columnflow/tasks/framework/remote.py | 4 + columnflow/tasks/histograms.py | 5 +- columnflow/tasks/plotting.py | 15 +- columnflow/tasks/reduction.py | 6 +- modules/law | 2 +- 28 files changed, 1083 insertions(+), 872 deletions(-) create mode 100644 columnflow/cms_util.py create mode 100644 columnflow/production/cms/gen_particles.py delete mode 100644 columnflow/production/cms/gen_top_decay.py diff --git a/.all-contributorsrc b/.all-contributorsrc index 09e5078cf..78194e308 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -71,7 +71,8 @@ "profile": "https://github.com/Bogdan-Wiederspan", "contributions": [ "code", - "test" + "test", + "review" ] }, { @@ -153,7 +154,8 @@ "avatar_url": "https://avatars.githubusercontent.com/u/99343616?v=4", "profile": "https://github.com/aalvesan", "contributions": [ - "code" + "code", + "review" ] }, { diff --git a/README.md b/README.md index 5c6368a59..43abbe91b 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ For a better overview of the tasks that are triggered by the commands below, che Daniel Savoiu
Daniel Savoiu

💻 👀 pkausw
pkausw

💻 👀 nprouvost
nprouvost

💻 ⚠️ - Bogdan-Wiederspan
Bogdan-Wiederspan

💻 ⚠️ + Bogdan-Wiederspan
Bogdan-Wiederspan

💻 ⚠️ 👀 Tobias Kramer
Tobias Kramer

💻 👀 @@ -151,7 +151,7 @@ For a better overview of the tasks that are triggered by the commands below, che JulesVandenbroeck
JulesVandenbroeck

💻 - Ana Andrade
Ana Andrade

💻 + Ana Andrade
Ana Andrade

💻 👀 philippgadow
philippgadow

💻 Lukas Schaller
Lukas Schaller

💻 diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index 06307d72f..d2db0c3aa 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -30,7 +30,7 @@ default_dataset: st_tchannel_t_4f_powheg calibration_modules: columnflow.calibration.cms.{jets,met,tau}, __cf_module_name__.calibration.example selection_modules: columnflow.selection.empty, columnflow.selection.cms.{json_filter,met_filters}, __cf_module_name__.selection.example reduction_modules: columnflow.reduction.default, __cf_module_name__.reduction.example -production_modules: columnflow.production.{categories,matching,normalization,processes}, columnflow.production.cms.{btag,electron,jet,matching,mc_weight,muon,pdf,pileup,scale,parton_shower,seeds}, __cf_module_name__.production.example +production_modules: columnflow.production.{categories,matching,normalization,processes}, columnflow.production.cms.{btag,electron,jet,matching,mc_weight,muon,pdf,pileup,scale,parton_shower,seeds,gen_particles}, __cf_module_name__.production.example categorization_modules: __cf_module_name__.categorization.example hist_production_modules: columnflow.histogramming.default, __cf_module_name__.histogramming.example ml_modules: columnflow.ml, __cf_module_name__.ml.example diff --git a/columnflow/calibration/cms/egamma.py b/columnflow/calibration/cms/egamma.py index fc31a289e..137735329 100644 --- a/columnflow/calibration/cms/egamma.py +++ b/columnflow/calibration/cms/egamma.py @@ -1,649 +1,245 @@ # coding: utf-8 """ -Egamma energy correction methods. -Source: https://twiki.cern.ch/twiki/bin/view/CMS/EgammSFandSSRun3#Scale_And_Smearings_Correctionli +CMS-specific calibrators applying electron and photon energy scale and smearing. + +1. Scale corrections are applied to data. +2. Resolution smearing is applied to simulation. +3. Both scale and resolution uncertainties are applied to simulation. + +Resources: + - https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3#Scale_And_Smearings_Correctionli + - https://egammapog.docs.cern.ch/Run3/SaS + - https://cms-analysis-corrections.docs.cern.ch/corrections_era/Run3-22CDSep23-Summer22-NanoAODv12/EGM/2025-10-22 """ from __future__ import annotations -import abc import functools +import dataclasses + import law -from dataclasses import dataclass, field from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, ak_copy, optional_column +from columnflow.columnar_util import set_ak_column, full_like from columnflow.types import Any ak = maybe_import("awkward") np = maybe_import("numpy") +logger = law.logger.get_logger(__name__) + # helper set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) -@dataclass +@dataclasses.dataclass class EGammaCorrectionConfig: - correction_set: str - value_type: str - uncertainty_type: str - compound: bool = False - corrector_kwargs: dict[str, Any] = field(default_factory=dict) - - -class egamma_scale_corrector(Calibrator): - - with_uncertainties = True - """Switch to control whether uncertainties are calculated.""" - - @property - @abc.abstractmethod - def source_field(self) -> str: - """Fields required for the current calibrator.""" - ... - - @abc.abstractmethod - def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFileTarget: - """Function to retrieve the correction file from the external files. - - :param external_files: File target containing the files as requested - in the current config instance under ``config_inst.x.external_files`` - """ - ... - - @abc.abstractmethod - def get_scale_config(self) -> EGammaCorrectionConfig: - """Function to retrieve the configuration for the photon energy correction.""" - ... - - def call_func(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Apply energy corrections to EGamma objects in the events array. There are two types of implementations: standard - and Et dependent. - For Run2 the standard implementation is used, while for Run3 the Et dependent is recommended by the EGammaPog: - https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3?rev=41 - The Et dependendent recipe follows the example given in: - https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/66f581d0549e8d67fc55420d8bba15c9369fff7c/examples/egmScaleAndSmearingExample.py - - Requires an external file in the config under ``electron_ss``. Example: - - .. code-block:: python - - cfg.x.external_files = DotDict.wrap({ - "electron_ss": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-120c4271/POG/EGM/2022_Summer22//electronSS_EtDependent.json.gz", # noqa - }) - - The pairs of correction set, value and uncertainty type names, and if a compound method is used should be configured using the :py:class:`EGammaCorrectionConfig` as an - auxiliary entry in the config: - - .. code-block:: python - - cfg.x.eec = EGammaCorrectionConfig( - correction_set="EGMScale_Compound_Ele_2022preEE", - value_type="scale", - uncertainty_type="escale", - compound=True, - ) - - Derivatives of this base class require additional member variables and functions: - - - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` - or `Photon`). - - *get_correction_file*: Function to retrieve the correction file, e.g.from - the list, of external files in the current `config_inst`. - - *get_scale_config*: Function to retrieve the configuration for the energy correction. - This config must be an instance of - :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. - - If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. The - correction tool only supports flat arrays, so inputs are converted to a flat numpy view - first. Corrections are always applied to the raw pt, which is important if more than one - correction is applied in a row. The final corrections must be applied to the current pt. - - If :py:attr:`with_uncertainties` is set to `True`, the scale uncertainties are calculated. - The scale uncertainties are only available for simulated data. - - :param events: The events array containing EGamma objects. - :return: The events array with applied scale corrections. - - :notes: - - Varied corrections are only applied to Monte Carlo (MC) data. - - EGamma energy correction is only applied to real data. - - Changes are applied to the views and directly propagate to the original awkward - arrays. - """ - # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt - if "rawPt" not in events[self.source_field].fields: - events = set_ak_column_f32(events, f"{self.source_field}.rawPt", events[self.source_field].pt) - - # the correction tool only supports flat arrays, so convert inputs to flat np view first - # corrections are always applied to the raw pt - this is important if more than - # one correction is applied in a row - pt_eval = flat_np_view(events[self.source_field].rawPt, axis=1) - - # the final corrections must be applied to the current pt though - pt_application = flat_np_view(events[self.source_field].pt, axis=1) - - broadcasted_run = ak.broadcast_arrays(events[self.source_field].pt, events.run) - run = flat_np_view(broadcasted_run[1], axis=1) - gain = flat_np_view(events[self.source_field].seedGain, axis=1) - sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) - r9 = flat_np_view(events[self.source_field].r9, axis=1) - - # prepare arguments - # (energy is part of the LorentzVector behavior) - variable_map = { - "et": pt_eval, - "eta": sceta, - "gain": gain, - "r9": r9, - "run": run, - "seedGain": gain, - "pt": pt_eval, - "AbsScEta": np.abs(sceta), - "ScEta": sceta, - **self.scale_config.corrector_kwargs, - } - args = tuple( - variable_map[inp.name] for inp in self.scale_corrector.inputs - if inp.name in variable_map - ) - - # varied corrections are only applied to MC - if self.with_uncertainties and self.dataset_inst.is_mc: - scale_uncertainties = self.scale_corrector.evaluate(self.scale_config.uncertainty_type, *args) - scales_up = (1 + scale_uncertainties) - scales_down = (1 - scale_uncertainties) - - for (direction, scales) in [("up", scales_up), ("down", scales_down)]: - # copy pt and mass - pt_varied = ak_copy(events[self.source_field].pt) - pt_view = flat_np_view(pt_varied, axis=1) - - # apply the scale variation - pt_view *= scales - - # save columns - postfix = f"scale_{direction}" - events = set_ak_column_f32(events, f"{self.source_field}.pt_{postfix}", pt_varied) - - # apply the nominal correction - # note: changes are applied to the views and directly propagate to the original ak arrays - # and do not need to be inserted into the events chunk again - # EGamma energy correction is ONLY applied to DATA - if self.dataset_inst.is_data: - scales_nom = self.scale_corrector.evaluate(self.scale_config.value_type, *args) - pt_application *= scales_nom - - return events - - def init_func(self, **kwargs) -> None: - """Function to initialize the calibrator. - - Sets the required and produced columns for the calibrator. - """ - self.uses |= { - # nano columns - f"{self.source_field}.{{seedGain,pt,eta,phi,superclusterEta,r9}}", - "run", - optional_column(f"{self.source_field}.rawPt"), - } - self.produces |= { - f"{self.source_field}.pt", - optional_column(f"{self.source_field}.rawPt"), - } - - # if we do not calculate uncertainties, this module - # should only run on observed DATA - self.data_only = not self.with_uncertainties - - # add columns with unceratinties if requested - # photon scale _uncertainties_ are only available for MC - if self.with_uncertainties and self.dataset_inst.is_mc: - self.produces |= {f"{self.source_field}.pt_scale_{{up,down}}"} - - def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: - """Function to add necessary requirements. - - This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` - task to the requirements. - - :param reqs: Dictionary of requirements. - """ - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(task) - - def setup_func( - self, - task: law.Task, - reqs: dict[str, DotDict[str, Any]], - inputs: dict[str, Any], - reader_targets: law.util.InsertableDict, - **kwargs, - ) -> None: - """Setup function before event chunk loop. - - This function loads the correction file and sets up the correction tool. - Additionally, the *scale_config* is retrieved. - - :param reqs: Dictionary with resolved requirements. - :param inputs: Dictionary with inputs (not used). - :param reader_targets: Dictionary for optional additional columns to load - """ - self.scale_config = self.get_scale_config() - # create the egamma corrector - corr_file = self.get_correction_file(reqs["external_files"].files) - # init and extend the correction set - corr_set = load_correction_set(corr_file) - if self.scale_config.compound: - corr_set = corr_set.compound - self.scale_corrector = corr_set[self.scale_config.correction_set] - - -class egamma_resolution_corrector(Calibrator): - - with_uncertainties = True - """Switch to control whether uncertainties are calculated.""" - - # smearing of the energy resolution is only applied to MC - mc_only = True - """This calibrator is only applied to simulated data.""" - - deterministic_seed_index = -1 - """ use deterministic seeds for random smearing and - take the "index"-th random number per seed when not -1 """ + Container class to describe energy scaling and smearing configurations. Example: - @property - @abc.abstractmethod - def source_field(self) -> str: - """Fields required for the current calibrator.""" - ... - - @abc.abstractmethod - def get_correction_file(self, external_files: law.FileTargetCollection) -> law.LocalFile: - """Function to retrieve the correction file from the external files. - - :param external_files: File target containing the files as requested - in the current config instance under ``config_inst.x.external_files`` - """ - ... - - @abc.abstractmethod - def get_resolution_config(self) -> EGammaCorrectionConfig: - """Function to retrieve the configuration for the photon energy correction.""" - ... - - def call_func(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Apply energy resolution corrections to EGamma objects in the events array. - - There are two types of implementations: standard and Et dependent. For Run2 the standard - implementation is used, while for Run3 the Et dependent is recommended by the EGammaPog: - https://twiki.cern.ch/twiki/bin/viewauth/CMS/EgammSFandSSRun3?rev=41 The Et dependendent - recipe follows the example given in: - https://gitlab.cern.ch/cms-nanoAOD/jsonpog-integration/-/blob/66f581d0549e8d67fc55420d8bba15c9369fff7c/examples/egmScaleAndSmearingExample.py - - Requires an external file in the config under ``electron_ss``. Example: - - .. code-block:: python - - cfg.x.external_files = DotDict.wrap({ - "electron_ss": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-120c4271/POG/EGM/2022_Summer22/electronSS_EtDependent.json.gz", # noqa - }) - - The pairs of correction set, value and uncertainty type names, and if a compound method is used should be configured using the :py:class:`EGammaCorrectionConfig` as an - auxiliary entry in the config: - - .. code-block:: python - - cfg.x.eec = EGammaCorrectionConfig( - correction_set="EGMSmearAndSyst_ElePTsplit_2022preEE", - value_type="smear", - uncertainty_type="esmear", - ) - - Derivatives of this base class require additional member variables and functions: - - - *source_field*: The field name of the EGamma objects in the events array (i.e. `Electron` or `Photon`). - - *get_correction_file*: Function to retrieve the correction file, e.g. - from the list of external files in the current `config_inst`. - - *get_resolution_config*: Function to retrieve the configuration for the energy resolution correction. - This config must be an instance of :py:class:`~columnflow.calibration.cms.egamma.EGammaCorrectionConfig`. - - If no raw pt (i.e., pt before any corrections) is available, use the nominal pt. - The correction tool only supports flat arrays, so inputs are converted to a flat numpy view first. - Corrections are always applied to the raw pt, which is important if more than one correction is applied in a - row. The final corrections must be applied to the current pt. - - If :py:attr:`with_uncertainties` is set to `True`, the resolution uncertainties are calculated. - - If :py:attr:`deterministic_seed_index` is set to a value greater than or equal to 0, deterministic seeds - are used for random smearing. The "index"-th random number per seed is taken for the nominal resolution - correction. The "index+1"-th random number per seed is taken for the up variation and the "index+2"-th random - number per seed is taken for the down variation. - - :param events: The events array containing EGamma objects. - :return: The events array with applied resolution corrections. - - :notes: - - Energy resolution correction are only to be applied to simulation. - - Changes are applied to the views and directly propagate to the original awkward arrays. - """ - - # if no raw pt (i.e. pt for any corrections) is available, use the nominal pt - if "rawPt" not in events[self.source_field].fields: - events = set_ak_column_f32(events, f"{self.source_field}.rawPt", ak_copy(events[self.source_field].pt)) - - # the correction tool only supports flat arrays, so convert inputs to flat np view first - sceta = flat_np_view(events[self.source_field].superclusterEta, axis=1) - r9 = flat_np_view(events[self.source_field].r9, axis=1) - flat_seeds = flat_np_view(events[self.source_field].deterministic_seed, axis=1) - pt = flat_np_view(events[self.source_field].rawPt, axis=1) - - # prepare arguments - variable_map = { - "AbsScEta": np.abs(sceta), - "ScEta": sceta, # 2024 version - "eta": sceta, - "r9": r9, - "pt": pt, - **self.resolution_cfg.corrector_kwargs, - } + .. code-block:: python - args = tuple( - variable_map[inp.name] - for inp in self.resolution_corrector.inputs - if inp.name in variable_map + cfg.x.ess = EGammaCorrectionConfig( + scale_correction_set="Scale", + scale_compound=True, + smear_syst_correction_set="SmearAndSyst", + systs=["scale_down", "scale_up", "smear_down", "smear_up"], ) - - # calculate the smearing scale - # as mentioned in the example above, allows us to apply them directly to the MC simulation. - rho = self.resolution_corrector.evaluate(self.resolution_cfg.value_type, *args) - - # varied corrections - if self.with_uncertainties and self.dataset_inst.is_mc: - rho_unc = self.resolution_corrector.evaluate(self.resolution_cfg.uncertainty_type, *args) - random_normal_number = functools.partial(ak_random, 0, 1) - smearing_func = lambda rng_array, variation: rng_array * variation + 1 - - smearing_up = ( - smearing_func( - random_normal_number(flat_seeds, rand_func=self.deterministic_normal_up), - rho + rho_unc, - ) - if self.deterministic_seed_index >= 0 - else smearing_func( - random_normal_number(rand_func=np.random.Generator(np.random.SFC64(events.event.to_list())).normal), - rho + rho_unc, - ) - ) - - smearing_down = ( - smearing_func( - random_normal_number(flat_seeds, rand_func=self.deterministic_normal_down), - rho - rho_unc, - ) - if self.deterministic_seed_index >= 0 - else smearing_func( - random_normal_number(rand_func=np.random.Generator(np.random.SFC64(events.event.to_list())).normal), - rho - rho_unc, - ) - ) - - for (direction, smear) in [("up", smearing_up), ("down", smearing_down)]: - # copy pt and mass - pt_varied = ak_copy(events[self.source_field].pt) - pt_view = flat_np_view(pt_varied, axis=1) - - # apply the scale variation - # cast ak to numpy array for convenient usage of *= - pt_view *= smear.to_numpy() - - # save columns - postfix = f"res_{direction}" - events = set_ak_column_f32(events, f"{self.source_field}.pt_{postfix}", pt_varied) - - # apply the nominal correction - # note: changes are applied to the views and directly propagate to the original ak arrays - # and do not need to be inserted into the events chunk again - # EGamma energy resolution correction is ONLY applied to MC - if self.dataset_inst.is_mc: - smearing = ( - ak_random(1, rho, flat_seeds, rand_func=self.deterministic_normal) - if self.deterministic_seed_index >= 0 - else ak_random(1, rho, rand_func=np.random.Generator( - np.random.SFC64(events.event.to_list())).normal, - ) - ) - # the final corrections must be applied to the current pt though - pt = flat_np_view(events[self.source_field].pt, axis=1) - pt *= smearing.to_numpy() - - return events - - def init_func(self, **kwargs) -> None: - """Function to initialize the calibrator. - - Sets the required and produced columns for the calibrator. - """ - self.uses |= { - # nano columns - f"{self.source_field}.{{pt,eta,phi,superclusterEta,r9}}", - optional_column(f"{self.source_field}.rawPt"), - } - self.produces |= { - f"{self.source_field}.pt", - optional_column(f"{self.source_field}.rawPt"), - } - - # add columns with unceratinties if requested - if self.with_uncertainties and self.dataset_inst.is_mc: - self.produces |= {f"{self.source_field}.pt_res_{{up,down}}"} - - def requires_func(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: - """Function to add necessary requirements. - - This function add the :py:class:`~columnflow.tasks.external.BundleExternalFiles` - task to the requirements. - - :param reqs: Dictionary of requirements. - """ - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(task) - - def setup_func( - self, - task: law.Task, - reqs: dict[str, DotDict[str, Any]], - inputs: dict[str, Any], - reader_targets: law.util.InsertableDict, - **kwargs, - ) -> None: - """Setup function before event chunk loop. - - This function loads the correction file and sets up the correction tool. - Additionally, the *resolution_config* is retrieved. - If :py:attr:`deterministic_seed_index` is set to a value greater than or equal to 0, - random generator based on object-specific random seeds are setup. - - :param reqs: Dictionary with resolved requirements. - :param inputs: Dictionary with inputs (not used). - :param reader_targets: Dictionary for optional additional columns to load - (not used). - """ - self.resolution_cfg = self.get_resolution_config() - # create the egamma corrector - corr_file = self.get_correction_file(reqs["external_files"].files) - corr_set = load_correction_set(corr_file) - if self.resolution_cfg.compound: - corr_set = corr_set.compound - self.resolution_corrector = corr_set[self.resolution_cfg.correction_set] - - # use deterministic seeds for random smearing if requested - if self.deterministic_seed_index >= 0: - idx = self.deterministic_seed_index - bit_generator = np.random.SFC64 - - def deterministic_normal(loc, scale, seed, idx_offset=0): - return np.asarray([ - np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1 + idx_offset)[-1] - for _loc, _scale, _seed in zip(loc, scale, seed) - ]) - self.deterministic_normal = functools.partial(deterministic_normal, idx_offset=0) - self.deterministic_normal_up = functools.partial(deterministic_normal, idx_offset=1) - self.deterministic_normal_down = functools.partial(deterministic_normal, idx_offset=2) - - -pec = egamma_scale_corrector.derive( - "pec", cls_dict={ - "source_field": "Photon", - "with_uncertainties": True, - "get_correction_file": (lambda self, external_files: external_files.photon_ss), - "get_scale_config": (lambda self: self.config_inst.x.pec), - }, -) - -per = egamma_resolution_corrector.derive( - "per", cls_dict={ - "source_field": "Photon", - "with_uncertainties": True, - # function to determine the correction file - "get_correction_file": (lambda self, external_files: external_files.photon_ss), - # function to determine the tec config - "get_resolution_config": (lambda self: self.config_inst.x.per), - }, -) + """ + scale_correction_set: str + smear_syst_correction_set: str + scale_compound: bool = False + smear_syst_compound: bool = False + systs: list[str] = dataclasses.field(default_factory=list) + corrector_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) @calibrator( - uses={per, pec}, - produces={per, pec}, + exposed=False, + # used and produced columns are defined dynamically in init function with_uncertainties=True, - get_correction_file=None, - get_scale_config=None, - get_resolution_config=None, - deterministic_seed_index=-1, + collection_name=None, # to be set in derived classes to "Electron" or "Photon" + get_scale_smear_config=None, # to be set in derived classes + get_correction_file=None, # to be set in derived classes + deterministic_seed_index=-1, # use deterministic seeds for random smearing when >=0 + store_original=False, # if original columns (pt, energyErr) should be stored as "*_uncorrected" ) -def photons(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Calibrator for photons. This calibrator runs the energy scale and resolution calibrators - for photons. - - Careful! Always apply resolution before scale corrections for MC. - """ +def _egamma_scale_smear(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: + # gather inputs + coll = events[self.collection_name] + variable_map = { + "run": events.run, + "pt": coll.pt, + "ScEta": coll.superclusterEta, + "r9": coll.r9, + "seedGain": coll.seedGain, + **self.cfg.corrector_kwargs, + } + def get_inputs(corrector, **additional_variables): + _variable_map = variable_map | additional_variables + return (_variable_map[inp.name] for inp in corrector.inputs if inp.name in _variable_map) + + # apply scale correction to data + if self.dataset_inst.is_data: + # store uncorrected values before correcting + if self.store_original: + events = set_ak_column(events, f"{self.collection_name}.pt_scale_uncorrected", coll.pt) + events = set_ak_column(events, f"{self.collection_name}.energyErr_scale_uncorrected", coll.energyErr) + + # get scaled pt + scale = self.scale_corrector.evaluate("scale", *get_inputs(self.scale_corrector)) + pt_scaled = coll.pt * scale + + # get scaled energy error + smear = self.smear_syst_corrector.evaluate("smear", *get_inputs(self.smear_syst_corrector, pt=pt_scaled)) + energy_err_scaled = (((coll.energyErr)**2 + (coll.energy * smear)**2) * scale)**0.5 + + # store columns + events = set_ak_column_f32(events, f"{self.collection_name}.pt", pt_scaled) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr", energy_err_scaled) + + # apply smearing to MC if self.dataset_inst.is_mc: - events = self[per](events, **kwargs) - - if self.with_uncertainties or self.dataset_inst.is_data: - events = self[pec](events, **kwargs) + # store uncorrected values before correcting + if self.store_original: + events = set_ak_column(events, f"{self.collection_name}.pt_smear_uncorrected", coll.pt) + events = set_ak_column(events, f"{self.collection_name}.energyErr_smear_uncorrected", coll.energyErr) + + # helper to compute random variables in the shape of the collection + def get_rnd(syst): + args = (full_like(coll.pt, 0.0), full_like(coll.pt, 1.0)) + if self.use_deterministic_seeds: + args += (coll.deterministic_seed,) + rand_func = self.deterministic_normal[syst] + else: + # TODO: bit generator could be configurable + rand_func = np.random.Generator(np.random.SFC64((events.event + sum(map(ord, syst))).to_list())).normal + return ak_random(*args, rand_func=rand_func) + + # helper to compute smeared pt and energy error values given a syst + def apply_smearing(syst): + # get smeared pt + smear = self.smear_syst_corrector.evaluate(syst, *get_inputs(self.smear_syst_corrector)) + smear_factor = 1.0 + smear * get_rnd(syst) + pt_smeared = coll.pt * smear_factor + # get smeared energy error + energy_err_smeared = (((coll.energyErr)**2 + (coll.energy * smear)**2) * smear_factor)**0.5 + # return both + return pt_smeared, energy_err_smeared + + # compute and store columns + pt_smeared, energy_err_smeared = apply_smearing("smear") + events = set_ak_column_f32(events, f"{self.collection_name}.pt", pt_smeared) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr", energy_err_smeared) + + # apply scale and smearing uncertainties to MC + if self.with_uncertainties and self.cfg.systs: + for syst in self.cfg.systs: + # exact behavior depends on syst itself + if syst in {"scale_up", "scale_down"}: + # compute scale with smeared pt and apply muliplicatively to smeared values + scale = self.smear_syst_corrector.evaluate(syst, *get_inputs(self.smear_syst_corrector, pt=pt_smeared)) # noqa: E501 + events = set_ak_column_f32(events, f"{self.collection_name}.pt_{syst}", pt_smeared * scale) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr_{syst}", energy_err_smeared * scale) # noqa: E501 + + elif syst in {"smear_up", "smear_down"}: + # compute smearing variations on original variables with same method as above + pt_smeared_syst, energy_err_smeared_syst = apply_smearing(syst) + events = set_ak_column_f32(events, f"{self.collection_name}.pt_{syst}", pt_smeared_syst) + events = set_ak_column_f32(events, f"{self.collection_name}.energyErr_{syst}", energy_err_smeared_syst) # noqa: E501 + + else: + logger.error(f"{self.cls_name} calibrator received unknown systematic '{syst}', skipping") return events -@photons.pre_init -def photons_pre_init(self, **kwargs) -> None: - # forward argument to the producers - if pec not in self.deps_kwargs: - self.deps_kwargs[pec] = dict() - if per not in self.deps_kwargs: - self.deps_kwargs[per] = dict() - self.deps_kwargs[pec]["with_uncertainties"] = self.with_uncertainties - self.deps_kwargs[per]["with_uncertainties"] = self.with_uncertainties - - self.deps_kwargs[per]["deterministic_seed_index"] = self.deterministic_seed_index - if self.get_correction_file is not None: - self.deps_kwargs[pec]["get_correction_file"] = self.get_correction_file - self.deps_kwargs[per]["get_correction_file"] = self.get_correction_file - - if self.get_resolution_config is not None: - self.deps_kwargs[per]["get_resolution_config"] = self.get_resolution_config - if self.get_scale_config is not None: - self.deps_kwargs[pec]["get_scale_config"] = self.get_scale_config - - -photons_nominal = photons.derive("photons_nominal", cls_dict={"with_uncertainties": False}) - - -eer = egamma_resolution_corrector.derive( - "eer", cls_dict={ - "source_field": "Electron", - # calculation of superclusterEta for electrons requires the deltaEtaSC - "uses": {"Electron.deltaEtaSC"}, - "with_uncertainties": True, - # function to determine the correction file - "get_correction_file": (lambda self, external_files: external_files.electron_ss), - # function to determine the tec config - "get_resolution_config": (lambda self: self.config_inst.x.eer), - }, -) +@_egamma_scale_smear.init +def _egamma_scale_smear_init(self: Calibrator, **kwargs) -> None: + # store the config + self.cfg = self.get_scale_smear_config() + + # update used columns + self.uses |= {"run", f"{self.collection_name}.{{pt,eta,phi,mass,energyErr,superclusterEta,r9,seedGain}}"} + + # update produced columns + if self.dataset_inst.is_data: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}"} + if self.store_original: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}_scale_uncorrected"} + else: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}"} + if self.store_original: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}_smear_uncorrected"} + if self.with_uncertainties: + for syst in self.cfg.systs: + self.produces |= {f"{self.collection_name}.{{pt,energyErr}}_{syst}"} + + +@_egamma_scale_smear.requires +def _egamma_scale_smear_requires(self, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs) -> None: + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +@_egamma_scale_smear.setup +def _egamma_scale_smear_setup( + self, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: + # get and load the correction file + corr_file = self.get_correction_file(reqs["external_files"].files) + corr_set = load_correction_set(corr_file) + + # setup the correctors + get_set = lambda set_name, compound: (corr_set.compound if compound else corr_set)[set_name] + self.scale_corrector = get_set(self.cfg.scale_correction_set, self.cfg.scale_compound) + self.smear_syst_corrector = get_set(self.cfg.smear_syst_correction_set, self.cfg.smear_syst_compound) + + # use deterministic seeds for random smearing if requested + self.use_deterministic_seeds = self.deterministic_seed_index >= 0 + if self.use_deterministic_seeds: + idx = self.deterministic_seed_index + bit_generator = np.random.SFC64 + + def _deterministic_normal(loc, scale, seed, idx_offset=0): + return np.asarray([ + np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1 + idx_offset)[-1] + for _loc, _scale, _seed in zip(loc, scale, seed) + ]) + + self.deterministic_normal = { + "smear": functools.partial(_deterministic_normal, idx_offset=0), + "smear_up": functools.partial(_deterministic_normal, idx_offset=1), + "smear_down": functools.partial(_deterministic_normal, idx_offset=2), + } + -eec = egamma_scale_corrector.derive( - "eec", cls_dict={ - "source_field": "Electron", - # calculation of superclusterEta for electrons requires the deltaEtaSC - "uses": {"Electron.deltaEtaSC"}, - "with_uncertainties": True, - "get_correction_file": (lambda self, external_files: external_files.electron_ss), - "get_scale_config": (lambda self: self.config_inst.x.eec), +electron_scale_smear = _egamma_scale_smear.derive( + "electron_scale_smear", + cls_dict={ + "collection_name": "Electron", + "get_scale_smear_config": lambda self: self.config_inst.x.ess, + "get_correction_file": lambda self, external_files: external_files.electron_ss, }, ) - -@calibrator( - uses={eer, eec}, - produces={eer, eec}, - with_uncertainties=True, - get_correction_file=None, - get_scale_config=None, - get_resolution_config=None, - deterministic_seed_index=-1, +photon_scale_smear = _egamma_scale_smear.derive( + "photon_scale_smear", + cls_dict={ + "collection_name": "Photon", + "get_scale_smear_config": lambda self: self.config_inst.x.gss, + "get_correction_file": lambda self, external_files: external_files.photon_ss, + }, ) -def electrons(self, events: ak.Array, **kwargs) -> ak.Array: - """ - Calibrator for electrons. This calibrator runs the energy scale and resolution calibrators - for electrons. - - Careful! Always apply resolution before scale corrections for MC. - """ - if self.dataset_inst.is_mc: - events = self[eer](events, **kwargs) - - if self.with_uncertainties or self.dataset_inst.is_data: - events = self[eec](events, **kwargs) - - return events - - -@electrons.pre_init -def electrons_pre_init(self, **kwargs) -> None: - # forward argument to the producers - if eec not in self.deps_kwargs: - self.deps_kwargs[eec] = dict() - if eer not in self.deps_kwargs: - self.deps_kwargs[eer] = dict() - self.deps_kwargs[eec]["with_uncertainties"] = self.with_uncertainties - self.deps_kwargs[eer]["with_uncertainties"] = self.with_uncertainties - - self.deps_kwargs[eer]["deterministic_seed_index"] = self.deterministic_seed_index - if self.get_correction_file is not None: - self.deps_kwargs[eec]["get_correction_file"] = self.get_correction_file - self.deps_kwargs[eer]["get_correction_file"] = self.get_correction_file - - if self.get_resolution_config is not None: - self.deps_kwargs[eer]["get_resolution_config"] = self.get_resolution_config - if self.get_scale_config is not None: - self.deps_kwargs[eec]["get_scale_config"] = self.get_scale_config - - -electrons_nominal = photons.derive("electrons_nominal", cls_dict={"with_uncertainties": False}) diff --git a/columnflow/calibration/cms/tau.py b/columnflow/calibration/cms/tau.py index 4cd4e7081..69e5a6760 100644 --- a/columnflow/calibration/cms/tau.py +++ b/columnflow/calibration/cms/tau.py @@ -263,7 +263,7 @@ def tec_setup( self.tec_corrector = load_correction_set(tau_file)[self.tec_cfg.correction_set] # check versions - assert self.tec_corrector.version in [0, 1] + assert self.tec_corrector.version in {0, 1, 2} tec_nominal = tec.derive("tec_nominal", cls_dict={"with_uncertainties": False}) diff --git a/columnflow/cms_util.py b/columnflow/cms_util.py new file mode 100644 index 000000000..2e283009f --- /dev/null +++ b/columnflow/cms_util.py @@ -0,0 +1,201 @@ +# coding: utf-8 + +""" +Collection of CMS specific helpers and utilities. +""" + +from __future__ import annotations + +__all__ = [] + +import os +import re +import copy +import pathlib +import dataclasses + +from columnflow.types import ClassVar, Generator + + +#: Default root path to CAT metadata. +cat_metadata_root = "/cvmfs/cms-griddata.cern.ch/cat/metadata" + + +@dataclasses.dataclass +class CATSnapshot: + """ + Dataclass to wrap YYYY-MM-DD stype timestamps of CAT metadata per POG stored in + "/cvmfs/cms-griddata.cern.ch/cat/metadata". No format parsing or validation is done, leaving responsibility to the + user. + """ + btv: str = "" + dc: str = "" + egm: str = "" + jme: str = "" + lum: str = "" + muo: str = "" + tau: str = "" + + def items(self) -> Generator[tuple[str, str], None, None]: + return ((k, getattr(self, k)) for k in self.__dataclass_fields__.keys()) + + +@dataclasses.dataclass +class CATInfo: + """ + Dataclass to describe and wrap information about a specific CAT-defined metadata era. + + .. code-block:: python + + CATInfo( + run=3, + era="22CDSep23-Summer22", + vnano=12, + snapshot=CATSnapshot( + btv="2025-08-20", + dc="2025-07-25", + egm="2025-04-15", + jme="2025-09-23", + lum="2024-01-31", + muo="2025-08-14", + tau="2025-10-01", + ), + # pog-specific settings + pog_directories={"dc": "Collisions22"}, + ) + """ + run: int + era: str + vnano: int + snapshot: CATSnapshot + # optional POG-specific overrides + pog_eras: dict[str, str] = dataclasses.field(default_factory=dict) + pog_directories: dict[str, str] = dataclasses.field(default_factory=dict) + + metadata_root: ClassVar[str] = cat_metadata_root + + def get_era_directory(self, pog: str = "") -> str: + """ + Returns the era directory name for a given *pog*. + + :param pog: The POG to get the era for. Leave empty if the common POG-unspecific directory name should be used. + """ + pog = pog.lower() + + # use specific directory if defined + if pog in self.pog_directories: + return self.pog_directories[pog] + + # build common directory name from run, era, and vnano + era = self.pog_eras.get(pog.lower(), self.era) if pog else self.era + return f"Run{self.run}-{era}-NanoAODv{self.vnano}" + + def get_file(self, pog: str, *paths: str | pathlib.Path) -> str: + """ + Returns the full path to a specific file or directory defined by *paths* in the CAT metadata structure for a + given *pog*. + """ + return os.path.join( + self.metadata_root, + pog.upper(), + self.get_era_directory(pog), + getattr(self.snapshot, pog.lower()), + *(str(p).strip("/") for p in paths), + ) + + +@dataclasses.dataclass +class CMSDatasetInfo: + """ + Container to wrap a CMS dataset given by its *key* with access to its components. The key should be in the format + ``//--/AOD``. + + .. code-block:: python + + d = CMSDatasetInfo.from_key("/TTtoLNu2Q_TuneCP5_13p6TeV_powheg-pythia8/RunIII2024Summer24MiniAODv6-150X_mcRun3_2024_realistic_v2-v2/MINIAODSIM") # noqa + print(d.name) # TTtoLNu2Q_TuneCP5_13p6TeV_powheg-pythia8 + print(d.campaign) # RunIII2024Summer24MiniAODv6 + print(d.campaign_version) # 150X_mcRun3_2024_realistic_v2 + print(d.dataset_version) # v2 + print(d.tier) # mini (lower case) + print(d.mc) # True + print(d.data) # False + print(d.kind) # mc + """ + name: str + campaign: str + campaign_version: str + dataset_version: str # this is usually the GT for MC + tier: str + mc: bool + + @classmethod + def from_key(cls, key: str) -> CMSDatasetInfo: + """ + Takes a dataset *key*, splits it into its components, and returns a new :py:class:`CMSDatasetInfo` instance. + + :param key: The dataset key: + :return: A new instance of :py:class:`CMSDatasetInfo`. + """ + # split + if not (m := re.match(r"^/([^/]+)/([^/-]+)-([^/-]+)-([^/-]+)/([^/-]+)AOD(SIM)?$", key)): + raise ValueError(f"invalid dataset key '{key}'") + + # create instance + return cls( + name=m.group(1), + campaign=m.group(2), + campaign_version=m.group(3), + dataset_version=m.group(4), + tier=m.group(5).lower(), + mc=m.group(6) == "SIM", + ) + + @property + def key(self) -> str: + # transform back to key format + return ( + f"/{self.name}" + f"/{self.campaign}-{self.campaign_version}-{self.dataset_version}" + f"/{self.tier.upper()}AOD{'SIM' if self.mc else ''}" + ) + + @property + def data(self) -> bool: + return not bool(self.mc) + + @data.setter + def data(self, value: bool) -> None: + self.mc = not bool(value) + + @property + def kind(self) -> str: + return "mc" if self.mc else "data" + + @kind.setter + def kind(self, value: str) -> None: + if (_value := str(value).lower()) not in {"mc", "data"}: + raise ValueError(f"invalid kind '{value}', expected 'mc' or 'data'") + self.mc = _value == "mc" + + @property + def store_path(self) -> str: + return ( + "/store" + f"/{self.kind}" + f"/{self.campaign}" + f"/{self.name}" + f"/{self.tier.upper()}AOD{'SIM' if self.mc else ''}" + f"/{self.campaign_version}-{self.dataset_version}" + ) + + def copy(self, **kwargs) -> CMSDatasetInfo: + """ + Creates a copy of this instance, allowing to override specific attributes via *kwargs*. + + :param kwargs: Attributes to override in the copy. + :return: A new instance of :py:class:`CMSDatasetInfo`. + """ + attrs = copy.deepcopy(self.__dict__) + attrs.update(kwargs) + return self.__class__(**attrs) diff --git a/columnflow/config_util.py b/columnflow/config_util.py index 3875ea502..0958e0ec7 100644 --- a/columnflow/config_util.py +++ b/columnflow/config_util.py @@ -333,16 +333,27 @@ def get_shift_from_configs(configs: list[od.Config], shift: str | od.Shift, sile def get_shifts_from_sources(config: od.Config, *shift_sources: Sequence[str]) -> list[od.Shift]: """ - Takes a *config* object and returns a list of shift instances for both directions given a - sequence *shift_sources*. + Takes a *config* object and returns a list of shift instances for both directions given a sequence of + *shift_sources*. Each source should be the name of a shift source (no direction suffix) or a pattern. + + :param config: :py:class:`order.Config` object from which to retrieve the shifts. + :param shift_sources: Sequence of shift source names or patterns. + :return: List of :py:class:`order.Shift` instances obtained from the given sources. """ - return sum( - ( - [config.get_shift(f"{s}_{od.Shift.UP}"), config.get_shift(f"{s}_{od.Shift.DOWN}")] - for s in shift_sources - ), - [], - ) + # since each passed source can be a pattern, all existing sources need to be checked + # however, the order should be preserved, so loop through each pattern and check for matching sources + existing_sources = {shift.source for shift in config.shifts} + found_sources = set() + shifts = [] + for pattern in shift_sources: + for source in existing_sources: + if source not in found_sources and law.util.multi_match(source, pattern): + found_sources.add(source) + shifts += [ + config.get_shift(f"{source}_{od.Shift.UP}"), + config.get_shift(f"{source}_{od.Shift.DOWN}"), + ] + return shifts def group_shifts( diff --git a/columnflow/hist_util.py b/columnflow/hist_util.py index f579c0af5..1a82c8617 100644 --- a/columnflow/hist_util.py +++ b/columnflow/hist_util.py @@ -14,7 +14,7 @@ from columnflow.columnar_util import flat_np_view from columnflow.util import maybe_import -from columnflow.types import TYPE_CHECKING, Any +from columnflow.types import TYPE_CHECKING, Any, Sequence np = maybe_import("numpy") ak = maybe_import("awkward") @@ -306,3 +306,37 @@ def add_missing_shifts( h.fill(*dummy_fill, weight=0) # TODO: this might skip overflow and underflow bins h[{str_axis: hist.loc(missing_shift)}] = nominal.view() + + +def sum_hists(hists: Sequence[hist.Hist]) -> hist.Hist: + """ + Sums a sequence of histograms into a new histogram. In case axis labels differ, which typically leads to errors + ("axes not mergable"), the labels of the first histogram are used. + + :param hists: The histograms to sum. + :return: The summed histogram. + """ + hists = list(hists) + if not hists: + raise ValueError("no histograms given for summation") + + # copy the first histogram + h_sum = hists[0].copy() + if len(hists) == 1: + return h_sum + + # store labels of first histogram + axis_labels = {ax.name: ax.label for ax in h_sum.axes} + + for h in hists[1:]: + # align axis labels if needed, only copy if necessary + h_aligned_labels = None + for ax in h.axes: + if ax.name not in axis_labels or ax.label == axis_labels[ax.name]: + continue + if h_aligned_labels is None: + h_aligned_labels = h.copy() + h_aligned_labels.axes[ax.name].label = axis_labels[ax.name] + h_sum = h_sum + (h if h_aligned_labels is None else h_aligned_labels) + + return h_sum diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index bca00f5fd..394960c6a 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -13,6 +13,7 @@ from columnflow import __version__ as cf_version from columnflow.inference import InferenceModel, ParameterType, ParameterTransformation, FlowStrategy +from columnflow.hist_util import sum_hists from columnflow.util import DotDict, maybe_import, real_path, ensure_dir, safe_div, maybe_int from columnflow.types import TYPE_CHECKING, Sequence, Any, Union, Hashable @@ -616,7 +617,7 @@ def fill_empty(cat_obj, h): continue # helper to sum over them for a given shift key and an optional fallback - def sum_hists(key: Hashable, fallback_key: Hashable | None = None) -> hist.Hist: + def get_hist_sum(key: Hashable, fallback_key: Hashable | None = None) -> hist.Hist: def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: if key in hd: return hd[key] @@ -625,7 +626,7 @@ def get(hd: dict[Hashable, hist.Hist]) -> hist.Hist: raise Exception( f"'{key}' shape for process '{proc_name}' in category '{cat_name}' misconfigured: {hd}", ) - return sum(map(get, hists[1:]), get(hists[0]).copy()) + return sum_hists(map(get, hists)) # helper to extract sum of hists, apply scale, handle flow and fill empty bins def load( @@ -634,7 +635,7 @@ def load( fallback_key: Hashable | None = None, scale: float = 1.0, ) -> hist.Hist: - h = sum_hists(hist_key, fallback_key) * scale + h = get_hist_sum(hist_key, fallback_key) * scale handle_flow(cat_obj, h, hist_name) fill_empty(cat_obj, h) return h @@ -826,7 +827,7 @@ def load( if not h_data: proc_str = ",".join(map(str, cat_obj.data_from_processes)) raise Exception(f"none of requested processes '{proc_str}' found to create fake data") - h_data = sum(h_data[1:], h_data[0].copy()) + h_data = sum_hists(h_data) data_name = data_pattern.format(category=cat_name) fill_empty(cat_obj, h_data) handle_flow(cat_obj, h_data, data_name) @@ -845,7 +846,7 @@ def load( h_data.append(proc_hists["data"][config_name]["nominal"]) # simply save the data histogram that was already built from the requested datasets - h_data = sum(h_data[1:], h_data[0].copy()) + h_data = sum_hists(h_data) data_name = data_pattern.format(category=cat_name) handle_flow(cat_obj, h_data, data_name) out_file[data_name] = h_data diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index 3a424bf30..ef93d9566 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -365,13 +365,17 @@ def plot_all( rax = None grid_spec = {"left": 0.15, "right": 0.95, "top": 0.95, "bottom": 0.1} grid_spec |= style_config.get("gridspec_cfg", {}) + + # Get figure size from style_config, with default values + subplots_cfg = style_config.get("subplots_cfg", {}) + if not skip_ratio: grid_spec = {"height_ratios": [3, 1], "hspace": 0, **grid_spec} - fig, axs = plt.subplots(2, 1, gridspec_kw=grid_spec, sharex=True) + fig, axs = plt.subplots(2, 1, gridspec_kw=grid_spec, sharex=True, **subplots_cfg) (ax, rax) = axs else: grid_spec.pop("height_ratios", None) - fig, ax = plt.subplots(gridspec_kw=grid_spec) + fig, ax = plt.subplots(gridspec_kw=grid_spec, **subplots_cfg) axs = (ax,) # invoke all plots methods diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index 69e26562e..34b6d02a7 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -30,7 +30,7 @@ remove_negative_contributions, join_labels, ) -from columnflow.hist_util import add_missing_shifts +from columnflow.hist_util import add_missing_shifts, sum_hists from columnflow.types import TYPE_CHECKING, Iterable np = maybe_import("numpy") @@ -76,7 +76,7 @@ def plot_variable_stack( if len(shift_insts) == 1: # when there is exactly one shift bin, we can remove the shift axis - hists = remove_residual_axis(hists, "shift", select_value=shift_insts[0].name) + hists = remove_residual_axis(hists, "shift") else: # remove shift axis of histograms that are not to be stacked unstacked_hists = { @@ -265,7 +265,7 @@ def plot_shifted_variable( add_missing_shifts(h, all_shifts, str_axis="shift", nominal_bin="nominal") # create the sum of histograms over all processes - h_sum = sum(list(hists.values())[1:], list(hists.values())[0].copy()) + h_sum = sum_hists(hists.values()) # setup plotting configs plot_config = {} diff --git a/columnflow/plotting/plot_functions_2d.py b/columnflow/plotting/plot_functions_2d.py index c731c4822..2009586fe 100644 --- a/columnflow/plotting/plot_functions_2d.py +++ b/columnflow/plotting/plot_functions_2d.py @@ -16,6 +16,7 @@ import order as od from columnflow.util import maybe_import +from columnflow.hist_util import sum_hists from columnflow.plotting.plot_util import ( remove_residual_axis, apply_variable_settings, @@ -81,7 +82,7 @@ def plot_2d( extremes = "color" # add all processes into 1 histogram - h_sum = sum(list(hists.values())[1:], list(hists.values())[0].copy()) + h_sum = sum_hists(hists.values()) if shape_norm: h_sum = h_sum / h_sum.sum().value diff --git a/columnflow/plotting/plot_util.py b/columnflow/plotting/plot_util.py index 3c892c974..c680cc46a 100644 --- a/columnflow/plotting/plot_util.py +++ b/columnflow/plotting/plot_util.py @@ -18,8 +18,8 @@ import order as od import scinum as sn -from columnflow.util import maybe_import, try_int, try_complex, UNSET -from columnflow.hist_util import copy_axis +from columnflow.util import maybe_import, try_int, try_complex, safe_div, UNSET +from columnflow.hist_util import copy_axis, sum_hists from columnflow.types import TYPE_CHECKING, Iterable, Any, Callable, Sequence, Hashable np = maybe_import("numpy") @@ -225,7 +225,7 @@ def get_stack_integral() -> float: if scale_factor == "stack": # compute the scale factor and round h_no_shift = remove_residual_axis_single(h, "shift", select_value="nominal") - scale_factor = round_dynamic(get_stack_integral() / h_no_shift.sum().value) or 1 + scale_factor = round_dynamic(safe_div(get_stack_integral(), h_no_shift.sum().value)) or 1 if try_int(scale_factor): scale_factor = int(scale_factor) hists[proc_inst] = h * scale_factor @@ -571,9 +571,9 @@ def prepare_stack_plot_config( h_data, h_mc, h_mc_stack = None, None, None if data_hists: - h_data = sum(data_hists[1:], data_hists[0].copy()) + h_data = sum_hists(data_hists) if mc_hists: - h_mc = sum(mc_hists[1:], mc_hists[0].copy()) + h_mc = sum_hists(mc_hists) h_mc_stack = hist.Stack(*mc_hists) # setup plotting configs diff --git a/columnflow/production/cms/dy.py b/columnflow/production/cms/dy.py index 46201d28d..9e618c007 100644 --- a/columnflow/production/cms/dy.py +++ b/columnflow/production/cms/dy.py @@ -6,9 +6,9 @@ from __future__ import annotations -import law +import dataclasses -from dataclasses import dataclass +import law from columnflow.production import Producer, producer from columnflow.util import maybe_import, load_correction_set @@ -21,14 +21,23 @@ logger = law.logger.get_logger(__name__) -@dataclass +@dataclasses.dataclass class DrellYanConfig: + # era, e.g. "2022preEE" era: str + # correction set name correction: str + # uncertainty correction set name unc_correction: str | None = None + # generator order order: str | None = None - njets: bool = False + # list of systematics to be considered systs: list[str] | None = None + # functions to get the number of jets and b-tagged jets from the events in case they should be used as inputs + get_njets: callable[["dy_weights", ak.Array], ak.Array] | None = None + get_nbtags: callable[["dy_weights", ak.Array], ak.Array] | None = None + # additional columns to be loaded, e.g. as needed for njets or nbtags + used_columns: set = dataclasses.field(default_factory=set) def __post_init__(self) -> None: if not self.era or not self.correction: @@ -135,7 +144,8 @@ def dy_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: *get_dy_weight_file* can be adapted in a subclass in case it is stored differently in the external files. - The campaign era and name of the correction set (see link above) should be given as an auxiliary entry in the config: + The analysis config should contain an auxiliary entry *dy_weight_config* pointing to a :py:class:`DrellYanConfig` + object: .. code-block:: python @@ -157,8 +167,12 @@ def dy_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: # optionals if self.dy_config.order: variable_map["order"] = self.dy_config.order - if self.dy_config.njets: - variable_map["njets"] = ak.num(events.Jet, axis=1) + if callable(self.dy_config.get_njets): + variable_map["njets"] = self.dy_config.get_njets(self, events) + if callable(self.dy_config.get_nbtags): + variable_map["nbtags"] = self.dy_config.get_nbtags(self, events) + # for compatibility + variable_map["ntags"] = variable_map["nbtags"] # initializing the list of weight variations (called syst in the dy files) systs = [("nom", "")] @@ -193,10 +207,12 @@ def dy_weights_init(self: Producer) -> None: f"campaign year {self.config_inst.campaign.x.year} is not yet supported by {self.cls_name}", ) - # declare additional used columns + # get the dy weight config self.dy_config: DrellYanConfig = self.get_dy_weight_config() - if self.dy_config.njets: - self.uses.add("Jet.pt") + + # declare additional used columns + if self.dy_config.used_columns: + self.uses.update(self.dy_config.used_columns) # declare additional produced columns if self.dy_config.unc_correction: diff --git a/columnflow/production/cms/gen_particles.py b/columnflow/production/cms/gen_particles.py new file mode 100644 index 000000000..294af1a81 --- /dev/null +++ b/columnflow/production/cms/gen_particles.py @@ -0,0 +1,359 @@ +# coding: utf-8 + +""" +Producers that determine the generator-level particles and bring them into a structured format. This is most likely +useful for generator studies and truth definitions of physics objects. +""" + +from __future__ import annotations + +import law + +from columnflow.production import Producer, producer +from columnflow.columnar_util import set_ak_column +from columnflow.util import UNSET, maybe_import + +np = maybe_import("numpy") +ak = maybe_import("awkward") + + +logger = law.logger.get_logger(__name__) + +_keep_gen_part_fields = ["pt", "eta", "phi", "mass", "pdgId"] + + +# helper to transform generator particles by dropping / adding fields +def transform_gen_part(gen_parts: ak.Array, *, depth_limit: int, optional: bool = False) -> ak.Array: + # reduce down to relevant fields + arr = {} + for f in _keep_gen_part_fields: + if optional: + if (v := getattr(gen_parts, f, UNSET)) is not UNSET: + arr[f] = v + else: + arr[f] = getattr(gen_parts, f) + arr = ak.zip(arr, depth_limit=depth_limit) + + # remove parameters and add Lorentz vector behavior + arr = ak.without_parameters(arr) + arr = ak.with_name(arr, "PtEtaPhiMLorentzVector") + + return arr + + +@producer( + uses={ + "GenPart.{genPartIdxMother,status,statusFlags}", # required by the gen particle identification + f"GenPart.{{{','.join(_keep_gen_part_fields)}}}", # additional fields that should be read and added to gen_top + }, + produces={"gen_top.*.*"}, +) +def gen_top_lookup(self: Producer, events: ak.Array, strict: bool = True, **kwargs) -> ak.Array: + """ + Creates a new ragged column "gen_top" containing information about generator-level top quarks and their decay + products in a structured array with the following fields: + + - ``t``: list of all top quarks in the event, sorted such that top quarks precede anti-top quarks + - ``b``: list of bottom quarks from top quark decays, consistent ordering w.r.t. ``t`` (note that, in rare + cases, the decay into charm or down quarks is realized, and therefore stored in this field) + - ``w``: list of W bosons from top quark decays, consistent ordering w.r.t. ``t`` + - ``w_children``: list of W boson decay products, consistent ordering w.r.t. ``w``, the first entry is the + down-type quark or charged lepton, the second entry is the up-type quark or neutrino, and additional decay + products (e.g photons) are appended afterwards + - ``w_tau_children``: list of decay products from tau lepton decays stemming from W boson decays, however, + skipping the W boson from the tau lepton decay itself; the first entry is the tau neutrino, the second and + third entries are either the charged lepton and neutrino, or quarks or hadrons sorted by ascending absolute + pdg id; additional decay products (e.g photons) are appended afterwards + """ + # helper to extract unique values + unique_set = lambda a: set(np.unique(ak.flatten(a, axis=None))) + + # find hard top quarks + t = events.GenPart[abs(events.GenPart.pdgId) == 6] + t = t[t.hasFlags("isLastCopy")] # they are either fromHardProcess _or_ isLastCopy + + # sort them so that that top quarks come before anti-top quarks + t = t[ak.argsort(t.pdgId, axis=1, ascending=False)] + + # distinct top quark children + # (asking for isLastCopy leads to some tops that miss children, usually b's) + t_children = ak.drop_none(t.distinctChildren[t.distinctChildren.hasFlags("fromHardProcess", "isFirstCopy")]) + + # strict mode: check that there are exactly two children that are b and w + if strict: + if (tcn := unique_set(ak.num(t_children, axis=2))) != {2}: + raise Exception(f"found top quarks that have != 2 children: {tcn - {2}}") + if (tci := unique_set(abs(t_children.pdgId))) - {1, 3, 5, 24}: + raise Exception(f"found top quark children with unexpected pdgIds: {tci - {1, 3, 5, 24}}") + + # store b's (or s/d) and w's + abs_tc_ids = abs(t_children.pdgId) + b = ak.drop_none(ak.firsts(t_children[(abs_tc_ids == 1) | (abs_tc_ids == 3) | (abs_tc_ids == 5)], axis=2)) + w = ak.drop_none(ak.firsts(t_children[abs(t_children.pdgId) == 24], axis=2)) + + # distinct w children + w_children = ak.drop_none(w.distinctChildrenDeep) + + # distinguish into "hard" and additional ones + w_children_hard = w_children[(hard_mask := w_children.hasFlags("fromHardProcess"))] + w_children_rest = w_children[~hard_mask] + + # strict: check that there are exactly two hard children + if strict: + if (wcn := unique_set(ak.num(w_children_hard, axis=2))) != {2}: + raise Exception(f"found W bosons that have != 2 children: {wcn - {2}}") + + # sort them so that down-type quarks and charged leptons (odd pdgIds) come first, followed by up-type quarks and + # neutrinos (even pdgIds), then add back the remaining ones + w_children_hard = w_children_hard[ak.argsort(-(w_children_hard.pdgId % 2), axis=2)] + w_children = ak.concatenate([w_children_hard, w_children_rest], axis=2) + + # further distinguish tau decays in w_children + w_tau_children = ak.drop_none(w_children[abs(w_children.pdgId) == 15].distinctChildrenDeep) + # sort: nu tau first, photons last, rest in between sorted by ascending absolute pdgId + w_tau_nu_mask = abs(w_tau_children.pdgId) == 16 + w_tau_photon_mask = w_tau_children.pdgId == 22 + w_tau_rest = w_tau_children[~(w_tau_nu_mask | w_tau_photon_mask)] + w_tau_rest = w_tau_rest[ak.argsort(abs(w_tau_rest.pdgId), axis=3, ascending=True)] + w_tau_children = ak.concatenate( + [w_tau_children[w_tau_nu_mask], w_tau_rest, w_tau_children[w_tau_photon_mask]], + axis=3, + ) + + # zip into a single array with named fields + gen_top = ak.zip( + { + "t": transform_gen_part(t, depth_limit=2), + "b": transform_gen_part(b, depth_limit=2), + "w": transform_gen_part(w, depth_limit=2), + "w_children": transform_gen_part(w_children, depth_limit=3), + "w_tau_children": transform_gen_part(w_tau_children, depth_limit=4), + }, + depth_limit=1, + ) + + # save the column + events = set_ak_column(events, "gen_top", gen_top) + + return events + + +@producer( + uses={ + "GenPart.{genPartIdxMother,status,statusFlags}", # required by the gen particle identification + f"GenPart.{{{','.join(_keep_gen_part_fields)}}}", # additional fields that should be read and added to gen_top + }, + produces={"gen_higgs.*.*"}, +) +def gen_higgs_lookup(self: Producer, events: ak.Array, strict: bool = True, **kwargs) -> ak.Array: + """ + Creates a new ragged column "gen_higgs" containing information about generator-level Higgs bosons and their decay + products in a structured array with the following fields: + + - ``h``: list of all Higgs bosons in the event, sorted by the pdgId of their decay products such that Higgs + bosons decaying to quarks (b's) come first, followed by leptons, and then gauge bosons + - ``h_children``: list of direct Higgs boson children, consistent ordering w.r.t. ``h``, with the first entry + being the particle and the second one being the anti-particle; for Z bosons and (effective) gluons and + photons, no ordering is applied + - ``tau_children``: list of decay products from tau lepton decays coming from Higgs bosons, with the first entry + being the neutrino and the second one being the W boson + - ``tau_w_children``: list of the decay products from W boson decays from tau lepton decays, with the first + entry being the down-type quark or charged lepton, the second entry being the up-type quark or neutrino, and + additional decay products (e.g photons) are appended afterwards + - ``z_children``: not yet implemented + - ``w_children``: not yet implemented + """ + # helper to extract unique values + unique_set = lambda a: set(np.unique(ak.flatten(a, axis=None))) + + # find higgs + h = events.GenPart[events.GenPart.pdgId == 25] + h = h[h.hasFlags("fromHardProcess", "isLastCopy")] + + # sort them by increasing pdgId of their children (quarks, leptons, Z, W, effective gluons/photons) + h = h[ak.argsort(abs(ak.drop_none(ak.min(h.children.pdgId, axis=2))), axis=1, ascending=True)] + + # get distinct children + h_children = ak.drop_none(h.distinctChildren[h.distinctChildren.hasFlags("fromHardProcess", "isFirstCopy")]) + + # strict mode: check that there are exactly two children + if strict: + if (hcn := unique_set(ak.num(h_children, axis=2))) != {2}: + raise Exception(f"found Higgs bosons that have != 2 children: {hcn - {2}}") + + # sort them by decreasing pdgId + h_children = h_children[ak.argsort(h_children.pdgId, axis=2, ascending=False)] + # in strict mode, fix the children dimension to 2 + if strict: + h_children = h_children[:, :, [0, 1]] + + # further treatment of tau decays + tau_mask = h_children.pdgId[:, :, 0] == 15 + tau = ak.fill_none(h_children[ak.mask(tau_mask, tau_mask)], [], axis=1) + tau_children = tau.distinctChildrenDeep[tau.distinctChildrenDeep.hasFlags("isFirstCopy", "isTauDecayProduct")] + tau_children = ak.drop_none(tau_children) + # prepare neutrino and W boson handling + tau_nu_mask = abs(tau_children.pdgId) == 16 + tau_w_mask = abs(tau_children.pdgId) == 24 + tau_rest_mask = ~(tau_nu_mask | tau_w_mask) + tau_has_rest = ak.any(tau_rest_mask, axis=3) + # strict mode: there should always be a neutrino, and _either_ a W and nothing else _or_ no W at all + if strict: + if not ak.all(ak.any(tau_nu_mask[tau_mask], axis=3)): + raise Exception("found tau leptons without a tau neutrino among their children") + tau_has_w = ak.any(tau_w_mask, axis=3) + if not ak.all((tau_has_w ^ tau_has_rest)[tau_mask]): + raise Exception("found tau leptons with both W bosons and other decay products among their children") + # get the tau neutrino + tau_nu = tau_children[tau_nu_mask].sum(axis=3) + tau_nu = set_ak_column(tau_nu, "pdgId", ak.values_astype(16 * np.sign(tau.pdgId), np.int32)) + # get the W boson in case it is part of the tau children, otherwise build it from the sum of children + tau_w = tau_children[tau_w_mask].sum(axis=3) + if ak.any(tau_has_rest): + tau_w_rest = tau_children[tau_rest_mask].sum(axis=-1) + tau_w = ak.where(tau_has_rest, tau_w_rest, tau_w) + tau_w = set_ak_column(tau_w, "pdgId", ak.values_astype(-24 * np.sign(tau.pdgId), np.int32)) + # combine nu and w again + tau_nuw = ak.concatenate([tau_nu[..., None], tau_w[..., None]], axis=3) + # define w children + tau_w_children = ak.concatenate( + [tau_children[tau_rest_mask], ak.drop_none(ak.firsts(tau_children[tau_w_mask], axis=3).children)], + axis=2, + ) + + # children for decays other than taus are not yet implemented, so show a warning in case they are found + unhandled_ids = unique_set(abs(h_children.pdgId)) - set(range(1, 6 + 1)) - set(range(11, 16 + 1)) + if unhandled_ids: + logger.warning_once( + f"gen_higgs_undhandled_children_{'_'.join(map(str, sorted(unhandled_ids)))}", + f"found Higgs boson decays in the {self.cls_name} producer with pdgIds {unhandled_ids}, for which the " + "lookup of children is not yet implemented", + ) + + # zip into a single array with named fields + gen_higgs = ak.zip( + { + "h": transform_gen_part(h, depth_limit=2), + "h_children": transform_gen_part(h_children, depth_limit=3), + "tau_children": transform_gen_part(tau_nuw, depth_limit=4), + "tau_w_children": transform_gen_part(tau_w_children, depth_limit=4), + # "z_children": None, # not yet implemented + # "w_children": None, # not yet implemented + }, + depth_limit=1, + ) + + # save the column + events = set_ak_column(events, "gen_higgs", gen_higgs) + + return events + + +@producer( + uses={ + "GenPart.{genPartIdxMother,status,statusFlags}", # required by the gen particle identification + f"GenPart.{{{','.join(_keep_gen_part_fields)}}}", # additional fields that should be read and added to gen_top + }, + produces={"gen_dy.*.*"}, +) +def gen_dy_lookup(self: Producer, events: ak.Array, strict: bool = True, **kwargs) -> ak.Array: + """ + Creates a new ragged column "gen_dy" containing information about generator-level Z/g bosons and their decay + products in a structured array with the following fields: + + - ``z``: list of all Z/g bosons in the event, sorted by the pdgId of their decay products + - ``lep``: list of direct Z/g boson children, consistent ordering w.r.t. ``z``, with the first entry being the + lepton and the second one being the anti-lepton + - ``tau_children``: list of decay products from tau lepton decays coming from Z/g bosons, with the first entry + being the neutrino and the second one being the W boson + - ``tau_w_children``: list of the decay products from W boson decays from tau lepton decays, with the first + entry being the down-type quark or charged lepton, the second entry being the up-type quark or neutrino, and + additional decay products (e.g photons) are appended afterwards + """ + # note: in about 4% of DY events, the Z/g boson is missing, so this lookup starts at lepton level, see + # -> https://indico.cern.ch/event/1495537/contributions/6359516/attachments/3014424/5315938/HLepRare_25.02.14.pdf + # -> https://indico.cern.ch/event/1495537/contributions/6359516/attachments/3014424/5315938/HLepRare_25.02.14.pdf + + # helper to extract unique values + unique_set = lambda a: set(np.unique(ak.flatten(a, axis=None))) + + # get the e/mu and tau masks + abs_id = abs(events.GenPart.pdgId) + emu_mask = ( + ((abs_id == 11) | (abs_id == 13)) & + (events.GenPart.status == 1) & + events.GenPart.hasFlags("fromHardProcess") + ) + # taus need to have status == 2 + tau_mask = ( + (abs_id == 15) & + (events.GenPart.status == 2) & + events.GenPart.hasFlags("fromHardProcess") + ) + lep_mask = emu_mask | tau_mask + + # strict mode: there must be exactly two charged leptons per event + if strict: + if (nl := unique_set(ak.num(events.GenPart[lep_mask], axis=1))) - {2}: + raise Exception(f"found events that have != 2 charged leptons: {nl - {2}}") + + # get the leptons and sort by decreasing pdgId (lepton before anti-lepton) + lep = events.GenPart[lep_mask] + lep = lep[ak.argsort(lep.pdgId, axis=1, ascending=False)] + + # in strict mode, fix the lep dimension to 2 + if strict: + lep = lep[:, [0, 1]] + + # build the z from them + z = lep.sum(axis=-1) + z = set_ak_column(z, "pdgId", np.int32(23)) + + # further treatment of tau decays + tau = events.GenPart[tau_mask] + tau_children = tau.distinctChildren[tau.distinctChildren.hasFlags("isFirstCopy", "isTauDecayProduct")] + tau_children = ak.drop_none(tau_children) + # prepare neutrino and W boson handling + tau_nu_mask = abs(tau_children.pdgId) == 16 + tau_w_mask = abs(tau_children.pdgId) == 24 + tau_rest_mask = ~(tau_nu_mask | tau_w_mask) + tau_has_rest = ak.any(tau_rest_mask, axis=2) + # strict mode: there should always be a neutrino, and _either_ a W and nothing else _or_ no W at all + if strict: + if not ak.all(ak.any(tau_nu_mask, axis=2)): + raise Exception("found tau leptons without a tau neutrino among their children") + tau_has_w = ak.any(tau_w_mask, axis=2) + if not ak.all(tau_has_w ^ tau_has_rest): + raise Exception("found tau leptons with both W bosons and other decay products among their children") + # get the tau neutrino + tau_nu = tau_children[tau_nu_mask].sum(axis=2) + tau_nu = set_ak_column(tau_nu, "pdgId", ak.values_astype(16 * np.sign(tau.pdgId), np.int32)) + # get the W boson in case it is part of the tau children, otherwise build it from the sum of children + tau_w = tau_children[tau_w_mask].sum(axis=2) + if ak.any(tau_has_rest): + tau_w_rest = tau_children[tau_rest_mask].sum(axis=-1) + tau_w = ak.where(tau_has_rest, tau_w_rest, tau_w) + tau_w = set_ak_column(tau_w, "pdgId", ak.values_astype(-24 * np.sign(tau.pdgId), np.int32)) + # combine nu and w again + tau_nuw = ak.concatenate([tau_nu[..., None], tau_w[..., None]], axis=2) + # define w children + tau_w_children = ak.concatenate( + [tau_children[tau_rest_mask], ak.drop_none(ak.firsts(tau_children[tau_w_mask], axis=2).children)], + axis=1, + ) + + # zip into a single array with named fields + gen_dy = ak.zip( + { + "z": transform_gen_part(z, depth_limit=1), + "lep": transform_gen_part(lep, depth_limit=2), + "tau_children": transform_gen_part(tau_nuw, depth_limit=3), + "tau_w_children": transform_gen_part(tau_w_children, depth_limit=3), + }, + depth_limit=1, + ) + + # save the column + events = set_ak_column(events, "gen_dy", gen_dy) + + return events diff --git a/columnflow/production/cms/gen_top_decay.py b/columnflow/production/cms/gen_top_decay.py deleted file mode 100644 index 8e925aaa0..000000000 --- a/columnflow/production/cms/gen_top_decay.py +++ /dev/null @@ -1,90 +0,0 @@ -# coding: utf-8 - -""" -Producers that determine the generator-level particles related to a top quark decay. -""" - -from __future__ import annotations - -from columnflow.production import Producer, producer -from columnflow.util import maybe_import -from columnflow.columnar_util import set_ak_column - -ak = maybe_import("awkward") - - -@producer( - uses={"GenPart.{genPartIdxMother,pdgId,statusFlags}"}, - produces={"gen_top_decay"}, -) -def gen_top_decay_products(self: Producer, events: ak.Array, **kwargs) -> ak.Array: - """ - Creates a new ragged column "gen_top_decay" with one element per hard top quark. Each element is - a GenParticleArray with five or more objects in a distinct order: top quark, bottom quark, - W boson, down-type quark or charged lepton, up-type quark or neutrino, and any additional decay - produces of the W boson (if any, then most likly photon radiations). Per event, the structure - will be similar to: - - .. code-block:: python - - [ - # event 1 - [ - # top 1 - [t1, b1, W1, q1/l, q2/n(, additional_w_decay_products)], - # top 2 - [...], - ], - # event 2 - ... - ] - """ - # find hard top quarks - abs_id = abs(events.GenPart.pdgId) - t = events.GenPart[abs_id == 6] - t = t[t.hasFlags("isHardProcess")] - t = t[~ak.is_none(t, axis=1)] - - # distinct top quark children (b's and W's) - t_children = t.distinctChildrenDeep[t.distinctChildrenDeep.hasFlags("isHardProcess")] - - # get b's - b = t_children[abs(t_children.pdgId) == 5][:, :, 0] - - # get W's - w = t_children[abs(t_children.pdgId) == 24][:, :, 0] - - # distinct W children - w_children = w.distinctChildrenDeep[w.distinctChildrenDeep.hasFlags("isHardProcess")] - - # reorder the first two W children (leptons or quarks) so that the charged lepton / down-type - # quark is listed first (they have an odd pdgId) - w_children_firsttwo = w_children[:, :, :2] - w_children_firsttwo = w_children_firsttwo[(w_children_firsttwo.pdgId % 2 == 0) * 1] - w_children_rest = w_children[:, :, 2:] - - # concatenate to create the structure to return - groups = ak.concatenate( - [ - t[:, :, None], - b[:, :, None], - w[:, :, None], - w_children_firsttwo, - w_children_rest, - ], - axis=2, - ) - - # save the column - events = set_ak_column(events, "gen_top_decay", groups) - - return events - - -@gen_top_decay_products.skip -def gen_top_decay_products_skip(self: Producer, **kwargs) -> bool: - """ - Custom skip function that checks whether the dataset is a MC simulation containing top - quarks in the first place. - """ - return self.dataset_inst.is_data or not self.dataset_inst.has_tag("has_top") diff --git a/columnflow/production/cms/top_pt_weight.py b/columnflow/production/cms/top_pt_weight.py index bb1fb4c4e..8207414d2 100644 --- a/columnflow/production/cms/top_pt_weight.py +++ b/columnflow/production/cms/top_pt_weight.py @@ -6,13 +6,13 @@ from __future__ import annotations -from dataclasses import dataclass +import dataclasses import law from columnflow.production import Producer, producer from columnflow.util import maybe_import -from columnflow.columnar_util import set_ak_column +from columnflow.columnar_util import set_ak_column, full_like ak = maybe_import("awkward") np = maybe_import("numpy") @@ -21,134 +21,101 @@ logger = law.logger.get_logger(__name__) -@dataclass -class TopPtWeightConfig: - params: dict[str, float] - pt_max: float = 500.0 - - @classmethod - def new(cls, obj: TopPtWeightConfig | dict[str, float]) -> TopPtWeightConfig: - # backward compatibility only - if isinstance(obj, cls): - return obj - return cls(params=obj) - - -@producer( - uses={"GenPart.{pdgId,statusFlags}"}, - # requested GenPartonTop columns, passed to the *uses* and *produces* - produced_top_columns={"pt"}, - mc_only=True, - # skip the producer unless the datasets has this specified tag (no skip check performed when none) - require_dataset_tag="has_top", -) -def gen_parton_top(self: Producer, events: ak.Array, **kwargs) -> ak.Array: +@dataclasses.dataclass +class TopPtWeightFromDataConfig: """ - Produce parton-level top quarks (before showering and detector simulation). - Creates new collection named "GenPartonTop" - - *produced_top_columns* can be adapted to change the columns that will be produced - for the GenPartonTop collection. - - The function is skipped when the dataset is data or when it does not have the tag *has_top*. - - :param events: awkward array containing events to process + Container to configure the top pt reweighting parameters for the method based on fits to data. For more info, see + https://twiki.cern.ch/twiki/bin/viewauth/CMS/TopPtReweighting?rev=31#TOP_PAG_corrections_based_on_dat """ - # find parton-level top quarks - abs_id = abs(events.GenPart.pdgId) - t = events.GenPart[abs_id == 6] - t = t[t.hasFlags("isLastCopy")] - t = t[~ak.is_none(t, axis=1)] - - # save the column - events = set_ak_column(events, "GenPartonTop", t) - - return events - - -@gen_parton_top.init -def gen_parton_top_init(self: Producer, **kwargs) -> bool: - for col in self.produced_top_columns: - self.uses.add(f"GenPart.{col}") - self.produces.add(f"GenPartonTop.{col}") + params: dict[str, float] = dataclasses.field(default_factory=lambda: { + "a": 0.0615, + "a_up": 0.0615 * 1.5, + "a_down": 0.0615 * 0.5, + "b": -0.0005, + "b_up": -0.0005 * 1.5, + "b_down": -0.0005 * 0.5, + }) + pt_max: float = 500.0 -@gen_parton_top.skip -def gen_parton_top_skip(self: Producer, **kwargs) -> bool: +@dataclasses.dataclass +class TopPtWeightFromTheoryConfig: """ - Custom skip function that checks whether the dataset is a MC simulation containing top quarks in the first place - using the :py:attr:`require_dataset_tag` attribute. + Container to configure the top pt reweighting parameters for the theory-based method. For more info, see + https://twiki.cern.ch/twiki/bin/viewauth/CMS/TopPtReweighting?rev=31#TOP_PAG_corrections_based_on_the """ - # never skip if the tag is not set - if self.require_dataset_tag is None: - return False - - return self.dataset_inst.is_data or not self.dataset_inst.has_tag(self.require_dataset_tag) - - -def get_top_pt_weight_config(self: Producer) -> TopPtWeightConfig: - if self.config_inst.has_aux("top_pt_reweighting_params"): - logger.info_once( - "deprecated_top_pt_weight_config", - "the config aux field 'top_pt_reweighting_params' is deprecated and will be removed in " - "a future release, please use 'top_pt_weight' instead", + params: dict[str, float] = dataclasses.field(default_factory=lambda: { + "a": 0.103, + "b": -0.0118, + "c": -0.000134, + "d": 0.973, + }) + + +# for backward compatibility +class TopPtWeightConfig(TopPtWeightFromDataConfig): + + def __init__(self, *args, **kwargs): + logger.warning_once( + "TopPtWeightConfig is deprecated and will be removed in future versions, please use " + "TopPtWeightFromDataConfig instead to keep using the data-based method, or TopPtWeightFromTheoryConfig to " + "use the theory-based method", ) - params = self.config_inst.x.top_pt_reweighting_params - else: - params = self.config_inst.x.top_pt_weight - - return TopPtWeightConfig.new(params) + super().__init__(*args, **kwargs) @producer( - uses={"GenPartonTop.pt"}, + uses={"gen_top.t.pt"}, produces={"top_pt_weight{,_up,_down}"}, - get_top_pt_weight_config=get_top_pt_weight_config, - # skip the producer unless the datasets has this specified tag (no skip check performed when none) - require_dataset_tag="is_ttbar", + get_top_pt_weight_config=(lambda self: self.config_inst.x.top_pt_weight), ) def top_pt_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: - """ - Compute SF to be used for top pt reweighting. + r""" + Compute SF to be used for top pt reweighting, either with information from a fit to data or from theory. See https://twiki.cern.ch/twiki/bin/view/CMS/TopPtReweighting?rev=31 for more information. - The *GenPartonTop.pt* column can be produced with the :py:class:`gen_parton_top` Producer. The - SF should *only be applied in ttbar MC* as an event weight and is computed based on the - gen-level top quark transverse momenta. - - The top pt reweighting parameters should be given as an auxiliary entry in the config: + The method to be used depends on the config entry obtained with *get_top_pt_config* which should either be of + type :py:class:`TopPtWeightFromDataConfig` or :py:class:`TopPtWeightFromTheoryConfig`. - .. code-block:: python + - data-based: $SF(p_T)=e^{a + b \cdot p_T}$ + - theory-based: $SF(p_T)=a \cdot e^{b \cdot p_T} + c \cdot p_T + d$ - cfg.x.top_pt_reweighting_params = { - "a": 0.0615, - "a_up": 0.0615 * 1.5, - "a_down": 0.0615 * 0.5, - "b": -0.0005, - "b_up": -0.0005 * 1.5, - "b_down": -0.0005 * 0.5, - } + The *gen_top.t.pt* column can be produced with the :py:class:`gen_top_lookup` producer. The SF should *only be + applied in ttbar MC* as an event weight and is computed based on the gen-level top quark transverse momenta. + The top pt weight configuration should be given as an auxiliary entry "top_pt_weight" in the config. *get_top_pt_config* can be adapted in a subclass in case it is stored differently in the config. - - :param events: awkward array containing events to process """ # check the number of gen tops - if ak.any((n_tops := ak.num(events.GenPartonTop, axis=1)) != 2): + if ak.any((n_tops := ak.num(events.gen_top.t, axis=1)) != 2): raise Exception( - f"{self.cls_name} can only run on events with two generator top quarks, but found " - f"counts of {','.join(map(str, sorted(set(n_tops))))}", + f"{self.cls_name} can only run on events with two generator top quarks, but found counts of " + f"{','.join(map(str, sorted(set(n_tops))))}", ) - # clamp top pt - top_pt = events.GenPartonTop.pt - if self.cfg.pt_max >= 0.0: + # get top pt + top_pt = events.gen_top.t.pt + if not self.theory_method and self.cfg.pt_max >= 0.0: top_pt = ak.where(top_pt > self.cfg.pt_max, self.cfg.pt_max, top_pt) - for variation in ("", "_up", "_down"): - # evaluate SF function - sf = np.exp(self.cfg.params[f"a{variation}"] + self.cfg.params[f"b{variation}"] * top_pt) + for variation in ["", "_up", "_down"]: + # evaluate SF function, implementation is method dependent + if self.theory_method: + # up variation: apply twice the effect + # down variation: no weight at all + if variation != "_down": + sf = ( + self.cfg.params["a"] * np.exp(self.cfg.params["b"] * top_pt) + + self.cfg.params["c"] * top_pt + + self.cfg.params["d"] + ) + if variation == "_up": + sf = 1.0 + 2.0 * (sf - 1.0) + elif variation == "_down": + sf = full_like(top_pt, 1.0) + else: + sf = np.exp(self.cfg.params[f"a{variation}"] + self.cfg.params[f"b{variation}"] * top_pt) # compute weight from SF product for top and anti-top weight = np.sqrt(np.prod(sf, axis=1)) @@ -163,14 +130,9 @@ def top_pt_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: def top_pt_weight_init(self: Producer) -> None: # store the top pt weight config self.cfg = self.get_top_pt_weight_config() - - -@top_pt_weight.skip -def top_pt_weight_skip(self: Producer, **kwargs) -> bool: - """ - Skip if running on anything except ttbar MC simulation, evaluated via the :py:attr:`require_dataset_tag` attribute. - """ - if self.require_dataset_tag is None: - return self.dataset_inst.is_data - - return self.dataset_inst.is_data or not self.dataset_inst.has_tag("is_ttbar") + if not isinstance(self.cfg, (TopPtWeightFromDataConfig, TopPtWeightFromTheoryConfig)): + raise Exception( + f"{self.cls_name} expects the config entry obtained with get_top_pt_weight_config to be of type " + f"TopPtWeightFromDataConfig or TopPtWeightFromTheoryConfig, but got {type(self.cfg)}", + ) + self.theory_method = isinstance(self.cfg, TopPtWeightFromTheoryConfig) diff --git a/columnflow/production/util.py b/columnflow/production/util.py index 1df6d49f9..938876282 100644 --- a/columnflow/production/util.py +++ b/columnflow/production/util.py @@ -47,11 +47,14 @@ def attach_coffea_behavior( # general awkward array functions # -def ak_extract_fields(arr: ak.Array, fields: list[str], **kwargs): +def ak_extract_fields(arr: ak.Array, fields: list[str], optional_fields: list[str] | None = None, **kwargs): """ Build an array containing only certain `fields` of an input array `arr`, preserving behaviors. """ + if optional_fields is None: + optional_fields = [] + # reattach behavior if "behavior" not in kwargs: kwargs["behavior"] = arr.behavior @@ -60,6 +63,10 @@ def ak_extract_fields(arr: ak.Array, fields: list[str], **kwargs): { field: getattr(arr, field) for field in fields + } | { + field: getattr(arr, field) + for field in optional_fields + if field in arr.fields }, **kwargs, ) diff --git a/columnflow/tasks/cms/external.py b/columnflow/tasks/cms/external.py index 03eb98220..148b0bac5 100644 --- a/columnflow/tasks/cms/external.py +++ b/columnflow/tasks/cms/external.py @@ -6,6 +6,11 @@ from __future__ import annotations +__all__ = [] + +import os +import glob + import luigi import law @@ -20,6 +25,8 @@ class CreatePileupWeights(ConfigTask): + task_namespace = "cf.cms" + single_config = True data_mode = luigi.ChoiceParameter( @@ -162,3 +169,73 @@ def normalize_values(cls, values: Sequence[float]) -> list[float]: enable=["configs", "skip_configs"], attributes={"version": None}, ) + + +class CheckCATUpdates(ConfigTask, law.tasks.RunOnceTask): + """ + CMS specific task that checks for updates in the metadata managed and stored by the CAT group. See + https://cms-analysis-corrections.docs.cern.ch for more info. + + To function correctly, this task requires an auxiliary entry ``cat_info`` in the analysis config, pointing to a + :py:class:`columnflow.cms_util.CATInfo` instance that defines the era information and the current POG correction + timestamps. The task will then check in the CAT metadata structure if newer timestamps are available. + """ + + task_namespace = "cf.cms" + + version = None + + single_config = False + + def run(self): + # helpers to convert date strings to tuples for numeric comparisons + decode_date_str = lambda s: tuple(map(int, s.split("-"))) + + # loop through configs + for config_inst in self.config_insts: + with self.publish_step( + f"checking CAT metadata updates for config '{law.util.colored(config_inst.name, style='bright')}' in " + f"{config_inst.x.cat_info.metadata_root}", + ): + newest_dates = {} + updated_any = False + for pog, date_str in config_inst.x.cat_info.snapshot.items(): + if not date_str: + continue + + # get all versions in the cat directory, split by date numbers + pog_era_dir = os.path.join( + config_inst.x.cat_info.metadata_root, + pog.upper(), + config_inst.x.cat_info.get_era_directory(pog), + ) + if not os.path.isdir(pog_era_dir): + self.logger.warning(f"CAT metadata directory '{pog_era_dir}' does not exist, skipping") + continue + dates = [ + os.path.basename(path) + for path in glob.glob(os.path.join(pog_era_dir, "*-*-*")) + ] + if not dates: + raise ValueError(f"no CAT snapshots found in '{pog_era_dir}'") + + # compare with current date + latest_date_str = max(dates, key=decode_date_str) + if date_str == "latest" or decode_date_str(date_str) < decode_date_str(latest_date_str): + newest_dates[pog] = latest_date_str + updated_any = True + self.publish_message( + f"found newer {law.util.colored(pog.upper(), color='cyan')} snapshot: {date_str} -> " + f"{latest_date_str} ({os.path.join(pog_era_dir, latest_date_str)})", + ) + else: + newest_dates[pog] = date_str + + # print a new CATSnapshot line that can be copy-pasted into the config + if updated_any: + args_str = ", ".join(f"{pog}=\"{date_str}\"" for pog, date_str in newest_dates.items() if date_str) + self.publish_message( + f"{law.util.colored('new CATSnapshot line ->', style='bright')} CATSnapshot({args_str})\n", + ) + else: + self.publish_message("no updates found\n") diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index e88c41975..abf8ec2ec 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -130,7 +130,7 @@ def run(self): proc_objs.append(self.inference_model_inst.process_spec(name="data")) for proc_obj in proc_objs: # skip the process objects if it does not contribute to this config_inst - if config_inst.name not in proc_obj.config_data: + if config_inst.name not in proc_obj.config_data and proc_obj.name != "data": continue # get all process instances (keys in _input_hists) to be combined diff --git a/columnflow/tasks/external.py b/columnflow/tasks/external.py index 8f37ede77..33c4a4793 100644 --- a/columnflow/tasks/external.py +++ b/columnflow/tasks/external.py @@ -591,7 +591,12 @@ def fetch(src, dst): # copy local dir shutil.copytree(src, dst) else: - raise NotImplementedError(f"fetching {src} is not supported") + err = f"cannot fetch {src}" + if src.startswith("/") and os.path.isdir("/".join(src.split("/", 2)[:2])): + err += ", file or directory does not exist" + else: + err += ", resource type is not supported" + raise NotImplementedError(err) # helper function to fetch generic files def fetch_file(ext_file, counter=[0]): diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index 177ed8e84..0cf4e0d42 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -1270,6 +1270,17 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params["config_insts"] = [params["config_inst"]] else: if "config_insts" not in params and "configs" in params: + # custom pattern matching + matched_config_names = [] + for pattern in params["configs"]: + matched_config_names.extend( + config_name for config_name in analysis_inst.configs.names() + if law.util.multi_match(config_name, pattern) + ) + matched_config_names = law.util.make_unique(matched_config_names) + if matched_config_names: + params["configs"] = matched_config_names + # load config instances params["config_insts"] = list(map(analysis_inst.get_config, params["configs"])) # resolving of parameters that is required before ArrayFunctions etc. can be initialized diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index 54fd42d35..e4a4ce9e7 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -30,6 +30,7 @@ from columnflow.timing import Timer +np = maybe_import("numpy") ak = maybe_import("awkward") @@ -2634,18 +2635,25 @@ class ChunkedIOMixin(ConfigTask): @classmethod def raise_if_not_finite(cls, ak_array: ak.Array) -> None: """ - Checks whether all values in array *ak_array* are finite. + Checks whether values of all columns in *ak_array* are finite. String and bytestring types are skipped. The check is performed using the :external+numpy:py:func:`numpy.isfinite` function. - :param ak_array: Array with events to check. + :param ak_array: Array with columns to check. :raises ValueError: If any value in *ak_array* is not finite. """ - import numpy as np from columnflow.columnar_util import get_ak_routes for route in get_ak_routes(ak_array): - if ak.any(~np.isfinite(ak.flatten(route.apply(ak_array), axis=None))): + # flatten + flat = ak.flatten(route.apply(ak_array), axis=None) + # perform parameter dependent checks + if isinstance((params := getattr(getattr(flat, "layout", None), "parameters", None)), dict): + # skip string and bytestring arrays + if params.get("__array__") in {"string", "bytestring"}: + continue + # check finiteness + if ak.any(~np.isfinite(flat)): raise ValueError(f"found one or more non-finite values in column '{route.column}' of array {ak_array}") @classmethod diff --git a/columnflow/tasks/framework/remote.py b/columnflow/tasks/framework/remote.py index 8f3393ba4..fae1d3559 100644 --- a/columnflow/tasks/framework/remote.py +++ b/columnflow/tasks/framework/remote.py @@ -48,6 +48,10 @@ class BundleRepo(AnalysisTask, law.git.BundleGitRepository, law.tasks.TransferLo os.environ["CF_CONDA_BASE"], ] + include_files = [ + "law_user.cfg", + ] + def get_repo_path(self): # required by BundleGitRepository return os.environ["CF_REPO_BASE"] diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index 6fc0fc604..b78003302 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -21,6 +21,7 @@ from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.tasks.ml import MLEvaluation +from columnflow.hist_util import sum_hists from columnflow.util import dev_sandbox @@ -446,7 +447,7 @@ def run(self): # merge them variable_hists = [h[variable_name] for h in hists] - merged = sum(variable_hists[1:], variable_hists[0].copy()) + merged = sum_hists(variable_hists) # post-process the merged histogram merged = self.hist_producer_inst.run_post_process_merged_hist(merged, task=self) @@ -544,7 +545,7 @@ def run(self): ] # merge and write the output - merged = sum(variable_hists[1:], variable_hists[0].copy()) + merged = sum_hists(variable_hists) outp.dump(merged, formatter="pickle") diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index 6dd2acf3b..ba91c0492 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -266,15 +266,12 @@ def run(self): for process_inst in hists.keys(): h = hists[process_inst] # determine expected shifts from the intersection of requested shifts and those known for the process - # process_shifts = ( - # process_shift_map[process_inst.name] - # if process_inst.name in process_shift_map - # else {"nominal"} - # ) - - # change Ghent: replace all expected shifts with nominal. - # not preffered by columnflow: https://github.com/columnflow/columnflow/pull/692 - expected_shifts = plot_shift_names # & process_shifts + process_shifts = ( + process_shift_map[process_inst.name] + if process_inst.name in process_shift_map + else {"nominal"} + ) + expected_shifts = (process_shifts & plot_shift_names) or (process_shifts & {"nominal"}) if not expected_shifts: raise Exception(f"no shifts to plot found for process {process_inst.name}") # selections diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 08aadca45..5deef6bb8 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -213,12 +213,16 @@ def run(self): ) # invoke the reducer - if len(events): + if len(events) > 0: n_all += len(events) events = attach_coffea_behavior(events) events = self.reducer_inst(events, selection=sel, task=self) n_reduced += len(events) + # no need to proceed when no events are left + if len(events) == 0: + continue + # remove columns events = route_filter(events) diff --git a/modules/law b/modules/law index 44b98b7dc..3adec62db 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 44b98b7dcd434badd003fd498eaf399e14c3ee53 +Subproject commit 3adec62db42d1fe8021c792538fe66ee1ed77b91 From 7ab29f31a032e353a9b5752d4e267b3aae1e507e Mon Sep 17 00:00:00 2001 From: JulesVandenbroeck <93740577+JulesVandenbroeck@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:27:20 +0100 Subject: [PATCH 120/123] Get upstream changes (#114) * Extend dy weight application to use btag multiplicity. (#739) * Extend dy weight application to use btag multiplicity. * Update docstring. * Hotfix nbtags variable in dy weight producer. * fix skipping data in CreateDatacards * Add objects for interacting with CMS CAT meta data. (#740) * Add objects for interacting with CAT meta data. * Remove namespace for now. * Cleanup. * Update fixed law. * Use cf.cms task namespace. * Add CMSDatasetInfo. * Allow pathlib input. * Add dc pog to CATSnapshot. * More flexible POG overrides. * Typo. * Simplify. * Hotfix CAT metadata update check for missing POG dirs. * add subplots_cfg in plot_all (#742) Co-authored-by: Mathis Frahm * Update law. * Refactor generator-level top and top decay product lookup (#741) * Refactor gen top lookup. * Add theory-based top pt weight method. * Comments. * Comments. * Rename field wDecay -> wChildren. * Update kept fields in gen_particles.py Removed 'status' and 'statusFlags' from kept generator particle fields. * Fix gen part field transformations. * Add suggestion by @jolange * Add gen_higgs_lookup. * Hotfix saving of columns in gen_particle lookups. * Hotfix depth limit of gen particles. * Add gen_dy_lookup. * Hotfix multi-config lookup via patterns. * Hotfix reduction to skip empty chunks. * Hotfix higgs gen lookup, considering effective gluon/photon decays. * Hotfix single shift selection in plotting. * Allow patterns in get_shifts_from_sources. * Hotfix save_div in plot scale factor. * [cms] Update log in CheckCATUpdates task. * Skip string columns in finiteness checks, fixes #743. * Hotfix repo bunlding, add missing user config. * [cms] Refactor egamma calibrators. (#745) * docs: add Bogdan-Wiederspan as a contributor for review (#746) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> * docs: add aalvesan as a contributor for review (#747) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> * Add t->w->tau children in gen_top_lookup. * Hotfix typo in gen_top lookup. * Add and use sum_hists helper. * Extend tes versions. * [cms] Hotfix tau energy calibration, skip e-fake mask. * [cms] Hotfix egamma calibrator, use same random numbers for all smearing variations. * Add option to skip auto categories in track_category_changes. * Add n_chunks entry to ChunkPosition. * mutliple fixes regarding empty files or (almost) empty chunks (#750) * mutliple fixes regarding empty files or (almost) empty chunks * move chunk skip out of variable loop * add AbsScEta to variable_map for backwards compatibility * use last instead of first chunk for empty outputs * Fix broadcasting with empty egamma collection. --------- Co-authored-by: Mathis Frahm Co-authored-by: Marcel R. * Add simple column selection to UniteColumns. * Remove unneeded columns in cms tec calibrator. * Add variabble_repr to control paths. (#751) * Hotfix tec, add back charge. * Log broken parquet file paths. * Cleanup of e/mu id, update law. * Fix cf_inspect script after coffea update. (#753) * Hotfix electron weight producer with nested working points. * Hotfix attributes added by taf decorators. * Rename max-runtime -> {htcondor,slurm}-runtime. (#755) * Simplify requiring producers. (#756) * Simplify requiring producers. * Add same mechanism for calibrators. * Revert pilot decisions. * Add muon_sr calibrator. (#754) * Hotfix version resolution from config. * Hotfix required producers/calibrators for workflows. * Persistent local files of BundleExternalFiles. (#752) * Presistent local files of BundleExternalFiles. * Fix files_dir property. * Better caching. * Preserve types. * Ensure clean dir. * Allow unpacking in remote envs. * Pass-through workflow requirements in CreateHistograms. * Feature/histogram user multiconfig (#709) * make HistogramsUserBase compatible with multi-config * backwards compatibility to single-config * improve flexibility & runtime of helper functions * make shifts a set * add inputs as argument to load_histograms --------- Co-authored-by: Marcel Rieger Co-authored-by: Mathis Frahm * update hist axis labels during histogram merging (#705) * update labels during histogram merging * move update_ax_labels to hist_util.py * Linting --------- Co-authored-by: Marcel Rieger Co-authored-by: Mathis Frahm * Fix variance of fake data in datacard writer, better logs. * Update law. * Fix mamba setup. --------- Co-authored-by: Marcel Rieger Co-authored-by: Marcel R. Co-authored-by: Mathis Frahm Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com> Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> Co-authored-by: jomatthi <82223346+jomatthi@users.noreply.github.com> Co-authored-by: juvanden --- analysis_templates/cms_minimal/law.cfg | 3 +- bin/cf_inspect.py | 13 +- columnflow/calibration/__init__.py | 63 ++++-- columnflow/calibration/cms/egamma.py | 37 ++-- columnflow/calibration/cms/muon.py | 222 ++++++++++++++++++++ columnflow/calibration/cms/tau.py | 35 ++-- columnflow/columnar_util.py | 50 ++++- columnflow/config_util.py | 18 +- columnflow/hist_util.py | 26 +++ columnflow/histogramming/__init__.py | 13 +- columnflow/inference/__init__.py | 9 +- columnflow/inference/cms/datacard.py | 31 +-- columnflow/production/__init__.py | 44 +++- columnflow/production/cms/electron.py | 123 ++++++----- columnflow/production/cms/muon.py | 59 +++--- columnflow/production/normalization.py | 2 +- columnflow/reduction/__init__.py | 13 +- columnflow/selection/__init__.py | 32 +-- columnflow/tasks/calibration.py | 4 +- columnflow/tasks/cms/external.py | 2 +- columnflow/tasks/external.py | 256 +++++++++++++---------- columnflow/tasks/framework/base.py | 44 ++-- columnflow/tasks/framework/histograms.py | 146 ++++++++----- columnflow/tasks/framework/mixins.py | 1 - columnflow/tasks/framework/remote.py | 40 ++-- columnflow/tasks/histograms.py | 46 ++-- columnflow/tasks/inspection.py | 73 +++++-- columnflow/tasks/production.py | 4 +- columnflow/tasks/reduction.py | 13 +- columnflow/tasks/selection.py | 6 +- columnflow/tasks/union.py | 22 +- law.cfg | 2 +- modules/law | 2 +- setup.sh | 1 + 34 files changed, 1015 insertions(+), 440 deletions(-) create mode 100644 columnflow/calibration/cms/muon.py diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index d2db0c3aa..35a233f00 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -27,7 +27,7 @@ default_analysis: __cf_module_name__.config.analysis___cf_short_name_lc__.analys default_config: run2_2017_nano_v9 default_dataset: st_tchannel_t_4f_powheg -calibration_modules: columnflow.calibration.cms.{jets,met,tau}, __cf_module_name__.calibration.example +calibration_modules: columnflow.calibration.cms.{jets,met,tau,egamma,muon}, __cf_module_name__.calibration.example selection_modules: columnflow.selection.empty, columnflow.selection.cms.{json_filter,met_filters}, __cf_module_name__.selection.example reduction_modules: columnflow.reduction.default, __cf_module_name__.reduction.example production_modules: columnflow.production.{categories,matching,normalization,processes}, columnflow.production.cms.{btag,electron,jet,matching,mc_weight,muon,pdf,pileup,scale,parton_shower,seeds,gen_particles}, __cf_module_name__.production.example @@ -65,6 +65,7 @@ htcondor_flavor: $CF_HTCONDOR_FLAVOR htcondor_share_software: False htcondor_memory: -1 htcondor_disk: -1 +htcondor_runtime: 3h slurm_flavor: $CF_SLURM_FLAVOR slurm_partition: $CF_SLURM_PARTITION diff --git a/bin/cf_inspect.py b/bin/cf_inspect.py index f4b13742a..8e5465508 100644 --- a/bin/cf_inspect.py +++ b/bin/cf_inspect.py @@ -59,10 +59,13 @@ def _load_nano_root(fname: str, treepath: str | None = None, **kwargs) -> ak.Arr except: return uproot.open(fname) - -def _load_h5(fname: str, **kwargs): - import h5py - return h5py.File(fname, "r") + return coffea.nanoevents.NanoEventsFactory.from_root( + source, + treepath=treepath, + mode="eager", + runtime_cache=None, + persistent_cache=None, + ).events() def load(fname: str, **kwargs) -> Any: @@ -78,8 +81,6 @@ def load(fname: str, **kwargs) -> Any: return _load_nano_root(fname, **kwargs) if ext == ".json": return _load_json(fname, **kwargs) - if ext in [".h5", ".hdf5"]: - return _load_h5(fname, **kwargs) raise NotImplementedError(f"no loader implemented for extension '{ext}'") diff --git a/columnflow/calibration/__init__.py b/columnflow/calibration/__init__.py index 276e22c6d..f0ed046bf 100644 --- a/columnflow/calibration/__init__.py +++ b/columnflow/calibration/__init__.py @@ -8,18 +8,55 @@ import inspect -from columnflow.types import Callable +import law + from columnflow.util import DerivableMeta from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, Sequence, Any + + +class TaskArrayFunctionWithCalibratorRequirements(TaskArrayFunction): + + require_calibrators: Sequence[str] | set[str] | None = None + + def _req_calibrator(self, task: law.Task, calibrator: str) -> Any: + # hook to customize how required calibrators are requested + from columnflow.tasks.calibration import CalibrateEvents + return CalibrateEvents.req_other_calibrator(task, calibrator=calibrator) + def requires_func(self, task: law.Task, reqs: dict, **kwargs) -> None: + # no requirements for workflows in pilot mode + if callable(getattr(task, "is_workflow", None)) and task.is_workflow() and getattr(task, "pilot", False): + return -class Calibrator(TaskArrayFunction): + # add required calibrators when set + if (calibs := self.require_calibrators): + reqs["required_calibrators"] = {calib: self._req_calibrator(task, calib) for calib in calibs} + + def setup_func( + self, + task: law.Task, + reqs: dict, + inputs: dict, + reader_targets: law.util.InsertableDict, + **kwargs, + ) -> None: + if "required_calibrators" in inputs: + for calib, inp in inputs["required_calibrators"].items(): + reader_targets[f"required_calibrator_{calib}"] = inp["columns"] + + +class Calibrator(TaskArrayFunctionWithCalibratorRequirements): """ Base class for all calibrators. """ exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def calibrator( cls, @@ -27,25 +64,26 @@ def calibrator( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_calibrators: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ - Decorator for creating a new :py:class:`~.Calibrator` subclass with additional, optional - *bases* and attaching the decorated function to it as ``call_func``. + Decorator for creating a new :py:class:`~.Calibrator` subclass with additional, optional *bases* and attaching + the decorated function to it as ``call_func``. - When *mc_only* (*data_only*) is *True*, the calibrator is skipped and not considered by - other calibrators, selectors and producers in case they are evalauted on a - :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` - (``is_data``) attribute is *False*. + When *mc_only* (*data_only*) is *True*, the calibrator is skipped and not considered by other calibrators, + selectors and producers in case they are evalauted on a :py:class:`order.Dataset` (using the + :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Calibrator` class. :param bases: Additional bases for the new :py:class:`Calibrator`. - :param mc_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on - Monte Carlo simulation and skipped for real data. - :param data_only: Boolean flag indicating that this :py:class:`Calibrator` should only run - on real data and skipped for Monte Carlo simulation. + :param mc_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on Monte Carlo + simulation and skipped for real data. + :param data_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on real data and + skipped for Monte Carlo simulation. + :param require_calibrators: Sequence of names of other calibrators to add to the requirements. :return: New :py:class:`Calibrator` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -55,6 +93,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_calibrators": require_calibrators, } # get the module name diff --git a/columnflow/calibration/cms/egamma.py b/columnflow/calibration/cms/egamma.py index 137735329..54993bf01 100644 --- a/columnflow/calibration/cms/egamma.py +++ b/columnflow/calibration/cms/egamma.py @@ -23,7 +23,7 @@ from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import ak_random from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, full_like +from columnflow.columnar_util import TAFConfig, set_ak_column, full_like from columnflow.types import Any ak = maybe_import("awkward") @@ -37,7 +37,7 @@ @dataclasses.dataclass -class EGammaCorrectionConfig: +class EGammaCorrectionConfig(TAFConfig): """ Container class to describe energy scaling and smearing configurations. Example: @@ -54,7 +54,7 @@ class EGammaCorrectionConfig: smear_syst_correction_set: str scale_compound: bool = False smear_syst_compound: bool = False - systs: list[str] = dataclasses.field(default_factory=list) + systs: list[str] = dataclasses.field(default_factory=lambda: ["scale_down", "scale_up", "smear_down", "smear_up"]) corrector_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -72,9 +72,10 @@ def _egamma_scale_smear(self: Calibrator, events: ak.Array, **kwargs) -> ak.Arra # gather inputs coll = events[self.collection_name] variable_map = { - "run": events.run, + "run": events.run if ak.sum(ak.num(coll, axis=1), axis=0) else [], "pt": coll.pt, "ScEta": coll.superclusterEta, + "AbsScEta": abs(coll.superclusterEta), "r9": coll.r9, "seedGain": coll.seedGain, **self.cfg.corrector_kwargs, @@ -109,22 +110,21 @@ def get_inputs(corrector, **additional_variables): events = set_ak_column(events, f"{self.collection_name}.pt_smear_uncorrected", coll.pt) events = set_ak_column(events, f"{self.collection_name}.energyErr_smear_uncorrected", coll.energyErr) - # helper to compute random variables in the shape of the collection - def get_rnd(syst): - args = (full_like(coll.pt, 0.0), full_like(coll.pt, 1.0)) - if self.use_deterministic_seeds: - args += (coll.deterministic_seed,) - rand_func = self.deterministic_normal[syst] - else: - # TODO: bit generator could be configurable - rand_func = np.random.Generator(np.random.SFC64((events.event + sum(map(ord, syst))).to_list())).normal - return ak_random(*args, rand_func=rand_func) + # compute random variables in the shape of the collection once + rnd_args = (full_like(coll.pt, 0.0), full_like(coll.pt, 1.0)) + if self.use_deterministic_seeds: + rnd_args += (coll.deterministic_seed,) + rand_func = self.deterministic_normal + else: + # TODO: bit generator could be configurable + rand_func = np.random.Generator(np.random.SFC64((events.event).to_list())).normal + rnd = ak_random(*rnd_args, rand_func=rand_func) # helper to compute smeared pt and energy error values given a syst def apply_smearing(syst): # get smeared pt smear = self.smear_syst_corrector.evaluate(syst, *get_inputs(self.smear_syst_corrector)) - smear_factor = 1.0 + smear * get_rnd(syst) + smear_factor = 1.0 + smear * rnd pt_smeared = coll.pt * smear_factor # get smeared energy error energy_err_smeared = (((coll.energyErr)**2 + (coll.energy * smear)**2) * smear_factor)**0.5 @@ -219,11 +219,8 @@ def _deterministic_normal(loc, scale, seed, idx_offset=0): for _loc, _scale, _seed in zip(loc, scale, seed) ]) - self.deterministic_normal = { - "smear": functools.partial(_deterministic_normal, idx_offset=0), - "smear_up": functools.partial(_deterministic_normal, idx_offset=1), - "smear_down": functools.partial(_deterministic_normal, idx_offset=2), - } + # each systematic is to be evaluated with the same random number so use a fixed offset + self.deterministic_normal = functools.partial(_deterministic_normal, idx_offset=0) electron_scale_smear = _egamma_scale_smear.derive( diff --git a/columnflow/calibration/cms/muon.py b/columnflow/calibration/cms/muon.py new file mode 100644 index 000000000..d096c94f1 --- /dev/null +++ b/columnflow/calibration/cms/muon.py @@ -0,0 +1,222 @@ +# coding: utf-8 + +""" +Muon calibration methods. +""" + +from __future__ import annotations + +import functools +import dataclasses +import inspect + +import law + +from columnflow.calibration import Calibrator, calibrator +from columnflow.columnar_util import TAFConfig, set_ak_column, IF_MC +from columnflow.util import maybe_import, load_correction_set, import_file, DotDict +from columnflow.types import Any + +ak = maybe_import("awkward") +np = maybe_import("numpy") + + +logger = law.logger.get_logger(__name__) + +# helper +set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) + + +@dataclasses.dataclass +class MuonSRConfig(TAFConfig): + """ + Container class to configure muon momentum scale and resolution corrections. Example: + + .. code-block:: python + + cfg.x.muon_sr = MuonSRConfig( + systs=["scale_up", "scale_down", "res_up", "res_down"], + ) + """ + systs: list[str] = dataclasses.field(default_factory=lambda: ["scale_up", "scale_down", "res_up", "res_down"]) + + +@calibrator( + uses={ + "Muon.{pt,eta,phi,mass,charge}", + IF_MC("event", "luminosityBlock", "Muon.nTrackerLayers"), + }, + # uncertainty variations added in init + produces={"Muon.pt"}, + # whether to produce also uncertainties + with_uncertainties=True, + # functions to determine the correction and tool files + get_muon_sr_file=(lambda self, external_files: external_files.muon_sr), + get_muon_sr_tool_file=(lambda self, external_files: external_files.muon_sr_tools), + # function to determine the muon config + get_muon_sr_config=(lambda self: self.config_inst.x.muon_sr), + # if the original pt columns should be stored as "pt_sr_uncorrected" + store_original=False, +) +def muon_sr( + self: Calibrator, + events: ak.Array, + **kwargs, +) -> ak.Array: + """ + Calibrator for muon scale and resolution smearing. Requires two external file in the config under the ``muon_sr`` + and ``muon_sr_tools`` keys, pointing to the json correction file and the "MuonScaRe" tools script, respectively, + + .. code-block:: python + + cfg.x.external_files = DotDict.wrap({ + "muon_sr": "/cvmfs/cms-griddata.cern.ch/cat/metadata/MUO/Run3-22CDSep23-Summer22-NanoAODv12/2025-08-14/muon_scalesmearing.json.gz", # noqa + "muon_sr_tools": "/path/to/MuonScaRe.py", + }) + + and a :py:class:`MuonSRConfig` configuration object in the auxiliary field ``muon_sr``, + + .. code-block:: python + + from columnflow.calibration.cms.muon import MuonSRConfig + cfg.x.muon_sr = MuonSRConfig( + systs=["scale_up", "scale_down", "res_up", "res_down"], + ) + + *get_muon_sr_file*, *get_muon_sr_tool_file* and *get_muon_sr_config* can be adapted in a subclass in case they are + stored differently in the config. + + Resources: + + - https://gitlab.cern.ch/cms-muonPOG/muonscarekit + - https://cms-analysis-corrections.docs.cern.ch/corrections_era/Run3-22CDSep23-Summer22-NanoAODv12/MUO/latest/#muon_scalesmearingjsongz # noqa + """ + # store the original pt column if requested + if self.store_original: + events = set_ak_column(events, "Muon.pt_sr_uncorrected", events.Muon.pt) + + # apply scale correction to data + if self.dataset_inst.is_data: + pt_scale_corr = self.muon_sr_tools.pt_scale( + 1, + events.Muon.pt, + events.Muon.eta, + events.Muon.phi, + events.Muon.charge, + self.muon_correction_set, + nested=True, + ) + events = set_ak_column_f32(events, "Muon.pt", pt_scale_corr) + + # apply scale and resolution correction to mc + if self.dataset_inst.is_mc: + pt_scale_corr = self.muon_sr_tools.pt_scale( + 0, + events.Muon.pt, + events.Muon.eta, + events.Muon.phi, + events.Muon.charge, + self.muon_correction_set, + nested=True, + ) + pt_scale_res_corr = self.muon_sr_tools.pt_resol( + pt_scale_corr, + events.Muon.eta, + events.Muon.phi, + events.Muon.nTrackerLayers, + events.event, + events.luminosityBlock, + self.muon_correction_set, + rnd_gen="np", + nested=True, + ) + events = set_ak_column_f32(events, "Muon.pt", pt_scale_res_corr) + + # apply scale and resolution uncertainties to mc + if self.with_uncertainties and self.muon_cfg.systs: + for syst in self.muon_cfg.systs: + # the sr tools use up/dn naming + sr_direction = {"up": "up", "down": "dn"}[syst.rsplit("_", 1)[-1]] + + # exact behavior depends on syst itself + if syst in {"scale_up", "scale_down"}: + pt_syst = self.muon_sr_tools.pt_scale_var( + pt_scale_res_corr, + events.Muon.eta, + events.Muon.phi, + events.Muon.charge, + sr_direction, + self.muon_correction_set, + nested=True, + ) + events = set_ak_column_f32(events, f"Muon.pt_{syst}", pt_syst) + + elif syst in {"res_up", "res_down"}: + pt_syst = self.muon_sr_tools.pt_resol_var( + pt_scale_corr, + pt_scale_res_corr, + events.Muon.eta, + sr_direction, + self.muon_correction_set, + nested=True, + ) + events = set_ak_column_f32(events, f"Muon.pt_{syst}", pt_syst) + + else: + logger.error(f"{self.cls_name} calibrator received unknown systematic '{syst}', skipping") + + return events + + +@muon_sr.init +def muon_sr_init(self: Calibrator, **kwargs) -> None: + self.muon_cfg = self.get_muon_sr_config() + + # add produced columns with unceratinties if requested + if self.dataset_inst.is_mc and self.with_uncertainties and self.muon_cfg.systs: + for syst in self.muon_cfg.systs: + self.produces.add(f"Muon.pt_{syst}") + + # original column + if self.store_original: + self.produces.add("Muon.pt_sr_uncorrected") + + +@muon_sr.requires +def muon_sr_requires( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + **kwargs, +) -> None: + if "external_files" in reqs: + return + + from columnflow.tasks.external import BundleExternalFiles + reqs["external_files"] = BundleExternalFiles.req(task) + + +@muon_sr.setup +def muon_sr_setup( + self: Calibrator, + task: law.Task, + reqs: dict[str, DotDict[str, Any]], + inputs: dict[str, Any], + reader_targets: law.util.InsertableDict, + **kwargs, +) -> None: + # load the correction set + muon_sr_file = self.get_muon_sr_file(reqs["external_files"].files) + self.muon_correction_set = load_correction_set(muon_sr_file) + + # also load the tools as an external package + muon_sr_tool_file = self.get_muon_sr_tool_file(reqs["external_files"].files) + self.muon_sr_tools = import_file(muon_sr_tool_file.abspath) + + # silence printing of the filter_boundaries function + spec = inspect.getfullargspec(self.muon_sr_tools.filter_boundaries) + if "silent" in spec.args or "silent" in spec.kwonlyargs: + self.muon_sr_tools.filter_boundaries = functools.partial(self.muon_sr_tools.filter_boundaries, silent=True) + + +muon_sr_nominal = muon_sr.derive("muon_sr_nominal", cls_dict={"with_uncertainties": False}) diff --git a/columnflow/calibration/cms/tau.py b/columnflow/calibration/cms/tau.py index 69e5a6760..897ebea4f 100644 --- a/columnflow/calibration/cms/tau.py +++ b/columnflow/calibration/cms/tau.py @@ -8,14 +8,14 @@ import functools import itertools -from dataclasses import dataclass, field +import dataclasses import law from columnflow.calibration import Calibrator, calibrator from columnflow.calibration.util import propagate_met from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, ak_copy +from columnflow.columnar_util import TAFConfig, set_ak_column, flat_np_view, ak_copy from columnflow.types import Any ak = maybe_import("awkward") @@ -26,11 +26,11 @@ set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) -@dataclass -class TECConfig: +@dataclasses.dataclass +class TECConfig(TAFConfig): tagger: str correction_set: str = "tau_energy_scale" - corrector_kwargs: dict[str, Any] = field(default_factory=dict) + corrector_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) @classmethod def new(cls, obj: TECConfig | tuple[str] | dict[str, str]) -> TECConfig: @@ -44,14 +44,8 @@ def new(cls, obj: TECConfig | tuple[str] | dict[str, str]) -> TECConfig: @calibrator( - uses={ - # nano columns - "nTau", "Tau.pt", "Tau.eta", "Tau.phi", "Tau.mass", "Tau.charge", "Tau.genPartFlav", - "Tau.decayMode", - }, - produces={ - "Tau.pt", "Tau.mass", - }, + uses={"Tau.{pt,eta,phi,mass,charge,genPartFlav,decayMode}"}, + produces={"Tau.{pt,mass}"}, # whether to produce also uncertainties with_uncertainties=True, # toggle for propagation to MET @@ -142,17 +136,12 @@ def tec( scales_down = np.ones_like(dm_mask, dtype=np.float32) scales_down[dm_mask] = self.tec_corrector(*args, "down") - # custom adjustment 1: reset where the matching value is unhandled - # custom adjustment 2: reset electrons faking taus where the pt is too small - mask1 = (match < 1) | (match > 5) - mask2 = ((match == 1) | (match == 3)) & (pt <= 20.0) - - # apply reset masks - mask = mask1 | mask2 - scales_nom[mask] = 1.0 + # custom adjustment: reset where the matching value is unhandled + reset_mask = (match < 1) | (match > 5) + scales_nom[reset_mask] = 1.0 if self.with_uncertainties: - scales_up[mask] = 1.0 - scales_down[mask] = 1.0 + scales_up[reset_mask] = 1.0 + scales_down[reset_mask] = 1.0 # create varied collections per decay mode if self.with_uncertainties: diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 0727e48ec..00878e579 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -16,6 +16,7 @@ import enum import inspect import threading +import dataclasses import multiprocessing import multiprocessing.pool from functools import partial @@ -2217,6 +2218,30 @@ def __call__(self, *args, **kwargs) -> Any: deferred_column = ArrayFunction.DeferredColumn.deferred_column +@deferred_column +def IF_DATA(self: ArrayFunction.DeferredColumn, func: ArrayFunction) -> Any | set[Any]: + return self.get() if func.dataset_inst.is_data else None + + +@deferred_column +def IF_MC(self: ArrayFunction.DeferredColumn, func: ArrayFunction) -> Any | set[Any]: + return self.get() if func.dataset_inst.is_mc else None + + +def IF_DATASET_HAS_TAG(*args, negate: bool = False, **kwargs) -> ArrayFunction.DeferredColumn: + @deferred_column + def deferred( + self: ArrayFunction.DeferredColumn, + func: ArrayFunction, + ) -> Any | set[Any]: + return self.get() if func.dataset_inst.has_tag(*args, **kwargs) is not negate else None + + return deferred + + +IF_DATASET_NOT_HAS_TAG = partial(IF_DATASET_HAS_TAG, negate=True) + + def tagged_column( tag: str | Sequence[str] | set[str], *routes: Route | Any | set[Route | Any], @@ -2926,6 +2951,19 @@ def get_min_chunk_size(self) -> int | None: return min((s for s in sizes if isinstance(s, int)), default=None) +@dataclasses.dataclass +class TAFConfig: + + def copy(self, **kwargs) -> TAFConfig: + """ + Returns a copy of this TAFConfig instance, updated by any given *kwargs*. + + :param kwargs: Attributes to update in the copied instance. + :return: The copied and updated TAFConfig instance. + """ + return self.__class__(self.__dict__ | kwargs) + + class NoThreadPool(object): """ Dummy implementation that mimics parts of the usual thread pool interface but instead of @@ -3251,7 +3289,11 @@ def __init__(self, path: str, open_options: dict | None = None) -> None: meta_options.pop("row_groups", None) meta_options.pop("ignore_metadata", None) meta_options.pop("columns", None) - self.metadata = ak.metadata_from_parquet(path, **meta_options) + try: + self.metadata = ak.metadata_from_parquet(path, **meta_options) + except: + logger.error(f"unable to read {path}") + raise # extract row group sizes for chunked reading if "col_counts" not in self.metadata: @@ -3433,7 +3475,7 @@ class ChunkedIOHandler(object): # chunk position container ChunkPosition = namedtuple( "ChunkPosition", - ["index", "entry_start", "entry_stop", "max_chunk_size"], + ["index", "entry_start", "entry_stop", "max_chunk_size", "n_chunks"], ) # read result container @@ -3539,11 +3581,13 @@ def create_chunk_position( if n_entries == 0: entry_start = 0 entry_stop = 0 + n_chunks = 0 else: entry_start = chunk_index * chunk_size entry_stop = min((chunk_index + 1) * chunk_size, n_entries) + n_chunks = int(math.ceil(n_entries / chunk_size)) - return cls.ChunkPosition(chunk_index, entry_start, entry_stop, chunk_size) + return cls.ChunkPosition(chunk_index, entry_start, entry_stop, chunk_size, n_chunks) @classmethod def get_source_handler( diff --git a/columnflow/config_util.py b/columnflow/config_util.py index 0958e0ec7..d728f5800 100644 --- a/columnflow/config_util.py +++ b/columnflow/config_util.py @@ -680,6 +680,9 @@ def kwargs_fn(categories): cat = od.Category(name=cat_name, **kwargs) created_categories[cat_name] = cat + # add a tag to denote this category was auto-created + cat.add_tag("auto_created_by_combinations") + # ID uniqueness check: raise an error when a non-unique id is detected for a new category if isinstance(kwargs["id"], int): if kwargs["id"] in unique_ids_cache: @@ -735,14 +738,21 @@ def _parent_gen(): return len(created_categories) -def track_category_changes(config: od.Config, summary_path: str | None = None) -> None: +def track_category_changes( + config: od.Config, + summary_path: str | None = None, + skip_auto_created: bool = False, +) -> None: """ Scans the categories in *config* and saves a summary in a file located at *summary_path*. If the file exists, the summary from a previous run is loaded first and compare to the current categories. If changes are found, a warning is shown with details about these changes. + Categories automatically created via :py:func:`create_category_combinations` can be skipped via *skip_auto_created*. + :param config: :py:class:`~order.config.Config` instance to scan for categories. :param summary_path: Path to the summary file. Defaults to "$LAW_HOME/category_summary_{config.name}.json". + :param skip_auto_created: If *True*, categories with the tag "auto_created_by_combinations" are skipped. """ # build summary file as law target if not summary_path: @@ -750,7 +760,11 @@ def track_category_changes(config: od.Config, summary_path: str | None = None) - summary_file = law.LocalFileTarget(summary_path) # gather category info - cat_pairs = sorted((cat.name, cat.id) for cat, *_ in config.walk_categories(include_self=True)) + cat_pairs = sorted( + (cat.name, cat.id) + for cat, *_ in config.walk_categories(include_self=True) + if not skip_auto_created or not cat.has_tag("auto_created_by_combinations") + ) cat_summary = { "hash": law.util.create_hash(cat_pairs), "categories": dict(cat_pairs), diff --git a/columnflow/hist_util.py b/columnflow/hist_util.py index 1a82c8617..4efd3c73d 100644 --- a/columnflow/hist_util.py +++ b/columnflow/hist_util.py @@ -308,6 +308,32 @@ def add_missing_shifts( h[{str_axis: hist.loc(missing_shift)}] = nominal.view() +def update_ax_labels(hists: list[hist.Hist], config_inst: od.Config, variable_name: str) -> None: + """ + Helper function to update the axis labels of histograms based on variable instances from + the *config_inst*. + + :param hists: List of histograms to update. + :param config_inst: Configuration instance containing variable definitions. + :param variable_name: Name of the variable to update labels for, formatted as a string + with variable names separated by hyphens (e.g., "var1-var2"). + :raises ValueError: If a variable name is not found in the histogram axes. + """ + labels = {} + for var_name in variable_name.split("-"): + var_inst = config_inst.get_variable(var_name, None) + if var_inst: + labels[var_name] = var_inst.x_title + + for h in hists: + for var_name, label in labels.items(): + ax_names = [ax.name for ax in h.axes] + if var_name in ax_names: + h.axes[var_name].label = label + else: + raise ValueError(f"variable '{var_name}' not found in histogram axes: {h.axes}") + + def sum_hists(hists: Sequence[hist.Hist]) -> hist.Hist: """ Sums a sequence of histograms into a new histogram. In case axis labels differ, which typically leads to errors diff --git a/columnflow/histogramming/__init__.py b/columnflow/histogramming/__init__.py index 41a9438c7..f8b76ff20 100644 --- a/columnflow/histogramming/__init__.py +++ b/columnflow/histogramming/__init__.py @@ -11,15 +11,15 @@ import law import order as od -from columnflow.columnar_util import TaskArrayFunction +from columnflow.production import TaskArrayFunctionWithProducerRequirements from columnflow.util import DerivableMeta, maybe_import -from columnflow.types import TYPE_CHECKING, Any, Callable +from columnflow.types import TYPE_CHECKING, Any, Callable, Sequence if TYPE_CHECKING: hist = maybe_import("hist") -class HistProducer(TaskArrayFunction): +class HistProducer(TaskArrayFunctionWithProducerRequirements): """ Base class for all histogram producers, i.e., functions that control the creation of histograms, event weights, and optional post-processing. @@ -57,6 +57,10 @@ class HistProducer(TaskArrayFunction): skip_compatibility_check = False exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def hist_producer( cls, @@ -64,6 +68,7 @@ def hist_producer( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_producers: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -82,6 +87,7 @@ def hist_producer( skipped for real data. :param data_only: Boolean flag indicating that this hist producer should only run on real data and skipped for Monte Carlo simulation. + :param require_producers: Sequence of names of other producers to add to the requirements. :return: New hist producer subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -91,6 +97,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_producers": require_producers, } # get the module name diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 70e1f7785..0023c1352 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -15,8 +15,7 @@ from columnflow.types import Generator, Callable, TextIO, Sequence, Any, Hashable, Type, T from columnflow.util import ( - CachedDerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url, - freeze, + CachedDerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url, freeze, ) @@ -221,7 +220,7 @@ def __str__(self) -> str: class InferenceModelMeta(CachedDerivableMeta): def _get_inst_cache_key(cls, args: tuple, kwargs: dict) -> Hashable: - config_insts = args[0] + config_insts = args[0] if args else kwargs.get("config_insts", []) config_names = tuple(sorted(config_inst.name for config_inst in config_insts)) return freeze((cls, config_names, kwargs.get("inst_dict", {}))) @@ -601,11 +600,11 @@ def parameter_config_spec( ("shift_source", str(shift_source) if shift_source else None), ]) - def __init__(self, config_insts: list[od.Config]) -> None: + def __init__(self, config_insts: list[od.Config] | None = None) -> None: super().__init__() # store attributes - self.config_insts = config_insts + self.config_insts = config_insts or [] # temporary attributes for as long as we issue deprecation warnings self.__config_inst = None diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index 394960c6a..373f0875a 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -556,12 +556,14 @@ def handle_flow(cat_obj, h, name): # warn in case of flow content if cat_obj.flow_strategy == FlowStrategy.warn: if underflow[0]: - logger.warning( + logger.warning_once( + f"underflow_warn_{self.inference_model_inst.cls_name}_{cat_obj.name}_{name}", f"underflow content detected in category '{cat_obj.name}' for histogram " f"'{name}' ({underflow[0] / view.value.sum() * 100:.1f}% of integral)", ) if overflow[0]: - logger.warning( + logger.warning_once( + f"overflow_warn_{self.inference_model_inst.cls_name}_{cat_obj.name}_{name}", f"overflow content detected in category '{cat_obj.name}' for histogram " f"'{name}' ({overflow[0] / view.value.sum() * 100:.1f}% of integral)", ) @@ -611,7 +613,7 @@ def fill_empty(cat_obj, h): # flat list of hists for configs that contribute to this category hists: list[dict[Hashable, hist.Hist]] = [ hd for config_name, hd in config_hists.items() - if config_name in cat_obj.config_data + if not cat_obj.config_data or config_name in cat_obj.config_data ] if not hists: continue @@ -827,23 +829,25 @@ def load( if not h_data: proc_str = ",".join(map(str, cat_obj.data_from_processes)) raise Exception(f"none of requested processes '{proc_str}' found to create fake data") - h_data = sum_hists(h_data) data_name = data_pattern.format(category=cat_name) - fill_empty(cat_obj, h_data) + h_data = sum_hists(h_data) handle_flow(cat_obj, h_data, data_name) + h_data.view().variance = h_data.view().value out_file[data_name] = h_data _rates["data"] = float(h_data.sum().value) - elif any(cd.data_datasets for cd in cat_obj.config_data.values()): + elif proc_hists.get("data"): + # real data h_data = [] - for config_name, config_data in cat_obj.config_data.items(): - if "data" not in proc_hists or config_name not in proc_hists["data"]: + for config_name, config_hists in proc_hists["data"].items(): + if cat_obj.config_data and config_name not in cat_obj.config_data: raise Exception( - f"the inference model '{self.inference_model_inst.cls_name}' is configured to use real " - f"data for config '{config_name}' in category '{cat_name}' but no histogram received at " - f"entry ['data']['{config_name}']: {proc_hists}", + f"received real data in datacard category '{cat_name}' for config '{config_name}', but the " + f"inference model '{self.inference_model_inst.cls_name}' is not configured to use it in " + f"the config_data for that config; configured config_names are " + f"'{','.join(cat_obj.config_data.keys())}'", ) - h_data.append(proc_hists["data"][config_name]["nominal"]) + h_data.append(config_hists["nominal"]) # simply save the data histogram that was already built from the requested datasets h_data = sum_hists(h_data) @@ -852,6 +856,9 @@ def load( out_file[data_name] = h_data _rates["data"] = h_data.sum().value + else: + logger.warning(f"neither real data found nor fake data created in category '{cat_name}'") + return (rates, effects, nom_pattern_comb, syst_pattern_comb) @classmethod diff --git a/columnflow/production/__init__.py b/columnflow/production/__init__.py index 529191cf3..03ff6faf9 100644 --- a/columnflow/production/__init__.py +++ b/columnflow/production/__init__.py @@ -8,18 +8,55 @@ import inspect -from columnflow.types import Callable +import law + from columnflow.util import DerivableMeta from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, Sequence, Any + + +class TaskArrayFunctionWithProducerRequirements(TaskArrayFunction): + + require_producers: Sequence[str] | set[str] | None = None + + def _req_producer(self, task: law.Task, producer: str) -> Any: + # hook to customize how required producers are requested + from columnflow.tasks.production import ProduceColumns + return ProduceColumns.req_other_producer(task, producer=producer) + def requires_func(self, task: law.Task, reqs: dict, **kwargs) -> None: + # no requirements for workflows in pilot mode + if callable(getattr(task, "is_workflow", None)) and task.is_workflow() and getattr(task, "pilot", False): + return -class Producer(TaskArrayFunction): + # add required producers when set + if (prods := self.require_producers): + reqs["required_producers"] = {prod: self._req_producer(task, prod) for prod in prods} + + def setup_func( + self, + task: law.Task, + reqs: dict, + inputs: dict, + reader_targets: law.util.InsertableDict, + **kwargs, + ) -> None: + if "required_producers" in inputs: + for prod, inp in inputs["required_producers"].items(): + reader_targets[f"required_producer_{prod}"] = inp["columns"] + + +class Producer(TaskArrayFunctionWithProducerRequirements): """ Base class for all producers. """ exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def producer( cls, @@ -27,6 +64,7 @@ def producer( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_producers: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -46,6 +84,7 @@ def producer( Monte Carlo simulation and skipped for real data. :param data_only: Boolean flag indicating that this :py:class:`Producer` should only run on real data and skipped for Monte Carlo simulation. + :param require_producers: Sequence of names of other producers to add to the requirements. :return: New :py:class:`Producer` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -55,6 +94,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_producers": require_producers, } # get the module name diff --git a/columnflow/production/cms/electron.py b/columnflow/production/cms/electron.py index 38e628a62..89b739daa 100644 --- a/columnflow/production/cms/electron.py +++ b/columnflow/production/cms/electron.py @@ -6,31 +6,35 @@ from __future__ import annotations -from dataclasses import dataclass +import dataclasses import law from columnflow.production import Producer, producer from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array, EMPTY_FLOAT +from columnflow.columnar_util import set_ak_column, full_like, flat_np_view from columnflow.types import Any, Callable np = maybe_import("numpy") ak = maybe_import("awkward") -@dataclass +@dataclasses.dataclass class ElectronSFConfig: correction: str campaign: str working_point: str | dict[str, Callable] = "" hlt_path: str = "" + min_pt: float = 0.0 + max_pt: float = 0.0 def __post_init__(self) -> None: if not self.working_point and not self.hlt_path: raise ValueError("either working_point or hlt_path must be set") if self.working_point and self.hlt_path: raise ValueError("only one of working_point or hlt_path must be set") + if 0.0 < self.max_pt <= self.min_pt: + raise ValueError(f"{self.__class__.__name__}: max_pt must be larger than min_pt") @classmethod def new(cls, obj: ElectronSFConfig | tuple[str, str, str]) -> ElectronSFConfig: @@ -52,11 +56,12 @@ def new(cls, obj: ElectronSFConfig | tuple[str, str, str]) -> ElectronSFConfig: # function to determine the correction file get_electron_file=(lambda self, external_files: external_files.electron_sf), # function to determine the electron weight config - get_electron_config=(lambda self: ElectronSFConfig.new(self.config_inst.x.electron_sf_names)), + get_electron_config=(lambda self: ElectronSFConfig.new(self.config_inst.x("electron_sf", self.config_inst.x("electron_sf_names", None)))), # noqa: E501 # choose if the eta variable should be the electron eta or the super cluster eta use_supercluster_eta=True, + # name of the saved weight column weight_name="electron_weight", - supported_versions=(1, 2, 3), + supported_versions={1, 2, 3}, ) def electron_weights( self: Producer, @@ -65,8 +70,7 @@ def electron_weights( **kwargs, ) -> ak.Array: """ - Creates electron weights using the correctionlib. Requires an external file in the config under - ``electron_sf``: + Creates electron weights using the correctionlib. Requires an external file in the config under ``electron_sf``: .. code-block:: python @@ -74,26 +78,26 @@ def electron_weights( "electron_sf": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/EGM/2017_UL/electron.json.gz", # noqa }) - *get_electron_file* can be adapted in a subclass in case it is stored differently in the - external files. + *get_electron_file* can be adapted in a subclass in case it is stored differently in the external files. - The name of the correction set, the year string for the weight evaluation, and the name of the - working point should be given as an auxiliary entry in the config: + The name of the correction set, the year string for the weight evaluation, and the name of the working point should + be given as an auxiliary entry in the config: .. code-block:: python - cfg.x.electron_sf_names = ElectronSFConfig( + cfg.x.electron_sf = ElectronSFConfig( correction="UL-Electron-ID-SF", campaign="2017", working_point="wp80iso", # for trigger weights use hlt_path instead ) - The *working_point* can also be a dictionary mapping working point names to functions - that return a boolean mask for the electrons. This is useful to compute scale factors for - multiple working points at once, e.g. for the electron reconstruction scale factors: + The *working_point* can also be a dictionary mapping working point names to functions that return a boolean mask for + the electrons. This is useful to compute scale factors for multiple working points at once, e.g. for the electron + reconstruction scale factors: .. code-block:: python - cfg.x.electron_sf_names = ElectronSFConfig( + + cfg.x.electron_sf = ElectronSFConfig( correction="Electron-ID-SF", campaign="2022Re-recoE+PromptFG", working_point={ @@ -103,65 +107,68 @@ def electron_weights( }, ) - *get_electron_config* can be adapted in a subclass in case it is stored differently in the - config. + *get_electron_config* can be adapted in a subclass in case it is stored differently in the config. - Optionally, an *electron_mask* can be supplied to compute the scale factor weight - based only on a subset of electrons. + Optionally, an *electron_mask* can be supplied to compute the scale factor weight based only on a subset of + electrons. """ - # flat super cluster eta/flat eta and pt views + # fold electron mask with pt cuts if given + if self.electron_config.min_pt > 0.0: + pt_mask = events.Electron.pt >= self.electron_config.min_pt + electron_mask = pt_mask if electron_mask is Ellipsis else (pt_mask & electron_mask) + if self.electron_config.max_pt > 0.0: + pt_mask = events.Electron.pt <= self.electron_config.max_pt + electron_mask = pt_mask if electron_mask is Ellipsis else (pt_mask & electron_mask) + + # prepare input variables + electrons = events.Electron[electron_mask] + eta = electrons.eta if self.use_supercluster_eta: - eta = flat_np_view(( - events.Electron.eta[electron_mask] + - events.Electron.deltaEtaSC[electron_mask] - ), axis=1) - else: - eta = flat_np_view(events.Electron.eta[electron_mask], axis=1) - pt = flat_np_view(events.Electron.pt[electron_mask], axis=1) - phi = flat_np_view(events.Electron.phi[electron_mask], axis=1) - + eta = ( + electrons.superclusterEta + if "superclusterEta" in electrons.fields + else electrons.eta + electrons.deltaEtaSC + ) variable_map = { "year": self.electron_config.campaign, - "WorkingPoint": self.electron_config.working_point, "Path": self.electron_config.hlt_path, - "pt": pt, + "pt": electrons.pt, + "phi": electrons.phi, "eta": eta, - "phi": phi, } # loop over systematics for syst, postfix in zip(self.sf_variations, ["", "_up", "_down"]): # get the inputs for this type of variation - variable_map_syst = { - **variable_map, - "ValType": syst, - } - if isinstance(variable_map["WorkingPoint"], str): - inputs = [variable_map_syst[inp.name] for inp in self.electron_sf_corrector.inputs] - sf_flat = self.electron_sf_corrector(*inputs) - elif isinstance(variable_map["WorkingPoint"], dict): - sf_flat = np.ones_like(pt, dtype=np.float32) * EMPTY_FLOAT - for working_point, mask_fn in variable_map_syst["WorkingPoint"].items(): + variable_map_syst = variable_map | {"ValType": syst} + + # add working point + wp = self.electron_config.working_point + if isinstance(wp, str): + # single wp, just evaluate + variable_map_syst_wp = variable_map_syst | {"WorkingPoint": wp} + inputs = [variable_map_syst_wp[inp.name] for inp in self.electron_sf_corrector.inputs] + sf = self.electron_sf_corrector.evaluate(*inputs) + elif isinstance(wp, dict): + # mapping of wps to masks, evaluate per wp and combine + sf = full_like(eta, 1.0) + sf_flat = flat_np_view(sf) + for _wp, mask_fn in wp.items(): mask = mask_fn(variable_map) - variable_map_syst_wp = { - **variable_map_syst, - "WorkingPoint": working_point, - } - for key, value in variable_map_syst_wp.items(): - # apply mask to array-like values - if isinstance(value, np.ndarray) or isinstance(value, ak.Array): - variable_map_syst_wp[key] = value[mask] + variable_map_syst_wp = variable_map_syst | {"WorkingPoint": _wp} # call the corrector with the masked inputs - inputs = [variable_map_syst_wp[inp.name] for inp in self.electron_sf_corrector.inputs] - sf_flat[mask] = self.electron_sf_corrector(*inputs) - if np.any(sf_flat == EMPTY_FLOAT): - raise ValueError("some electrons did not have a valid scale factor, check your inputs") + inputs = [ + ( + variable_map_syst_wp[inp.name][mask] + if isinstance(variable_map_syst_wp[inp.name], (np.ndarray, ak.Array)) + else variable_map_syst_wp[inp.name] + ) + for inp in self.electron_sf_corrector.inputs + ] + sf_flat[flat_np_view(mask)] = flat_np_view(self.electron_sf_corrector.evaluate(*inputs)) else: raise ValueError(f"unsupported working point type {type(variable_map['WorkingPoint'])}") - # add the correct layout to it - sf = layout_ak_array(sf_flat, events.Electron.pt[electron_mask]) - # create the product over all electrons in one event weight = ak.prod(sf, axis=1, mask_identity=False) diff --git a/columnflow/production/cms/muon.py b/columnflow/production/cms/muon.py index ee790a7d5..762e6b544 100644 --- a/columnflow/production/cms/muon.py +++ b/columnflow/production/cms/muon.py @@ -8,21 +8,23 @@ import law -from dataclasses import dataclass +import dataclasses from columnflow.production import Producer, producer from columnflow.util import maybe_import, load_correction_set, DotDict -from columnflow.columnar_util import set_ak_column, flat_np_view, layout_ak_array +from columnflow.columnar_util import set_ak_column from columnflow.types import Any np = maybe_import("numpy") ak = maybe_import("awkward") -@dataclass +@dataclasses.dataclass class MuonSFConfig: correction: str campaign: str = "" + min_pt: float = 0.0 + max_pt: float = 0.0 @classmethod def new(cls, obj: MuonSFConfig | tuple[str, str]) -> MuonSFConfig: @@ -37,18 +39,23 @@ def new(cls, obj: MuonSFConfig | tuple[str, str]) -> MuonSFConfig: return cls(**obj) raise ValueError(f"cannot convert {obj} to MuonSFConfig") + def __post_init__(self): + if 0.0 < self.max_pt <= self.min_pt: + raise ValueError(f"{self.__class__.__name__}: max_pt must be larger than min_pt") + @producer( uses={"Muon.{pt,eta}"}, - # produces in the init + # produces defined in init # only run on mc mc_only=True, # function to determine the correction file get_muon_file=(lambda self, external_files: external_files.muon_sf), # function to determine the muon weight config - get_muon_config=(lambda self: MuonSFConfig.new(self.config_inst.x.muon_sf_names)), + get_muon_config=(lambda self: MuonSFConfig.new(self.config_inst.x("muon_sf", self.config_inst.x("muon_sf_names", None)))), # noqa: E501 + # name of the saved weight column weight_name="muon_weight", - supported_versions=(1, 2), + supported_versions={1, 2}, ) def muon_weights( self: Producer, @@ -57,8 +64,7 @@ def muon_weights( **kwargs, ) -> ak.Array: """ - Creates muon weights using the correctionlib. Requires an external file in the config under - ``muon_sf``: + Creates muon weights using the correctionlib. Requires an external file in the config under ``muon_sf``: .. code-block:: python @@ -66,33 +72,37 @@ def muon_weights( "muon_sf": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/MUO/2017_UL/muon_z.json.gz", # noqa }) - *get_muon_file* can be adapted in a subclass in case it is stored differently in the external - files. + *get_muon_file* can be adapted in a subclass in case it is stored differently in the external files. - The name of the correction set and the year string for the weight evaluation should be given as - an auxiliary entry in the config: + The name of the correction set and the year string for the weight evaluation should be given as an auxiliary entry + in the config: .. code-block:: python - cfg.x.muon_sf_names = MuonSFConfig( + cfg.x.muon_sf = MuonSFConfig( correction="NUM_TightRelIso_DEN_TightIDandIPCut", campaign="2017_UL", ) *get_muon_config* can be adapted in a subclass in case it is stored differently in the config. - Optionally, a *muon_mask* can be supplied to compute the scale factor weight based only on a - subset of muons. + Optionally, a *muon_mask* can be supplied to compute the scale factor weight based only on a subset of muons. """ - # flat eta and pt views - eta = flat_np_view(events.Muon["eta"][muon_mask], axis=1) - pt = flat_np_view(events.Muon["pt"][muon_mask], axis=1) - + # fold muon mask with pt cuts if given + if self.muon_config.min_pt > 0.0: + pt_mask = events.Muon.pt >= self.muon_config.min_pt + muon_mask = pt_mask if muon_mask is Ellipsis else (pt_mask & muon_mask) + if self.muon_config.max_pt > 0.0: + pt_mask = events.Muon.pt <= self.muon_config.max_pt + muon_mask = pt_mask if muon_mask is Ellipsis else (pt_mask & muon_mask) + + # prepare input variables + muons = events.Muon[muon_mask] variable_map = { "year": self.muon_config.campaign, - "eta": eta, - "abseta": abs(eta), - "pt": pt, + "eta": muons.eta, + "abseta": abs(muons.eta), + "pt": muons.pt, } # loop over systematics @@ -108,10 +118,7 @@ def muon_weights( "ValType": syst, # syst key in 2017 } inputs = [variable_map_syst[inp.name] for inp in self.muon_sf_corrector.inputs] - sf_flat = self.muon_sf_corrector(*inputs) - - # add the correct layout to it - sf = layout_ak_array(sf_flat, events.Muon["pt"][muon_mask]) + sf = self.muon_sf_corrector.evaluate(*inputs) # create the product over all muons in one event weight = ak.prod(sf, axis=1, mask_identity=False) diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 0d3478956..c118b3468 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -330,7 +330,7 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra ) # read the weight per process (defined as lumi * xsec / sum_weights) from the lookup table - process_weight = np.squeeze(np.asarray(self.process_weight_table[process_id, 0].todense())) + process_weight = np.squeeze(np.asarray(self.process_weight_table[process_id].todense()), axis=-1) # compute the weight and store it norm_weight = events.mc_weight * process_weight diff --git a/columnflow/reduction/__init__.py b/columnflow/reduction/__init__.py index c35975c2e..e58c3ba61 100644 --- a/columnflow/reduction/__init__.py +++ b/columnflow/reduction/__init__.py @@ -8,18 +8,22 @@ import inspect -from columnflow.types import Callable +from columnflow.calibration import TaskArrayFunctionWithCalibratorRequirements from columnflow.util import DerivableMeta -from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, Sequence -class Reducer(TaskArrayFunction): +class Reducer(TaskArrayFunctionWithCalibratorRequirements): """ Base class for all reducers. """ exposed = True + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + @classmethod def reducer( cls, @@ -27,6 +31,7 @@ def reducer( bases: tuple = (), mc_only: bool = False, data_only: bool = False, + require_calibrators: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ @@ -45,6 +50,7 @@ def reducer( for real data. :param data_only: Boolean flag indicating that this reducer should only run on real data and skipped for Monte Carlo simulation. + :param require_calibrators: Sequence of names of calibrators to add to the requirements. :return: New reducer subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -54,6 +60,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_calibrators": require_calibrators, } # get the module name diff --git a/columnflow/selection/__init__.py b/columnflow/selection/__init__.py index 5f0af3ce7..1f4368fc0 100644 --- a/columnflow/selection/__init__.py +++ b/columnflow/selection/__init__.py @@ -12,9 +12,9 @@ import law import order as od -from columnflow.types import Callable, T +from columnflow.calibration import TaskArrayFunctionWithCalibratorRequirements from columnflow.util import maybe_import, DotDict, DerivableMeta -from columnflow.columnar_util import TaskArrayFunction +from columnflow.types import Callable, T, Sequence ak = maybe_import("awkward") np = maybe_import("numpy") @@ -22,13 +22,17 @@ logger = law.logger.get_logger(__name__) -class Selector(TaskArrayFunction): +class Selector(TaskArrayFunctionWithCalibratorRequirements): """ Base class for all selectors. """ exposed = False + # register attributes for arguments accepted by decorator + mc_only: bool = False + data_only: bool = False + def __init__(self: Selector, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -44,25 +48,26 @@ def selector( bases=(), mc_only: bool = False, data_only: bool = False, + require_calibrators: Sequence[str] | set[str] | None = None, **kwargs, ) -> DerivableMeta | Callable: """ - Decorator for creating a new :py:class:`~.Selector` subclass with additional, optional - *bases* and attaching the decorated function to it as ``call_func``. + Decorator for creating a new :py:class:`~.Selector` subclass with additional, optional *bases* and attaching the + decorated function to it as ``call_func``. - When *mc_only* (*data_only*) is *True*, the selector is skipped and not considered by - other calibrators, selectors and producers in case they are evaluated on a - :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` - (``is_data``) attribute is *False*. + When *mc_only* (*data_only*) is *True*, the selector is skipped and not considered by other calibrators, + selectors and producers in case they are evaluated on a :py:class:`order.Dataset` (using the + :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Selector` class. :param bases: Additional bases for the new :py:class:`Selector`. - :param mc_only: Boolean flag indicating that this :py:class:`Selector` should only run on - Monte Carlo simulation and skipped for real data. - :param data_only: Boolean flag indicating that this :py:class:`Selector` should only run on - real data and skipped for Monte Carlo simulation. + :param mc_only: Boolean flag indicating that this :py:class:`Selector` should only run on Monte Carlo simulation + and skipped for real data. + :param data_only: Boolean flag indicating that this :py:class:`Selector` should only run on real data and + skipped for Monte Carlo simulation. + :param require_calibrators: Sequence of names of calibrators to add to the requirements. :return: New :py:class:`Selector` subclass. """ def decorator(func: Callable) -> DerivableMeta: @@ -72,6 +77,7 @@ def decorator(func: Callable) -> DerivableMeta: "call_func": func, "mc_only": mc_only, "data_only": data_only, + "require_calibrators": require_calibrators, } # get the module name diff --git a/columnflow/tasks/calibration.py b/columnflow/tasks/calibration.py index 2158487de..8b58be61f 100644 --- a/columnflow/tasks/calibration.py +++ b/columnflow/tasks/calibration.py @@ -60,9 +60,7 @@ def workflow_requires(self) -> dict: reqs["lfns"] = self.reqs.GetDatasetLFNs.req(self) # add calibrator dependent requirements - reqs["calibrator"] = law.util.make_unique(law.util.flatten( - self.calibrator_inst.run_requires(task=self), - )) + reqs["calibrator"] = law.util.make_unique(law.util.flatten(self.calibrator_inst.run_requires(task=self))) return reqs diff --git a/columnflow/tasks/cms/external.py b/columnflow/tasks/cms/external.py index 148b0bac5..89d527244 100644 --- a/columnflow/tasks/cms/external.py +++ b/columnflow/tasks/cms/external.py @@ -67,7 +67,7 @@ def run(self): # since this tasks uses stage-in into and stage-out from the sandbox, # prepare external files with the staged-in inputs - externals.get_files(self.input()) + externals.get_files_collection(self.input()) # read the mc profile mc_profile = self.read_mc_profile_from_cfg(externals.files.pu.mc_profile) diff --git a/columnflow/tasks/external.py b/columnflow/tasks/external.py index 33c4a4793..0d901bc54 100644 --- a/columnflow/tasks/external.py +++ b/columnflow/tasks/external.py @@ -17,6 +17,7 @@ import law import order as od +from columnflow import env_is_local from columnflow.tasks.framework.base import AnalysisTask, ConfigTask, DatasetTask, wrapper_factory from columnflow.tasks.framework.parameters import user_parameter_inst from columnflow.tasks.framework.decorators import only_local_env @@ -404,7 +405,7 @@ class ExternalFile: """ location: str - subpaths: dict[str, str] = field(default_factory=str) + subpaths: dict[str, str] = field(default_factory=dict) version: str = "v1" def __str__(self) -> str: @@ -428,6 +429,11 @@ def new(cls, resource: ExternalFile | str | tuple[str] | tuple[str, str]) -> Ext return cls(location=resource[0], version=resource[1]) raise ValueError(f"invalid resource type and format: {resource}") + def __getattr__(self, attr: str) -> str: + if attr in self.subpaths: + return self.subpaths[attr] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") + class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): """ @@ -436,8 +442,8 @@ class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): This task is intended to download source files for other tasks, such as files containing corrections for objects, the "golden" json files, source files for the calculation of pileup weights, and others. - All information about the relevant external files is extracted from the given ``config_inst``, which must contain - the keyword ``external_files`` in the auxiliary information. This can look like this: + All information about the relevant external files is extracted from the given ``config_inst``, which must contain an + auxiliary field ``external_files`` like the following (all entries are optional and user-defined): .. code-block:: python @@ -452,7 +458,7 @@ class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): "electron_sf": ExternalFile(f"{SOURCE_URL}/POG/EGM/{year}{corr_postfix}_UL/electron.json.gz", version="v1"), }) - The entries in this DotDict should be :py:class:`ExternalFile` instances. + All entries should be :py:class:`ExternalFile` instances. """ single_config = True @@ -461,6 +467,11 @@ class BundleExternalFiles(ConfigTask, law.tasks.TransferLocalFile): default=5, description="number of replicas to generate; default: 5", ) + recreate = luigi.BoolParameter( + default=False, + significant=False, + description="when True, forces the recreation of the bundle even if it exists; default: False", + ) user = user_parameter_inst version = None @@ -477,18 +488,10 @@ def __init__(self, *args, **kwargs): self._file_names = None # cached dict for lazy access to files in fetched bundle - self.files_dir = None - self._files = None + self._files_collection = None @classmethod def create_unique_basename(cls, path: str | ExternalFile) -> str | dict[str, str]: - """ - Create a unique basename for a given path. When *path* is an :py:class:`ExternalFile` with one or more subpaths - defined, a dictionary mapping subpaths to unique basenames is returned. - - :param path: path or external file object. - :return: Unique basename(s). - """ if isinstance(path, str): return f"{law.util.create_hash(path)}_{os.path.basename(path)}" @@ -503,11 +506,6 @@ def create_unique_basename(cls, path: str | ExternalFile) -> str | dict[str, str @property def files_hash(self) -> str: - """ - Create a hash based on all external files. - - :return: Hash based on the flattened list of external files in the current config instance. - """ if self._files_hash is None: # take the external files and flatten them into a deterministic order, then hash def deterministic_flatten(d): @@ -522,122 +520,156 @@ def deterministic_flatten(d): @property def file_names(self) -> DotDict: - """ - Create a unique basename for each external file. - - :return: DotDict of same shape as ``external_files`` DotDict with unique basenames. - """ if self._file_names is None: self._file_names = law.util.map_struct(self.create_unique_basename, self.ext_files) return self._file_names - def get_files(self, output=None): - if self._files is None: + def get_files_collection(self, output=None) -> law.SiblingFileCollection: + if self._files_collection is None: # get the output if not output: output = self.output() - if not output.exists(): + if not output["local_files"].exists(): raise Exception( - f"accessing external files from the bundle requires the output of {self} to " - "exist, but it appears to be missing", + f"accessing external files from the bundle requires the output of {self} to exist, but it appears " + "to be missing", ) - if isinstance(output, law.FileCollection): - output = output.random_target() - self.files_dir = law.LocalDirectoryTarget(is_tmp=True) - output.load(self.files_dir, formatter="tar") + self._files_collection = output["local_files"] - # resolve basenames in the bundle directory and map to local targets - def resolve_basename(unique_basename): - return self.files_dir.child(unique_basename) + return self._files_collection - self._files = law.util.map_struct(resolve_basename, self.file_names) - - return self._files + @property + def files(self) -> DotDict: + return self.get_files_collection().targets @property - def files(self): - return self.get_files() + def files_dir(self) -> law.LocalDirectoryTarget: + return self.get_files_collection().dir def single_output(self): # required by law.tasks.TransferLocalFile return self.target(f"externals_{self.files_hash}.tgz") - @only_local_env + def output(self): + def local_target(basename): + path = os.path.join(f"externals_{self.files_hash}", basename) + is_dir = "." not in basename # simple heuristic, but type actually checked after unpacking below + return self.local_target(path, dir=is_dir) + + return DotDict( + bundle=super().output(), + local_files=law.SiblingFileCollection(law.util.map_struct(local_target, self.file_names)), + ) + + def trace_transfer_output(self, output): + return output["bundle"] + @law.decorator.notify @law.decorator.log @law.decorator.safe_output def run(self): - # create a tmp dir to work in - tmp_dir = law.LocalDirectoryTarget(is_tmp=True) - tmp_dir.touch() - - # create a scratch directory for temporary downloads that will not be bundled - scratch_dir = tmp_dir.child("scratch", type="d") - scratch_dir.touch() - - # progress callback - progress = self.create_progress_callback(len(law.util.flatten(self.ext_files))) - - # helper to fetch a single src to dst - def fetch(src, dst): - if src.startswith(("http://", "https://")): - # download via wget - wget(src, dst) - elif os.path.isfile(src): - # copy local file - shutil.copy2(src, dst) - elif os.path.isdir(src): - # copy local dir - shutil.copytree(src, dst) - else: - err = f"cannot fetch {src}" - if src.startswith("/") and os.path.isdir("/".join(src.split("/", 2)[:2])): - err += ", file or directory does not exist" + outputs = self.output() + + # remove the bundle if recreating + if outputs["bundle"].exists() and self.recreate: + outputs["bundle"].remove() + + # bundle only if needed + if not outputs["bundle"].exists(): + if not env_is_local: + raise RuntimeError( + f"the output bundle {outputs['bundle'].basename} is missing, but cannot be created in non-local " + "environments", + ) + + # create a tmp dir to work in + tmp_dir = law.LocalDirectoryTarget(is_tmp=True) + tmp_dir.touch() + + # create a scratch directory for temporary downloads that will not be bundled + scratch_dir = tmp_dir.child("scratch", type="d") + scratch_dir.touch() + + # progress callback + progress = self.create_progress_callback(len(law.util.flatten(self.ext_files))) + + # helper to fetch a single src to dst + def fetch(src, dst): + if src.startswith(("http://", "https://")): + # download via wget + wget(src, dst) + elif os.path.isfile(src): + # copy local file + shutil.copy2(src, dst) + elif os.path.isdir(src): + # copy local dir + shutil.copytree(src, dst) else: - err += ", resource type is not supported" - raise NotImplementedError(err) - - # helper function to fetch generic files - def fetch_file(ext_file, counter=[0]): - if ext_file.subpaths: - # copy to scratch dir first in case a subpath is requested - basename = self.create_unique_basename(ext_file.location) - scratch_dst = os.path.join(scratch_dir.abspath, basename) - fetch(ext_file.location, scratch_dst) - # when not a directory, assume the file is an archive and unpack it - if not os.path.isdir(scratch_dst): - arc_dir = scratch_dir.child(basename.split(".")[0] + "_unpacked", type="d") - self.publish_message(f"unpacking {scratch_dst}") - law.LocalFileTarget(scratch_dst).load(arc_dir) - scratch_src = arc_dir.abspath + err = f"cannot fetch {src}" + if src.startswith("/") and os.path.isdir("/".join(src.split("/", 2)[:2])): + err += ", file or directory does not exist" + else: + err += ", resource type is not supported" + raise NotImplementedError(err) + + # helper function to fetch generic files + def fetch_file(ext_file, counter=[0]): + if ext_file.subpaths: + # copy to scratch dir first in case a subpath is requested + basename = self.create_unique_basename(ext_file.location) + scratch_dst = os.path.join(scratch_dir.abspath, basename) + fetch(ext_file.location, scratch_dst) + # when not a directory, assume the file is an archive and unpack it + if not os.path.isdir(scratch_dst): + arc_dir = scratch_dir.child(basename.split(".")[0] + "_unpacked", type="d") + self.publish_message(f"unpacking {scratch_dst}") + law.LocalFileTarget(scratch_dst).load(arc_dir) + scratch_src = arc_dir.abspath + else: + scratch_src = scratch_dst + # copy all subpaths + basenames = self.create_unique_basename(ext_file) + for name, subpath in ext_file.subpaths.items(): + fetch(os.path.join(scratch_src, subpath), os.path.join(tmp_dir.abspath, basenames[name])) else: - scratch_src = scratch_dst - # copy all subpaths - basenames = self.create_unique_basename(ext_file) - for name, subpath in ext_file.subpaths.items(): - fetch(os.path.join(scratch_src, subpath), os.path.join(tmp_dir.abspath, basenames[name])) - else: - # copy directly to the bundle dir - src = ext_file.location - dst = os.path.join(tmp_dir.abspath, self.create_unique_basename(ext_file.location)) - fetch(src, dst) - # log - self.publish_message(f"fetched {ext_file}") - progress(counter[0]) - counter[0] += 1 - - # fetch all files and cleanup scratch dir - law.util.map_struct(fetch_file, self.ext_files) - scratch_dir.remove() - - # create the bundle - tmp = law.LocalFileTarget(is_tmp="tgz") - tmp.dump(tmp_dir, formatter="tar") - - # log the file size - bundle_size = law.util.human_bytes(tmp.stat().st_size, fmt=True) - self.publish_message(f"bundle size is {bundle_size}") - - # transfer the result - self.transfer(tmp) + # copy directly to the bundle dir + src = ext_file.location + dst = os.path.join(tmp_dir.abspath, self.create_unique_basename(ext_file.location)) + fetch(src, dst) + # log + self.publish_message(f"fetched {ext_file}") + progress(counter[0]) + counter[0] += 1 + + # fetch all files and cleanup scratch dir + law.util.map_struct(fetch_file, self.ext_files) + scratch_dir.remove() + + # create the bundle + tmp = law.LocalFileTarget(is_tmp="tgz") + tmp.dump(tmp_dir, formatter="tar") + + # log the file size + bundle_size = law.util.human_bytes(tmp.stat().st_size, fmt=True) + self.publish_message(f"bundle size is {bundle_size}") + + # transfer the result + self.transfer(tmp, outputs["bundle"]) + + # unpack the bundle to have local files available + with self.publish_step(f"unpacking to {outputs['local_files'].dir.abspath} ..."): + outputs["local_files"].dir.remove() + bundle = outputs["bundle"] + if isinstance(bundle, law.FileCollection): + bundle = bundle.random_target() + bundle.load(outputs["local_files"].dir, formatter="tar") + + # check if unpacked files/directories are described by the correct target class + for target in outputs["local_files"]._flat_target_list: + mismatch = ( + (isinstance(target, law.FileSystemFileTarget) and not os.path.isfile(target.abspath)) or + (isinstance(target, law.FileSystemDirectoryTarget) and not os.path.isdir(target.abspath)) + ) + if mismatch: + raise Exception(f"mismatching file/directory type of unpacked target {target!r}") diff --git a/columnflow/tasks/framework/base.py b/columnflow/tasks/framework/base.py index 0cf4e0d42..4315af4e4 100644 --- a/columnflow/tasks/framework/base.py +++ b/columnflow/tasks/framework/base.py @@ -38,6 +38,11 @@ default_repr_max_count = law.config.get_expanded_int("analysis", "repr_max_count") default_repr_hash_len = law.config.get_expanded_int("analysis", "repr_hash_len") +# cached and parsed sections of the law config for faster lookup +_cfg_outputs_dict = None +_cfg_versions_dict = None +_cfg_resources_dict = None + # placeholder to denote a default value that is resolved dynamically RESOLVE_DEFAULT = "DEFAULT" @@ -130,11 +135,6 @@ class AnalysisTask(BaseTask, law.SandboxTask): exclude_params_branch = {"user"} exclude_params_workflow = {"user", "notify_slack", "notify_mattermost", "notify_custom"} - # cached and parsed sections of the law config for faster lookup - _cfg_outputs_dict = None - _cfg_versions_dict = None - _cfg_resources_dict = None - @classmethod def modify_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().modify_param_values(params) @@ -197,7 +197,7 @@ def req_params(cls, inst: AnalysisTask, **kwargs) -> dict[str, Any]: not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") and ( cls.task_family != inst.task_family or - freeze(cls.get_config_lookup_keys(params)) != freeze(inst.get_config_lookup_keys(params)) + freeze(cls.get_config_lookup_keys(params)) != freeze(inst.get_config_lookup_keys(inst)) ) ): default_version = cls.get_default_version(inst, params) @@ -233,17 +233,19 @@ def _structure_cfg_items(cls, items: list[tuple[str, Any]]) -> dict[str, Any]: d[part] = {"*": d[part]} d = d[part] else: - # assign value to the last nesting level - if part in d and isinstance(d[part], dict): - d[part]["*"] = value - else: + # assign value to the last nesting level, do not overwrite + if part not in d: d[part] = value + elif isinstance(d[part], dict): + d[part]["*"] = value return items_dict @classmethod def _get_cfg_outputs_dict(cls) -> dict[str, Any]: - if cls._cfg_outputs_dict is None and law.config.has_section("outputs"): + global _cfg_outputs_dict + + if _cfg_outputs_dict is None and law.config.has_section("outputs"): # collect config item pairs skip_keys = {"wlcg_file_systems", "lfn_sources"} items = [ @@ -251,26 +253,30 @@ def _get_cfg_outputs_dict(cls) -> dict[str, Any]: for key, value in law.config.items("outputs") if value and key not in skip_keys ] - cls._cfg_outputs_dict = cls._structure_cfg_items(items) + _cfg_outputs_dict = cls._structure_cfg_items(items) - return cls._cfg_outputs_dict + return _cfg_outputs_dict @classmethod def _get_cfg_versions_dict(cls) -> dict[str, Any]: - if cls._cfg_versions_dict is None and law.config.has_section("versions"): + global _cfg_versions_dict + + if _cfg_versions_dict is None and law.config.has_section("versions"): # collect config item pairs items = [ (key, value) for key, value in law.config.items("versions") if value ] - cls._cfg_versions_dict = cls._structure_cfg_items(items) + _cfg_versions_dict = cls._structure_cfg_items(items) - return cls._cfg_versions_dict + return _cfg_versions_dict @classmethod def _get_cfg_resources_dict(cls) -> dict[str, Any]: - if cls._cfg_resources_dict is None and law.config.has_section("resources"): + global _cfg_resources_dict + + if _cfg_resources_dict is None and law.config.has_section("resources"): # helper to split resource values into key-value pairs themselves def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: params = [] @@ -294,9 +300,9 @@ def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: for key, value in law.config.items("resources") if value and not key.startswith("_") ] - cls._cfg_resources_dict = cls._structure_cfg_items(items) + _cfg_resources_dict = cls._structure_cfg_items(items) - return cls._cfg_resources_dict + return _cfg_resources_dict @classmethod def get_default_version(cls, inst: AnalysisTask, params: dict[str, Any]) -> str | None: diff --git a/columnflow/tasks/framework/histograms.py b/columnflow/tasks/framework/histograms.py index 03df5e0b7..81a8682b1 100644 --- a/columnflow/tasks/framework/histograms.py +++ b/columnflow/tasks/framework/histograms.py @@ -33,7 +33,7 @@ class HistogramsUserBase( CategoriesMixin, VariablesMixin, ): - single_config = True + single_config = False sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -44,29 +44,49 @@ def store_parts(self) -> law.util.InsertableDict: def load_histogram( self, + inputs: dict, + config: str | od.Config, dataset: str | od.Dataset, variable: str | od.Variable, + update_label: bool = True, ) -> hist.Hist: """ Helper function to load the histogram from the input for a given dataset and variable. + :param inputs: The inputs dictionary containing the histograms. + :param config: The config name or instance. :param dataset: The dataset name or instance. :param variable: The variable name or instance. + :param update_label: Whether to update the label of the variable axis in the histogram. + If True, the label will be updated based on the first config instance's variable label. :return: The loaded histogram. """ + if isinstance(dataset, od.Dataset): dataset = dataset.name if isinstance(variable, od.Variable): variable = variable.name - histogram = self.input()[dataset]["collection"][0]["hists"].targets[variable].load(formatter="pickle") + if isinstance(config, od.Config): + config = config.name + histogram = inputs[config][dataset]["collection"][0]["hists"].targets[variable].load(formatter="pickle") + + if update_label: + # get variable label from first config instance + for var_name in variable.split("-"): + label = self.config_insts[0].get_variable(var_name).x_title + ax_names = [ax.name for ax in histogram.axes] + if var_name in ax_names: + # update the label of the variable axis + histogram.axes[var_name].label = label return histogram def slice_histogram( self, histogram: hist.Hist, - processes: str | list[str], - categories: str | list[str], - shifts: str | list[str], + config_inst: od.Config, + processes: str | list[str] | None = None, + categories: str | list[str] | None = None, + shifts: str | list[str] | None = None, reduce_axes: bool = False, ) -> hist.Hist: """ @@ -87,49 +107,56 @@ def slice_histogram( def flatten_nested_list(nested_list): return [item for sublist in nested_list for item in sublist] - # transform into lists if necessary - processes = law.util.make_list(processes) - categories = law.util.make_list(categories) - shifts = law.util.make_list(shifts) - - # get all leaf categories - category_insts = list(map(self.config_inst.get_category, categories)) - leaf_category_insts = set(flatten_nested_list([ - category_inst.get_leaf_categories() or [category_inst] - for category_inst in category_insts - ])) + selection_dict = {} - # get all sub processes - process_insts = list(map(self.config_inst.get_process, processes)) - sub_process_insts = set(flatten_nested_list([ - [sub for sub, _, _ in proc.walk_processes(include_self=True)] - for proc in process_insts - ])) + if processes: + # transform into lists if necessary + processes = law.util.make_list(processes) + # get all sub processes - # get all shift instances - shift_insts = [self.config_inst.get_shift(shift) for shift in shifts] - - # work on a copy - h = histogram.copy() - - # axis selections - h = h[{ - "process": [ + process_insts = list(map(config_inst.get_process, processes)) + sub_process_insts = set(flatten_nested_list([ + [sub for sub, _, _ in proc.walk_processes(include_self=True)] + for proc in process_insts + ])) + selection_dict["process"] = [ hist.loc(p.name) for p in sub_process_insts - if p.name in h.axes["process"] - ], - "category": [ + if p.name in histogram.axes["process"] + ] + if categories: + # transform into lists if necessary + categories = law.util.make_list(categories) + + # get all leaf categories + category_insts = list(map(config_inst.get_category, categories)) + leaf_category_insts = set(flatten_nested_list([ + category_inst.get_leaf_categories() or [category_inst] + for category_inst in category_insts + ])) + selection_dict["category"] = [ hist.loc(c.name) for c in leaf_category_insts - if c.name in h.axes["category"] - ], - "shift": [ + if c.name in histogram.axes["category"] + ] + + if shifts: + # transform into lists if necessary + shifts = law.util.make_list(shifts) + + # get all shift instances + shift_insts = [config_inst.get_shift(shift) for shift in shifts] + selection_dict["shift"] = [ hist.loc(s.name) for s in shift_insts - if s.name in h.axes["shift"] - ], - }] + if s.name in histogram.axes["shift"] + ] + + # work on a copy + h = histogram.copy() + + # axis selections + h = h[selection_dict] if reduce_axes: # axis reductions @@ -156,14 +183,20 @@ def workflow_requires(self): return reqs def requires(self): + datasets = [self.datasets] if self.single_config else self.datasets return { - d: self.reqs.MergeHistograms.req_different_branching( - self, - dataset=d, - branch=-1, - _prefer_cli={"variables"}, - ) - for d in self.datasets + config_inst.name: { + d: self.reqs.MergeHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=d, + branch=-1, + _prefer_cli={"variables"}, + ) + for d in datasets[i] + if config_inst.has_dataset(d) + } + for i, config_inst in enumerate(self.config_insts) } @@ -185,12 +218,17 @@ def workflow_requires(self): return reqs def requires(self): + datasets = [self.datasets] if self.single_config else self.datasets return { - d: self.reqs.MergeShiftedHistograms.req_different_branching( - self, - dataset=d, - branch=-1, - _prefer_cli={"variables"}, - ) - for d in self.datasets + config_inst.name: { + d: self.reqs.MergeShiftedHistograms.req_different_branching( + self, + config=config_inst.name, + dataset=d, + branch=-1, + _prefer_cli={"variables"}, + ) + for d in datasets[i] + } + for i, config_inst in enumerate(self.config_insts) } diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index e4a4ce9e7..36b99d5ad 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -29,7 +29,6 @@ from columnflow.types import Callable from columnflow.timing import Timer - np = maybe_import("numpy") ak = maybe_import("awkward") diff --git a/columnflow/tasks/framework/remote.py b/columnflow/tasks/framework/remote.py index fae1d3559..da2c59c2c 100644 --- a/columnflow/tasks/framework/remote.py +++ b/columnflow/tasks/framework/remote.py @@ -665,6 +665,11 @@ def handle_scheduler_message(self, msg, _attr_value=None): input_unit="GB", unit="GB", ) +_default_htcondor_runtime = law.util.parse_duration( + law.config.get_expanded("analysis", "htcondor_runtime", 3.0), + input_unit="h", + unit="h", +) class HTCondorWorkflow(RemoteWorkflowMixin, law.htcondor.HTCondorWorkflow): @@ -674,11 +679,11 @@ class HTCondorWorkflow(RemoteWorkflowMixin, law.htcondor.HTCondorWorkflow): significant=False, description="transfer job logs to the output directory; default: True", ) - max_runtime = law.DurationParameter( - default=2.0, + htcondor_runtime = law.DurationParameter( + default=_default_htcondor_runtime, unit="h", significant=False, - description="maximum runtime; default unit is hours; default: 2", + description=f"maximum runtime; default unit is hours; default: {_default_htcondor_runtime}", ) htcondor_logs = luigi.BoolParameter( default=False, @@ -732,12 +737,12 @@ class HTCondorWorkflow(RemoteWorkflowMixin, law.htcondor.HTCondorWorkflow): # parameters that should not be passed to a workflow required upstream exclude_params_req_set = { - "max_runtime", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", + "htcondor_runtime", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", } # parameters that should not be passed from workflow to branches exclude_params_branch = { - "max_runtime", "htcondor_logs", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", + "htcondor_runtime", "htcondor_logs", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_flavor", "htcondor_share_software", } @@ -767,7 +772,7 @@ def __init__(self, *args, **kwargs): self.bundle_repo_req = self.reqs.BundleRepo.req(self) # add scheduler message handlers - self.add_message_handler("max_runtime") + self.add_message_handler("htcondor_runtime") self.add_message_handler("htcondor_logs") self.add_message_handler("htcondor_cpus") self.add_message_handler("htcondor_gpus") @@ -847,8 +852,8 @@ def htcondor_job_config(self, config, job_num, branches): config.custom_content.append(("Request_OpSysAndVer", "\"RedHat9\"")) # maximum runtime, compatible with multiple batch systems - if self.max_runtime is not None and self.max_runtime > 0: - max_runtime = int(math.floor(self.max_runtime * 3600)) - 1 + if self.htcondor_runtime is not None and self.htcondor_runtime > 0: + max_runtime = int(math.floor(self.htcondor_runtime * 3600)) - 1 config.custom_content.append(("+MaxRuntime", max_runtime)) config.custom_content.append(("+RequestRuntime", max_runtime)) @@ -902,6 +907,11 @@ def htcondor_destination_info(self, info: dict[str, str]) -> dict[str, str]: _default_slurm_flavor = law.config.get_expanded("analysis", "slurm_flavor", "maxwell") _default_slurm_partition = law.config.get_expanded("analysis", "slurm_partition", "cms-uhh") +_default_slurm_runtime = law.util.parse_duration( + law.config.get_expanded("analysis", "slurm_runtime", 3.0), + input_unit="h", + unit="h", +) class SlurmWorkflow(RemoteWorkflowMixin, law.slurm.SlurmWorkflow): @@ -911,11 +921,11 @@ class SlurmWorkflow(RemoteWorkflowMixin, law.slurm.SlurmWorkflow): significant=False, description="transfer job logs to the output directory; default: True", ) - max_runtime = law.DurationParameter( - default=2.0, + slurm_runtime = law.DurationParameter( + default=_default_slurm_runtime, unit="h", significant=False, - description="maximum runtime; default unit is hours; default: 2", + description=f"maximum runtime; default unit is hours; default: {_default_slurm_runtime}", ) slurm_partition = luigi.Parameter( default=_default_slurm_partition, @@ -931,10 +941,10 @@ class SlurmWorkflow(RemoteWorkflowMixin, law.slurm.SlurmWorkflow): ) # parameters that should not be passed to a workflow required upstream - exclude_params_req_set = {"max_runtime"} + exclude_params_req_set = {"slurm_runtime"} # parameters that should not be passed from workflow to branches - exclude_params_branch = {"max_runtime", "slurm_partition", "slurm_flavor"} + exclude_params_branch = {"slurm_runtime", "slurm_partition", "slurm_flavor"} # mapping of environment variables to render variables that are forwarded slurm_forward_env_variables = { @@ -986,9 +996,9 @@ def slurm_job_config(self, config, job_num, branches): ) # set job time - if self.max_runtime is not None: + if self.slurm_runtime is not None and self.slurm_runtime > 0: job_time = law.util.human_duration( - seconds=int(math.floor(self.max_runtime * 3600)) - 1, + seconds=int(math.floor(self.slurm_runtime * 3600)) - 1, colon_format=True, ) config.custom_content.append(("time", job_time)) diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index b78003302..f57e18b14 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -21,19 +21,27 @@ from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.tasks.ml import MLEvaluation -from columnflow.hist_util import sum_hists +from columnflow.hist_util import update_ax_labels, sum_hists from columnflow.util import dev_sandbox +class VariablesMixinWorkflow( + VariablesMixin, + law.LocalWorkflow, + RemoteWorkflow, +): + + def control_output_postfix(self) -> str: + return f"{super().control_output_postfix()}__vars_{self.variables_repr}" + + class _CreateHistograms( ReducedEventsUser, ProducersMixin, MLModelsMixin, HistProducerMixin, ChunkedIOMixin, - VariablesMixin, - law.LocalWorkflow, - RemoteWorkflow, + VariablesMixinWorkflow, ): """ Base classes for :py:class:`CreateHistograms`. @@ -100,11 +108,13 @@ def workflow_requires(self): self.reqs.MLEvaluation.req(self, ml_model=ml_model_inst.cls_name) for ml_model_inst in self.ml_model_insts ] + elif self.producer_insts: + # pass-through pilot workflow requirements of upstream task + t = self.reqs.ProduceColumns.req(self) + law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) - # add hist_producer dependent requirements - reqs["hist_producer"] = law.util.make_unique(law.util.flatten( - self.hist_producer_inst.run_requires(task=self), - )) + # add hist producer dependent requirements + reqs["hist_producer"] = law.util.make_unique(law.util.flatten(self.hist_producer_inst.run_requires(task=self))) return reqs @@ -241,6 +251,10 @@ def run(self): events = attach_coffea_behavior(events) events, weight = self.hist_producer_inst(events, task=self) + if len(events) == 0: + self.publish_message(f"no events found in chunk {pos}") + continue + # merge category ids and check that they are defined as leaf categories category_ids = ak.concatenate( [Route(c).apply(events) for c in self.category_id_columns], @@ -340,9 +354,7 @@ class _MergeHistograms( ProducersMixin, MLModelsMixin, HistProducerMixin, - VariablesMixin, - law.LocalWorkflow, - RemoteWorkflow, + VariablesMixinWorkflow, ): """ Base classes for :py:class:`MergeHistograms`. @@ -444,9 +456,12 @@ def run(self): variable_names = list(hists[0].keys()) for variable_name in self.iter_progress(variable_names, len(variable_names), reach=(50, 100)): self.publish_message(f"merging histograms for '{variable_name}'") + variable_hists = [h[variable_name] for h in hists] + + # update axis labels from variable insts for consistency + update_ax_labels(variable_hists, self.config_inst, variable_name) # merge them - variable_hists = [h[variable_name] for h in hists] merged = sum_hists(variable_hists) # post-process the merged histogram @@ -479,9 +494,7 @@ class _MergeShiftedHistograms( ProducerClassesMixin, MLModelsMixin, HistProducerClassMixin, - VariablesMixin, - law.LocalWorkflow, - RemoteWorkflow, + VariablesMixinWorkflow, ): """ Base classes for :py:class:`MergeShiftedHistograms`. @@ -544,6 +557,9 @@ def run(self): for coll in inputs.values() ] + # update axis labels from variable insts for consistency + update_ax_labels(variable_hists, self.config_inst, variable_name) + # merge and write the output merged = sum_hists(variable_hists) outp.dump(merged, formatter="pickle") diff --git a/columnflow/tasks/inspection.py b/columnflow/tasks/inspection.py index 3d9a1ce3b..11b472577 100644 --- a/columnflow/tasks/inspection.py +++ b/columnflow/tasks/inspection.py @@ -26,17 +26,42 @@ def output(self): return {"always_incomplete_dummy": self.target("dummy.txt")} def run(self): + """ + Loads histograms for all configs, variables, and datasets, + sums them up for each variable and + slices them according to the processes, categories, and shift, + The resulting histograms are stored in a dictionary with variable names as keys. + If `debugger` is set to True, an IPython debugger session is started for + interactive inspection of the histograms. + """ + inputs = self.input() + shifts = {self.shift, "nominal"} hists = {} - for dataset in self.datasets: - for variable in self.variables: - h_in = self.load_histogram(dataset, variable) - h_in = self.slice_histogram(h_in, self.processes, self.categories, self.shift) + for variable in self.variables: + for i, config_inst in enumerate(self.config_insts): + hist_per_config = None + sub_processes = self.processes[i] + for dataset in self.datasets[i]: + # sum over all histograms of the same variable and config + if hist_per_config is None: + hist_per_config = self.load_histogram(inputs, config_inst, dataset, variable) + else: + hist_per_config += self.load_histogram(inputs, config_inst, dataset, variable) + + # slice histogram per config according to the sub_processes and categories + hist_per_config = self.slice_histogram( + histogram=hist_per_config, + config_inst=config_inst, + processes=sub_processes, + categories=self.categories, + shifts=shifts, + ) if variable in hists.keys(): - hists[variable] += h_in + hists[variable] += hist_per_config else: - hists[variable] = h_in + hists[variable] = hist_per_config if self.debugger: from IPython import embed @@ -57,18 +82,42 @@ def output(self): return {"always_incomplete_dummy": self.target("dummy.txt")} def run(self): + """ + Loads histograms for all configs, variables, and datasets, + sums them up for each variable and + slices them according to the processes, categories, and shift, + The resulting histograms are stored in a dictionary with variable names as keys. + If `debugger` is set to True, an IPython debugger session is started for + interactive inspection of the histograms. + """ + inputs = self.input() shifts = ["nominal"] + self.shifts hists = {} - for dataset in self.datasets: - for variable in self.variables: - h_in = self.load_histogram(dataset, variable) - h_in = self.slice_histogram(h_in, self.processes, self.categories, shifts) + for variable in self.variables: + for i, config_inst in enumerate(self.config_insts): + hist_per_config = None + sub_processes = self.processes[i] + for dataset in self.datasets[i]: + # sum over all histograms of the same variable and config + if hist_per_config is None: + hist_per_config = self.load_histogram(inputs, config_inst, dataset, variable) + else: + hist_per_config += self.load_histogram(inputs, config_inst, dataset, variable) + + # slice histogram per config according to the sub_processes and categories + hist_per_config = self.slice_histogram( + histogram=hist_per_config, + config_inst=config_inst, + processes=sub_processes, + categories=self.categories, + shifts=shifts, + ) if variable in hists.keys(): - hists[variable] += h_in + hists[variable] += hist_per_config else: - hists[variable] = h_in + hists[variable] = hist_per_config if self.debugger: from IPython import embed diff --git a/columnflow/tasks/production.py b/columnflow/tasks/production.py index 842dfbdcd..bb5f8d090 100644 --- a/columnflow/tasks/production.py +++ b/columnflow/tasks/production.py @@ -50,9 +50,7 @@ def workflow_requires(self): reqs["events"] = self.reqs.ProvideReducedEvents.req(self) # add producer dependent requirements - reqs["producer"] = law.util.make_unique(law.util.flatten( - self.producer_inst.run_requires(task=self), - )) + reqs["producer"] = law.util.make_unique(law.util.flatten(self.producer_inst.run_requires(task=self))) return reqs diff --git a/columnflow/tasks/reduction.py b/columnflow/tasks/reduction.py index 5deef6bb8..1d62a5d71 100644 --- a/columnflow/tasks/reduction.py +++ b/columnflow/tasks/reduction.py @@ -74,14 +74,13 @@ def workflow_requires(self): if calibrator_inst.produced_columns ] reqs["selection"] = self.reqs.SelectEvents.req(self) - # reducer dependent requirements - reqs["reducer"] = law.util.make_unique(law.util.flatten( - self.reducer_inst.run_requires(task=self), - )) else: # pass-through pilot workflow requirements of upstream task t = self.reqs.SelectEvents.req(self) - reqs = law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) + law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) + + # add reducer dependent requirements + reqs["reducer"] = law.util.make_unique(law.util.flatten(self.reducer_inst.run_requires(task=self))) return reqs @@ -219,8 +218,8 @@ def run(self): events = self.reducer_inst(events, selection=sel, task=self) n_reduced += len(events) - # no need to proceed when no events are left - if len(events) == 0: + # no need to proceed when no events are left (except for the last chunk to create empty output) + if len(events) == 0 and (output_chunks or pos.index < pos.n_chunks - 1): continue # remove columns diff --git a/columnflow/tasks/selection.py b/columnflow/tasks/selection.py index cc5d9e7d7..cc2b32fae 100644 --- a/columnflow/tasks/selection.py +++ b/columnflow/tasks/selection.py @@ -90,12 +90,10 @@ def workflow_requires(self): elif self.calibrator_insts: # pass-through pilot workflow requirements of upstream task t = self.reqs.CalibrateEvents.req(self) - reqs = law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) + law.util.merge_dicts(reqs, t.workflow_requires(), inplace=True) # add selector dependent requirements - reqs["selector"] = law.util.make_unique(law.util.flatten( - self.selector_inst.run_requires(task=self), - )) + reqs["selector"] = law.util.make_unique(law.util.flatten(self.selector_inst.run_requires(task=self))) return reqs diff --git a/columnflow/tasks/union.py b/columnflow/tasks/union.py index 933c7dd53..e73784175 100644 --- a/columnflow/tasks/union.py +++ b/columnflow/tasks/union.py @@ -37,6 +37,12 @@ class UniteColumns(_UniteColumns): sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + keep_columns_key = luigi.Parameter( + default=law.NO_STR, + description="if the 'keep_columns' auxiliary config entry for the task family 'cf.UniteColumns' is defined as " + "a dictionary, this key can selects which of the entries of columns to use; uses all columns when empty; " + "default: empty", + ) file_type = luigi.ChoiceParameter( default="parquet", choices=("parquet", "root"), @@ -104,7 +110,8 @@ def requires(self): @workflow_condition.output def output(self): - return {"events": self.target(f"data_{self.branch}.{self.file_type}")} + key_postfix = "" if self.keep_columns_key in {law.NO_STR, "", None} else f"_{self.keep_columns_key}" + return {"events": self.target(f"data_{self.branch}{key_postfix}.{self.file_type}")} @law.decorator.notify @law.decorator.log @@ -127,7 +134,18 @@ def run(self): # define columns that will be written write_columns: set[Route] = set() skip_columns: set[Route] = set() - for c in self.config_inst.x.keep_columns.get(self.task_family, ["*"]): + keep_struct = self.config_inst.x.keep_columns.get(self.task_family, ["*"]) + if isinstance(keep_struct, dict): + if self.keep_columns_key not in {law.NO_STR, "", None}: + if self.keep_columns_key not in keep_struct: + raise KeyError( + f"keep_columns_key '{self.keep_columns_key}' not found in keep_columns config entry for " + f"task family '{self.task_family}', existing keys: {list(keep_struct.keys())}", + ) + keep_struct = keep_struct[self.keep_columns_key] + else: + keep_struct = law.util.flatten(keep_struct.values()) + for c in law.util.make_unique(keep_struct): for r in self._expand_keep_column(c): if r.has_tag("skip"): skip_columns.add(r) diff --git a/law.cfg b/law.cfg index aaefc05fb..2ba8f2ba3 100644 --- a/law.cfg +++ b/law.cfg @@ -275,4 +275,4 @@ wait_interval: 20 check_unfulfilled_deps: False cache_task_completion: True keep_alive: $CF_WORKER_KEEP_ALIVE -force_multiprocessing: False +force_multiprocessing: $CF_REMOTE_ENV diff --git a/modules/law b/modules/law index 3adec62db..44b98b7dc 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 3adec62db42d1fe8021c792538fe66ee1ed77b91 +Subproject commit 44b98b7dcd434badd003fd498eaf399e14c3ee53 diff --git a/setup.sh b/setup.sh index 059a51433..bdcf9337b 100644 --- a/setup.sh +++ b/setup.sh @@ -663,6 +663,7 @@ EOF echo cf_color cyan "setting up conda / micromamba environment" micromamba install \ + gcc \ libgcc \ bash \ zsh \ From db65c2c4a93230108790ce80886f03e264529f47 Mon Sep 17 00:00:00 2001 From: Jules Vandenbroeck Date: Thu, 13 Nov 2025 14:39:05 +0100 Subject: [PATCH 121/123] updated law module --- modules/law | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/law b/modules/law index 44b98b7dc..73ff4fd52 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit 44b98b7dcd434badd003fd498eaf399e14c3ee53 +Subproject commit 73ff4fd520ddecff5288ee3804aa4b4c8d929858 From 8cfa9a58d0eda9a2c09c7fba3f053dc900e0ed52 Mon Sep 17 00:00:00 2001 From: Jules Vandenbroeck Date: Thu, 20 Nov 2025 11:44:34 +0100 Subject: [PATCH 122/123] missing import of hist --- columnflow/plotting/plot_all.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index ef93d9566..7f8d67ad6 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -34,6 +34,8 @@ def draw_stat_error_bands( norm: float | Sequence | np.ndarray = 1.0, **kwargs, ) -> None: + import hist + assert len(h.axes) == 1 # compute relative statistical errors From 776f441e5057a3697435850229a7da728a3232f8 Mon Sep 17 00:00:00 2001 From: Jules Vandenbroeck Date: Thu, 20 Nov 2025 11:45:28 +0100 Subject: [PATCH 123/123] fix pt overflow slowing down b-tagging scale factors --- columnflow/production/cmsGhent/btag_weights.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/columnflow/production/cmsGhent/btag_weights.py b/columnflow/production/cmsGhent/btag_weights.py index eaf1879a7..bb5112023 100644 --- a/columnflow/production/cmsGhent/btag_weights.py +++ b/columnflow/production/cmsGhent/btag_weights.py @@ -140,6 +140,9 @@ def fixed_wp_btag_weights( # get the total number of jets in the chunk jets = events.Jet[jet_mask] if jet_mask is not None else events.Jet jets = set_ak_column(jets, "abseta", abs(jets.eta)) + # currently set hard max on pt for efficiency since overflow could not be changed in correctionlib + # (could also manually change the flow) + jets = set_ak_column(jets, "minpt", ak.where(jets.pt <= 999, jets.pt, 999)) # helper to create and store the weight def add_weight(flavour_group, systematic, variation=None): @@ -185,9 +188,7 @@ def sf_eff_wp(working_point, none_value=0.): ) eff = self.btag_eff_corrector( flat_input.hadronFlavour, - # currently set hard max on pt since overflow could not be changed in correctionlib - # (could also manually change the flow) - ak.min([flat_input.pt, 999 * ak.ones_like(flat_input.pt)], axis=0), + flat_input.minpt, flat_input.abseta, working_point, )