Source code for rail.estimation.algos.sklearn_nn

"""
Example code that implements a simple Neural Net predictor
for z_mode, and Gaussian centered at z_mode with base_width
read in fromfile and pdf width set to base_width*(1+zmode).
"""

import numpy as np
# from numpy import inf
from ceci.config import StageParameter as Param
from rail.estimation.estimator import CatEstimator, CatInformer
import qp
from rail.core.common_params import SHARED_PARAMS


def_filt = ['u', 'g', 'r', 'i', 'z', 'y']
def_bands = [f"mag_{band}_lsst" for band in def_filt]

[docs]def make_color_data(data_dict, bands, ref_band, nondet_val): """ make a dataset consisting of the i-band mag and the five colors Returns -------- input_data: `ndarray` array of imag and 5 colors """ input_data = data_dict[ref_band] # make colors and append to input data for i in range(len(bands)-1): # replace the non-detect 99s with 28.0 just arbitrarily for now band1 = data_dict[bands[i]] # band1err = data_dict[f'mag_err_{bands[i]}_lsst'] band2 = data_dict[bands[i+1]] # band2err = data_dict[f'mag_err_{bands[i+1]}_lsst'] # for j,xx in enumerate(band1): # if np.isclose(xx,99.,atol=.01): # band1[j] = band1err[j] # band1err[j] = 1.0 # for j,xx in enumerate(band2): # if np.isclose(xx,99.,atol=0.01): # band2[j] = band2err[j] # band2err[j] = 1.0 for band in [band1, band2]: if np.isnan(nondet_val): # pragma: no cover nondetmask = np.isnan(band) else: # pragma: no cover nondetmask = np.isclose(band, nondet_val) band[nondetmask] = 28.0 input_data = np.vstack((np.nan_to_num(input_data, 28), band1-band2)) return input_data.T
[docs]def regularize_data(data): """Utility function to prepare data for sklearn""" from sklearn.preprocessing import StandardScaler scaler = StandardScaler() scaler.fit(data) regularized_data = scaler.transform(data) return regularized_data
[docs]class Inform_SimpleNN(CatInformer): """ Subclass to train a simple point estimate Neural Net photoz rather than actually predict PDF, for now just predict point zb and then put an error of width*(1+zb). We'll do a "real" NN photo-z later. """ name = 'Inform_SimpleNN' config_options = CatInformer.config_options.copy() config_options.update(zmin=SHARED_PARAMS, zmax=SHARED_PARAMS, nzbins=SHARED_PARAMS, nondetect_val=SHARED_PARAMS, mag_limits=SHARED_PARAMS, bands=SHARED_PARAMS, ref_band=SHARED_PARAMS, redshift_col=SHARED_PARAMS, hdf5_groupname=SHARED_PARAMS, width=Param(float, 0.05, msg="The ad hoc base width of the PDFs"), max_iter=Param(int, 500, msg="max number of iterations while " "training the neural net. Too low a value will cause an " "error to be printed (though the code will still work, just" "not optimally)")) def __init__(self, args, comm=None): """ Constructor: Do CatInformer specific initialization """ CatInformer.__init__(self, args, comm=comm) if self.config.ref_band not in self.config.bands: raise ValueError("ref_band not present in bands list! ")
[docs] def run(self): """Train the NN model """ import sklearn.neural_network as sknn if self.config.hdf5_groupname: training_data = self.get_data('input')[self.config.hdf5_groupname] else: #pragma: no cover training_data = self.get_data('input') speczs = training_data[self.config.redshift_col] print("stacking some data...") color_data = make_color_data(training_data, self.config.bands, self.config.ref_band, self.config.nondetect_val) input_data = regularize_data(color_data) simplenn = sknn.MLPRegressor(hidden_layer_sizes=(12, 12), activation='tanh', solver='lbfgs', max_iter=self.config.max_iter) simplenn.fit(input_data, speczs) self.model = simplenn self.add_data('model', self.model)
[docs]class SimpleNN(CatEstimator): """ Subclass to implement a simple point estimate Neural Net photoz rather than actually predict PDF, for now just predict point zb and then put an error of width*(1+zb). We'll do a "real" NN photo-z later. """ name = 'SimpleNN' config_options = CatEstimator.config_options.copy() config_options.update(width=Param(float, 0.05, msg="The ad hoc base width of the PDFs"), ref_band=SHARED_PARAMS, nondetect_val=SHARED_PARAMS, bands=SHARED_PARAMS) def __init__(self, args, comm=None): """ Constructor: Do CatEstimator specific initialization """ CatEstimator.__init__(self, args, comm=comm) if self.config.ref_band not in self.config.bands: raise ValueError("ref_band is not in list of bands!") def _process_chunk(self, start, end, data, first): color_data = make_color_data(data, self.config.bands, self.config.ref_band, self.config.nondetect_val) input_data = regularize_data(color_data) zmode = np.round(self.model.predict(input_data), 3) widths = self.config.width * (1.0+zmode) qp_dstn = qp.Ensemble(qp.stats.norm, data=dict(loc=np.expand_dims(zmode, -1), #pylint: disable=no-member scale=np.expand_dims(widths, -1))) qp_dstn.set_ancil(dict(zmode=zmode)) self._do_chunk_output(qp_dstn, start, end, first)