"""
Implement simple version of TxPipe NZDir summarizer
"""
import numpy as np
from ceci.config import StageParameter as Param
from rail.estimation.estimator import CatEstimator, CatInformer
from rail.core.data import QPHandle
import qp
import scipy.spatial
import pandas as pd
[docs]class NZDir(CatEstimator):
"""Quick implementation of a summarizer that creates
weights for each input object using sklearn's
NearestNeighbors. Very basic, we can probably
create a more sophisticated SOM-based DIR method in
the future
Parameters
----------
zmin: float
min redshift for z grid
zmax: float
max redshift for z grid
nzbins: int
number of bins in z grid
Returns
-------
qp_ens: qp Ensemble
histogram Ensemble describing N(z) estimate
"""
bands = ['u', 'g', 'r', 'i', 'z', 'y']
default_usecols = [f"mag_{band}_lsst" for band in bands]
name = 'NZDir'
config_options = CatEstimator.config_options.copy()
config_options.update(zmin=Param(float, 0.0, msg="The minimum redshift of the z grid"),
zmax=Param(float, 3.0, msg="The maximum redshift of the z grid"),
nzbins=Param(int, 301, msg="The number of gridpoints in the z grid"),
seed=Param(int, 87, msg="random seed"),
usecols=Param(list, default_usecols, msg="columns from sz_date for Neighor calculation"),
leafsize=Param(int, 40, msg="leaf size for testdata KDTree"),
hdf5_groupname=Param(str, "photometry", msg="name of hdf5 group for data, if None, then set to ''"),
phot_weightcol=Param(str, "", msg="name of photometry weight, if present"),
nsamples=Param(int, 20, msg="number of bootstrap samples to generate"))
outputs = [('output', QPHandle),
('single_NZ', QPHandle)]
def __init__(self, args, comm=None):
self.zgrid = None
self.model = None
self.distances = None
self.szusecols = None
self.szweights = None
self.sz_mag_data = None
self.bincents = None
CatEstimator.__init__(self, args, comm=comm)
[docs] def open_model(self, **kwargs):
CatEstimator.open_model(self, **kwargs)
self.distances = self.model['distances']
self.szusecols = self.model['szusecols']
self.szweights = self.model['szweights']
self.szvec = self.model['szvec']
self.sz_mag_data = self.model['sz_mag_data']
[docs] def run(self):
rng = np.random.default_rng(seed=self.config.seed)
self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1)
self.bincents = 0.5 * (self.zgrid[1:] + self.zgrid[:-1])
# Creating de indices for the bootstrap sampling, and broadcasting to the other process
# if we are running in parallel
ngal = len(self.szweights)
nsamp = self.config.nsamples
if self.rank == 0:
bootstrap_matrix = rng.integers(low=0, high=ngal, size=(ngal,nsamp))
else: # pragma: no cover
bootstrap_matrix = None
if self.comm is not None: # pragma: no cover
bootstrap_matrix = self.comm.bcast(bootstrap_matrix, root = 0)
iterator = self.input_iterator('input')
first = True
self._initialize_run()
self._output_handle = None
total_chunks = int(np.ceil(self._input_length/self.config.chunk_size))
for s, e, test_data in iterator:
print(f"Process {self.rank} running estimator on chunk {s} - {e}")
total_chunks = int(np.ceil(self._input_length/self.config.chunk_size))
chunk_number = s//self.config.chunk_size
self._process_chunk(first, total_chunks, chunk_number, test_data, bootstrap_matrix)
first = False
self._single_handle.finalize_write()
self._sample_handle.finalize_write()
if self.comm is not None: # pragma: no cover
self.comm.Barrier()
if self.rank == 0:
# Joining the histograms and updating the data handles
self.join_histograms()
def _process_chunk(self, first, number_of_chunks, chunk_number, test_data, bootstrap_matrix):
# assign weight vecs if present, else set all to 1.0
# tested in example notebook, so just put a pragma no cover for if present
if self.config.phot_weightcol == "":
pweight = np.ones(len(test_data[self.config.usecols[0]]))
elif self.config.phot_weightcol in test_data.keys(): # pragma: no cover
pweight = np.array(test_data[self.config.phot_weightcol])
else:
raise KeyError(f"photometric weight column {self.config.phot_weightcol} not present in data!")
# calculate weights for the test data
weights = np.zeros(len(self.szvec))
tmpdf = pd.DataFrame(test_data)
phot_mag_data = np.array([tmpdf[band] for band in self.config.usecols]).T
phot_mag_data[~np.isfinite(phot_mag_data)] = 40.
# create a tree for the photometric data, for each specz object find all the
# tomo objects within the distance to Kth speczNN from before
tree = scipy.spatial.KDTree(phot_mag_data, leafsize=self.config.leafsize)
indices = tree.query_ball_point(self.sz_mag_data, self.distances)
# for each of the indexed galaxies within the distance, add the weights to the
# appropriate tomographic bin
for j, index in enumerate(indices):
weights[j] += pweight[index].sum()
# make weighted histograms
hist_data = np.histogram(
self.szvec,
bins=self.zgrid,
weights=weights * self.szweights,
)
normalization_factor = hist_data[0].sum()
qp_dstn = qp.Ensemble(qp.hist,
data=dict(bins=self.zgrid, pdfs=hist_data[0]),ancil=dict(normalization=np.array([normalization_factor])))
if first:
self._single_handle = self.initialize_handle('single_NZ', qp_dstn, number_of_chunks)
self._do_chunk_output(qp_dstn, chunk_number, chunk_number+1, self._single_handle)
# The way things are set up, it is easier and faster to bootstrap the spec-z gals
# and weights, but if we wanted to be more like the other bootstraps we should really
# bootstrap the photometric data and re-run the ball tree query N times.
nsamp = self.config.nsamples
hist_vals = np.empty((nsamp, self.config.nzbins))
for i in range(nsamp):
bootstrap_indices = bootstrap_matrix[:,i]
zarr = self.szvec[bootstrap_indices]
tmpweight = self.szweights[bootstrap_indices] * weights[bootstrap_indices]
tmp_hist_vals = np.histogram(zarr, bins=self.zgrid, weights=tmpweight)[0]
hist_vals[i] = tmp_hist_vals
normalization_factor = hist_vals.sum(axis=(1,))
sample_ens = qp.Ensemble(qp.hist,
data=dict(bins=self.zgrid, pdfs=np.atleast_2d(hist_vals)),ancil=dict(normalization=normalization_factor))
npdf = nsamp*number_of_chunks
if first:
self._sample_handle = self.initialize_handle('output', sample_ens, npdf)
self._do_chunk_output(sample_ens, chunk_number*nsamp, (chunk_number+1)*nsamp, self._sample_handle)
[docs] def initialize_handle(self, tag, data, npdf):
handle = self.add_handle(tag, data = data)
handle.initialize_write(npdf, communicator = self.comm)
return handle
def _do_chunk_output(self, qp_dstn, start, end, handle):
handle.set_data(qp_dstn, partial=True)
handle.write_chunk(start, end)
[docs] def join_histograms(self):
qp_d_spread = self._single_handle.read(force=True)
tab_single = qp_d_spread.build_tables()
hist_single = np.zeros(self.config.nzbins)
normalization = tab_single['ancil']['normalization']
total_chunks = len(normalization)
for i in range(total_chunks):
hist_single += tab_single['data']['pdfs'][i]*normalization[i]
sample_ens_spread = self._sample_handle.read(force=True)
tab_samples = sample_ens_spread.build_tables()
nsamp = self.config.nsamples
hist_samples = np.zeros((nsamp,self.config.nzbins))
normalization = tab_samples['ancil']['normalization']
hist_samples_spread = tab_samples['data']['pdfs']*normalization[:,None]
for i in range(total_chunks):
hist_samples += hist_samples_spread[i*nsamp:(i+1)*nsamp]
sample_ens = qp.Ensemble(qp.hist, data=dict(bins=self.zgrid, pdfs=np.atleast_2d(hist_samples)))
qp_d = qp.Ensemble(qp.hist, data=dict(bins=self.zgrid, pdfs=np.atleast_2d(hist_single)))
self._single_handle.set_data(qp_d)
self._sample_handle.set_data(sample_ens)
self._single_handle.write()
self._sample_handle.write()