diff --git a/bin/cf_sandbox_file_hash b/bin/cf_sandbox_file_hash index 18f846c35..bf3ae5387 100755 --- a/bin/cf_sandbox_file_hash +++ b/bin/cf_sandbox_file_hash @@ -11,6 +11,6 @@ action() { setopt globdots fi - python "${this_dir}/$( basename "${this_file}" ).py" "$@" + python3 "${this_dir}/$( basename "${this_file}" ).py" "$@" } action "$@" diff --git a/columnflow/calibration/cms/jets.py b/columnflow/calibration/cms/jets.py index bd910264b..20e600fa3 100644 --- a/columnflow/calibration/cms/jets.py +++ b/columnflow/calibration/cms/jets.py @@ -1,1085 +1,1109 @@ -# coding: utf-8 - -""" -Jet energy corrections and jet resolution smearing. -""" - -import functools - -import law - -from columnflow.types import Any -from columnflow.calibration import Calibrator, calibrator -from columnflow.calibration.util import ak_random, propagate_met -from columnflow.production.util import attach_coffea_behavior -from columnflow.util import maybe_import, InsertableDict, DotDict -from columnflow.columnar_util import set_ak_column, layout_ak_array, optional_column as optional - -np = maybe_import("numpy") -ak = maybe_import("awkward") -correctionlib = maybe_import("correctionlib") - -logger = law.logger.get_logger(__name__) - - -# -# helper functions -# - -set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) - - -def get_evaluators( - correction_set: correctionlib.highlevel.CorrectionSet, - names: list[str], -) -> list[Any]: - """ - Helper function to get a list of correction evaluators from a - :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` object given - a list of *names*. The *names* can refer to either simple or compound - corrections. - - :param correction_set: evaluator provided by :external+correctionlib:doc:`index` - :param names: List of names of corrections to be applied - :raises RuntimeError: If a requested correction in *names* is not available - :return: List of compounded corrections, see - :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` - """ - # raise nice error if keys not found - available_keys = set(correction_set.keys()).union(correction_set.compound.keys()) - missing_keys = set(names) - available_keys - if missing_keys: - raise RuntimeError("corrections not found:" + "".join( - f"\n - {name}" for name in names if name in missing_keys - ) + "\navailable:" + "".join( - f"\n - {name}" for name in sorted(available_keys) - )) - - # retrieve the evaluators - return [ - correction_set.compound[name] - if name in correction_set.compound - else correction_set[name] - for name in names - ] - - -def ak_evaluate(evaluator: correctionlib.highlevel.Correction, *args) -> float: - """ - Evaluate a :external+correctionlib:py:class:`correctionlib.highlevel.Correction` - using one or more :external+ak:py:class:`awkward arrays ` as inputs. - - :param evaluator: Evaluator instance - :raises ValueError: If no :external+ak:py:class:`awkward arrays ` are provided - :return: The correction factor derived from the input arrays - """ - # fail if no arguments - if not args: - raise ValueError("Expected at least one argument.") - - # collect arguments that are awkward arrays - ak_args = [ - arg for arg in args if isinstance(arg, ak.Array) - ] - - # broadcast akward arrays together and flatten - if ak_args: - bc_args = ak.broadcast_arrays(*ak_args) - flat_args = ( - np.asarray(ak.flatten(bc_arg, axis=None)) - for bc_arg in bc_args - ) - output_layout_array = bc_args[0] - else: - flat_args = iter(()) - output_layout_array = None - - # multiplex flattened and non-awkward inputs - all_flat_args = [ - next(flat_args) if isinstance(arg, ak.Array) else arg - for arg in args - ] - - # apply evaluator to flattened/multiplexed inputs - result = evaluator.evaluate(*all_flat_args) - - # apply broadcasted layout to result - if output_layout_array is not None: - result = layout_ak_array(result, output_layout_array) - - return result - - -# -# jet energy corrections -# - -# define default functions for jec calibrator -def get_jerc_file_default(self: Calibrator, external_files: DotDict) -> str: - """ - Function to obtain external correction files for JEC and/or JER. - - By default, this function extracts the location of the jec correction - files from the current config instance *config_inst*. The key of the - external file depends on the jet collection. For ``Jet`` (AK4 jets), this - resolves to ``jet_jerc``, and for ``FatJet`` it is resolved to - ``fat_jet_jerc``. - - .. code-block:: python - - cfg.x.external_files = DotDict.wrap({ - "jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/jet_jerc.json.gz", - "fat_jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/fatJet_jerc.json.gz", - }) - - :param external_files: Dictionary containing the information about the file location - :return: path or url to correction file(s) - """ # noqa - - # get config - try_attrs = ("get_jec_config", "get_jer_config") - jerc_config = None - for try_attr in try_attrs: - try: - jerc_config = getattr(self, try_attr)() - except AttributeError: - continue - else: - break - - # fail if not found - if jerc_config is None: - raise ValueError( - "could not retrieve jer/jec config, none of the following methods " - f"were found: {try_attrs}", - ) - - # first check config for user-supplied `external_file_key` - ext_file_key = jerc_config.get("external_file_key", None) - if ext_file_key is not None: - return external_files[ext_file_key] - - # if not found, try to resolve from jet collection name and fail if not standard NanoAOD - if self.jet_name not in get_jerc_file_default.map_jet_name_file_key: - available_keys = ", ".join(sorted(get_jerc_file_default.map_jet_name_file_key)) - raise ValueError( - f"could not determine external file key for jet collection '{self.jet_name}', " - f"name is not one of standard NanoAOD jet collections: {available_keys}", - ) - - # return external file - ext_file_key = get_jerc_file_default.map_jet_name_file_key[self.jet_name] - return external_files[ext_file_key] - - -# default external file keys for known jet collections -get_jerc_file_default.map_jet_name_file_key = { - "Jet": "jet_jerc", - "FatJet": "fat_jet_jerc", -} - - -def get_jec_config_default(self: Calibrator) -> DotDict: - """ - Load config relevant to the jet energy corrections (JEC). - - By default, this is extracted from the current *config_inst*, - assuming the JEC configurations are stored under the 'jec' - aux key. Separate configurations should be specified for each - jet collection, using the collection name as a key. For example, - the configuration for the default jet collection ``Jet`` will - be retrieved from the following config entry: - - .. code-block:: python - - self.config_inst.x.jec.Jet - - Used in :py:meth:`~.jec.setup_func`. - - :return: Dictionary containing configuration for jet energy calibration - """ - jec_cfg = self.config_inst.x.jec - - # check for old-style config - if self.jet_name not in jec_cfg: - # if jet collection is `Jet`, issue deprecation warning - if self.jet_name == "Jet": - logger.warning_once( - f"{id(self)}_depr_jec_config", - "config aux 'jec' does not contain key for input jet " - f"collection '{self.jet_name}'. This may be due to " - "an outdated config. Continuing under the assumption that " - "the entire 'jec' entry refers to this jet collection. " - "This assumption will be removed in future versions of " - "columnflow, so please adapt the config according to the " - "documentation to remove this warning and ensure future " - "compatibility of the code.", - ) - return jec_cfg - - # otherwise raise exception - raise ValueError( - "config aux 'jec' does not contain key for input jet " - f"collection '{self.jet_name}'.", - ) - - return jec_cfg[self.jet_name] - - -@calibrator( - uses={ - optional("fixedGridRhoFastjetAll"), - optional("Rho.fixedGridRhoFastjetAll"), - attach_coffea_behavior, - }, - # name of the jet collection to calibrate - jet_name="Jet", - # name of the associated MET collection - met_name="MET", - # name of the associated Raw MET collection - raw_met_name="RawMET", - # custom uncertainty sources, defaults to config when empty - uncertainty_sources=None, - # toggle for propagation to MET - propagate_met=True, - # function to determine the correction file - get_jec_file=get_jerc_file_default, - # function to determine the jec configuration dict - get_jec_config=get_jec_config_default, -) -def jec( - self: Calibrator, - events: ak.Array, - min_pt_met_prop: float = 15.0, - max_eta_met_prop: float = 5.2, - **kwargs, -) -> ak.Array: - """Performs the jet energy corrections (JECs) and uncertainty shifts using the - :external+correctionlib:doc:`index`, optionally - propagating the changes to the MET. - - The *jet_name* should be set to the name of the NanoAOD jet collection to calibrate - (default: ``Jet``, i.e. AK4 jets). - - Requires an external file in the config pointing to the JSON files containing the JECs. - The file key can be specified via an optional ``external_file_key`` in the ``jec`` config entry. - If not given, the file key will be determined automatically based on the jet collection name: - ``jet_jerc`` for ``Jet`` (AK4 jets), ``fat_jet_jerc`` for``FatJet`` (AK8 jets). A full set of JSON files - can be specified as: - - .. code-block:: python - - cfg.x.external_files = DotDict.wrap({ - "jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/jet_jerc.json.gz", - "fat_jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/fatJet_jerc.json.gz", - }) - - For more file-grained control, the *get_jec_file* can be adapted in a subclass in case it is stored - differently in the external files - - The JEC configuration should be an auxiliary entry in the config, specifying the correction - details under "jec". Separate configs should be given for each jet collection to calibrate, - using the jet collection name as a subkey. An example of a valid configuration for correction - AK4 jets with JEC is: - - .. code-block:: python - - cfg.x.jec = { - "Jet": { - "campaign": "Summer19UL17", - "version": "V5", - "jet_type": "AK4PFchs", - "levels": ["L1L2L3Res"], # or individual correction levels - "levels_for_type1_met": ["L1FastJet"], - "uncertainty_sources": [ - "Total", - "CorrelationGroupMPFInSitu", - "CorrelationGroupIntercalibration", - "CorrelationGroupbJES", - "CorrelationGroupFlavor", - "CorrelationGroupUncorrelated", - ] - }, - } - - *get_jec_config* can be adapted in a subclass in case it is stored differently in the config. - - If running on data, the datasets must have an auxiliary field *jec_era* defined, e.g. "RunF", - or an auxiliary field *era*, e.g. "F". - - This instance of :py:class:`~columnflow.calibration.Calibrator` is - initialized with the following parameters by default: - - :param events: awkward array containing events to process - - :param min_pt_met_prop: If *propagate_met* variable is ``True`` propagate the updated jet values - to the missing transverse energy (MET) using - :py:func:`~columnflow.calibration.util.propagate_met` for events where - ``met.pt > *min_pt_met_prop*``. - :param max_eta_met_prop: If *propagate_met* variable is ``True`` propagate the updated jet - values to the missing transverse energy (MET) using - :py:func:`~columnflow.calibration.util.propagate_met` for events where - ``met.eta > *min_eta_met_prop*``. - """ # noqa - # use local variable for convenience - jet_name = self.jet_name - - # calculate uncorrected pt, mass - events = set_ak_column_f32(events, f"{jet_name}.pt_raw", events[jet_name].pt * (1 - events[jet_name].rawFactor)) - events = set_ak_column_f32(events, f"{jet_name}.mass_raw", events[jet_name].mass * (1 - events[jet_name].rawFactor)) - - def correct_jets(*, pt, eta, phi, area, rho, evaluator_key="jec"): - # variable naming convention - variable_map = { - "JetA": area, - "JetEta": eta, - "JetPt": pt, - "JetPhi": phi, - "Rho": ak.values_astype(rho, np.float32), - } - - # apply all correctors sequentially, updating the pt each time - full_correction = ak.ones_like(pt, dtype=np.float32) - for corrector in self.evaluators[evaluator_key]: - # determine correct inputs (change depending on corrector) - inputs = [ - variable_map[inp.name] - for inp in corrector.inputs - ] - correction = ak_evaluate(corrector, *inputs) - # update pt for subsequent correctors - variable_map["JetPt"] = variable_map["JetPt"] * correction - full_correction = full_correction * correction - - return full_correction - - # obtain rho, which might be located at different routes, depending on the nano version - rho = ( - events.fixedGridRhoFastjetAll - if "fixedGridRhoFastjetAll" in events.fields - else events.Rho.fixedGridRhoFastjetAll - ) - - # correct jets with only a subset of correction levels - # (for calculating TypeI MET correction) - if self.propagate_met: - # get correction factors - jec_factors_subset_type1_met = correct_jets( - pt=events[jet_name].pt_raw, - eta=events[jet_name].eta, - phi=events[jet_name].phi, - area=events[jet_name].area, - rho=rho, - evaluator_key="jec_subset_type1_met", - ) - - # temporarily apply the new factors with only subset of corrections - events = set_ak_column_f32(events, f"{jet_name}.pt", events[jet_name].pt_raw * jec_factors_subset_type1_met) - events = set_ak_column_f32(events, f"{jet_name}.mass", events[jet_name].mass_raw * jec_factors_subset_type1_met) - events = self[attach_coffea_behavior](events, collections=[jet_name], **kwargs) - - # store pt and phi of the full jet system for MET propagation, including a selection in raw info - # see https://twiki.cern.ch/twiki/bin/view/CMS/JECAnalysesRecommendations?rev=19#Minimum_jet_selection_cuts - met_prop_mask = (events[jet_name].pt_raw > min_pt_met_prop) & (abs(events[jet_name].eta) < max_eta_met_prop) - jetsum = events[jet_name][met_prop_mask].sum(axis=1) - jetsum_pt_subset_type1_met = jetsum.pt - jetsum_phi_subset_type1_met = jetsum.phi - - # factors for full jet correction with all levels - jec_factors = correct_jets( - pt=events[jet_name].pt_raw, - eta=events[jet_name].eta, - phi=events[jet_name].phi, - area=events[jet_name].area, - rho=rho, - evaluator_key="jec", - ) - - # apply full jet correction - events = set_ak_column_f32(events, f"{jet_name}.pt", events[jet_name].pt_raw * jec_factors) - events = set_ak_column_f32(events, f"{jet_name}.mass", events[jet_name].mass_raw * jec_factors) - rawFactor = ak.nan_to_num(1 - events[jet_name].pt_raw / events[jet_name].pt, nan=0.0) - events = set_ak_column_f32(events, f"{jet_name}.rawFactor", rawFactor) - events = self[attach_coffea_behavior](events, collections=[jet_name], **kwargs) - - # nominal met propagation - if self.propagate_met: - # get pt and phi of all jets after correcting - jetsum = events[jet_name][met_prop_mask].sum(axis=1) - jetsum_pt_all_levels = jetsum.pt - jetsum_phi_all_levels = jetsum.phi - # propagate changes to MET, starting from jets corrected with subset of JEC levels - # (recommendation is to propagate only L2 corrections and onwards) - met_pt, met_phi = propagate_met( - jetsum_pt_subset_type1_met, - jetsum_phi_subset_type1_met, - jetsum_pt_all_levels, - jetsum_phi_all_levels, - events[self.raw_met_name].pt, - events[self.raw_met_name].phi, - ) - events = set_ak_column_f32(events, f"{self.met_name}.pt", met_pt) - events = set_ak_column_f32(events, f"{self.met_name}.phi", met_phi) - - # variable naming conventions - variable_map = { - "JetEta": events[jet_name].eta, - "JetPt": events[jet_name].pt_raw, - } - - # jet energy uncertainty components - for name, evaluator in self.evaluators["junc"].items(): - # get uncertainty - inputs = [variable_map[inp.name] for inp in evaluator.inputs] - jec_uncertainty = ak_evaluate(evaluator, *inputs) - - # apply jet uncertainty shifts - events = set_ak_column_f32( - events, f"{jet_name}.pt_jec_{name}_up", events[jet_name].pt * (1.0 + jec_uncertainty), - ) - events = set_ak_column_f32( - events, f"{jet_name}.pt_jec_{name}_down", events[jet_name].pt * (1.0 - jec_uncertainty), - ) - events = set_ak_column_f32( - events, f"{jet_name}.mass_jec_{name}_up", events[jet_name].mass * (1.0 + jec_uncertainty), - ) - events = set_ak_column_f32( - events, f"{jet_name}.mass_jec_{name}_down", events[jet_name].mass * (1.0 - jec_uncertainty), - ) - - # propagate shifts to MET - if self.propagate_met: - jet_pt_up = events[jet_name][met_prop_mask][f"pt_jec_{name}_up"] - jet_pt_down = events[jet_name][met_prop_mask][f"pt_jec_{name}_down"] - met_pt_up, met_phi_up = propagate_met( - jetsum_pt_all_levels, - jetsum_phi_all_levels, - jet_pt_up, - events[jet_name][met_prop_mask].phi, - met_pt, - met_phi, - ) - met_pt_down, met_phi_down = propagate_met( - jetsum_pt_all_levels, - jetsum_phi_all_levels, - jet_pt_down, - events[jet_name][met_prop_mask].phi, - met_pt, - met_phi, - ) - events = set_ak_column_f32(events, f"{self.met_name}.pt_jec_{name}_up", met_pt_up) - events = set_ak_column_f32(events, f"{self.met_name}.pt_jec_{name}_down", met_pt_down) - events = set_ak_column_f32(events, f"{self.met_name}.phi_jec_{name}_up", met_phi_up) - events = set_ak_column_f32(events, f"{self.met_name}.phi_jec_{name}_down", met_phi_down) - - return events - - -@jec.init -def jec_init(self: Calibrator) -> None: - jec_cfg = self.get_jec_config() - - sources = self.uncertainty_sources - if sources is None: - sources = jec_cfg.uncertainty_sources - - # register used jet columns - self.uses.add(f"{self.jet_name}.{{pt,eta,phi,mass,area,rawFactor}}") - - # register produced jet columns - self.produces.add(f"{self.jet_name}.{{pt,mass,rawFactor}}") - - # add shifted jet variables - self.produces |= { - f"{self.jet_name}.{shifted_var}_jec_{junc_name}_{junc_dir}" - for shifted_var in ("pt", "mass") - for junc_name in sources - for junc_dir in ("up", "down") - } - - # add MET variables - if self.propagate_met: - self.uses.add(f"{self.raw_met_name}.{{pt,phi}}") - self.produces.add(f"{self.met_name}.{{pt,phi}}") - - # add shifted MET variables - self.produces |= { - f"{self.met_name}.{shifted_var}_jec_{junc_name}_{junc_dir}" - for shifted_var in ("pt", "phi") - for junc_name in sources - for junc_dir in ("up", "down") - } - - -@jec.requires -def jec_requires(self: Calibrator, reqs: dict) -> None: - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) - - -@jec.setup -def jec_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: - """ - Load the correct jec files using the :py:func:`from_string` method of the - :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` - function and apply the corrections as needed. - - The source files for the :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` - instance are extracted with the :py:meth:`~.jec.get_jec_file`. - - Uses the member function :py:meth:`~.jec.get_jec_config` to construct the - required keys, which are based on the following information about the JEC: - - - levels - - campaign - - version - - jet_type - - A corresponding example snippet wihtin the *config_inst* could like something - like this: - - .. code-block:: python - - cfg.x.jec = DotDict.wrap({ - "Jet": { - # campaign name for this JEC correctiono - "campaign": f"Summer19UL{year2}{jerc_postfix}", - # version of the corrections - "version": "V7", - # Type of jets that the corrections should be applied on - "jet_type": "AK4PFchs", - # relevant levels in the derivation process of the JEC - "levels": ["L1FastJet", "L2Relative", "L2L3Residual", "L3Absolute"], - # relevant levels in the derivation process of the Type 1 MET JEC - "levels_for_type1_met": ["L1FastJet"], - # names of the uncertainties to be applied - "uncertainty_sources": [ - "Total", - "CorrelationGroupMPFInSitu", - "CorrelationGroupIntercalibration", - "CorrelationGroupbJES", - "CorrelationGroupFlavor", - "CorrelationGroupUncorrelated", - ], - }, - }) - - :param reqs: Requirement dictionary for this - :py:class:`~columnflow.calibration.Calibrator` instance - :param inputs: Additional inputs, currently not used - :param reader_targets: TODO: add documentation - """ - bundle = reqs["external_files"] - - # import the correction sets from the external file - import correctionlib - correction_set = correctionlib.CorrectionSet.from_string( - self.get_jec_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) - - # compute JEC keys from config information - jec_cfg = self.get_jec_config() - - def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): - if is_data: - jec_era = self.dataset_inst.get_aux("jec_era", None) - # if no special JEC era is specified, infer based on 'era' - if jec_era is None: - jec_era = "Run" + self.dataset_inst.get_aux("era") - - return [ - f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{name}_{jec.jet_type}" - if is_data else - f"{jec.campaign}_{jec.version}_MC_{name}_{jec.jet_type}" - for name in names - ] - - # take sources from constructor or config - sources = self.uncertainty_sources - if sources is None: - sources = jec_cfg.uncertainty_sources - - jec_keys = make_jme_keys(jec_cfg.levels) - jec_keys_subset_type1_met = make_jme_keys(jec_cfg.levels_for_type1_met) - junc_keys = make_jme_keys(sources, is_data=False) # uncertainties only stored as MC keys - - # store the evaluators - self.evaluators = { - "jec": get_evaluators(correction_set, jec_keys), - "jec_subset_type1_met": get_evaluators(correction_set, jec_keys_subset_type1_met), - "junc": dict(zip(sources, get_evaluators(correction_set, junc_keys))), - } - - -# custom jec calibrator that only runs nominal correction -jec_nominal = jec.derive("jec_nominal", cls_dict={"uncertainty_sources": []}) - -# explicit calibrators for standard jet collections -jec_ak4 = jec.derive("jec_ak4", cls_dict={"jet_name": "Jet"}) -jec_ak8 = jec.derive("jec_ak8", cls_dict={"jet_name": "FatJet", "propagate_met": False}) -jec_ak4_nominal = jec_ak4.derive("jec_ak4", cls_dict={"uncertainty_sources": []}) -jec_ak8_nominal = jec_ak8.derive("jec_ak8", cls_dict={"uncertainty_sources": []}) - - -def get_jer_config_default(self: Calibrator) -> DotDict: - """ - Load config relevant to the jet energy resolution (JER) smearing. - - By default, this is extracted from the current *config_inst*, - assuming the JER configurations are stored under the 'jer' - aux key. Separate configurations should be specified for each - jet collection, using the collection name as a key. For example, - the configuration for the default jet collection ``Jet`` will - be retrieved from the following config entry: - - .. code-block:: python - - self.config_inst.x.jer.Jet - - Used in :py:meth:`~.jer.setup_func`. - - :return: Dictionary containing configuration for JER smearing - """ - jer_cfg = self.config_inst.x.jer - - # check for old-style config - if self.jet_name not in jer_cfg: - # if jet collection is `Jet`, issue deprecation warning - if self.jet_name == "Jet": - logger.warning_once( - f"{id(self)}_depr_jer_config", - "config aux 'jer' does not contain key for input jet " - f"collection '{self.jet_name}'. This may be due to " - "an outdated config. Continuing under the assumption that " - "the entire 'jer' entry refers to this jet collection. " - "This assumption will be removed in future versions of " - "columnflow, so please adapt the config according to the " - "documentation to remove this warning and ensure future " - "compatibility of the code.", - ) - return jer_cfg - - # otherwise raise exception - raise ValueError( - "config aux 'jer' does not contain key for input jet " - f"collection '{self.jet_name}'.", - ) - - return jer_cfg[self.jet_name] - - -# -# jet energy resolution smearing -# - -@calibrator( - uses={ - optional("Rho.fixedGridRhoFastjetAll"), - optional("fixedGridRhoFastjetAll"), - attach_coffea_behavior, - }, - # name of the jet collection to smear - jet_name="Jet", - # name of the associated gen jet collection - gen_jet_name="GenJet", - # name of the associated MET collection - met_name="MET", - # toggle for propagation to MET - propagate_met=True, - # only run on mc - mc_only=True, - # use deterministic seeds for random smearing and - # take the "index"-th random number per seed when not -1 - deterministic_seed_index=-1, - # function to determine the correction file - get_jer_file=get_jerc_file_default, - # function to determine the jer configuration dict - get_jer_config=get_jer_config_default, -) -def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: - """ - Applies the jet energy resolution smearing in MC and calculates the associated uncertainty - shifts using the :external+correctionlib:doc:`index`, following the recommendations given in - https://twiki.cern.ch/twiki/bin/viewauth/CMS/JetResolution. - - The *jet_name* and *gen_jet_name* should be set to the name of the NanoAOD jet and gen jet - collections to use as an input for JER smearing (default: ``Jet`` and ``GenJet``, respectively, - i.e. AK4 jets). - - Requires an external file in the config pointing to the JSON files containing the JER information. - The file key can be specified via an optional ``external_file_key`` in the ``jer`` config entry. - If not given, the file key will be determined automatically based on the jet collection name: - ``jet_jerc`` for ``Jet`` (AK4 jets), ``fat_jet_jerc`` for``FatJet`` (AK8 jets). A full set of JSON files - can be specified as: - - .. code-block:: python - - cfg.x.external_files = DotDict.wrap({ - "jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/jet_jerc.json.gz", - "fat_jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/fatJet_jerc.json.gz", - }) - - For more fine-grained control, the *get_jer_file* can be adapted in a subclass in case it is stored - differently in the external files. - - The JER smearing configuration should be an auxiliary entry in the config, specifying the input - JER to use under "jer". Separate configs should be given for each jet collection to smear, using - the jet collection name as a subkey. An example of a valid configuration for smearing - AK4 jets with JER is: - - .. code-block:: python - - cfg.x.jer = { - "Jet": { - "campaign": "Summer19UL17", - "version": "JRV2", - "jet_type": "AK4PFchs", - }, - } - - *get_jer_config* can be adapted in a subclass in case it is stored differently in the config. - - Throws an error if running on data. - - :param events: awkward array containing events to process - """ # noqa - # use local variables for convenience - jet_name = self.jet_name - gen_jet_name = self.gen_jet_name - - # fail when running on data - if self.dataset_inst.is_data: - raise ValueError("attempt to apply jet energy resolution smearing in data") - - # save the unsmeared properties in case they are needed later - events = set_ak_column_f32(events, f"{jet_name}.pt_unsmeared", events[jet_name].pt) - events = set_ak_column_f32(events, f"{jet_name}.mass_unsmeared", events[jet_name].mass) - - # obtain rho, which might be located at different routes, depending on the nano version - rho = ( - events.fixedGridRhoFastjetAll - if "fixedGridRhoFastjetAll" in events.fields else - events.Rho.fixedGridRhoFastjetAll - ) - - # variable naming convention - variable_map = { - "JetEta": events[jet_name].eta, - "JetPt": events[jet_name].pt, - "Rho": rho, - } - - # pt resolution - inputs = [variable_map[inp.name] for inp in self.evaluators["jer"].inputs] - jer = ak_evaluate(self.evaluators["jer"], *inputs) - - # JER scale factors and systematic variations - jersf = {} - for syst in ("nom", "up", "down"): - variable_map_syst = dict(variable_map, systematic=syst) - inputs = [variable_map_syst[inp.name] for inp in self.evaluators["sf"].inputs] - jersf[syst] = ak_evaluate(self.evaluators["sf"], *inputs) - - # array with all JER scale factor variations as an additional axis - # (note: axis needs to be regular for broadcasting to work correctly) - jersf = ak.concatenate( - [jersf[syst][..., None] for syst in ("nom", "up", "down")], - axis=-1, - ) - - # -- stochastic smearing - # normally distributed random numbers according to JER - jer_random_normal = ( - ak_random(0, jer, events[jet_name].deterministic_seed, rand_func=self.deterministic_normal) - if self.deterministic_seed_index >= 0 - else ak_random(0, jer, rand_func=np.random.Generator( - np.random.SFC64(events.event.to_list())).normal, - ) - ) - - # scale random numbers according to JER SF - jersf2_m1 = jersf ** 2 - 1 - add_smear = np.sqrt(ak.where(jersf2_m1 < 0, 0, jersf2_m1)) - - # broadcast over JER SF variations - jer_random_normal, jersf_z = ak.broadcast_arrays(jer_random_normal, add_smear) - - # compute smearing factors (stochastic method) - smear_factors_stochastic = 1.0 + jer_random_normal * add_smear - - # -- scaling method (using gen match) - - # mask negative gen jet indices (= no gen match) - gen_jet_idx = events[jet_name][self.gen_jet_idx_column] - valid_gen_jet_idxs = ak.mask(gen_jet_idx, gen_jet_idx >= 0) - - # pad list of gen jets to prevent index error on match lookup - max_gen_jet_idx = ak.max(valid_gen_jet_idxs) - padded_gen_jets = ak.pad_none( - events[gen_jet_name], - 0 if max_gen_jet_idx is None else (max_gen_jet_idx + 1), - ) - - # gen jets that match the reconstructed jets - matched_gen_jets = padded_gen_jets[valid_gen_jet_idxs] - - # compute the relative (reco - gen) pt difference - pt_relative_diff = (events[jet_name].pt - matched_gen_jets.pt) / events[jet_name].pt - - # test if matched gen jets are within 3 * resolution - is_matched_pt = np.abs(pt_relative_diff) < 3 * jer - is_matched_pt = ak.fill_none(is_matched_pt, False) # masked values = no gen match - - # (no check for Delta-R matching criterion; we assume this was done during - # nanoAOD production to get the `genJetIdx`) - - # broadcast over JER SF variations - pt_relative_diff, jersf = ak.broadcast_arrays(pt_relative_diff, jersf) - - # compute smearing factors (scaling method) - smear_factors_scaling = 1.0 + (jersf - 1.0) * pt_relative_diff - - # -- hybrid smearing: take smear factors from scaling if there was a match, - # otherwise take the stochastic ones - smear_factors = ak.where( - is_matched_pt[:, :, None], - smear_factors_scaling, - smear_factors_stochastic, - ) - - # ensure array is not nullable (avoid ambiguity on Arrow/Parquet conversion) - smear_factors = ak.fill_none(smear_factors, 0.0) - - # store pt and phi of the full jet system - if self.propagate_met: - jetsum = events[jet_name].sum(axis=1) - jetsum_pt_before = jetsum.pt - jetsum_phi_before = jetsum.phi - - # apply the smearing factors to the pt and mass - # (note: apply variations first since they refer to the original pt) - events = set_ak_column_f32(events, f"{jet_name}.pt_jer_up", events[jet_name].pt * smear_factors[:, :, 1]) - events = set_ak_column_f32(events, f"{jet_name}.mass_jer_up", events[jet_name].mass * smear_factors[:, :, 1]) - events = set_ak_column_f32(events, f"{jet_name}.pt_jer_down", events[jet_name].pt * smear_factors[:, :, 2]) - events = set_ak_column_f32(events, f"{jet_name}.mass_jer_down", events[jet_name].mass * smear_factors[:, :, 2]) - events = set_ak_column_f32(events, f"{jet_name}.pt", events[jet_name].pt * smear_factors[:, :, 0]) - events = set_ak_column_f32(events, f"{jet_name}.mass", events[jet_name].mass * smear_factors[:, :, 0]) - - # recover coffea behavior - events = self[attach_coffea_behavior](events, collections=[jet_name], **kwargs) - - # met propagation - if self.propagate_met: - - # save unsmeared quantities - events = set_ak_column_f32(events, f"{self.met_name}.pt_unsmeared", events[self.met_name].pt) - events = set_ak_column_f32(events, f"{self.met_name}.phi_unsmeared", events[self.met_name].phi) - - # get pt and phi of all jets after correcting - jetsum = events[jet_name].sum(axis=1) - jetsum_pt_after = jetsum.pt - jetsum_phi_after = jetsum.phi - - # propagate changes to MET - met_pt, met_phi = propagate_met( - jetsum_pt_before, - jetsum_phi_before, - jetsum_pt_after, - jetsum_phi_after, - events[self.met_name].pt, - events[self.met_name].phi, - ) - events = set_ak_column_f32(events, f"{self.met_name}.pt", met_pt) - events = set_ak_column_f32(events, f"{self.met_name}.phi", met_phi) - - # syst variations on top of corrected MET - met_pt_up, met_phi_up = propagate_met( - jetsum_pt_after, - jetsum_phi_after, - events[jet_name].pt_jer_up, - events[jet_name].phi, - met_pt, - met_phi, - ) - met_pt_down, met_phi_down = propagate_met( - jetsum_pt_after, - jetsum_phi_after, - events[jet_name].pt_jer_down, - events[jet_name].phi, - met_pt, - met_phi, - ) - events = set_ak_column_f32(events, f"{self.met_name}.pt_jer_up", met_pt_up) - events = set_ak_column_f32(events, f"{self.met_name}.pt_jer_down", met_pt_down) - events = set_ak_column_f32(events, f"{self.met_name}.phi_jer_up", met_phi_up) - events = set_ak_column_f32(events, f"{self.met_name}.phi_jer_down", met_phi_down) - - return events - - -@jer.init -def jer_init(self: Calibrator) -> None: - # determine gen-level jet index column - lower_first = lambda s: s[0].lower() + s[1:] if s else s - self.gen_jet_idx_column = lower_first(self.gen_jet_name) + "Idx" - - # register used jet columns - self.uses.add(f"{self.jet_name}.{{pt,eta,phi,mass,{self.gen_jet_idx_column}}}") - - # register used gen jet columns - self.uses.add(f"{self.gen_jet_name}.{{pt,eta,phi}}") - - # register produced jet columns - self.produces.add(f"{self.jet_name}.{{pt,mass}}{{,_unsmeared,_jer_up,_jer_down}}") - - # register produced MET columns - if self.propagate_met: - # register used MET columns - self.uses.add(f"{self.met_name}.{{pt,phi}}") - - # register produced MET columns - self.produces.add(f"{self.met_name}.{{pt,phi}}{{,_jer_up,_jer_down,_unsmeared}}") - - -@jer.requires -def jer_requires(self: Calibrator, reqs: dict) -> None: - if "external_files" in reqs: - return - - from columnflow.tasks.external import BundleExternalFiles - reqs["external_files"] = BundleExternalFiles.req(self.task) - - -@jer.setup -def jer_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: - """ - Load the correct jer files using the :py:func:`from_string` method of the - :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` function and apply the - corrections as needed. - - The source files for the :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` - instance are extracted with the :py:meth:`~.jer.get_jer_file`. - - Uses the member function :py:meth:`~.jer.get_jer_config` to construct the required keys, which - are based on the following information about the JER: - - - campaign - - version - - jet_type - - A corresponding example snippet within the *config_inst* could like something like this: - - .. code-block:: python - - cfg.x.jer = DotDict.wrap({ - "Jet": { - "campaign": f"Summer19UL{year2}{jerc_postfix}", - "version": "JRV3", - "jet_type": "AK4PFchs", - }, - }) - - :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` - instance. - :param inputs: Additional inputs, currently not used. - :param reader_targets: TODO: add documentation. - """ - bundle = reqs["external_files"] - - # import the correction sets from the external file - import correctionlib - correction_set = correctionlib.CorrectionSet.from_string( - self.get_jer_file(bundle.files).load(formatter="gzip").decode("utf-8"), - ) - - # compute JER keys from config information - jer_cfg = self.get_jer_config() - jer_keys = { - "jer": f"{jer_cfg.campaign}_{jer_cfg.version}_MC_PtResolution_{jer_cfg.jet_type}", - "sf": f"{jer_cfg.campaign}_{jer_cfg.version}_MC_ScaleFactor_{jer_cfg.jet_type}", - } - - # store the evaluators - self.evaluators = { - name: get_evaluators(correction_set, [key])[0] - for name, key in jer_keys.items() - } - - # use deterministic seeds for random smearing if requested - if self.deterministic_seed_index >= 0: - idx = self.deterministic_seed_index - bit_generator = np.random.SFC64 - def deterministic_normal(loc, scale, seed): - return np.asarray([ - np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1)[-1] - for _loc, _scale, _seed in zip(loc, scale, seed) - ]) - self.deterministic_normal = deterministic_normal - - -# explicit calibrators for standard jet collections -jer_ak4 = jer.derive("jer_ak4", cls_dict={"jet_name": "Jet", "gen_jet_name": "GenJet"}) -jer_ak8 = jer.derive("jer_ak8", cls_dict={"jet_name": "FatJet", "gen_jet_name": "GenJetAK8", "propagate_met": False}) - - -# -# single calibrator for doing both JEC and JER smearing -# - -@calibrator( - uses={jec, jer}, - produces={jec, jer}, - # name of the jet collection to smear - jet_name="Jet", - # name of the associated gen jet collection (for JER smearing) - gen_jet_name="GenJet", - # toggle for propagation to MET - propagate_met=None, - # functions to determine configs and files - get_jec_file=None, - get_jec_config=None, - get_jer_file=None, - get_jer_config=None, -) -def jets(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: - """ - Instance of :py:class:`~columnflow.calibration.Calibrator` that does all relevant calibrations - for jets, i.e. JEC and JER. For more information, see :py:func:`~.jec` and :py:func:`~.jer`. - - :param events: awkward array containing events to process - """ - # apply jet energy corrections - events = self[jec](events, **kwargs) - - # apply jer smearing on MC only - if self.dataset_inst.is_mc: - events = self[jer](events, **kwargs) - - return events - - -@jets.init -def jets_init(self: Calibrator) -> None: - # forward argument to the producers - self.deps_kwargs[jec]["jet_name"] = self.jet_name - self.deps_kwargs[jer]["jet_name"] = self.jet_name - self.deps_kwargs[jer]["gen_jet_name"] = self.gen_jet_name - if self.propagate_met is not None: - self.deps_kwargs[jec]["propagate_met"] = self.propagate_met - self.deps_kwargs[jer]["propagate_met"] = self.propagate_met - if self.get_jec_file is not None: - self.deps_kwargs[jec]["get_jec_file"] = self.get_jec_file - if self.get_jec_config is not None: - self.deps_kwargs[jec]["get_jec_config"] = self.get_jec_config - if self.get_jer_file is not None: - self.deps_kwargs[jer]["get_jer_file"] = self.get_jer_file - if self.get_jer_config is not None: - self.deps_kwargs[jer]["get_jer_config"] = self.get_jer_config - - -# explicit calibrators for standard jet collections -jets_ak4 = jets.derive("jets_ak4", cls_dict={"jet_name": "Jet", "gen_jet_name": "GenJet"}) -jets_ak8 = jets.derive("jets_ak8", cls_dict={"jet_name": "FatJet", "gen_jet_name": "GenJetAK8"}) +# # coding: utf-8 + +# """ +# Jet energy corrections and jet resolution smearing. +# """ +# from pprint import pprint + +# import functools + +# import law + +# from columnflow.types import Any +# from columnflow.calibration import Calibrator, calibrator +# from columnflow.calibration.util import ak_random, propagate_met +# from columnflow.production.util import attach_coffea_behavior +# from columnflow.util import maybe_import, InsertableDict, DotDict +# from columnflow.columnar_util import set_ak_column, layout_ak_array, optional_column as optional + +# np = maybe_import("numpy") +# ak = maybe_import("awkward") +# correctionlib = maybe_import("correctionlib") + +# logger = law.logger.get_logger(__name__) + + +# # +# # helper functions +# # + +# set_ak_column_f32 = functools.partial(set_ak_column, value_type=np.float32) + + +# import difflib + +# def get_evaluators( +# correction_set: correctionlib.highlevel.CorrectionSet, +# names: list[str], +# ) -> list[Any]: +# """ +# Helper function to get a list of correction evaluators from a +# :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` object given +# a list of *names*. The *names* can refer to either simple or compound +# corrections. + +# :param correction_set: evaluator provided by :external+correctionlib:doc:`index` +# :param names: List of names of corrections to be applied +# :raises RuntimeError: If a requested correction in *names* is not available +# :return: List of compounded corrections, see +# :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` +# """ +# available_keys = set(correction_set.keys()).union(correction_set.compound.keys()) +# corrected_names = [] + +# for name in names: +# if name not in available_keys: +# # Find the closest match using difflib +# closest_matches = difflib.get_close_matches(name, available_keys, n=1) +# if closest_matches: +# closest_match = closest_matches[0] +# print( +# f"Correction '{name}' not found. Using closest match: '{closest_match}'", +# ) +# corrected_names.append(closest_match) +# else: +# raise RuntimeError(f"Correction '{name}' not found and no close match available.") +# else: +# corrected_names.append(name) + +# # Retrieve the evaluators +# return [ +# correction_set.compound[name] +# if name in correction_set.compound +# else correction_set[name] +# for name in corrected_names +# ] + +# def ak_evaluate(evaluator: correctionlib.highlevel.Correction, *args) -> float: +# """ +# Evaluate a :external+correctionlib:py:class:`correctionlib.highlevel.Correction` +# using one or more :external+ak:py:class:`awkward arrays ` as inputs. + +# :param evaluator: Evaluator instance +# :raises ValueError: If no :external+ak:py:class:`awkward arrays ` are provided +# :return: The correction factor derived from the input arrays +# """ +# # fail if no arguments +# if not args: +# raise ValueError("Expected at least one argument.") + +# # collect arguments that are awkward arrays +# ak_args = [ +# arg for arg in args if isinstance(arg, ak.Array) +# ] + +# # broadcast akward arrays together and flatten +# if ak_args: +# bc_args = ak.broadcast_arrays(*ak_args) +# flat_args = ( +# np.asarray(ak.flatten(bc_arg, axis=None)) +# for bc_arg in bc_args +# ) +# output_layout_array = bc_args[0] +# else: +# flat_args = iter(()) +# output_layout_array = None + +# # multiplex flattened and non-awkward inputs +# all_flat_args = [ +# next(flat_args) if isinstance(arg, ak.Array) else arg +# for arg in args +# ] + +# # apply evaluator to flattened/multiplexed inputs +# result = evaluator.evaluate(*all_flat_args) + +# # apply broadcasted layout to result +# if output_layout_array is not None: +# result = layout_ak_array(result, output_layout_array) + +# return result + + +# # +# # jet energy corrections +# # +# def get_jec_file_default(self, external_files: DotDict) -> str: +# """ +# Function to obtain external jec files. + +# By default, this function extracts the location of the jec correction +# files from the current config instance *config_inst*: + +# .. code-block:: python + +# cfg.x.external_files = DotDict.wrap({ +# "jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/jet_jerc.json.gz", +# }) + +# :param external_files: Dictionary containing the information about the file location +# :return: path or url to correction file(s) +# """ # noqa +# return external_files.jet_jerc + + +# # define default functions for jec calibrator +# def get_jerc_file_default(self: Calibrator, external_files: DotDict) -> str: +# """ +# Function to obtain external correction files for JEC and/or JER. + +# By default, this function extracts the location of the jec correction +# files from the current config instance *config_inst*. The key of the +# external file depends on the jet collection. For ``Jet`` (AK4 jets), this +# resolves to ``jet_jerc``, and for ``FatJet`` it is resolved to +# ``fat_jet_jerc``. + +# .. code-block:: python + +# cfg.x.external_files = DotDict.wrap({ +# "jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/jet_jerc.json.gz", +# "fat_jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/fatJet_jerc.json.gz", +# }) + +# :param external_files: Dictionary containing the information about the file location +# :return: path or url to correction file(s) +# """ # noqa + +# # get config +# try_attrs = ("get_jec_config", "get_jer_config") +# jerc_config = None +# for try_attr in try_attrs: +# try: +# jerc_config = getattr(self, try_attr)() +# except AttributeError: +# continue +# else: +# break + +# # fail if not found +# if jerc_config is None: +# raise ValueError( +# "could not retrieve jer/jec config, none of the following methods " +# f"were found: {try_attrs}", +# ) + +# # first check config for user-supplied `external_file_key` +# ext_file_key = jerc_config.get("external_file_key", None) +# if ext_file_key is not None: +# return external_files[ext_file_key] + +# # if not found, try to resolve from jet collection name and fail if not standard NanoAOD +# if self.jet_name not in get_jerc_file_default.map_jet_name_file_key: +# available_keys = ", ".join(sorted(get_jerc_file_default.map_jet_name_file_key)) +# raise ValueError( +# f"could not determine external file key for jet collection '{self.jet_name}', " +# f"name is not one of standard NanoAOD jet collections: {available_keys}", +# ) + +# # return external file +# ext_file_key = get_jerc_file_default.map_jet_name_file_key[self.jet_name] +# return external_files[ext_file_key] + + +# # default external file keys for known jet collections +# get_jerc_file_default.map_jet_name_file_key = { +# "Jet": "jet_jerc", +# "FatJet": "fat_jet_jerc", +# } + + +# def get_jec_config_default(self: Calibrator) -> DotDict: +# """ +# Load config relevant to the jet energy corrections (JEC). + +# By default, this is extracted from the current *config_inst*, +# assuming the JEC configurations are stored under the 'jec' +# aux key. Separate configurations should be specified for each +# jet collection, using the collection name as a key. For example, +# the configuration for the default jet collection ``Jet`` will +# be retrieved from the following config entry: + +# .. code-block:: python + +# self.config_inst.x.jec.Jet + +# Used in :py:meth:`~.jec.setup_func`. + +# :return: Dictionary containing configuration for jet energy calibration +# """ +# jec_cfg = self.config_inst.x.jec + +# # check for old-style config +# if self.jet_name not in jec_cfg: +# # if jet collection is `Jet`, issue deprecation warning +# if self.jet_name == "Jet": +# logger.warning_once( +# f"{id(self)}_depr_jec_config", +# "config aux 'jec' does not contain key for input jet " +# f"collection '{self.jet_name}'. This may be due to " +# "an outdated config. Continuing under the assumption that " +# "the entire 'jec' entry refers to this jet collection. " +# "This assumption will be removed in future versions of " +# "columnflow, so please adapt the config according to the " +# "documentation to remove this warning and ensure future " +# "compatibility of the code.", +# ) +# return jec_cfg + +# # otherwise raise exception +# raise ValueError( +# "config aux 'jec' does not contain key for input jet " +# f"collection '{self.jet_name}'.", +# ) + +# return jec_cfg[self.jet_name] + + +# @calibrator( +# uses={ +# optional("fixedGridRhoFastjetAll"), +# optional("Rho.fixedGridRhoFastjetAll"), +# attach_coffea_behavior, +# }, +# # name of the jet collection to calibrate +# jet_name="Jet", +# # name of the associated MET collection +# met_name="MET", +# # name of the associated Raw MET collection +# raw_met_name="RawMET", +# # custom uncertainty sources, defaults to config when empty +# uncertainty_sources=None, +# # toggle for propagation to PuppiMET +# propagate_met=True, +# # # function to determine the correction file +# get_jec_file=get_jec_file_default, +# # # function to determine the jec configuration dict +# get_jec_config=get_jec_config_default, +# ) + +# def jec( +# self: Calibrator, +# events: ak.Array, +# min_pt_met_prop: float = 15.0, +# max_eta_met_prop: float = 5.2, +# **kwargs, +# ) -> ak.Array: +# """Performs the jet energy corrections (JECs) and uncertainty shifts using the +# :external+correctionlib:doc:`index`, optionally +# propagating the changes to the PuppiMET. + +# The *jet_name* should be set to the name of the NanoAOD jet collection to calibrate +# (default: ``Jet``, i.e. AK4 jets). + +# Requires an external file in the config pointing to the JSON files containing the JECs. +# The file key can be specified via an optional ``external_file_key`` in the ``jec`` config entry. +# If not given, the file key will be determined automatically based on the jet collection name: +# ``jet_jerc`` for ``Jet`` (AK4 jets), ``fat_jet_jerc`` for``FatJet`` (AK8 jets). A full set of JSON files +# can be specified as: + +# .. code-block:: python + +# cfg.x.external_files = DotDict.wrap({ +# "jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/jet_jerc.json.gz", +# "fat_jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/fatJet_jerc.json.gz", +# }) + +# For more file-grained control, the *get_jec_file* can be adapted in a subclass in case it is stored +# differently in the external files + +# The JEC configuration should be an auxiliary entry in the config, specifying the correction +# details under "jec". Separate configs should be given for each jet collection to calibrate, +# using the jet collection name as a subkey. An example of a valid configuration for correction +# AK4 jets with JEC is: + +# .. code-block:: python + +# cfg.x.jec = { +# "Jet": { +# "campaign": "Summer19UL17", +# "version": "V5", +# "jet_type": "AK4PFchs", +# "levels": ["L1L2L3Res"], # or individual correction levels +# "levels_for_type1_met": ["L1FastJet"], +# "uncertainty_sources": [ +# "Total", +# "CorrelationGroupMPFInSitu", +# "CorrelationGroupIntercalibration", +# "CorrelationGroupbJES", +# "CorrelationGroupFlavor", +# "CorrelationGroupUncorrelated", +# ] +# }, +# } + +# *get_jec_config* can be adapted in a subclass in case it is stored differently in the config. + +# If running on data, the datasets must have an auxiliary field *jec_era* defined, e.g. "RunF", +# or an auxiliary field *era*, e.g. "F". + +# This instance of :py:class:`~columnflow.calibration.Calibrator` is +# initialized with the following parameters by default: + +# :param events: awkward array containing events to process + +# :param min_pt_met_prop: If *propagate_met* variable is ``True`` propagate the updated jet values +# to the missing transverse energy (PuppiMET) using +# :py:func:`~columnflow.calibration.util.propagate_met` for events where +# ``met.pt > *min_pt_met_prop*``. +# :param max_eta_met_prop: If *propagate_met* variable is ``True`` propagate the updated jet +# values to the missing transverse energy (PuppiMET) using +# :py:func:`~columnflow.calibration.util.propagate_met` for events where +# ``met.eta > *min_eta_met_prop*``. +# """ # noqa + +# # calculate uncorrected pt, mass +# events = set_ak_column_f32(events, "Jet.pt_raw", events.Jet.pt * (1 - events.Jet.rawFactor)) +# events = set_ak_column_f32(events, "Jet.mass_raw", events.Jet.mass * (1 - events.Jet.rawFactor)) + +# # calculate uncorrected pt, mass +# events = set_ak_column_f32(events, f"{jet_name}.pt_raw", events[jet_name].pt * (1 - events[jet_name].rawFactor)) +# events = set_ak_column_f32(events, f"{jet_name}.mass_raw", events[jet_name].mass * (1 - events[jet_name].rawFactor)) + +# def correct_jets(*, pt, eta, phi, area, rho, evaluator_key="jec"): +# # variable naming convention +# variable_map = { +# "JetA": area, +# "JetEta": eta, +# "JetPt": pt, +# "JetPhi": phi, +# "Rho": ak.values_astype(rho, np.float32), +# } + +# # apply all correctors sequentially, updating the pt each time +# full_correction = ak.ones_like(pt, dtype=np.float32) + + +# for corrector in self.evaluators[evaluator_key]: +# # determine correct inputs (change depending on corrector) +# inputs = [ +# variable_map[inp.name] +# for inp in corrector.inputs +# ] +# correction = ak_evaluate(corrector, *inputs) +# # update pt for subsequent correctors +# #pprint(corrector.__dict__) # If `corrector` is a custom object with attributes +# variable_map["JetPt"] = variable_map["JetPt"] * correction +# full_correction = full_correction * correction + +# return full_correction + +# # obtain rho, which might be located at different routes, depending on the nano version +# rho = ( +# events.fixedGridRhoFastjetAll +# if "fixedGridRhoFastjetAll" in events.fields +# else events.Rho.fixedGridRhoFastjetAll +# ) + +# # correct jets with only a subset of correction levels +# # (for calculating TypeI PuppiMET correction) +# if self.propagate_met: +# # get correction factors +# jec_factors_subset_type1_met = correct_jets( +# pt=events[jet_name].pt_raw, +# eta=events[jet_name].eta, +# phi=events[jet_name].phi, +# area=events[jet_name].area, +# rho=rho, +# evaluator_key="jec_subset_type1_met", +# ) + +# # temporarily apply the new factors with only subset of corrections +# events = set_ak_column_f32(events, f"{jet_name}.pt", events[jet_name].pt_raw * jec_factors_subset_type1_met) +# events = set_ak_column_f32(events, f"{jet_name}.mass", events[jet_name].mass_raw * jec_factors_subset_type1_met) +# events = self[attach_coffea_behavior](events, collections=[jet_name], **kwargs) + +# # store pt and phi of the full jet system for PuppiMET propagation, including a selection in raw info +# # see https://twiki.cern.ch/twiki/bin/view/CMS/JECAnalysesRecommendations?rev=19#Minimum_jet_selection_cuts +# met_prop_mask = (events[jet_name].pt_raw > min_pt_met_prop) & (abs(events[jet_name].eta) < max_eta_met_prop) +# jetsum = events[jet_name][met_prop_mask].sum(axis=1) +# jetsum_pt_subset_type1_met = jetsum.pt +# jetsum_phi_subset_type1_met = jetsum.phi + +# # factors for full jet correction with all levels +# jec_factors = correct_jets( +# pt=events[jet_name].pt_raw, +# eta=events[jet_name].eta, +# phi=events[jet_name].phi, +# area=events[jet_name].area, +# rho=rho, +# evaluator_key="jec", +# ) + +# # apply full jet correction +# events = set_ak_column_f32(events, f"{jet_name}.pt", events[jet_name].pt_raw * jec_factors) +# events = set_ak_column_f32(events, f"{jet_name}.mass", events[jet_name].mass_raw * jec_factors) +# rawFactor = ak.nan_to_num(1 - events[jet_name].pt_raw / events[jet_name].pt, nan=0.0) +# events = set_ak_column_f32(events, f"{jet_name}.rawFactor", rawFactor) +# events = self[attach_coffea_behavior](events, collections=[jet_name], **kwargs) + +# # nominal met propagation +# if self.propagate_met: +# # get pt and phi of all jets after correcting +# jetsum = events[jet_name][met_prop_mask].sum(axis=1) +# jetsum_pt_all_levels = jetsum.pt +# jetsum_phi_all_levels = jetsum.phi + +# # propagate changes to PuppiMET, starting from jets corrected with subset of JEC levels +# # (recommendation is to propagate only L2 corrections and onwards) +# met_pt, met_phi = propagate_met( +# jetsum_pt_subset_type1_met, +# jetsum_phi_subset_type1_met, +# jetsum_pt_all_levels, +# jetsum_phi_all_levels, +# events.RawPuppiMET.pt, +# events.RawPuppiMET.phi, +# ) + +# events = set_ak_column_f32(events, "PuppiMET.pt", met_pt) +# events = set_ak_column_f32(events, "PuppiMET.phi", met_phi) + +# # variable naming conventions +# variable_map = { +# "JetEta": events[jet_name].eta, +# "JetPt": events[jet_name].pt_raw, +# } + +# # jet energy uncertainty components +# for name, evaluator in self.evaluators["junc"].items(): +# # get uncertainty +# inputs = [variable_map[inp.name] for inp in evaluator.inputs] +# jec_uncertainty = ak_evaluate(evaluator, *inputs) + +# # apply jet uncertainty shifts +# events = set_ak_column_f32( +# events, f"{jet_name}.pt_jec_{name}_up", events[jet_name].pt * (1.0 + jec_uncertainty), +# ) +# events = set_ak_column_f32( +# events, f"{jet_name}.pt_jec_{name}_down", events[jet_name].pt * (1.0 - jec_uncertainty), +# ) +# events = set_ak_column_f32( +# events, f"{jet_name}.mass_jec_{name}_up", events[jet_name].mass * (1.0 + jec_uncertainty), +# ) +# events = set_ak_column_f32( +# events, f"{jet_name}.mass_jec_{name}_down", events[jet_name].mass * (1.0 - jec_uncertainty), +# ) + +# # propagate shifts to PuppiMET +# if self.propagate_met: +# jet_pt_up = events[jet_name][met_prop_mask][f"pt_jec_{name}_up"] +# jet_pt_down = events[jet_name][met_prop_mask][f"pt_jec_{name}_down"] +# met_pt_up, met_phi_up = propagate_met( +# jetsum_pt_all_levels, +# jetsum_phi_all_levels, +# jet_pt_up, +# events[jet_name][met_prop_mask].phi, +# met_pt, +# met_phi, +# ) +# met_pt_down, met_phi_down = propagate_met( +# jetsum_pt_all_levels, +# jetsum_phi_all_levels, +# jet_pt_down, +# events[jet_name][met_prop_mask].phi, +# met_pt, +# met_phi, +# ) +# events = set_ak_column_f32(events, f"PuppiMET.pt_jec_{name}_up", met_pt_up) +# events = set_ak_column_f32(events, f"PuppiMET.pt_jec_{name}_down", met_pt_down) +# events = set_ak_column_f32(events, f"PuppiMET.phi_jec_{name}_up", met_phi_up) +# events = set_ak_column_f32(events, f"PuppiMET.phi_jec_{name}_down", met_phi_down) + +# return events + + +# @jec.init +# def jec_init(self: Calibrator) -> None: +# jec_cfg = self.get_jec_config() + +# sources = self.uncertainty_sources +# if sources is None: +# sources = jec_cfg.uncertainty_sources + +# # register used jet columns +# self.uses.add(f"{self.jet_name}.{{pt,eta,phi,mass,area,rawFactor}}") + +# # register produced jet columns +# self.produces.add(f"{self.jet_name}.{{pt,mass,rawFactor}}") + +# # add shifted jet variables +# self.produces |= { +# f"{self.jet_name}.{shifted_var}_jec_{junc_name}_{junc_dir}" +# for shifted_var in ("pt", "mass") +# for junc_name in sources +# for junc_dir in ("up", "down") +# } + +# # add PuppiMET variables +# if self.propagate_met: +# self.uses |= {"RawPuppiMET.pt", "RawPuppiMET.phi","PuppiMET.pt", "PuppiMET.phi"} +# self.produces |= {"PuppiMET.pt", "PuppiMET.phi"} + +# # add shifted PuppiMET variables +# self.produces |= { +# f"PuppiMET.{shifted_var}_jec_{junc_name}_{junc_dir}" +# for shifted_var in ("pt", "phi") +# for junc_name in sources +# for junc_dir in ("up", "down") +# } + + +# @jec.requires +# def jec_requires(self: Calibrator, reqs: dict) -> None: +# if "external_files" in reqs: +# return + +# from columnflow.tasks.external import BundleExternalFiles +# reqs["external_files"] = BundleExternalFiles.req(self.task) + + +# @jec.setup +# def jec_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: +# """ +# Load the correct jec files using the :py:func:`from_string` method of the +# :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` +# function and apply the corrections as needed. + +# The source files for the :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` +# instance are extracted with the :py:meth:`~.jec.get_jec_file`. + +# Uses the member function :py:meth:`~.jec.get_jec_config` to construct the +# required keys, which are based on the following information about the JEC: + +# - levels +# - campaign +# - version +# - jet_type + +# A corresponding example snippet wihtin the *config_inst* could like something +# like this: + +# .. code-block:: python + +# cfg.x.jec = DotDict.wrap({ +# # campaign name for this JEC correctiono +# "campaign": f"Summer19UL{year2}{jerc_postfix}", +# # version of the corrections +# "version": "V7", +# # Type of jets that the corrections should be applied on +# "jet_type": "AK4PFchs", +# # relevant levels in the derivation process of the JEC +# "levels": ["L1FastJet", "L2Relative", "L2L3Residual", "L3Absolute"], +# # relevant levels in the derivation process of the Type 1 PuppiMET JEC +# "levels_for_type1_met": ["L1FastJet"], +# # names of the uncertainties to be applied +# "uncertainty_sources": [ +# "Total", +# "CorrelationGroupMPFInSitu", +# "CorrelationGroupIntercalibration", +# "CorrelationGroupbJES", +# "CorrelationGroupFlavor", +# "CorrelationGroupUncorrelated", +# ], +# }) + +# :param reqs: Requirement dictionary for this +# :py:class:`~columnflow.calibration.Calibrator` instance +# :param inputs: Additional inputs, currently not used +# :param reader_targets: TODO: add documentation +# """ + +# bundle = reqs["external_files"] + +# # import the correction sets from the external file +# import correctionlib + +# correction_set = correctionlib.CorrectionSet.from_string( +# self.get_jec_file(bundle.files).load(formatter="gzip").decode("utf-8"), +# ) + +# # compute JEC keys from config information +# jec_cfg = self.get_jec_config() + +# def make_jme_keys(names, jec=jec_cfg, is_data=self.dataset_inst.is_data): +# if is_data: + +# jec_era = self.dataset_inst.get_aux("jec_era", None) +# # if no special JEC era is specified, infer based on 'era' +# if jec_era is None: +# jec_era = "Run" + self.dataset_inst.get_aux("era") + +# return [ +# f"{jec.campaign}_{jec_era}_{jec.version}_DATA_{name}_{jec.jet_type}" +# if is_data else +# f"{jec.campaign}_{jec.version}_MC_{name}_{jec.jet_type}" +# for name in names +# ] + +# # take sources from constructor or config +# sources = self.uncertainty_sources +# if sources is None: +# sources = jec_cfg.uncertainty_sources + +# if self.dataset_inst.is_data : +# jec_keys = make_jme_keys(jec_cfg.levels_DATA) +# else : +# jec_keys = make_jme_keys(jec_cfg.levels_MC) +# jec_keys_subset_type1_met = make_jme_keys(jec_cfg.levels_for_type1_met) +# junc_keys = make_jme_keys(sources, is_data=False) # uncertainties only stored as MC keys + +# # store the evaluators +# self.evaluators = { +# "jec": get_evaluators(correction_set, jec_keys), +# "jec_subset_type1_met": get_evaluators(correction_set, jec_keys_subset_type1_met), +# "junc": dict(zip(sources, get_evaluators(correction_set, junc_keys))), +# } + + +# # custom jec calibrator that only runs nominal correction +# jec_nominal = jec.derive("jec_nominal", cls_dict={"uncertainty_sources": []}) + +# # define default functions for jec calibrator +# def get_jer_file(self, external_files: DotDict) -> str: +# """ +# Load config relevant to the jet energy resolution (JER) smearing. + +# By default, this is extracted from the current *config_inst*, +# assuming the JER configurations are stored under the 'jer' +# aux key. Separate configurations should be specified for each +# jet collection, using the collection name as a key. For example, +# the configuration for the default jet collection ``Jet`` will +# be retrieved from the following config entry: + +# .. code-block:: python + +# self.config_inst.x.jer.Jet + +# Used in :py:meth:`~.jer.setup_func`. + +# :return: Dictionary containing configuration for JER smearing +# """ +# jer_cfg = self.config_inst.x.jer + +# # check for old-style config +# if self.jet_name not in jer_cfg: +# # if jet collection is `Jet`, issue deprecation warning +# if self.jet_name == "Jet": +# logger.warning_once( +# f"{id(self)}_depr_jer_config", +# "config aux 'jer' does not contain key for input jet " +# f"collection '{self.jet_name}'. This may be due to " +# "an outdated config. Continuing under the assumption that " +# "the entire 'jer' entry refers to this jet collection. " +# "This assumption will be removed in future versions of " +# "columnflow, so please adapt the config according to the " +# "documentation to remove this warning and ensure future " +# "compatibility of the code.", +# ) +# return jer_cfg + +# # otherwise raise exception +# raise ValueError( +# "config aux 'jer' does not contain key for input jet " +# f"collection '{self.jet_name}'.", +# ) + +# return jer_cfg[self.jet_name] + + +# # +# # jet energy resolution smearing +# # + +# @calibrator( +# uses={ +# optional("Rho.fixedGridRhoFastjetAll"), +# optional("fixedGridRhoFastjetAll"), +# "GenJet.pt", "GenJet.eta", "GenJet.phi", +# "PuppiMET.pt", "PuppiMET.phi", +# attach_coffea_behavior, +# }, +# produces={ +# "Jet.pt", "Jet.mass", +# "Jet.pt_unsmeared", "Jet.mass_unsmeared", +# "Jet.pt_jer_up", "Jet.pt_jer_down", "Jet.mass_jer_up", "Jet.mass_jer_down", +# "PuppiMET.pt", "PuppiMET.phi", +# "PuppiMET.pt_jer_up", "PuppiMET.pt_jer_down", "PuppiMET.phi_jer_up", "PuppiMET.phi_jer_down", +# }, +# # toggle for propagation to PuppiMET +# propagate_met=True, +# # only run on mc +# mc_only=True, +# # use deterministic seeds for random smearing and +# # take the "index"-th random number per seed when not -1 +# deterministic_seed_index=-1, +# # function to determine the correction file +# get_jer_file=get_jerc_file_default, +# # function to determine the jer configuration dict +# get_jer_config=get_jer_config_default, +# ) +# def jer(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: +# """ +# Applies the jet energy resolution smearing in MC and calculates the associated uncertainty +# shifts using the :external+correctionlib:doc:`index`, following the recommendations given in +# https://twiki.cern.ch/twiki/bin/viewauth/CMS/JetResolution. + +# The *jet_name* and *gen_jet_name* should be set to the name of the NanoAOD jet and gen jet +# collections to use as an input for JER smearing (default: ``Jet`` and ``GenJet``, respectively, +# i.e. AK4 jets). + +# Requires an external file in the config pointing to the JSON files containing the JER information. +# The file key can be specified via an optional ``external_file_key`` in the ``jer`` config entry. +# If not given, the file key will be determined automatically based on the jet collection name: +# ``jet_jerc`` for ``Jet`` (AK4 jets), ``fat_jet_jerc`` for``FatJet`` (AK8 jets). A full set of JSON files +# can be specified as: + +# .. code-block:: python + +# cfg.x.external_files = DotDict.wrap({ +# "jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/jet_jerc.json.gz", +# "fat_jet_jerc": "/afs/cern.ch/work/m/mrieger/public/mirrors/jsonpog-integration-9ea86c4c/POG/JME/2017_UL/fatJet_jerc.json.gz", +# }) + +# For more fine-grained control, the *get_jer_file* can be adapted in a subclass in case it is stored +# differently in the external files. + +# The JER smearing configuration should be an auxiliary entry in the config, specifying the input +# JER to use under "jer". Separate configs should be given for each jet collection to smear, using +# the jet collection name as a subkey. An example of a valid configuration for smearing +# AK4 jets with JER is: + +# .. code-block:: python + +# cfg.x.jer = { +# "Jet": { +# "campaign": "Summer19UL17", +# "version": "JRV2", +# "jet_type": "AK4PFchs", +# }, +# } + +# *get_jer_config* can be adapted in a subclass in case it is stored differently in the config. + +# Throws an error if running on data. + +# :param events: awkward array containing events to process +# """ # noqa +# # use local variables for convenience +# jet_name = self.jet_name +# gen_jet_name = self.gen_jet_name + +# # fail when running on data +# if self.dataset_inst.is_data: +# raise ValueError("attempt to apply jet energy resolution smearing in data") + +# # save the unsmeared properties in case they are needed later +# events = set_ak_column_f32(events, f"{jet_name}.pt_unsmeared", events[jet_name].pt) +# events = set_ak_column_f32(events, f"{jet_name}.mass_unsmeared", events[jet_name].mass) + +# # obtain rho, which might be located at different routes, depending on the nano version +# rho = ( +# events.fixedGridRhoFastjetAll +# if "fixedGridRhoFastjetAll" in events.fields else +# events.Rho.fixedGridRhoFastjetAll +# ) + +# # variable naming convention +# variable_map = { +# "JetEta": events[jet_name].eta, +# "JetPt": events[jet_name].pt, +# "Rho": rho, +# } + +# # pt resolution +# inputs = [variable_map[inp.name] for inp in self.evaluators["jer"].inputs] +# jer = ak_evaluate(self.evaluators["jer"], *inputs) + +# # JER scale factors and systematic variations +# jersf = {} +# for syst in ("nom", "up", "down"): +# variable_map_syst = dict(variable_map, systematic=syst) +# inputs = [variable_map_syst[inp.name] for inp in self.evaluators["sf"].inputs] +# jersf[syst] = ak_evaluate(self.evaluators["sf"], *inputs) + +# # array with all JER scale factor variations as an additional axis +# # (note: axis needs to be regular for broadcasting to work correctly) +# jersf = ak.concatenate( +# [jersf[syst][..., None] for syst in ("nom", "up", "down")], +# axis=-1, +# ) + +# # -- stochastic smearing +# # normally distributed random numbers according to JER +# jer_random_normal = ( +# ak_random(0, jer, events[jet_name].deterministic_seed, rand_func=self.deterministic_normal) +# if self.deterministic_seed_index >= 0 +# else ak_random(0, jer, rand_func=np.random.Generator( +# np.random.SFC64(events.event.to_list())).normal, +# ) +# ) + +# # scale random numbers according to JER SF +# jersf2_m1 = jersf ** 2 - 1 +# add_smear = np.sqrt(ak.where(jersf2_m1 < 0, 0, jersf2_m1)) + +# # broadcast over JER SF variations +# jer_random_normal, jersf_z = ak.broadcast_arrays(jer_random_normal, add_smear) + +# # compute smearing factors (stochastic method) +# smear_factors_stochastic = 1.0 + jer_random_normal * add_smear + +# # -- scaling method (using gen match) + +# # mask negative gen jet indices (= no gen match) +# gen_jet_idx = events[jet_name][self.gen_jet_idx_column] +# valid_gen_jet_idxs = ak.mask(gen_jet_idx, gen_jet_idx >= 0) + +# # pad list of gen jets to prevent index error on match lookup +# max_gen_jet_idx = ak.max(valid_gen_jet_idxs) +# padded_gen_jets = ak.pad_none( +# events[gen_jet_name], +# 0 if max_gen_jet_idx is None else (max_gen_jet_idx + 1), +# ) + +# # gen jets that match the reconstructed jets +# matched_gen_jets = padded_gen_jets[valid_gen_jet_idxs] + +# # compute the relative (reco - gen) pt difference +# pt_relative_diff = (events[jet_name].pt - matched_gen_jets.pt) / events[jet_name].pt + +# # test if matched gen jets are within 3 * resolution +# is_matched_pt = np.abs(pt_relative_diff) < 3 * jer +# is_matched_pt = ak.fill_none(is_matched_pt, False) # masked values = no gen match + +# # (no check for Delta-R matching criterion; we assume this was done during +# # nanoAOD production to get the `genJetIdx`) + +# # broadcast over JER SF variations +# pt_relative_diff, jersf = ak.broadcast_arrays(pt_relative_diff, jersf) + +# # compute smearing factors (scaling method) +# smear_factors_scaling = 1.0 + (jersf - 1.0) * pt_relative_diff + +# # -- hybrid smearing: take smear factors from scaling if there was a match, +# # otherwise take the stochastic ones +# smear_factors = ak.where( +# is_matched_pt[:, :, None], +# smear_factors_scaling, +# smear_factors_stochastic, +# ) + +# # ensure array is not nullable (avoid ambiguity on Arrow/Parquet conversion) +# smear_factors = ak.fill_none(smear_factors, 0.0) + +# # store pt and phi of the full jet system +# if self.propagate_met: +# jetsum = events[jet_name].sum(axis=1) +# jetsum_pt_before = jetsum.pt +# jetsum_phi_before = jetsum.phi + +# # apply the smearing factors to the pt and mass +# # (note: apply variations first since they refer to the original pt) +# events = set_ak_column_f32(events, f"{jet_name}.pt_jer_up", events[jet_name].pt * smear_factors[:, :, 1]) +# events = set_ak_column_f32(events, f"{jet_name}.mass_jer_up", events[jet_name].mass * smear_factors[:, :, 1]) +# events = set_ak_column_f32(events, f"{jet_name}.pt_jer_down", events[jet_name].pt * smear_factors[:, :, 2]) +# events = set_ak_column_f32(events, f"{jet_name}.mass_jer_down", events[jet_name].mass * smear_factors[:, :, 2]) +# events = set_ak_column_f32(events, f"{jet_name}.pt", events[jet_name].pt * smear_factors[:, :, 0]) +# events = set_ak_column_f32(events, f"{jet_name}.mass", events[jet_name].mass * smear_factors[:, :, 0]) + +# # recover coffea behavior +# events = self[attach_coffea_behavior](events, collections=[jet_name], **kwargs) + +# # met propagation +# if self.propagate_met: + +# # save unsmeared quantities +# events = set_ak_column_f32(events, "PuppiMET.pt_unsmeared", events.PuppiMET.pt) +# events = set_ak_column_f32(events, "PuppiMET.phi_unsmeared", events.PuppiMET.phi) + +# # get pt and phi of all jets after correcting +# jetsum = events[jet_name].sum(axis=1) +# jetsum_pt_after = jetsum.pt +# jetsum_phi_after = jetsum.phi + +# # propagate changes to PuppiMET +# met_pt, met_phi = propagate_met( +# jetsum_pt_before, +# jetsum_phi_before, +# jetsum_pt_after, +# jetsum_phi_after, +# events.PuppiMET.pt, +# events.PuppiMET.phi, +# ) +# events = set_ak_column_f32(events, "PuppiMET.pt", met_pt) +# events = set_ak_column_f32(events, "PuppiMET.phi", met_phi) + +# # syst variations on top of corrected PuppiMET +# met_pt_up, met_phi_up = propagate_met( +# jetsum_pt_after, +# jetsum_phi_after, +# events[jet_name].pt_jer_up, +# events[jet_name].phi, +# met_pt, +# met_phi, +# ) +# met_pt_down, met_phi_down = propagate_met( +# jetsum_pt_after, +# jetsum_phi_after, +# events[jet_name].pt_jer_down, +# events[jet_name].phi, +# met_pt, +# met_phi, +# ) +# events = set_ak_column_f32(events, "PuppiMET.pt_jer_up", met_pt_up) +# events = set_ak_column_f32(events, "PuppiMET.pt_jer_down", met_pt_down) +# events = set_ak_column_f32(events, "PuppiMET.phi_jer_up", met_phi_up) +# events = set_ak_column_f32(events, "PuppiMET.phi_jer_down", met_phi_down) + +# return events + + +# @jer.init +# def jer_init(self: Calibrator) -> None: +# # determine gen-level jet index column +# lower_first = lambda s: s[0].lower() + s[1:] if s else s +# self.gen_jet_idx_column = lower_first(self.gen_jet_name) + "Idx" + +# self.uses |= { +# "PuppiMET.pt", "PuppiMET.phi", +# } +# self.produces |= { +# "PuppiMET.pt", "PuppiMET.phi", "PuppiMET.pt_jer_up", "PuppiMET.pt_jer_down", "PuppiMET.phi_jer_up", +# "PuppiMET.phi_jer_down", "PuppiMET.pt_unsmeared", "PuppiMET.phi_unsmeared", +# } + + +# @jer.requires +# def jer_requires(self: Calibrator, reqs: dict) -> None: +# if "external_files" in reqs: +# return + +# from columnflow.tasks.external import BundleExternalFiles +# reqs["external_files"] = BundleExternalFiles.req(self.task) + + +# @jer.setup +# def jer_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None: +# """ +# Load the correct jer files using the :py:func:`from_string` method of the +# :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` function and apply the +# corrections as needed. + +# The source files for the :external+correctionlib:py:class:`correctionlib.highlevel.CorrectionSet` +# instance are extracted with the :py:meth:`~.jer.get_jer_file`. + +# Uses the member function :py:meth:`~.jer.get_jer_config` to construct the required keys, which +# are based on the following information about the JER: + +# - campaign +# - version +# - jet_type + +# A corresponding example snippet within the *config_inst* could like something like this: + +# .. code-block:: python + +# cfg.x.jer = DotDict.wrap({ +# "Jet": { +# "campaign": f"Summer19UL{year2}{jerc_postfix}", +# "version": "JRV3", +# "jet_type": "AK4PFchs", +# }, +# }) + +# :param reqs: Requirement dictionary for this :py:class:`~columnflow.calibration.Calibrator` +# instance. +# :param inputs: Additional inputs, currently not used. +# :param reader_targets: TODO: add documentation. +# """ +# bundle = reqs["external_files"] + +# # import the correction sets from the external file +# import correctionlib +# correction_set = correctionlib.CorrectionSet.from_string( +# self.get_jer_file(bundle.files).load(formatter="gzip").decode("utf-8"), +# ) + +# # compute JER keys from config information +# jer_cfg = self.get_jer_config() +# jer_keys = { +# "jer": f"{jer_cfg.campaign}_{jer_cfg.version}_MC_PtResolution_{jer_cfg.jet_type}", +# "sf": f"{jer_cfg.campaign}_{jer_cfg.version}_MC_ScaleFactor_{jer_cfg.jet_type}", +# } + +# # store the evaluators +# self.evaluators = { +# name: get_evaluators(correction_set, [key])[0] +# for name, key in jer_keys.items() +# } + +# # use deterministic seeds for random smearing if requested +# if self.deterministic_seed_index >= 0: +# idx = self.deterministic_seed_index +# bit_generator = np.random.SFC64 +# def deterministic_normal(loc, scale, seed): +# return np.asarray([ +# np.random.Generator(bit_generator(_seed)).normal(_loc, _scale, size=idx + 1)[-1] +# for _loc, _scale, _seed in zip(loc, scale, seed) +# ]) +# self.deterministic_normal = deterministic_normal + + +# # explicit calibrators for standard jet collections +# jer_ak4 = jer.derive("jer_ak4", cls_dict={"jet_name": "Jet", "gen_jet_name": "GenJet"}) +# jer_ak8 = jer.derive("jer_ak8", cls_dict={"jet_name": "FatJet", "gen_jet_name": "GenJetAK8", "propagate_met": False}) + + +# # +# # single calibrator for doing both JEC and JER smearing +# # + +# @calibrator( +# uses={jec, jer}, +# produces={jec, jer}, +# # toggle for propagation to PuppiMET +# propagate_met=None, +# # functions to determine configs and files +# get_jec_file=None, +# get_jec_config=None, +# get_jer_file=None, +# get_jer_config=None, +# ) +# def jets(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: +# """ +# Instance of :py:class:`~columnflow.calibration.Calibrator` that does all relevant calibrations +# for jets, i.e. JEC and JER. For more information, see :py:func:`~.jec` and :py:func:`~.jer`. + +# :param events: awkward array containing events to process +# """ +# # apply jet energy corrections +# events = self[jec](events, **kwargs) + +# # apply jer smearing on MC only +# if self.dataset_inst.is_mc: +# events = self[jer](events, **kwargs) + +# return events + + +# @jets.init +# def jets_init(self: Calibrator) -> None: +# # forward argument to the producers +# self.deps_kwargs[jec]["jet_name"] = self.jet_name +# self.deps_kwargs[jer]["jet_name"] = self.jet_name +# self.deps_kwargs[jer]["gen_jet_name"] = self.gen_jet_name +# if self.propagate_met is not None: +# self.deps_kwargs[jec]["propagate_met"] = self.propagate_met +# self.deps_kwargs[jer]["propagate_met"] = self.propagate_met +# if self.get_jec_file is not None: +# self.deps_kwargs[jec]["get_jec_file"] = self.get_jec_file +# if self.get_jec_config is not None: +# self.deps_kwargs[jec]["get_jec_config"] = self.get_jec_config +# if self.get_jer_file is not None: +# self.deps_kwargs[jer]["get_jer_file"] = self.get_jer_file +# if self.get_jer_config is not None: +# self.deps_kwargs[jer]["get_jer_config"] = self.get_jer_config + + +# # explicit calibrators for standard jet collections +# jets_ak4 = jets.derive("jets_ak4", cls_dict={"jet_name": "Jet", "gen_jet_name": "GenJet"}) +# jets_ak8 = jets.derive("jets_ak8", cls_dict={"jet_name": "FatJet", "gen_jet_name": "GenJetAK8"}) diff --git a/columnflow/calibration/cms/met.py b/columnflow/calibration/cms/met.py index 01b6ea9ef..aec3ca73b 100644 --- a/columnflow/calibration/cms/met.py +++ b/columnflow/calibration/cms/met.py @@ -1,7 +1,7 @@ # coding: utf-8 """ -MET corrections. +PuppiMET corrections. """ from columnflow.calibration import Calibrator, calibrator @@ -13,9 +13,8 @@ @calibrator( - uses={"run", "PV.npvs"}, - # name of the MET collection to calibrate - met_name="MET", + uses={"run", "PV.npvs", "PuppiMET.pt", "PuppiMET.phi"}, + produces={"PuppiMET.pt", "PuppiMET.phi"}, # function to determine the correction file get_met_file=(lambda self, external_files: external_files.met_phi_corr), # function to determine met correction config @@ -23,9 +22,9 @@ ) def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: """ - Performs the MET phi (type II) correction using the + Performs the PuppiMET phi (type II) correction using the :external+correctionlib:doc:`index` for events there the - uncorrected MET pt is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``). + uncorrected PuppiMET pt is below the beam energy (extracted from ``config_inst.campaign.ecm * 0.5``). Requires an external file in the config under ``met_phi_corr``: .. code-block:: python @@ -54,16 +53,16 @@ def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: met = events[self.met_name] # copy the intial pt and phi values - corr_pt = np.array(met.pt, dtype=np.float32) - corr_phi = np.array(met.phi, dtype=np.float32) + corr_pt = np.array(events.PuppiMET.pt, dtype=np.float32) + corr_phi = np.array(events.PuppiMET.phi, dtype=np.float32) - # select only events where MET pt is below the expected beam energy - mask = met.pt < (0.5 * self.config_inst.campaign.ecm) + # select only events where PuppiMET pt is below the expected beam energy + mask = events.PuppiMET.pt < (0.5 * self.config_inst.campaign.ecm) # arguments for evaluation args = ( - met.pt[mask], - met.phi[mask], + events.PuppiMET.pt[mask], + events.PuppiMET.phi[mask], ak.values_astype(events.PV.npvs[mask], np.float32), ak.values_astype(events.run[mask], np.float32), ) @@ -73,8 +72,8 @@ def met_phi(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array: corr_phi[mask] = self.met_phi_corrector.evaluate(*args) # save the corrected values - events = set_ak_column(events, f"{self.met_name}.pt", corr_pt, value_type=np.float32) - events = set_ak_column(events, f"{self.met_name}.phi", corr_phi, value_type=np.float32) + events = set_ak_column(events, "PuppiMET.pt", corr_pt, value_type=np.float32) + events = set_ak_column(events, "PuppiMET.phi", corr_phi, value_type=np.float32) return events @@ -110,7 +109,7 @@ def met_phi_setup(self: Calibrator, reqs: dict, inputs: dict, reader_targets: di :param reader_targets: Additional targets, currently not used. """ bundle = reqs["external_files"] - + # create the pt and phi correctors import correctionlib correction_set = correctionlib.CorrectionSet.from_string( diff --git a/columnflow/calibration/util.py b/columnflow/calibration/util.py index ac20de9bb..af9955836 100644 --- a/columnflow/calibration/util.py +++ b/columnflow/calibration/util.py @@ -13,6 +13,9 @@ np = maybe_import("numpy") ak = maybe_import("awkward") +import law + +logger = law.logger.get_logger(__name__) # https://github.com/scikit-hep/awkward/issues/489\#issuecomment-711090923 def ak_random(*args, rand_func: Callable) -> ak.Array: @@ -91,7 +94,29 @@ def propagate_met( if jet_pt2.ndim > 1: jet_px2 = ak.sum(jet_px2, axis=1) jet_py2 = ak.sum(jet_py2, axis=1) - + + # RawPuppiMET sanity check + + crazy_PuppiMET_values_mask = met_pt1 > 14*10**3 + + crazy_PuppiMET_values = met_pt1[crazy_PuppiMET_values_mask] + + # Get the indices of the infinite values + crazy_PuppiMET_indices = np.where(crazy_PuppiMET_values_mask)[0] + + # Count the number of infinite values + crazy_PuppiMET_count = ak.sum(crazy_PuppiMET_values_mask) + + if crazy_PuppiMET_count > 0: + # Replace infinite values with 0 + met_pt1 = ak.where(~crazy_PuppiMET_values_mask, met_pt1, 1000) + + # Raise a warning about the replacement + logger.warning( + f"Warning: Found and replaced {crazy_PuppiMET_count} crazy value(s) {crazy_PuppiMET_values.tolist()} in 'RawPuppiMET.pt' with 1000.\n" + f"Indices in the chuck: {crazy_PuppiMET_indices.tolist()}\n" + f"We will get rid of these events in the selection step") + # propagate to met met_px2 = met_pt1 * np.cos(met_phi1) - (jet_px2 - jet_px1) met_py2 = met_pt1 * np.sin(met_phi1) - (jet_py2 - jet_py1) diff --git a/columnflow/columnar_util.py b/columnflow/columnar_util.py index 7651675f5..24be344f2 100644 --- a/columnflow/columnar_util.py +++ b/columnflow/columnar_util.py @@ -14,6 +14,7 @@ import math import time import enum + import inspect import threading import multiprocessing @@ -40,6 +41,7 @@ maybe_import("coffea.nanoevents.methods.base") maybe_import("coffea.nanoevents.methods.nanoaod") pq = maybe_import("pyarrow.parquet") +hist = maybe_import("hist") # loggers @@ -1354,6 +1356,68 @@ def ak_copy(ak_array: ak.Array) -> ak.Array: return layout_ak_array(np.array(ak.flatten(ak_array)), ak_array) +def fill_hist( + h: hist.Hist, + data: ak.Array | np.array | dict[str, ak.Array | np.array], + *, + last_edge_inclusive: bool | None = None, + fill_kwargs: dict[str, Any] | None = None, +) -> None: + """ + Fills a histogram *h* with data from an awkward array, numpy array or nested dictionary *data*. + The data is assumed to be structured in the same way as the histogram axes. If + *last_edge_inclusive* is *True*, values that would land exactly on the upper-most bin edge of an + axis are shifted into the last bin. If it is *None*, the behavior is determined automatically + and depends on the variable axis type. In this case, shifting is applied to all continuous, + non-circular axes. + """ + if fill_kwargs is None: + fill_kwargs = {} + + # helper to decide whether the variable axis qualifies for shifting the last bin + def allows_shift(ax) -> bool: + return ax.traits.continuous and not ax.traits.circular + + # determine the axis names, figure out which which axes the last bin correction should be done + axis_names = [] + correct_last_bin_axes = [] + for ax in h.axes: + axis_names.append(ax.name) + # include values hitting last edge? + if not len(ax.widths) or not isinstance(ax, hist.axis.Variable): + continue + if (last_edge_inclusive is None and allows_shift(ax)) or last_edge_inclusive: + correct_last_bin_axes.append(ax) + + # check data + if not isinstance(data, dict): + if len(axis_names) != 1: + raise ValueError("got multi-dimensional hist but only one dimensional data") + data = {axis_names[0]: data} + else: + for name in axis_names: + if name not in data and name not in fill_kwargs: + raise ValueError(f"missing data for histogram axis '{name}'") + + # correct last bin values + for ax in correct_last_bin_axes: + right_egde_mask = ak.flatten(data[ax.name], axis=None) == ax.edges[-1] + if np.any(right_egde_mask): + data[ax.name] = ak.copy(data[ax.name]) + flat_np_view(data[ax.name])[right_egde_mask] -= ax.widths[-1] * 1e-5 + + # fill + if 'event' in data.keys(): + arrays = {} + for ax_name in axis_names: + if ax_name in data.keys(): + arrays[ax_name] = data[ax_name] + h.fill(**fill_kwargs, **arrays) + else: + arrays = ak.flatten(ak.cartesian(data)) + h.fill(**fill_kwargs, **{field: arrays[field] for field in arrays.fields}) + + class RouteFilter(object): """ Shallow helper class that handles removal of routes in an awkward array that do not match those @@ -2401,6 +2465,19 @@ def setup(cls, func: Callable[[dict], None]) -> None: """ cls.setup_func = func + @classmethod + def teardown(cls, func: Callable[[dict], None]) -> None: + """ + Decorator to wrap a function *func* that should be registered as :py:meth:`teardown_func` + which is used to perform a custom teardown of objects at the end of processing. The function + should accept one argument: + + - *task*, the invoking task instance. + + The decorator does not return the wrapped function. + """ + cls.teardown_func = func + def __init__( self, *args, diff --git a/columnflow/hist_util.py b/columnflow/hist_util.py index 3c2b60ca6..ff44709d8 100644 --- a/columnflow/hist_util.py +++ b/columnflow/hist_util.py @@ -72,8 +72,27 @@ def allows_shift(ax) -> bool: flat_np_view(data[ax.name])[right_egde_mask] -= ax.widths[-1] * 1e-5 # fill - arrays = ak.flatten(ak.cartesian(data)) - h.fill(**fill_kwargs, **{field: arrays[field] for field in arrays.fields}) + + flat_data = {} + arr_shape = None + for key, arr in data.items(): + if arr.ndim > 1: + logger.warning( + f"Found axis {key} that is not 1-dimensional: trying to broadcast all other axes:" + ) + arr_shape = ak.local_index(arr) + + for key, arr in data.items(): + if arr_shape is not None: + if arr.ndim == 1: + _, br_arr = ak.broadcast_arrays(arr_shape, arr) + flat_data[key] = ak.flatten(br_arr) + else: + flat_data[key] = ak.flatten(arr) + + else: flat_data[key] = arr + h.fill(**fill_kwargs, **flat_data) + def add_hist_axis(histogram: hist.Hist, variable_inst: od.Variable) -> hist.Hist: diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 7926a9f78..d60fd87c4 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -325,6 +325,7 @@ def process_spec( name: str, config_process: str | None = None, is_signal: bool = False, + data_driven: bool = False, config_mc_datasets: Sequence[str] | None = None, scale: float | int = 1.0, ) -> DotDict: @@ -333,6 +334,7 @@ def process_spec( - *name*: The name of the process in the model. - *is_signal*: A boolean flag deciding whether this process describes signal. + - *data_driven*: A boolean flag deciding whether this process is data driven. - *config_process*: The name of the source process in the config to use. - *config_mc_datasets*: List of names or patterns of MC datasets in the config to use. - *scale*: A float value to scale the process, defaulting to 1.0. @@ -340,6 +342,7 @@ def process_spec( return DotDict([ ("name", str(name)), ("is_signal", bool(is_signal)), + ("data_driven", bool(data_driven)), ("config_process", str(config_process) if config_process else None), ("config_mc_datasets", list(map(str, config_mc_datasets or []))), ("scale", float(scale)), diff --git a/columnflow/plotting/plot_all.py b/columnflow/plotting/plot_all.py index 60207a301..a30db1e58 100644 --- a/columnflow/plotting/plot_all.py +++ b/columnflow/plotting/plot_all.py @@ -273,9 +273,15 @@ def plot_all( if not skip_legend: # resolve legend kwargs legend_kwargs = { - "ncols": 1, - "loc": "upper right", + "ncol": 2, + "loc": "center left", + "bbox_to_anchor": (0.35, 0.8), # Position the legend outside the plot + # Moves the legend to the right side of the plot. + # The first value (1) controls the horizontal position, + # and the second value (0.95) controls the vertical position. + "fontsize": 16, } + legend_kwargs.update(style_config.get("legend_cfg", {})) # retrieve the legend handles and their labels diff --git a/columnflow/plotting/plot_functions_1d.py b/columnflow/plotting/plot_functions_1d.py index f73ceac9c..839418c67 100644 --- a/columnflow/plotting/plot_functions_1d.py +++ b/columnflow/plotting/plot_functions_1d.py @@ -34,6 +34,7 @@ mplhep = maybe_import("mplhep") od = maybe_import("order") +logger = law.logger.get_logger(__name__) def plot_variable_per_process( hists: OrderedDict, @@ -50,25 +51,78 @@ def plot_variable_per_process( **kwargs, ) -> plt.Figure: """ - TODO. + Plots histograms for multiple processes, ordering them by a custom order: + the process with the highest number of events first, followed by the others, + and the process with the second highest number of events last. + Handles cases with only one or two processes. """ remove_residual_axis(hists, "shift") - variable_inst = variable_insts[0] - blinding_threshold = kwargs.get("blinding_threshold", None) + # Define the color maps + color_maps = { + "6": ["#5790fc", "#7a21dd", "#964a8b", "#9c9ca1", "#e42536", "#f89c20"], + "8": ["#1845fb", "#578dff", "#656364", "#86c8dd", "#adad7d", "#c849a9", "#c91f16", "#ff5e02"], + "10": ["#3f90da", "#717581", "#832db6", "#92dadd", "#94a4a2", "#a96b59", "#b9ac70", "#bd1f01", "#e76300", "#ffa90e"], + } - if blinding_threshold: - hists = blind_sensitive_bins(hists, config_inst, blinding_threshold) - hists = apply_variable_settings(hists, variable_insts, variable_settings) - hists = apply_process_settings(hists, process_settings) - hists = apply_density_to_hists(hists, density) + # Basic colors for more than 24 processes + basic_colors = ["#FF0000", "#0000FF", "#00FF00", "#FFFF00", "#FF00FF", "#00FFFF", "#800000", "#808000"] + + # Calculate the total number of events for each process + total_events = {key: sum(hist.values()) for key, hist in hists.items()} + + # Sort processes by total number of events in descending order + #sorted_hists_desc = OrderedDict(sorted(hists.items(), key=lambda item: total_events[item[0]], reverse=True)) + sorted_hists_desc = OrderedDict(hists.items()) + + # Get keys of sorted processes + sorted_keys = list(sorted_hists_desc.keys()) + + # Handle cases with 1 or 2 processes + if len(sorted_keys) == 1: + # Only one process, no special reordering needed + custom_order = sorted_keys + elif len(sorted_keys) == 2: + # Two processes, highest first, then second highest + custom_order = sorted_keys + else: + # More than two processes, custom order: highest, rest, then second highest + custom_order = sorted_keys #[sorted_keys[0]] + sorted_keys[2:] + [sorted_keys[1]] + + # Reorder histograms based on custom order + sorted_hists = OrderedDict((key, sorted_hists_desc[key]) for key in custom_order) + + variable_inst = variable_insts[0] + sorted_hists = apply_variable_settings(sorted_hists, variable_insts, variable_settings) + sorted_hists = apply_process_settings(sorted_hists, process_settings) + sorted_hists = apply_density_to_hists(sorted_hists, density) plot_config = prepare_plot_config( - hists, + sorted_hists, shape_norm=shape_norm, hide_errors=hide_errors, ) + if 'data' not in plot_config: + + # Determine the appropriate color map based on the number of processes + num_processes = len(sorted_hists) + if num_processes <= 6: + colors = color_maps["6"][:num_processes] + elif num_processes == 7: + colors = color_maps["8"][:num_processes] + elif num_processes <= 10: + colors = color_maps["8"][:num_processes] if num_processes == 8 else color_maps["10"][:num_processes] + elif num_processes <= 18: + colors = color_maps["10"] + color_maps["8"][:num_processes - 10] + elif num_processes <= 24: + colors = color_maps["10"] + color_maps["8"] + color_maps["6"][:num_processes - 18] + else: + logger.warning("You are about to plot more than 24 processes together, please reconsider... (Colors not in the approved palette will be assigned)") + colors = color_maps["10"] + color_maps["8"] + color_maps["6"] + colors += basic_colors[:num_processes - 24] + plot_config["mc_stack"]["kwargs"]["color"] = colors[:num_processes] + default_style_config = prepare_style_config( config_inst, category_inst, variable_inst, density, shape_norm, yscale, ) @@ -80,6 +134,52 @@ def plot_variable_per_process( return plot_all(plot_config, style_config, **kwargs) + +# def plot_variable_per_process( +# hists: OrderedDict, +# config_inst: od.Config, +# category_inst: od.Category, +# variable_insts: list[od.Variable], +# style_config: dict | None = None, +# density: bool | None = False, +# shape_norm: bool | None = False, +# yscale: str | None = "", +# hide_errors: bool | None = None, +# process_settings: dict | None = None, +# variable_settings: dict | None = None, +# **kwargs, +# ) -> plt.Figure: +# """ +# TODO. +# """ +# remove_residual_axis(hists, "shift") + +# variable_inst = variable_insts[0] +# blinding_threshold = kwargs.get("blinding_threshold", None) + +# if blinding_threshold: +# hists = blind_sensitive_bins(hists, config_inst, blinding_threshold) +# hists = apply_variable_settings(hists, variable_insts, variable_settings) +# hists = apply_process_settings(hists, process_settings) +# hists = apply_density_to_hists(hists, density) + +# plot_config = prepare_plot_config( +# hists, +# shape_norm=shape_norm, +# hide_errors=hide_errors, +# ) + +# default_style_config = prepare_style_config( +# config_inst, category_inst, variable_inst, density, shape_norm, yscale, +# ) + +# style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) +# if shape_norm: +# style_config["ax_cfg"]["ylabel"] = r"$\Delta N/N$" + +# return plot_all(plot_config, style_config, **kwargs) + + def plot_variable_variants( hists: OrderedDict, config_inst: od.Config, @@ -171,20 +271,36 @@ def plot_shifted_variable( plot_config = {} colors = { "nominal": "black", - "up": "red", - "down": "blue", + "up": "blue", + "down": "red", + } + shift_names = { + "nominal": "max mixing", + "ts_up": "CP-odd", + "ts_down": "CP-even", } + + hist_up = None + hist_down = None + hist_up_err = None + hist_down_err = None for i, shift_id in enumerate(h_sum.axes["shift"]): shift_inst = config_inst.get_shift(shift_id) - + h = h_sum[{"shift": hist.loc(shift_id)}] + if "up" in shift_inst.label: + hist_up = h.values() + hist_up_err = h.variances() + elif "down" in shift_inst.label: + hist_down = h.values() + hist_down_err = h.variances() # assuming `nominal` always has shift id 0 ratio_norm = h_sum[{"shift": hist.loc(0)}].values() diff = sum(h.values()) / sum(ratio_norm) - 1 - label = shift_inst.label + label = shift_names[shift_inst.label] if not shift_inst.is_nominal: - label += " ({0:+.2f}%)".format(diff * 100) + pass #label += " ({0:+.2f}%)".format(diff * 100) plot_config[shift_inst.name] = plot_cfg = { "method": "draw_hist", @@ -202,8 +318,18 @@ def plot_shifted_variable( if hide_errors: for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: - plot_cfg[key]["yerr"] = None - + plot_cfg[key]["yerr"] = False + h_sum = (hist_up + hist_down) + mask = (h_sum > 0) + asym_hist = np.where(mask, + np.abs(hist_up - hist_down)/h_sum, + 0) + herr_num = np.sqrt(hist_up_err + hist_down_err) + herr_den = np.sqrt(hist_up_err + hist_down_err) + dA = np.average(np.sqrt( (herr_num/h_sum)**2 + (herr_den*np.abs(hist_up - hist_down)/h_sum/h_sum)**2)) + + A = np.average(asym_hist) + # legend title setting if not legend_title and len(hists) == 1: # use process label as default if 1 process @@ -216,8 +342,11 @@ def plot_shifted_variable( default_style_config = prepare_style_config( config_inst, category_inst, variable_inst, density, shape_norm, yscale, ) - default_style_config["rax_cfg"]["ylim"] = (0.25, 1.75) + default_style_config["rax_cfg"]["ylim"] = (0.75, 1.25) default_style_config["rax_cfg"]["ylabel"] = "Ratio" + + default_style_config["annotate_cfg"]["text"] = f'A={A:1.3f}$\pm${dA:1.3f}' + default_style_config["annotate_cfg"]["fontsize"] = 22 if legend_title: default_style_config["legend_cfg"]["title"] = legend_title diff --git a/columnflow/production/cms/mc_weight.py b/columnflow/production/cms/mc_weight.py index 9994c5b5a..e56b60b6e 100644 --- a/columnflow/production/cms/mc_weight.py +++ b/columnflow/production/cms/mc_weight.py @@ -31,11 +31,14 @@ def mc_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: [1] https://twiki.cern.ch/twiki/bin/view/CMSPublic/WorkBookNanoAOD?rev=99#Weigths """ + # # determine the mc_weight + # mc_weight = events.genWeight + # if has_ak_column(events, "LHEWeight.originalXWGTUP") and ak.all(events.genWeight == 1.0): + # mc_weight = events.LHEWeight.originalXWGTUP # determine the mc_weight - mc_weight = events.genWeight + mc_weight = np.sign(events.genWeight) if has_ak_column(events, "LHEWeight.originalXWGTUP") and ak.all(events.genWeight == 1.0): - mc_weight = events.LHEWeight.originalXWGTUP - + mc_weight = np.sign(events.LHEWeight.originalXWGTUP) # store the column events = set_ak_column(events, "mc_weight", mc_weight, value_type=np.float32) diff --git a/columnflow/production/cms/pileup.py b/columnflow/production/cms/pileup.py index 5e025c120..346be3125 100644 --- a/columnflow/production/cms/pileup.py +++ b/columnflow/production/cms/pileup.py @@ -54,6 +54,11 @@ def pu_weight(self: Producer, events: ak.Array, **kwargs) -> ak.Array: # evaluate and store the produced column pu_weight = self.pileup_corrector.evaluate(*inputs) + ##################################################### + ### Keeps the pu_weight lower then 300 + pu_weight[pu_weight > 300] = 0 + ##################################################### + events = set_ak_column(events, column_name, pu_weight, value_type=np.float32) return events diff --git a/columnflow/production/normalization.py b/columnflow/production/normalization.py index 9c2dd296f..2144a52be 100644 --- a/columnflow/production/normalization.py +++ b/columnflow/production/normalization.py @@ -207,10 +207,8 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra f"process_id field contains id(s) {invalid_ids} for which no cross sections were " f"found; process ids with cross sections: {self.xs_process_ids}", ) - # read the weight per process (defined as lumi * xsec / sum_weights) from the lookup table process_weight = np.squeeze(np.asarray(self.process_weight_table[0, process_id].todense())) - # compute the weight and store it norm_weight = events.mc_weight * process_weight events = set_ak_column(events, self.weight_name, norm_weight, value_type=np.float32) @@ -341,7 +339,8 @@ def normalization_weights_setup( # fill the process weight table for proc_id, br in branching_ratios.items(): - sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(proc_id)] + #sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(proc_id)] + sum_weights = self.dataset_inst.n_events process_weight_table[0, proc_id] = lumi * inclusive_xsec * br / sum_weights else: # fill the process weight table with per-process cross sections @@ -351,14 +350,18 @@ def normalization_weights_setup( f"no cross section registered for process {process_inst} for center-of-mass " f"energy of {self.config_inst.campaign.ecm}", ) - sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(process_inst.id)] + #sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(process_inst.id)] + #quick fix that need to be fixed + ################################ + sum_weights = self.dataset_inst.n_events + ################################ xsec = process_inst.get_xsec(self.config_inst.campaign.ecm).nominal process_weight_table[0, process_inst.id] = lumi * xsec / sum_weights + self.process_weight_table = process_weight_table self.xs_process_ids = set(self.process_weight_table.rows[0]) - @normalization_weights.init def normalization_weights_init(self: Producer) -> None: """ @@ -398,3 +401,5 @@ def normalization_weights_init(self: Producer) -> None: "get_xsecs_from_inclusive_dataset": False, }, ) + + diff --git a/columnflow/selection/cms/jets.py b/columnflow/selection/cms/jets.py index 945ed1b1e..89be152c3 100644 --- a/columnflow/selection/cms/jets.py +++ b/columnflow/selection/cms/jets.py @@ -22,7 +22,7 @@ @selector( uses={ - "Jet.{pt,eta,phi,mass,jetId,chEmEF}", optional("Jet.puId"), + "Jet.{pt,eta,phi,mass,jetId,chEmEF}", "Muon.{pt,eta,phi,mass,isPFcand}", }, produces={"Jet.veto_map_mask"}, @@ -59,7 +59,7 @@ def jet_veto_map( # loose jet selection jet_mask = ( (jet.pt > 15) & - (jet.jetId >= 2) & # tight id + (jet.jetId >= 2) & # tight id (jet.chEmEF < 0.9) & ak.all(events.Jet.metric_table(muon) >= 0.2, axis=2) ) diff --git a/columnflow/selection/stats.py b/columnflow/selection/stats.py index 5038a6a03..8141fb957 100644 --- a/columnflow/selection/stats.py +++ b/columnflow/selection/stats.py @@ -145,7 +145,7 @@ def increment_stats( "'num' or 'sum'", ) - # interpret obj based on the aoperation to be applied + # interpret obj based on the operation to be applied weights = None weight_mask = Ellipsis if isinstance(obj, (tuple, list)): diff --git a/columnflow/tasks/cms/inference.py b/columnflow/tasks/cms/inference.py index 9386a47f6..24abc829d 100644 --- a/columnflow/tasks/cms/inference.py +++ b/columnflow/tasks/cms/inference.py @@ -10,7 +10,7 @@ from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory from columnflow.tasks.framework.mixins import ( - CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, InferenceModelMixin, + CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, InferenceModelMixin, HistHookMixin ) from columnflow.tasks.framework.remote import RemoteWorkflow from columnflow.tasks.histograms import MergeHistograms, MergeShiftedHistograms @@ -19,6 +19,7 @@ class CreateDatacards( + HistHookMixin, InferenceModelMixin, MLModelsMixin, ProducersMixin, @@ -91,6 +92,7 @@ def workflow_requires(self): for cat_obj in self.branch_map.values(): for proc_obj in cat_obj.processes: + if proc_obj.data_driven: continue for dataset in self.get_mc_datasets(proc_obj): # add all required variables and shifts per dataset mc_dataset_params[dataset]["variables"].add(cat_obj.config_variable) @@ -99,10 +101,8 @@ def workflow_requires(self): for param_obj in proc_obj.parameters if self.inference_model_inst.require_shapes_for_parameter(param_obj) ) - for dataset in self.get_data_datasets(cat_obj): data_dataset_params[dataset]["variables"].add(cat_obj.config_variable) - # set workflow requirements per mc dataset reqs["merged_hists"] = set( self.reqs.MergeShiftedHistograms.req_different_branching( @@ -128,6 +128,7 @@ def workflow_requires(self): def requires(self): cat_obj = self.branch_data + processes = [proc_obj for proc_obj in cat_obj.processes if not proc_obj.data_driven] reqs = { proc_obj.name: { dataset: self.reqs.MergeShiftedHistograms.req_different_branching( @@ -142,9 +143,9 @@ def requires(self): branch=-1, workflow="local", ) - for dataset in self.get_mc_datasets(proc_obj) + for dataset in self.get_mc_datasets(proc_obj) } - for proc_obj in cat_obj.processes + for proc_obj in processes } if cat_obj.config_data_datasets: reqs["data"] = { @@ -183,10 +184,11 @@ def run(self): category_inst = self.config_inst.get_category(cat_obj.config_category) variable_inst = self.config_inst.get_variable(cat_obj.config_variable) leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] - + # histogram data per process hists = OrderedDict() - + process_insts = [] + #prepare histogram objects with self.publish_step(f"extracting {variable_inst.name} in {category_inst.name} ..."): for proc_obj_name, inp in inputs.items(): if proc_obj_name == "data": @@ -194,76 +196,68 @@ def run(self): process_inst = self.config_inst.get_process("data") else: proc_obj = self.inference_model_inst.get_process(proc_obj_name, category=cat_obj.name) - process_inst = self.config_inst.get_process(proc_obj.config_process) + if not proc_obj.data_driven: # data driven processes will be added later with invoke_hist_hooks + process_inst = self.config_inst.get_process(proc_obj.config_process) + else: + pass + sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] - - h_proc = None for dataset, _inp in inp.items(): dataset_inst = self.config_inst.get_dataset(dataset) - - # skip when the dataset is already known to not contain any sub process - if not any(map(dataset_inst.has_process, sub_process_insts)): - self.logger.warning( - f"dataset '{dataset}' does not contain process '{process_inst.name}' " - "or any of its subprocesses which indicates a misconfiguration in the " - f"inference model '{self.inference_model}'", - ) - continue - - # open the histogram and work on a copy - h = _inp["collection"][0]["hists"][variable_inst.name].load(formatter="pickle").copy() - - # axis selections - h = h[{ - "process": [ - hist.loc(p.id) - for p in sub_process_insts - if p.id in h.axes["process"] - ], - "category": [ - hist.loc(c.id) - for c in leaf_category_insts - if c.id in h.axes["category"] - ], - }] - - # axis reductions - h = h[{"process": sum, "category": sum}] - - # add the histogram for this dataset - if h_proc is None: - h_proc = h - else: - h_proc += h - - # there must be a histogram - if h_proc is None: - raise Exception(f"no histograms found for process '{process_inst.name}'") - - # create the nominal hist - hists[proc_obj_name] = OrderedDict() - nominal_shift_inst = self.config_inst.get_shift("nominal") - hists[proc_obj_name]["nominal"] = h_proc[ - {"shift": hist.loc(nominal_shift_inst.id)} - ] - - # per shift - if proc_obj: - for param_obj in proc_obj.parameters: - # skip the parameter when varied hists are not needed - if not self.inference_model_inst.require_shapes_for_parameter(param_obj): + h_dict = _inp["collection"][0]["hists"][variable_inst.name].load(formatter="pickle").copy() + for region in h_dict.keys(): + if region not in hists: hists[region] = {} + # skip when the dataset is already known to not contain any sub process + if not any(map(dataset_inst.has_process, sub_process_insts)): + self.logger.warning( + f"dataset '{dataset}' does not contain process '{process_inst.name}' " + "or any of its subprocesses which indicates a misconfiguration in the " + f"inference model '{self.inference_model}'", + ) continue - # store the varied hists - hists[proc_obj_name][param_obj.name] = {} - for d in ["up", "down"]: - shift_inst = self.config_inst.get_shift(f"{param_obj.config_shift_source}_{d}") - hists[proc_obj_name][param_obj.name][d] = h_proc[ - {"shift": hist.loc(shift_inst.id)} + # open the histogram and work on a copy + h = h_dict[region] + # axis selections + h = h[{ + "process": [ + hist.loc(p.id) + for p in sub_process_insts + if p.id in h.axes["process"] ] - + }] + + # axis reductions + h = h[{"process": sum}] + if process_inst in hists[region]: + hists[region][process_inst] += h + else: + hists[region][process_inst] = h + + # there must be a histogra + if hists[region][process_inst] is None: + raise Exception(f"no histograms found for process '{process_inst.name}'") + if self.hist_hooks and category_inst.aux: #Assume that aux exists only for signal regions since it contains the information about application and determination regions + hists = self.invoke_hist_hooks(hists,category_inst) + else: + hists = hists[category_inst.name] + # prepare the hists to be used in the datacard writer + datacard_hists = OrderedDict() + for combine_proc, proc_name in self.inference_model_inst.proc_map.items(): + process_inst = [the_proc for the_proc in hists.keys() if the_proc.name == proc_name] + if len(process_inst) and not (hists[process_inst[0]].empty()): + # get the histogram for the process + datacard_hists[combine_proc] = OrderedDict() + nominal_shift_inst = self.config_inst.get_shift("nominal") + datacard_hists[combine_proc]["nominal"] = hists[process_inst[0]][{"shift": hist.loc(nominal_shift_inst.id)}] + # add data: + data_proc = [the_proc for the_proc in hists.keys() if the_proc.name == 'data'] + datacard_hists['data'] = OrderedDict() + nominal_shift_inst = self.config_inst.get_shift("nominal") + datacard_hists['data']["nominal"] = hists[data_proc[0]][{"shift": hist.loc(nominal_shift_inst.id)}] + # forward objects to the datacard writer outputs = self.output() - writer = DatacardWriter(self.inference_model_inst, {cat_obj.name: hists}) + writer = DatacardWriter(self.inference_model_inst, {cat_obj.name: datacard_hists}) with outputs["card"].localize("w") as tmp_card, outputs["shapes"].localize("w") as tmp_shapes: writer.write(tmp_card.abspath, tmp_shapes.abspath, shapes_path_ref=outputs["shapes"].basename) diff --git a/columnflow/tasks/data_driven_methods.py b/columnflow/tasks/data_driven_methods.py new file mode 100644 index 000000000..80162dacb --- /dev/null +++ b/columnflow/tasks/data_driven_methods.py @@ -0,0 +1,803 @@ + +""" +Task to produce and merge histograms. +""" + +from __future__ import annotations + +import luigi +import law + +from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory +from columnflow.tasks.framework.mixins import ( + CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin, VariablesMixin, + ShiftSourcesMixin, WeightProducerMixin, ChunkedIOMixin, DatasetsProcessesMixin, CategoriesMixin +) +from columnflow.tasks.framework.plotting import ProcessPlotSettingMixin + +from columnflow.tasks.framework.remote import RemoteWorkflow +from columnflow.tasks.framework.parameters import last_edge_inclusive_inst +from columnflow.tasks.reduction import ReducedEventsUser +from columnflow.tasks.production import ProduceColumns +from columnflow.tasks.ml import MLEvaluation +from columnflow.util import dev_sandbox, DotDict + + +class PrepareFakeFactorHistograms( + CategoriesMixin, + WeightProducerMixin, + MLModelsMixin, + ProducersMixin, + ReducedEventsUser, + ChunkedIOMixin, + law.LocalWorkflow, + RemoteWorkflow, +): + last_edge_inclusive = last_edge_inclusive_inst + + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + + # upstream requirements + reqs = Requirements( + ReducedEventsUser.reqs, + RemoteWorkflow.reqs, + ProduceColumns=ProduceColumns, + ) + + # strategy for handling missing source columns when adding aliases on event chunks + missing_column_alias_strategy = "original" + + # names of columns that contain category ids + # (might become a parameter at some point) + category_id_columns = {"category_ids"} + + # register sandbox and shifts found in the chosen weight producer to this task + register_weight_producer_sandbox = True + register_weight_producer_shifts = True + + @law.util.classproperty + def mandatory_columns(cls) -> set[str]: + return set(cls.category_id_columns) | {"process_id"} + + # def create_branch_map(self): + # # create a dummy branch map so that this task could be submitted as a job + # return {0: None} + + def workflow_requires(self): + reqs = super().workflow_requires() + + # require the full merge forest + reqs["events"] = self.reqs.ProvideReducedEvents.req(self) + + if not self.pilot: + if self.producer_insts: + reqs["producers"] = [ + self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + for producer_inst in self.producer_insts + if producer_inst.produced_columns + ] + + # add weight_producer dependent requirements + reqs["weight_producer"] = law.util.make_unique(law.util.flatten(self.weight_producer_inst.run_requires())) + + return reqs + + def requires(self): + reqs = {"events": self.reqs.ProvideReducedEvents.req(self)} + + if self.producer_insts: + reqs["producers"] = [ + self.reqs.ProduceColumns.req(self, producer=producer_inst.cls_name) + for producer_inst in self.producer_insts + if producer_inst.produced_columns + ] + + # add weight_producer dependent requirements + reqs["weight_producer"] = law.util.make_unique(law.util.flatten(self.weight_producer_inst.run_requires())) + + return reqs + + workflow_condition = ReducedEventsUser.workflow_condition.copy() + + @workflow_condition.output + def output(self): + return {"hists": self.target(f"ff_hist_{self.branch}.pickle")} + @law.decorator.notify + @law.decorator.log + @law.decorator.localize(input=True, output=False) + @law.decorator.safe_output + def run(self): + import hist + import numpy as np + import awkward as ak + from columnflow.columnar_util import ( + Route, update_ak_array, add_ak_aliases, has_ak_column, attach_coffea_behavior, EMPTY_FLOAT + ) + from columnflow.hist_util import fill_hist + # prepare inputs + inputs = self.input() + + # declare output: dict of histograms + histograms = {} + + # run the weight_producer setup + producer_reqs = self.weight_producer_inst.run_requires() + reader_targets = self.weight_producer_inst.run_setup(producer_reqs, luigi.task.getpaths(producer_reqs)) + + # create a temp dir for saving intermediate files + tmp_dir = law.LocalDirectoryTarget(is_tmp=True) + tmp_dir.touch() + + # get shift dependent aliases + aliases = self.local_shift_inst.x("column_aliases", {}) + ff_variables = [var.var_route for var in self.config_inst.x.fake_factor_method.axes.values()] + # define columns that need to be read + + read_columns = {Route("process_id")} + read_columns |= set(map(Route, self.category_id_columns)) + read_columns |= set(self.weight_producer_inst.used_columns) + read_columns |= set(map(Route, aliases.values())) + read_columns |= set(map(Route, ff_variables)) + # empty float array to use when input files have no entries + empty_f32 = ak.Array(np.array([], dtype=np.float32)) + + # iterate over chunks of events and diffs + file_targets = [inputs["events"]["events"]] + if self.producer_insts: + file_targets.extend([inp["columns"] for inp in inputs["producers"]]) + + # prepare inputs for localization + with law.localize_file_targets( + [*file_targets, *reader_targets.values()], + mode="r", + ) as inps: + + for (events, *columns), pos in self.iter_chunked_io( + [inp.abspath for inp in inps], + source_type=len(file_targets) * ["awkward_parquet"] + [None] * len(reader_targets), + read_columns=(len(file_targets) + len(reader_targets)) * [read_columns], + chunk_size=self.weight_producer_inst.get_min_chunk_size(), + ): + # optional check for overlapping inputs + if self.check_overlapping_inputs: + self.raise_if_overlapping([events] + list(columns)) + # add additional columns + events = update_ak_array(events, *columns) + # add aliases + events = add_ak_aliases( + events, + aliases, + remove_src=True, + missing_strategy=self.missing_column_alias_strategy, + ) + + # attach coffea behavior aiding functional variable expressions + events = attach_coffea_behavior(events) + # build the full event weight + if hasattr(self.weight_producer_inst, "skip_func") and not self.weight_producer_inst.skip_func(): + events, weight = self.weight_producer_inst(events) + else: + weight = ak.Array(np.ones(len(events), dtype=np.float32)) + # define and fill histograms, taking into account multiple axes + category_ids = ak.concatenate( + [Route(c).apply(events) for c in self.category_id_columns], + axis=-1,) + sr_names = self.categories + for sr_name in sr_names: + the_sr = self.config_inst.get_category(sr_name) + regions = [sr_name] + if the_sr.aux: + for the_key in the_sr.aux.keys(): + if (the_key == 'abcd_regs') or (the_key == 'ff_regs'): + regions += list(the_sr.aux[the_key].values()) + else: + raise KeyError(f"Application and determination regions are not found for {the_sr}. \n Check aux field of the category map!") + + for region in regions: + #by accessing the list of categories we check if the category with this name exists + cat = self.config_inst.get_category(region) + + # get variable instances + mask = ak.any(category_ids == cat.id, axis = 1) + masked_events = events[mask] + masked_weight = weight[mask] + + h = (hist.Hist.new.IntCat([], name="process", growth=True)) + for (var_name, var_axis) in self.config_inst.x.fake_factor_method.axes.items(): + h = eval(f'h.{var_axis.ax_str}') + + h = h.Weight() + # broadcast arrays so that each event can be filled for all its categories + + fill_data = { + "process": masked_events.process_id, + "weight" : masked_weight, + } + for (var_name, var_axis) in self.config_inst.x.fake_factor_method.axes.items(): + route = Route(var_axis.var_route) + if len(masked_events) == 0 and not has_ak_column(masked_events, route): + values = empty_f32 + else: + values = route.apply(masked_events) + if values.ndim != 1: values = ak.firsts(values,axis=1) + values = ak.fill_none(values, EMPTY_FLOAT) + + if var_name == 'n_jets': values = ak.where (values > 2, + 2 * ak.ones_like(values), + values) + + if 'Int' in var_axis.ax_str: values = ak.values_astype(values, np.int64) + fill_data[var_name] = values + # fill it + fill_hist( + h, + fill_data, + ) + if cat.name not in histograms.keys(): + histograms[cat.name] = h + else: + histograms[cat.name] +=h + + # merge output files + self.output()["hists"].dump(histograms, formatter="pickle") + + + + +# overwrite class defaults +check_overlap_tasks = law.config.get_expanded("analysis", "check_overlapping_inputs", [], split_csv=True) +PrepareFakeFactorHistograms.check_overlapping_inputs = ChunkedIOMixin.check_overlapping_inputs.copy( + default=PrepareFakeFactorHistograms.task_family in check_overlap_tasks, + add_default_to_description=True, +) + + +PrepareFakeFactorHistogramsWrapper = wrapper_factory( + base_cls=AnalysisTask, + require_cls=PrepareFakeFactorHistograms, + enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"], +) + + +class MergeFakeFactorHistograms( + #VariablesMixin, + #WeightProducerMixin, + #MLModelsMixin, + #ProducersMixin, + #SelectorStepsMixin, + #CalibratorsMixin, + DatasetTask, + law.LocalWorkflow, + RemoteWorkflow, +): + only_missing = luigi.BoolParameter( + default=False, + description="when True, identify missing variables first and only require histograms of " + "missing ones; default: False", + ) + remove_previous = luigi.BoolParameter( + default=False, + significant=False, + description="when True, remove particlar input histograms after merging; default: False", + ) + + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + + # upstream requirements + reqs = Requirements( + RemoteWorkflow.reqs, + PrepareFakeFactorHistograms=PrepareFakeFactorHistograms, + ) + + @classmethod + def req_params(cls, inst: AnalysisTask, **kwargs) -> dict: + _prefer_cli = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"variables"} + kwargs["_prefer_cli"] = _prefer_cli + return super().req_params(inst, **kwargs) + + def create_branch_map(self): + # create a dummy branch map so that this task could be submitted as a job + return {0: None} + + # def _get_variables(self): + # if self.is_workflow(): + # return self.as_branch()._get_variables() + + # variables = self.variables + + # # optional dynamic behavior: determine not yet created variables and require only those + # if self.only_missing: + # missing = self.output().count(existing=False, keys=True)[1] + # variables = sorted(missing, key=variables.index) + + # return variables + + def workflow_requires(self): + reqs = super().workflow_requires() + + if not self.pilot: + #variables = self._get_variables() + #if variables: + reqs["hists"] = self.reqs.PrepareFakeFactorHistograms.req_different_branching( + self, + branch=-1, + #variables=tuple(variables), + ) + + return reqs + + def requires(self): + #variables = self._get_variables() + #if not variables: + # return [] + + return self.reqs.PrepareFakeFactorHistograms.req_different_branching( + self, + branch=-1, + #variables=tuple(variables), + workflow="local", + ) + + def output(self): + return {"hists": self.target(f"merged_ff_hist.pickle")} + + @law.decorator.notify + @law.decorator.log + def run(self): + # preare inputs and outputs + inputs = self.input()["collection"] + outputs = self.output() + + # load input histograms + hists = [ + inp["hists"].load(formatter="pickle") + for inp in self.iter_progress(inputs.targets.values(), len(inputs), reach=(0, 50)) + ] + cats = list(hists[0].keys()) + get_hists = lambda hists, cat : [h[cat] for h in hists] + # create a separate file per output variable + merged_hists = {} + self.publish_message(f"merging {len(hists)} histograms for {self.dataset}") + for the_cat in cats: + h = get_hists(hists, the_cat) + merged_hists[the_cat] = sum(h[1:], h[0].copy()) + outputs["hists"].dump(merged_hists, formatter="pickle") + # optionally remove inputs + if self.remove_previous: + inputs.remove() + +MergeFakeFactorHistogramsWrapper = wrapper_factory( + base_cls=AnalysisTask, + require_cls=MergeFakeFactorHistograms, + enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"], +) + +class ComputeFakeFactors( + DatasetsProcessesMixin, + CategoriesMixin, + WeightProducerMixin, + ProducersMixin, +): + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + + only_missing = luigi.BoolParameter( + default=False, + description="when True, identify missing variables first and only require histograms of " + "missing ones; default: False", + ) + remove_previous = luigi.BoolParameter( + default=False, + significant=False, + description="when True, remove particlar input histograms after merging; default: False", + ) + + # upstream requirements + reqs = Requirements( + RemoteWorkflow.reqs, + MergeFakeFactorHistograms=MergeFakeFactorHistograms, + ) + + def store_parts(self): + parts = super().store_parts() + parts.insert_before("version", "datasets", f"datasets_{self.datasets_repr}") + return parts + + @classmethod + def req_params(cls, inst: AnalysisTask, **kwargs) -> dict: + _prefer_cli = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"variables"} + kwargs["_prefer_cli"] = _prefer_cli + return super().req_params(inst, **kwargs) + + def create_branch_map(self): + # create a dummy branch map so that this task could be submitted as a job + return {0: None} + + return reqs + def requires(self): + return { + d: self.reqs.MergeFakeFactorHistograms.req_different_branching( + self, + branch=-1, + dataset=d, + workflow="local", + ) + for d in self.datasets + } + + def output(self): + year = self.config_inst.campaign.aux['year'] + tag = self.config_inst.campaign.aux['tag'] + channel = self.config_inst.channels.get_first().name + return {"ff_json": self.target('_'.join(('fake_factors', + channel, + str(year), + tag)) + '.json'), + "plots": {'_'.join((ff_type, + syst, + f'n_jets_{str(nj)}')): self.target(f"fake_factor_{ff_type}_{syst}_njets_{str(nj)}.png") + for syst in ['nominal', 'up', 'down'] + for ff_type in ['qcd','wj'] + for nj in [0,1,2]}, + "plots1d": {'_'.join((ff_type, + str(dm), + str(nj))): self.target(f"fake_factor_{ff_type}_PNet_dm_{str(dm)}_njets_{str(nj)}.png") + for ff_type in ['qcd','wj'] + for dm in [0,1,2,10,11] + for nj in [0,1,2]}, + "fitres": self.target('_'.join(('fitres', + channel, + str(year), + tag)) + '.json'), + } + + @law.decorator.log + def run(self): + import hist + import numpy as np + from scipy.optimize import curve_fit + from scipy.special import erf + import matplotlib.pyplot as plt + import correctionlib.schemav2 as cs + from numpy import exp + plt.figure(dpi=200) + plt.rcParams.update({ + "text.usetex": True, + "font.family": "monospace", + "font.monospace": 'Computer Modern Typewriter' + }) + # preare inputs and outputs + inputs = self.input() + outputs = self.output() + + hists_by_dataset = [] + merged_hists = {} + for (dataset_name, dataset) in inputs.items(): + files = dataset['collection'][0] + + # load input histograms per dataset + input_chunked_hists = [] + input_chunked_hists = [f.load(formatter='pickle') for f in files.values()] + + for hists in input_chunked_hists: + for the_cat, the_hist in hists.items(): + if the_cat not in merged_hists.keys(): + merged_hists[the_cat] = [] + merged_hists[the_cat].append(the_hist) + else: + merged_hists[the_cat].append(the_hist) + + #merge histograms + mc_hists = {} + data_hists = {} + #devide between data and mc + for the_cat, h_list in merged_hists.items(): + for the_hist in h_list: + for proc_name in self.config_inst.processes.names(): + proc = self.config_inst.processes.get(proc_name) + if proc.id in the_hist.axes["process"]: + h = the_hist.copy() + h = h[{"process": hist.loc(proc.id)}] + if proc.is_mc and not proc.has_tag("signal"): + if the_cat in mc_hists: mc_hists[the_cat] += h + else: mc_hists[the_cat] = h + if proc.is_data: + if the_cat in data_hists: data_hists[the_cat] += h + else: data_hists[the_cat] = h + + def eval_formula(formula_str, popt,make_rounding=False): + for i,p in enumerate(popt): + if make_rounding: + formula_str = formula_str.replace(f'p{i}', '{:.3e}'.format(p)) + else: + formula_str = formula_str.replace(f'p{i}',str(p)) + return formula_str + + #Function that performs the calculation of t + def get_ff_corr(self, h_data, h_mc, dr_num, dr_den, name='ff_hist', label='ff_hist'): + + def get_single_cat(self, h, reg_name): + cat_name = self.config_inst.get_category(self.categories[0]).aux['ff_regs'][reg_name] + return h[cat_name] + data_num = get_single_cat(self, h_data, dr_num) + data_den = get_single_cat(self, h_data, dr_den) + mc_num = get_single_cat(self, h_mc, dr_num) + mc_den = get_single_cat(self, h_mc, dr_den) + print(name) + for nj in [0,1,2]: + for dm in [0,1,2,10,11]: + print(f'DM {dm} Nj {nj}') + print(f"data_num: {data_num[{'tau_dm_pnet': hist.loc(dm), 'n_jets': hist.loc(nj)}].values()}") + print(f"data_den: {data_den[{'tau_dm_pnet': hist.loc(dm), 'n_jets': hist.loc(nj)}].values()}") + print(f"mc_num: {mc_num[{'tau_dm_pnet': hist.loc(dm), 'n_jets': hist.loc(nj)}].values()}") + print(f"mc_den: {mc_den[{'tau_dm_pnet': hist.loc(dm), 'n_jets': hist.loc(nj)}].values()}") + num = data_num.values() - mc_num.values() + + den = data_den.values() - mc_den.values() + ff_val = np.where((num > 0) & (den > 0), + num / np.maximum(den, 1), + -1) + def rel_err(x): + return x.variances()/np.maximum(x.values()**2, 1) + + ff_err = ff_val * ((data_num.variances() + mc_num.variances())**0.5 / np.abs(num) + (data_den.variances() + mc_den.variances())**0.5 / np.abs(den)) + + ff_err[ff_val < 0] = 1 + h = hist.Hist.new + for (var_name, var_axis) in self.config_inst.x.fake_factor_method.axes.items(): + h = eval(f'h.{var_axis.ax_str}') + axes = list(h.axes[1:]) + h = h.StrCategory(['nominal', 'up', 'down'], name='syst', label='Statistical uncertainty of the fake factor') + ff_raw = h.Weight() + ff_raw.view().value[...,0] = ff_val + ff_raw.view().variance[...,0] = ff_err**2 + ff_raw.name = name + '_raw' + ff_raw.label = label + '_raw' + + def get_fitf(dm): + if dm==0: + formula_str = 'p0+p1*x+p2*x*x' + def fitf(x,p0,p1,p2): + return eval(formula_str) + else: + formula_str = 'p0+p1*exp(-p2*x)' + def fitf(x,p0,p1,p2): + from numpy import exp + return eval(formula_str) + return fitf, formula_str + + def get_jac(dm): + if dm==0: + def jac(x,p): + from numpy import array + return array([ 1., x, x**2]) + else: + def jac(x,p): + from numpy import array,exp,outer + ders=array([ 1., + exp(-p[2]*x), + -1*p[1]*x*exp(-p[2]*x)]) + return ders + return jac + + ff_fitted = ff_raw.copy().reset() + ff_fitted.name = name + ff_fitted.label = label + + fitres = {} + dm_axis = ff_raw.axes['tau_dm_pnet'] + n_jets_axis = ff_raw.axes['n_jets'] + + for nj in n_jets_axis: + if nj not in fitres.keys(): fitres[nj] = {} + for dm in dm_axis: + if dm not in fitres[nj].keys(): fitres[nj][dm] = {} + + + + + h1d = ff_raw[{'tau_dm_pnet': hist.loc(dm), + 'n_jets': hist.loc(nj), + 'syst': hist.loc('nominal')}] + mask = h1d.values() > 0 + x = h1d.axes[0].centers + if np.sum(mask) < 2: + y = np.zeros_like(x) + y_err = np.ones_like(x) + x_masked = x + else: + y = h1d.values()[mask] + y_err = (h1d.variances()[mask])**0.5 + x_masked = x[mask] + + fitf, formula_str = get_fitf(dm) + if dm==0: + the_bounds = ([-10,-5,-1],[10,5,1]) + else: + the_bounds = ([-0.5, -3, 0],[0.5,3,0.1]) + popt, pcov, infodict, mesg, ier = curve_fit(fitf, + x_masked, + y, + sigma=y_err, + bounds=the_bounds, + absolute_sigma=True, + full_output=True + ) + fitres[nj][dm]['chi2'] = sum((infodict['fvec'])**2) + fitres[nj][dm]['ndf'] = len(y) - len(popt) + fitres[nj][dm]['popt'] = popt + fitres[nj][dm]['pcov'] = pcov + fitres[nj][dm]['x_max'] = np.max(x_masked) + + fitres[nj][dm]['jac'] = get_jac(dm) + fitres[nj][dm]['name'] = name + fitres[nj][dm]['fitf'] = fitf + fitres[nj][dm]['fitf_str'] = formula_str + + for c, shift_name in enumerate(['down', 'nominal', 'up']): # if down then c=-1, if up c=+1, nominal => c=0 + ff_fitted.view().value[:, + ff_fitted.axes[1].index(dm), + ff_fitted.axes[2].index(nj), + ff_fitted.axes[3].index(shift_name)] = fitf(x, *popt + (c-1) * np.sqrt(np.diag(pcov))) + + return ff_raw, ff_fitted, fitres + + wj_raw, wj_fitted, wj_fitres = get_ff_corr(self, + data_hists, + mc_hists, + dr_num = 'dr_num_wj', + dr_den = 'dr_den_wj', + name='ff_wjets', + label='Fake factor W+jets') + + qcd_raw, qcd_fitted, qcd_fitres = get_ff_corr(self, + data_hists, + mc_hists, + dr_num = 'dr_num_qcd', + dr_den = 'dr_den_qcd', + name='ff_qcd', + label='Fake factor QCD') + + + corr_list = [] + for fitres_per_proc in [wj_fitres, qcd_fitres]: + nj_categories = [] + for nj, fitres_per_nj in fitres_per_proc.items(): + single_nj = [] + for dm, fitres in fitres_per_nj.items(): + x_max = fitres['x_max'] + fitf = fitres['fitf'] + popt = fitres['popt'] + fitf_str = eval_formula(fitres['fitf_str'], popt) + fx_max = np.maximum(fitf(x_max,*popt),0) + single_nj.append(cs.CategoryItem( + key=dm, + value=cs.Formula( + nodetype="formula", + variables=["tau_pt"], + parser="TFormula", + expression=f'({fitf_str})*((x-{x_max})<0) + ({fx_max})*((x-{x_max})>=0)', + ))) + nj_categories.append(cs.CategoryItem( + key=nj, + value=cs.Category( + nodetype="category", + input="tau_dm_pnet", + content=single_nj, + ))) + corr_list.append(cs.Correction( + name=fitres_per_proc[0][0]['name'], + description=f"fake factor correcton for {fitres_per_proc[0][0]['name'].split('_')[1]}", + version=2, + inputs=[ + cs.Variable(name="tau_pt", type="real",description="pt of tau"), + cs.Variable(name="tau_dm_pnet", type="int", description="PNet decay mode of tau"), + cs.Variable(name="n_jets", type="int", description="Number of jets with pt > 20 GeV and eta < 4.7"), + ], + output=cs.Variable(name="weight", type="real", description="Multiplicative event weight"), + data=cs.Category( + nodetype="category", + input="n_jets", + content=nj_categories, + ) + )) + cset = cs.CorrectionSet( + schema_version=2, + description="Fake factors", + corrections=corr_list + ) + self.output()['ff_json'].dump(cset.json(exclude_unset=True), formatter="json") + + chi2_string = 'type nj dm chi2 ndf,' + for fitres_per_proc in [wj_fitres, qcd_fitres]: + for dm, fitres_per_dm in fitres_per_proc.items(): + for nj, fitres in fitres_per_dm.items(): + chi2_string += ' '.join((fitres['name'], + str(nj), + str(dm), + str(fitres['chi2']), + str(fitres['ndf']))) + chi2_string += ',' + self.output()['fitres'].dump(chi2_string, formatter="json") + + #Plot fake factors: + for h_name in ['wj', 'qcd']: + h_raw = eval(f'{h_name}_raw') + h_fitted = eval(f'{h_name}_fitted') + fitres_dict = eval(f'{h_name}_fitres') + dm_axis = h_raw.axes['tau_dm_pnet'] + nj_axis = h_raw.axes['n_jets'] + for nj in nj_axis: + print(f"Plotting 2d map for n jets = {nj}") + fig, ax = plt.subplots(figsize=(12, 8)) + + single2d_h = h_raw[{'n_jets': hist.loc(nj), + 'syst': hist.loc('nominal')}] + pcm = ax.pcolormesh(*np.meshgrid(*single2d_h.axes.edges), single2d_h.view().value.T, cmap="viridis", vmin=0, vmax=0.5) + ax.set_yticks(dm_axis.centers, labels=list(map(dm_axis.bin, range(dm_axis.size)))) + plt.colorbar(pcm, ax=ax) + plt.xlabel(single2d_h.axes.label[0]) + plt.ylabel(single2d_h.axes.label[1]) + plt.title(single2d_h.label) + + self.output()['plots']['_'.join((h_name,'nominal',f'n_jets_{str(nj)}'))].dump(fig, formatter="mpl") + for dm in dm_axis: + print(f"Plotting 1d plot for n jets = {nj}, dm = {dm}") + h1d = h_raw[{'tau_dm_pnet': hist.loc(dm), + 'n_jets': hist.loc(nj), + 'syst': hist.loc('nominal')}] + hfit = h_fitted[{'tau_dm_pnet': hist.loc(dm), + 'n_jets': hist.loc(nj),}] + fig, ax = plt.subplots(figsize=(8, 6)) + mask = h1d.counts() > 0 + if np.sum(mask) > 0: + x = h1d.axes[0].centers[mask] + y = h1d.counts()[mask] + xerr = (np.diff(h1d.axes[0]).flatten()/2.)[mask], + yerr = np.sqrt(h1d.variances()).flatten()[mask], + else: + x = h1d.axes[0].centers + y = np.zeros_like(x) + xerr = (np.diff(h1d.axes[0]).flatten()/2.) + yerr = np.ones_like(y), + + ax.errorbar(x, y, xerr = xerr, yerr = yerr, + label=f"PNet decay mode = {dm}", + marker='o', + fmt='o', + line=None, color='#2478B7', capsize=4) + x_fine = np.linspace(x[0],x[-1],num=30) + fitres = fitres_dict[nj][dm] + popt = fitres['popt'] + pcov = fitres['pcov'] + jac = fitres['jac'] + def err(x,jac,pcov,popt): + from numpy import sqrt,einsum,abs + return sqrt(abs(einsum('i,ij,j',jac(x,popt).T,pcov,jac(x,popt)))) + + import functools + err_y = list(map(functools.partial(err, jac=jac,pcov=pcov,popt=popt), x_fine)) + + y_fitf = fitres['fitf'](x_fine,*popt) + y_fitf_up = fitres['fitf'](x_fine,*popt) + err_y + y_fitf_down = fitres['fitf'](x_fine,*(popt)) - err_y + + ax.plot(x_fine, + y_fitf, + color='#FF867B') + ax.fill_between(x_fine, y_fitf_up, y_fitf_down, color='#83d55f', alpha=0.5) + ax.set_ylabel('Fake Factor') + ax.set_xlabel('Tau pT [GeV]') + ax.set_title(f'Jet Fake Factors : Tau PNet Decay Mode {dm}, Njets {nj}') + ax.annotate(rf"$\frac{{\chi^2}}{{ndf}} = \frac{{{np.round(fitres['chi2'],2)}}}{{{fitres['ndf']}}}$", + (0.8, 0.75), + xycoords='axes fraction', + fontsize=20) + + formula_str = eval_formula(fitres['fitf_str'],popt, make_rounding=True) + + ax.annotate('y=' + formula_str, + (0.01, 0.95), + xycoords='axes fraction', + fontsize=12) + + self.output()['plots1d']['_'.join((h_name,str(dm),str(nj)))].dump(fig, formatter="mpl") \ No newline at end of file diff --git a/columnflow/tasks/framework/mixins.py b/columnflow/tasks/framework/mixins.py index 0de908b80..35549393b 100644 --- a/columnflow/tasks/framework/mixins.py +++ b/columnflow/tasks/framework/mixins.py @@ -2447,12 +2447,12 @@ class HistHookMixin(ConfigTask): "default: empty", ) - def invoke_hist_hooks(self, hists: dict) -> dict: + def invoke_hist_hooks(self, hists: dict, category_inst: od.Category) -> dict: """ Invoke hooks to update histograms before plotting. """ if not self.hist_hooks: - return hists + return hists[category_inst.name] for hook in self.hist_hooks: if hook in (None, "", law.NO_STR): @@ -2470,7 +2470,7 @@ def invoke_hist_hooks(self, hists: dict) -> dict: # invoke it self.publish_message(f"invoking hist hook '{hook}'") - hists = func(self, hists) + hists = func(self, hists, category_inst) return hists diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index 070e9c49d..d7603112c 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -22,7 +22,6 @@ from columnflow.util import dev_sandbox from columnflow.hist_util import create_hist_from_variables - class CreateHistograms( VariablesMixin, WeightProducerMixin, @@ -58,7 +57,7 @@ class CreateHistograms( @law.util.classproperty def mandatory_columns(cls) -> set[str]: - return set(cls.category_id_columns) | {"process_id"} + return set(cls.category_id_columns) | {"process_id", "ff_weight*"} def workflow_requires(self): reqs = super().workflow_requires() @@ -143,6 +142,9 @@ def run(self): read_columns = {Route("process_id")} read_columns |= set(map(Route, self.category_id_columns)) read_columns |= set(self.weight_producer_inst.used_columns) + read_columns |= set(map(Route, ['_'.join((the_name,the_shift)) + for the_name in self.config_inst.x.fake_factor_method.columns + for the_shift in self.config_inst.x.fake_factor_method.shifts])) read_columns |= set(map(Route, aliases.values())) read_columns |= { Route(inp) @@ -201,72 +203,85 @@ def run(self): # attach coffea behavior aiding functional variable expressions events = attach_coffea_behavior(events) - # build the full event weight if hasattr(self.weight_producer_inst, "skip_func") and not self.weight_producer_inst.skip_func(): events, weight = self.weight_producer_inst(events) else: weight = ak.Array(np.ones(len(events), dtype=np.float32)) + categories = self.config_inst.categories.names() + sr_names = [the_cat for the_cat in categories if 'sr' in the_cat] # define and fill histograms, taking into account multiple axes - for var_key, var_names in self.variable_tuples.items(): - # get variable instances - variable_insts = [self.config_inst.get_variable(var_name) for var_name in var_names] - - if var_key not in histograms: - # create the histogram in the first chunk - histograms[var_key] = create_hist_from_variables( - *variable_insts, - int_cat_axes=("category", "process", "shift"), - ) - - # mask events and weights when selection expressions are found - masked_events = events - masked_weights = weight - for variable_inst in variable_insts: - sel = variable_inst.selection - if sel == "1": - continue - if not callable(sel): - raise ValueError( - f"invalid selection '{sel}', for now only callables are supported", + for sr_name in sr_names: + #iterate over the regions needed for calculation of the ff_method + the_sr = self.config_inst.get_category(sr_name) + regions = [sr_name] + if the_sr.aux: + for the_key in the_sr.aux.keys(): + if (the_key == 'abcd_regs') or (the_key == 'ff_regs'): + regions += list(the_sr.aux[the_key].values()) + for region in regions: + #by accessing the list of categories we check if the category with this name exists + cat = self.config_inst.get_category(region) + if cat.name not in histograms.keys(): histograms[cat.name] = {} + for var_key, var_names in self.variable_tuples.items(): + # get variable instances + variable_insts = [self.config_inst.get_variable(var_name) for var_name in var_names] + + if var_key not in histograms[cat.name].keys(): + # create the histogram in the first chunk + histograms[cat.name][var_key] = create_hist_from_variables( + *variable_insts, + int_cat_axes=("process", "shift"), + ) + # mask events and weights when selection expressions are found + masked_events = events + + if 'apply_ff' in cat.aux.keys(): + if cat.aux['apply_ff'] == 'wj': + self.publish_message(f"applying FF weights: ff_weight_wj_nominal, category: {cat.name}") + masked_weights = weight * events.ff_weight_wj_nominal + elif cat.aux['apply_ff'] == 'qcd': + self.publish_message(f"applying FF weights: ff_weight_qcd_nominal, category: {cat.name}") + masked_weights = weight * events.ff_weight_qcd_nominal + else: + masked_weights = weight + else: + masked_weights = weight + + category_ids = ak.concatenate( + [Route(c).apply(masked_events) for c in self.category_id_columns], + axis=-1, + ) + mask = ak.any(category_ids == cat.id, axis = 1) + masked_events = masked_events[mask] + masked_weights = masked_weights[mask] + # broadcast arrays so that each event can be filled for all its categories + fill_data = { + "process": masked_events.process_id, + "shift": np.ones(len(masked_events), dtype=np.int32) * self.global_shift_inst.id, + "weight": masked_weights, + } + for variable_inst in variable_insts: + # prepare the expression + expr = variable_inst.expression + if isinstance(expr, str): + route = Route(expr) + def expr(masked_events, *args, **kwargs): + if len(masked_events) == 0 and not has_ak_column(masked_events, route): + return empty_f32 + return route.apply(masked_events, null_value=variable_inst.null_value) + # apply it + if variable_inst.name == "event": + fill_data[variable_inst.name] = np.sign(masked_events.event) + else: + fill_data[variable_inst.name] = expr(masked_events) + # fill it + fill_hist( + histograms[cat.name][var_key], + fill_data, + last_edge_inclusive=self.last_edge_inclusive, ) - mask = sel(masked_events) - masked_events = masked_events[mask] - masked_weights = masked_weights[mask] - - # merge category ids - category_ids = ak.concatenate( - [Route(c).apply(masked_events) for c in self.category_id_columns], - axis=-1, - ) - - # broadcast arrays so that each event can be filled for all its categories - fill_data = { - "category": category_ids, - "process": masked_events.process_id, - "shift": np.ones(len(masked_events), dtype=np.int32) * self.global_shift_inst.id, - "weight": masked_weights, - } - for variable_inst in variable_insts: - # prepare the expression - expr = variable_inst.expression - if isinstance(expr, str): - route = Route(expr) - def expr(events, *args, **kwargs): - if len(events) == 0 and not has_ak_column(events, route): - return empty_f32 - return route.apply(events, null_value=variable_inst.null_value) - # apply it - fill_data[variable_inst.name] = expr(masked_events) - - # fill it - fill_hist( - histograms[var_key], - fill_data, - last_edge_inclusive=self.last_edge_inclusive, - ) - # merge output files self.output()["hists"].dump(histograms, formatter="pickle") @@ -383,21 +398,21 @@ def run(self): inp["hists"].load(formatter="pickle") for inp in self.iter_progress(inputs.targets.values(), len(inputs), reach=(0, 50)) ] - + cats = list(hists[0].keys()) + variable_names = list(hists[0][cats[0]].keys()) + get_hists = lambda hists, cat, var : [h[cat][var] for h in hists] # create a separate file per output variable - variable_names = list(hists[0].keys()) for variable_name in self.iter_progress(variable_names, len(variable_names), reach=(50, 100)): - self.publish_message(f"merging histograms for '{variable_name}'") - - variable_hists = [h[variable_name] for h in hists] - merged = sum(variable_hists[1:], variable_hists[0].copy()) - outputs["hists"][variable_name].dump(merged, formatter="pickle") - + merged_hists = {} + for the_cat in cats: + self.publish_message(f"merging histograms for {variable_name}, category: {the_cat}") + variable_hists = get_hists(hists, the_cat, variable_name) + merged_hists[the_cat] = sum(variable_hists[1:], variable_hists[0].copy()) + outputs["hists"][variable_name].dump(merged_hists, formatter="pickle") # optionally remove inputs if self.remove_previous: inputs.remove() - MergeHistogramsWrapper = wrapper_factory( base_cls=AnalysisTask, require_cls=MergeHistograms, @@ -474,13 +489,18 @@ def run(self): self.publish_message(f"merging histograms for '{variable_name}'") # load hists + + variable_hists = [ coll["hists"].targets[variable_name].load(formatter="pickle") for coll in inputs.values() ] - - # merge and write the output - merged = sum(variable_hists[1:], variable_hists[0].copy()) + merged = {} + get_hists = lambda hists, cat : [h[cat] for h in hists] + for the_cat in variable_hists[0].keys(): + single_cat_hists = get_hists(variable_hists, the_cat) + merged[the_cat] = sum(single_cat_hists[1:], single_cat_hists[0].copy()) + outp.dump(merged, formatter="pickle") diff --git a/columnflow/tasks/plotting.py b/columnflow/tasks/plotting.py index b922684a8..71cbc5f27 100644 --- a/columnflow/tasks/plotting.py +++ b/columnflow/tasks/plotting.py @@ -111,44 +111,58 @@ def run(self): for dataset, inp in self.input().items(): dataset_inst = self.config_inst.get_dataset(dataset) h_in = inp["collection"][0]["hists"].targets[self.branch_data.variable].load(formatter="pickle") - + # loop and extract one histogram per process - for process_inst in process_insts: - # skip when the dataset is already known to not contain any sub process - if not any( - dataset_inst.has_process(sub_process_inst.name) - for sub_process_inst in sub_process_insts[process_inst] - ): - continue - - # select processes and reduce axis - h = h_in.copy() - h = h[{ - "process": [ - hist.loc(p.id) - for p in sub_process_insts[process_inst] - if p.id in h.axes["process"] - ], - }] - h = h[{"process": sum}] - - # add the histogram - if process_inst in hists: - hists[process_inst] += h - else: - hists[process_inst] = h - + for region in h_in.keys(): + if region not in hists: hists[region] = {} + for process_inst in process_insts: + # skip when the dataset is already known to not contain any sub process + if not any( + dataset_inst.has_process(sub_process_inst.name) + for sub_process_inst in sub_process_insts[process_inst] + ): + continue + + # select processes and reduce axis + h = h_in[region].copy() + h = h[{ + "process": [ + hist.loc(p.id) + for p in sub_process_insts[process_inst] + if p.id in h.axes["process"] + ], + }] + h = h[{"process": sum}] + + # add the histogram + if process_inst in hists[region]: + hists[region][process_inst] += h + else: + hists[region][process_inst] = h + + # there should be hists to plot + if not hists: raise Exception( "no histograms found to plot; possible reasons:\n" " - requested variable requires columns that were missing during histogramming\n" " - selected --processes did not match any value on the process axis of the input histogram", ) - - # update histograms using custom hooks - hists = self.invoke_hist_hooks(hists) - + if category_inst.aux: #Assume that aux exists only for signal regions since it contains the information about application and determination regions + if self.hist_hooks: + hists = self.invoke_hist_hooks(hists,category_inst) + else: + hists = hists[category_inst.name] + else: + if 'dr' in category_inst.name: + hists = self.invoke_hist_hooks(hists,category_inst) + elif category_inst.name in hists.keys(): + hists = hists[category_inst.name] + else: + raise Exception( + f"no histograms found to plot for {category_inst.name}" + ) # add new processes to the end of the list for process_inst in hists: if process_inst not in process_insts: @@ -160,11 +174,6 @@ def run(self): h = hists[process_inst] # selections h = h[{ - "category": [ - hist.loc(c.id) - for c in leaf_category_insts - if c.id in h.axes["category"] - ], "shift": [ hist.loc(s.id) for s in plot_shifts @@ -172,11 +181,9 @@ def run(self): ], }] # reductions - h = h[{"category": sum}] # store _hists[process_inst] = h hists = _hists - # call the plot function fig, _ = self.call_plot_func( self.plot_function, @@ -213,6 +220,11 @@ def create_branch_map(self): def workflow_requires(self): reqs = super().workflow_requires() + + # no need to require merged histograms since each branch already requires them as a workflow + # if self.workflow == "local": + # reqs.pop("merged_hists", None) + return reqs def requires(self): diff --git a/columnflow/tasks/yields.py b/columnflow/tasks/yields.py index e7d26ca57..01ba92079 100644 --- a/columnflow/tasks/yields.py +++ b/columnflow/tasks/yields.py @@ -21,6 +21,245 @@ from columnflow.util import dev_sandbox, try_int +# class CreateYieldTable( +# DatasetsProcessesMixin, +# CategoriesMixin, +# WeightProducerMixin, +# ProducersMixin, +# SelectorStepsMixin, +# CalibratorsMixin, +# law.LocalWorkflow, +# RemoteWorkflow, +# ): +# sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + +# table_format = luigi.Parameter( +# default="fancy_grid", +# significant=False, +# description="format of the yield table; accepts all formats of the tabulate package; " +# "default: fancy_grid", +# ) +# number_format = luigi.Parameter( +# default="pdg", +# significant=False, +# description="rounding format of each number in the yield table; accepts all formats " +# "understood by scinum.Number.str(), e.g. 'pdg', 'publication', '%.1f' or an integer " +# "(number of signficant digits); default: pdg", +# ) +# skip_uncertainties = luigi.BoolParameter( +# default=False, +# significant=False, +# description="when True, uncertainties are not displayed in the table; default: False", +# ) +# normalize_yields = luigi.ChoiceParameter( +# choices=(law.NO_STR, "per_process", "per_category", "all"), +# default=law.NO_STR, +# significant=False, +# description="string parameter to define the normalization of the yields; " +# "choices: '', per_process, per_category, all; empty default", +# ) +# output_suffix = luigi.Parameter( +# default=law.NO_STR, +# description="Adds a suffix to the output name of the yields table; empty default", +# ) + +# # upstream requirements +# reqs = Requirements( +# RemoteWorkflow.reqs, +# MergeHistograms=MergeHistograms, +# ) + +# # dummy branch map +# def create_branch_map(self): +# return [0] + +# def requires(self): +# return { +# d: self.reqs.MergeHistograms.req( +# self, +# dataset=d, +# variables=("event",), +# _prefer_cli={"variables"}, +# ) +# for d in self.datasets +# } + +# def workflow_requires(self): +# reqs = super().workflow_requires() + +# reqs["merged_hists"] = [ +# self.reqs.MergeHistograms.req( +# self, +# dataset=d, +# variables=("event",), +# _exclude={"branches"}, +# ) +# for d in self.datasets +# ] + +# return reqs + +# @classmethod +# def resolve_param_values(cls, params): +# params = super().resolve_param_values(params) + +# if "number_format" in params and try_int(params["number_format"]): +# # convert 'number_format' in integer if possible +# params["number_format"] = int(params["number_format"]) + +# return params + +# def output(self): +# suffix = "" +# if self.output_suffix and self.output_suffix != law.NO_STR: +# suffix = f"__{self.output_suffix}" + +# return { +# "table": self.target(f"table__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.txt"), +# "yields": self.target(f"yields__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.json"), +# } + +# @law.decorator.notify +# @law.decorator.log +# def run(self): +# import hist +# from tabulate import tabulate + +# inputs = self.input() +# outputs = self.output() + +# category_insts = list(map(self.config_inst.get_category, self.categories)) +# process_insts = list(map(self.config_inst.get_process, self.processes)) +# sub_process_insts = { +# proc: [sub for sub, _, _ in proc.walk_processes(include_self=True)] +# for proc in process_insts +# } + +# # histogram data per process +# hists = {} + +# with self.publish_step(f"Creating yields for processes {self.processes}, categories {self.categories}"): +# for dataset, inp in inputs.items(): +# dataset_inst = self.config_inst.get_dataset(dataset) + +# # load the histogram of the variable named "event" +# input_hists = inp["hists"]["event"].load(formatter="pickle") + +# # loop and extract one histogram per process +# for process_inst in process_insts: +# # skip when the dataset is already known to not contain any sub process +# if not any(map(dataset_inst.has_process, sub_process_insts[process_inst])): +# continue + +# # work on a copy +# h = h_in.copy() + +# # axis selections +# h = h[{ +# "process": [ +# hist.loc(p.id) +# for p in sub_process_insts[process_inst] +# if p.id in h.axes["process"] +# ], +# }] + +# # axis reductions +# h = h[{"process": sum, "shift": sum, "event": sum}] + +# # add the histogram +# if process_inst in hists: +# hists[process_inst] += h +# else: +# hists[process_inst] = h + +# # there should be hists to plot +# if not hists: +# raise Exception("no histograms found to plot") + +# # sort hists by process order +# hists = OrderedDict( +# (process_inst, hists[process_inst]) +# for process_inst in sorted(hists, key=process_insts.index) +# ) + +# yields, processes = defaultdict(list), [] + +# # read out yields per category and per process +# for process_inst, h in hists.items(): +# processes.append(process_inst) + +# for category_inst in category_insts: +# leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] + +# h_cat = h[{"category": [ +# hist.loc(c.id) +# for c in leaf_category_insts +# if c.id in h.axes["category"] +# ]}] +# h_cat = h_cat[{"category": sum}] + +# value = Number(h_cat.value) +# if not self.skip_uncertainties: +# # set a unique uncertainty name for correct propagation below +# value.set_uncertainty( +# f"mcstat_{process_inst.name}_{category_inst.name}", +# math.sqrt(h_cat.variance), +# ) +# yields[category_inst].append(value) + +# # obtain normalizaton factors +# norm_factors = 1 +# if self.normalize_yields == "all": +# norm_factors = sum( +# sum(category_yields) +# for category_yields in yields.values() +# ) +# elif self.normalize_yields == "per_process": +# norm_factors = [ +# sum(yields[category][i] for category in yields.keys()) +# for i in range(len(yields[category_insts[0]])) +# ] +# elif self.normalize_yields == "per_category": +# norm_factors = { +# category: sum(category_yields) +# for category, category_yields in yields.items() +# } + +# # initialize dicts +# yields_str = defaultdict(list, {"Process": [proc.label for proc in processes]}) +# raw_yields = defaultdict(dict, {}) + +# # apply normalization and format +# for category, category_yields in yields.items(): +# for i, value in enumerate(category_yields): +# # get correct norm factor per category and process +# if self.normalize_yields == "per_process": +# norm_factor = norm_factors[i] +# elif self.normalize_yields == "per_category": +# norm_factor = norm_factors[category] +# else: +# norm_factor = norm_factors + +# raw_yield = (value / norm_factor).nominal +# raw_yields[category.name][processes[i].name] = raw_yield + +# # format yields into strings +# yield_str = (value / norm_factor).str( +# combine_uncs="all", +# format=self.number_format, +# style="latex" if "latex" in self.table_format else "plain", +# ) +# if "latex" in self.table_format: +# yield_str = f"${yield_str}$" +# yields_str[category.label].append(yield_str) + +# # create, print and save the yield table +# yield_table = tabulate(yields_str, headers="keys", tablefmt=self.table_format) +# self.publish_message(yield_table) + +# outputs["table"].dump(yield_table, formatter="text") +# outputs["yields"].dump(raw_yields, formatter="json") + class CreateYieldTable( DatasetsProcessesMixin, CategoriesMixin, @@ -128,7 +367,7 @@ def run(self): inputs = self.input() outputs = self.output() - category_insts = list(map(self.config_inst.get_category, self.categories)) + category_insts = list(self.categories) process_insts = list(map(self.config_inst.get_process, self.processes)) sub_process_insts = { proc: [sub for sub, _, _ in proc.walk_processes(include_self=True)] @@ -136,123 +375,92 @@ def run(self): } # histogram data per process - hists = {} - + merged_hists = {} with self.publish_step(f"Creating yields for processes {self.processes}, categories {self.categories}"): for dataset, inp in inputs.items(): dataset_inst = self.config_inst.get_dataset(dataset) # load the histogram of the variable named "event" - h_in = inp["hists"]["event"].load(formatter="pickle") - - # loop and extract one histogram per process - for process_inst in process_insts: - # skip when the dataset is already known to not contain any sub process - if not any(map(dataset_inst.has_process, sub_process_insts[process_inst])): - continue - - # work on a copy - h = h_in.copy() - - # axis selections - h = h[{ - "process": [ - hist.loc(p.id) - for p in sub_process_insts[process_inst] - if p.id in h.axes["process"] - ], - }] - - # axis reductions - h = h[{"process": sum, "shift": sum, "event": sum}] - - # add the histogram - if process_inst in hists: - hists[process_inst] += h + input_hists = inp["hists"]["event"].load(formatter="pickle") + + for the_cat in category_insts: + the_hist = input_hists[the_cat] + if the_cat not in merged_hists.keys(): + merged_hists[the_cat] = [] + merged_hists[the_cat].append(the_hist) else: - hists[process_inst] = h - + merged_hists[the_cat].append(the_hist) + #merge histograms + + merged_hists_ = {} + for the_cat, h in merged_hists.items(): + if len(h) > 1: merged_hists_[the_cat] = sum(h[1:],h[0].copy()) + else: + merged_hists_[the_cat] = h[0].copy() + + hists_per_proc = {} + for the_cat, the_hist in merged_hists_.items(): + hists_per_proc[the_cat] = {} + for proc in process_insts: + leaf_procs = proc.get_leaf_processes() + if len(leaf_procs) == 0 : leaf_procs = [proc] + for leaf_proc in leaf_procs: + if leaf_proc.id in the_hist.axes["process"]: + h = the_hist.copy() + h = h[{"process": hist.loc(leaf_proc.id)}] + + if proc in hists_per_proc[the_cat]: + hists_per_proc[the_cat][proc] += h + else: + hists_per_proc[the_cat][proc] = h + # there should be hists to plot - if not hists: + if not hists_per_proc: raise Exception("no histograms found to plot") - # sort hists by process order - hists = OrderedDict( - (process_inst, hists[process_inst]) - for process_inst in sorted(hists, key=process_insts.index) + hists = {} + for the_cat in hists_per_proc.keys(): + single_cat_hists = hists_per_proc[the_cat] + hists[the_cat] = OrderedDict( + (process_inst, single_cat_hists[process_inst]) + for process_inst in sorted(single_cat_hists, key=process_insts.index) ) - - yields, processes = defaultdict(list), [] - - # read out yields per category and per process - for process_inst, h in hists.items(): - processes.append(process_inst) - - for category_inst in category_insts: - leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] - - h_cat = h[{"category": [ - hist.loc(c.id) - for c in leaf_category_insts - if c.id in h.axes["category"] - ]}] - h_cat = h_cat[{"category": sum}] - - value = Number(h_cat.value) - if not self.skip_uncertainties: + #Calculate yields + yields = {} + for the_cat in hists.keys(): + tmp = {} + for the_proc in hists[the_cat].keys(): + val = Number(hists[the_cat][the_proc].sum().value) + + if not self.skip_uncertainties and not the_proc.is_data: # set a unique uncertainty name for correct propagation below - value.set_uncertainty( - f"mcstat_{process_inst.name}_{category_inst.name}", - math.sqrt(h_cat.variance), + val.set_uncertainty( + f"mcstat_{the_proc.name}_{the_cat}", + math.sqrt(hists[the_cat][the_proc].sum().variance), ) - yields[category_inst].append(value) - - # obtain normalizaton factors - norm_factors = 1 - if self.normalize_yields == "all": - norm_factors = sum( - sum(category_yields) - for category_yields in yields.values() - ) - elif self.normalize_yields == "per_process": - norm_factors = [ - sum(yields[category][i] for category in yields.keys()) - for i in range(len(yields[category_insts[0]])) - ] - elif self.normalize_yields == "per_category": - norm_factors = { - category: sum(category_yields) - for category, category_yields in yields.items() - } - + tmp[the_proc]=val + yields[the_cat] = OrderedDict(tmp) # initialize dicts - yields_str = defaultdict(list, {"Process": [proc.label for proc in processes]}) + yields_str = defaultdict(list, {"Process" : [proc.label for proc in process_insts]}) raw_yields = defaultdict(dict, {}) - # apply normalization and format - for category, category_yields in yields.items(): - for i, value in enumerate(category_yields): - # get correct norm factor per category and process - if self.normalize_yields == "per_process": - norm_factor = norm_factors[i] - elif self.normalize_yields == "per_category": - norm_factor = norm_factors[category] - else: - norm_factor = norm_factors - - raw_yield = (value / norm_factor).nominal - raw_yields[category.name][processes[i].name] = raw_yield - - # format yields into strings - yield_str = (value / norm_factor).str( - combine_uncs="all", - format=self.number_format, - style="latex" if "latex" in self.table_format else "plain", - ) + for cat in yields.keys(): + yields_per_cat = yields[cat] + for proc in process_insts: + if proc in yields_per_cat: + raw_yield = yields_per_cat[proc].nominal + yield_str = (yields_per_cat[proc]).str( + combine_uncs="all", + format=self.number_format, + style="latex" if "latex" in self.table_format else "plain", + ) + else: + raw_yield = Number(-1).nominal + yield_str = str(-1) + raw_yields[cat][proc.name] = raw_yield if "latex" in self.table_format: yield_str = f"${yield_str}$" - yields_str[category.label].append(yield_str) - + yields_str[cat].append(yield_str) # create, print and save the yield table yield_table = tabulate(yields_str, headers="keys", tablefmt=self.table_format) self.publish_message(yield_table) diff --git a/law.cfg b/law.cfg index 86b667a76..5d01d5d05 100644 --- a/law.cfg +++ b/law.cfg @@ -8,6 +8,7 @@ columnflow.tasks.reduction columnflow.tasks.production columnflow.tasks.ml columnflow.tasks.union +columnflow.tasks.data_driven_methods columnflow.tasks.histograms columnflow.tasks.plotting columnflow.tasks.yields @@ -59,7 +60,7 @@ slurm_flavor: $CF_SLURM_FLAVOR slurm_partition: $CF_SLURM_PARTITION # ChunkedIOHandler defaults -chunked_io_chunk_size: 100000 +chunked_io_chunk_size: 50000 chunked_io_pool_size: 2 chunked_io_debug: False diff --git a/sandboxes/cmssw_columnar.sh b/sandboxes/cmssw_columnar.sh index 350a954be..5dd283d28 100644 --- a/sandboxes/cmssw_columnar.sh +++ b/sandboxes/cmssw_columnar.sh @@ -10,8 +10,10 @@ action() { # set variables and source the generic CMSSW setup export CF_SANDBOX_FILE="${CF_SANDBOX_FILE:-${this_file}}" - export CF_SCRAM_ARCH="el9_amd64_gcc11" - export CF_CMSSW_VERSION="CMSSW_13_0_19" + export CF_SCRAM_ARCH="el9_amd64_gcc12" + export CF_CMSSW_VERSION="CMSSW_14_1_0_pre4" + # export CF_SCRAM_ARCH="$( [ "${os_version}" = "8" ] && echo "el8" || echo "slc7" )_amd64_gcc10" + # export CF_CMSSW_VERSION="CMSSW_12_6_2" export CF_CMSSW_ENV_NAME="$( basename "${this_file%.sh}" )" export CF_CMSSW_FLAG="1" # increment when content changed diff --git a/sandboxes/cmssw_default.sh b/sandboxes/cmssw_default.sh index d2e31eb15..95dcaf592 100644 --- a/sandboxes/cmssw_default.sh +++ b/sandboxes/cmssw_default.sh @@ -10,8 +10,8 @@ action() { # set variables and source the generic CMSSW setup export CF_SANDBOX_FILE="${CF_SANDBOX_FILE:-${this_file}}" - export CF_SCRAM_ARCH="el9_amd64_gcc11" - export CF_CMSSW_VERSION="CMSSW_13_0_19" + export CF_SCRAM_ARCH="el9_amd64_gcc12" + export CF_CMSSW_VERSION="CMSSW_14_1_0_pre4" export CF_CMSSW_ENV_NAME="$( basename "${this_file%.sh}" )" export CF_CMSSW_FLAG="1" # increment when content changed