Source code for pybambi.bambi

"""Driving routine for pyBAMBI.

Author: Will Handley (wh260@cam.ac.uk)
Date: November 2018
"""
import os
from pybambi.manager import BambiManager


[docs]def run_pyBAMBI(loglikelihood, prior, nDims, **kwargs): """Run pyBAMBI. Parameters ---------- nested_sampler: str Choice of nested sampler. Options: `['multinest', 'polychord']`. Default `'polychord'`. nlive: int Number of live points. Default `nDims*25` root: str root of filename. Default `'chains/<nested_sampler>'` num_repeats: int number of repeats for polychord. Default `nDims*5` eff: float efficiency for multinest. Default `0.5**nDims` learner: object information indicating what learning algorithm to use for approximating the likelihood. Can be the string `'keras'`, or a `keras.models.Model` Default `'keras'` ntrain: int Number of training points to use Default `nlive` proxy_tolerance: float Required accuracy of proxy. Default `0.01` ns_output: int Nested sampling output level. """ # Process kwargs nested_sampler = kwargs.pop('nested_sampler', 'polychord') nlive = kwargs.pop('nlive', nDims*25) root = kwargs.pop('root', os.path.join('chains', nested_sampler)) num_repeats = kwargs.pop('num_repeats', nDims*5) eff = kwargs.pop('eff', 0.5**nDims) learner = kwargs.pop('learner', 'keras') proxy_tolerance = kwargs.pop('proxy_tolerance', 0.1) failure_tolerance = kwargs.pop('failure_tolerance', 0.5) ntrain = kwargs.pop('ntrain', nlive) seed = kwargs.pop('seed', -1) if kwargs: raise TypeError('Unexpected **kwargs: %r' % kwargs) # Set up the global manager of the BAMBI session. thumper = BambiManager(loglikelihood, learner, proxy_tolerance, failure_tolerance, ntrain) # Choose and run sampler if nested_sampler == 'polychord': from pybambi.polychord import run_polychord run_polychord(thumper.loglikelihood, prior, thumper.dumper, nDims, nlive, root, ntrain//2, num_repeats, seed) elif nested_sampler == 'multinest': from pybambi.multinest import run_multinest run_multinest(thumper.loglikelihood, prior, thumper.dumper, nDims, nlive, root, ntrain//2, eff, seed) else: raise NotImplementedError('nested sampler %s is not implemented' % nested_sampler)