Source code for pybambi.manager

"""BAMBI management object.

Author: Pat Scott (p.scott@imperial.ac.uk)
Date: Feb 2019
"""

import numpy as np

from pybambi.neuralnetworks.kerasnet import KerasNetInterpolation
from pybambi.neuralnetworks.nearestneighbour \
        import NearestNeighbourInterpolation
import keras.models


[docs]class BambiManager(object): """Does all the talking for BAMBI. Takes a new set of training data from the dumper and trains (or retrains) a neural net, and assesses whether or not it can be used for a given parameter combination. Parameters ---------- ntrain: int Number of training points to use """ def __init__(self, loglikelihood, learner, proxy_tolerance, failure_tolerance, ntrain): """Construct bambi object.""" self.proxy_tolerance = proxy_tolerance self._loglikelihood = loglikelihood self._learner = learner self._proxy_tolerance = proxy_tolerance self._failure_tolerance = failure_tolerance self._ntrain = ntrain self._proxy_trained = False self.old_learners = []
[docs] def make_learner(self, params, loglikes): """Construct a Predictor.""" if self._learner == 'keras': return KerasNetInterpolation(params, loglikes) elif self._learner == 'nearestneighbour': return NearestNeighbourInterpolation(params, loglikes) elif issubclass(type(self._learner), keras.models.Model): return KerasNetInterpolation(params, loglikes, model=self._learner) else: raise NotImplementedError('learner %s is not implemented.' % self._learner)
[docs] def dumper(self, live_params, live_loglks, dead_params, dead_loglks): """Respond to signal from nested sampler.""" if not self._proxy_trained: params = np.concatenate((live_params, dead_params)) loglikes = np.concatenate((live_loglks, dead_loglks)) self.train_new_learner(params[:self._ntrain, :], loglikes[:self._ntrain]) if self._proxy_trained: print("Using trained proxy") else: print("Unable to use proxy")
[docs] def loglikelihood(self, params): """Bambi Proxy wrapper for original loglikelihood.""" # Short circuit to the full likelihood if proxy not yet fully trained if not self._proxy_trained: return self._loglikelihood(params) # Call the learner candidate_loglikelihood = self._current_learner(params) # If the learner can be trusted, use its estimate, # otherwise use the original like and update the failure status if self._current_learner.valid(candidate_loglikelihood): return candidate_loglikelihood else: self._rolling_failure_fraction = (1.0 + (self._ntrain - 1.0) * self._rolling_failure_fraction ) / self._ntrain if self._rolling_failure_fraction > self._failure_tolerance: self._proxy_trained = False return self._loglikelihood(params)
[docs] def train_new_learner(self, params, loglikes): """Train a new Predictor.""" try: self.old_learners.append(self._current_learner) except AttributeError: pass self._current_learner = self.make_learner(params, loglikes) sigma = self._current_learner.uncertainty() print("Current uncertainty in network log-likelihood predictions: %s" % sigma) if sigma < self._proxy_tolerance: self._proxy_trained = True self._rolling_failure_fraction = 0.0