Source code for rail.estimation.algos.gpz_v1
"""
RAIL wrapping of Peter Hatfield's version of GPz, which
can be found at:
https://github.com/pwhatfield/GPz_py3
"""
import numpy as np
from ceci.config import StageParameter as Param
from rail.core.common_params import SHARED_PARAMS
from rail.estimation.estimator import CatEstimator, CatInformer
from .GPz import GP, getOmega
import qp
def _prepare_data(data_dict, bands, err_bands, nondet_val, maglims, logflag):
"""Put data in 2D np array expected by GPz.
For some reason they like to take the log of the magnitude errors, so
have that as a boolean option. Also replace nondetect vals for each
band
"""
numbands = len(bands)
totrows = len(data_dict[bands[0]])
data = np.empty([totrows, 2 * numbands])
for i, (band, eband, lim) in enumerate(zip(bands, err_bands, maglims.values())):
data[:, i] = data_dict[band]
mask = np.bitwise_or(np.isclose(data_dict[band], nondet_val), np.isnan(data_dict[band]))
data[:, i][mask] = lim
if logflag:
data[:, numbands + i] = np.log(data_dict[eband])
else: # pragma: no cover
data[:, numbands + i] = data_dict[eband]
data[:, numbands + i][mask] = 1.0
return data
[docs]class Inform_GPz_v1(CatInformer):
"""Inform stage for GPz_v1
Parameters
----------
Returns
-------
gpz_model: model
model file containing the trained GPz model to be used in estimate
stage
"""
name = "Inform_GPz_v1"
config_options = CatInformer.config_options.copy()
config_options.update(nondetect_val=SHARED_PARAMS,
mag_limits=SHARED_PARAMS,
trainfrac=Param(float, 0.75,
msg="fraction of training data used to make tree, rest used to set best sigma"),
seed=Param(int, 87, msg="random seed"),
bands=SHARED_PARAMS,
err_bands=SHARED_PARAMS,
redshift_col=SHARED_PARAMS,
gpz_method=Param(str, "VC", msg="method to be used in GPz, options are 'GL', 'VL', 'GD', 'VD', 'GC', and 'VC'"),
n_basis=Param(int, 50, msg="number of basis functions used"),
learn_jointly=Param(bool, True, msg="if True, jointly learns prior linear mean function"),
hetero_noise=Param(bool, True, msg="if True, learns heteroscedastic noise process, set False for point est."),
csl_method=Param(str, "normal", msg="cost sensitive learning type, 'balanced', 'normalized', or 'normal'"),
csl_binwidth=Param(float, 0.1, msg="width of bin for 'balanced' cost sensitive learning"),
pca_decorrelate=Param(bool, True, msg="if True, decorrelate data using PCA as preprocessing stage"),
max_iter=Param(int, 200, msg="max number of iterations"),
max_attempt=Param(int, 100, msg="max iterations if no progress on validation"),
log_errors=Param(bool, True, msg="if true, take log of magnitude errors")
)
def __init__(self, args, comm=None):
""" Constructor
Do CatInformer specific initialization"""
CatInformer.__init__(self, args, comm=comm)
self.zgrid = None
[docs] def run(self):
"""
train the GPz model after splitting train data into train/validation
"""
if self.config.hdf5_groupname:
training_data = self.get_data('input')[self.config.hdf5_groupname]
else: # pragma: no cover
training_data = self.get_data('input')
input_array = _prepare_data(training_data, self.config.bands, self.config.err_bands,
self.config.nondetect_val, self.config.mag_limits,
self.config.log_errors)
sz = np.expand_dims(training_data[self.config.redshift_col], -1)
# need permutation mask to define training vs validation
ngal = input_array.shape[0]
print(f"ngal: {ngal}")
ntrain = int(ngal * self.config.trainfrac)
randvec = np.random.permutation(ngal)
train_mask = np.zeros(ngal, dtype=bool)
val_mask = np.zeros(ngal, dtype=bool)
train_mask[randvec[:ntrain]] = True
val_mask[randvec[ntrain:]] = True
# get weights for cost sensitive learning
omega_weights = getOmega(sz, method=self.config.csl_method)
# initialize model
model = GP(self.config.n_basis,
method=self.config.gpz_method,
joint=self.config.learn_jointly,
heteroscedastic=self.config.hetero_noise,
decorrelate=self.config.pca_decorrelate,
seed=self.config.seed)
print("training model...")
model.train(input_array, sz, omega=omega_weights, training=train_mask,
validation=val_mask, maxIter=self.config.max_iter,
maxAttempts=self.config.max_attempt)
self.model = model
self.add_data('model', self.model)
[docs]class GPz_v1(CatEstimator):
"""GPz_v1 estimator
"""
name = "GPz_v1"
config_options = CatEstimator.config_options.copy()
config_options.update(zmin=SHARED_PARAMS,
zmax=SHARED_PARAMS,
nzbins=SHARED_PARAMS,
nondetect_val=SHARED_PARAMS,
mag_limits=SHARED_PARAMS,
bands=SHARED_PARAMS,
err_bands=SHARED_PARAMS,
ref_band=SHARED_PARAMS,
log_errors=Param(bool, True, msg="if true, take log of magnitude errors"))
def __init__(self, args, comm=None):
""" Constructor:
Do CatEstimator specific initialization """
CatEstimator.__init__(self, args, comm=comm)
self.zgrid = None
def _process_chunk(self, start, end, data, first):
print(f"Process {self.rank} estimating GPz PZ PDF for rows {start:,} - {end:,}")
test_array = _prepare_data(data, self.config.bands, self.config.err_bands,
self.config.nondetect_val, self.config.mag_limits,
self.config.log_errors)
mu, totalV, modelV, noiseV, _ = self.model.predict(test_array)
ens = qp.Ensemble(qp.stats.norm, data=dict(loc=mu, scale=totalV))
zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins)
zmode = ens.mode(grid=zgrid)
ens.set_ancil(dict(zmode=zmode))
self._do_chunk_output(ens, start, end, first)