Source code for rail.core.algo_utils

"""Utility functions to test alogrithms"""
import os
from rail.core.stage import RailStage
from rail.core.utils import RAILDIR
from rail.core.data import TableHandle
import scipy.special
sci_ver_str = scipy.__version__.split('.')


traindata = os.path.join(RAILDIR, 'rail/examples_data/testdata/training_100gal.hdf5')
validdata = os.path.join(RAILDIR, 'rail/examples_data/testdata/validation_10gal.hdf5')
DS = RailStage.data_store
DS.__class__.allow_overwrite = True


[docs]def one_algo(key, single_trainer, single_estimator, train_kwargs, estim_kwargs): """ A basic test of running an estimator subclass Run inform, write temporary trained model to 'tempmodelfile.tmp', run photo-z algorithm. Then, load temp modelfile and re-run, return both datasets. """ DS.clear() training_data = DS.read_file('training_data', TableHandle, traindata) validation_data = DS.read_file('validation_data', TableHandle, validdata) if single_trainer is not None: train_pz = single_trainer.make_stage(**train_kwargs) train_pz.inform(training_data) pz = single_estimator.make_stage(name=key, **estim_kwargs) estim = pz.estimate(validation_data) pz_2 = None estim_2 = estim pz_3 = None estim_3 = estim copy_estim_kwargs = estim_kwargs.copy() model_file = copy_estim_kwargs.pop('model', 'None') if model_file != 'None': copy_estim_kwargs['model'] = model_file pz_2 = single_estimator.make_stage(name=f"{pz.name}_copy", **copy_estim_kwargs) estim_2 = pz_2.estimate(validation_data) if single_trainer is not None and 'model' in single_trainer.output_tags(): copy3_estim_kwargs = estim_kwargs.copy() copy3_estim_kwargs['model'] = train_pz.get_handle('model') pz_3 = single_estimator.make_stage(name=f"{pz.name}_copy3", **copy3_estim_kwargs) estim_3 = pz_3.estimate(validation_data) os.remove(pz.get_output(pz.get_aliased_tag('output'), final_name=True)) if pz_2 is not None: os.remove(pz_2.get_output(pz_2.get_aliased_tag('output'), final_name=True)) if pz_3 is not None: os.remove(pz_3.get_output(pz_3.get_aliased_tag('output'), final_name=True)) model_file = estim_kwargs.get('model', 'None') if model_file != 'None': try: os.remove(model_file) except FileNotFoundError: #pragma: no cover pass return estim.data, estim_2.data, estim_3.data