Source code for rail.creation.degradation.quantityCut

"""Degrader that applies a cut to given columns."""

from numbers import Number

import numpy as np
from rail.creation.degrader import Degrader


[docs]class QuantityCut(Degrader): """Degrader that applies a cut to the given columns. Note if a galaxy fails any of the cuts on any one of its columns, that galaxy is removed from the sample. """ name = "QuantityCut" config_options = Degrader.config_options.copy() config_options.update(cuts=dict) def __init__(self, args, comm=None): """ Constructor Does standard Degrader initialization and also gets defines the cuts to be applied """ Degrader.__init__(self, args, comm=comm) self.cuts = None self.set_cuts(self.config["cuts"])
[docs] def set_cuts(self, cuts: dict): """ Parameters ---------- cuts : dict A dictionary of cuts to make on the data. Notes ----- The cut keys should be the names of columns you wish to make cuts on. The cut values should be either: - a number, which is the maximum value. I.e. if the dictionary contains "i": 25, then values of i > 25 are cut from the sample. - an iterable, which is the range of acceptable values. I.e. if the dictionary contains "redshift": (1.5, 2.3), then redshifts outside that range are cut from the sample. """ # check that cuts is a dictionary if not isinstance(cuts, dict): # pragma: no cover raise TypeError("cuts must be a dictionary.") # validate all the cuts and standardize format in dictionary self.cuts = {} for quantity, cut in cuts.items(): bad_cut_msg = ( f"Cut for {quantity} must be a number or an iterable of (min, max)" ) # if single number is provided, save that as the maximum if isinstance(cut, Number): self.cuts[quantity] = (-np.inf, cut) # else, if it's an iterable... elif hasattr(cut, "__iter__"): # if the iterable is a string or dict, raise error if isinstance(cut, (dict, str)): raise TypeError(bad_cut_msg) # if the length of the iterable isn't 2, raise error if len(cut) != 2: raise ValueError(bad_cut_msg) # check that both of the cut values are a number for c in cut: if not isinstance(c, Number): raise TypeError(bad_cut_msg) # check that the cuts are (min, max) if cut[0] > cut[1]: raise ValueError( f"The max for {quantity} must be greater than the min." ) # if all of these checks passed, save the cuts self.cuts[quantity] = (cut[0], cut[1]) # if the cut isn't a number or iterable, it's the wrong type else: raise TypeError(bad_cut_msg)
[docs] def run(self): """Run method Applies cuts Notes ----- Get the input data from the data store under this stages 'input' tag Puts the data into the data store under this stages 'output' tag """ data = self.get_data("input") # get overlap of columns from data and columns on which to make cuts columns = set(self.cuts.keys()).intersection(data.columns) if len(columns) == 0: # pragma: no cover self.add_data("output", data) else: # generate a pandas query from the cuts query = [ f"{col} > {self.cuts[col][0]} & {col} < {self.cuts[col][1]}" for col in columns ] query = " & ".join(query) out_data = data.query(query) self.add_data("output", out_data)
def __repr__(self): # pragma: no cover """Pretty print this object""" printMsg = "Degrader that applies the following cuts to a pandas DataFrame:\n" printMsg += "{column: (min, max), ...}\n" printMsg += self.cuts.__str__() return printMsg