orbit.estimators package

Submodules

orbit.estimators.base_estimator module

class orbit.estimators.base_estimator.BaseEstimator(seed=8888, verbose=True)

Bases: object

Base Estimator class for both Stan and Pyro Estimator

Parameters:
  • seed (int) – seed number for initial random values

  • verbose (bool) – If True (default), output all diagnostics messages from estimators

abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)
Parameters:
  • model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)

  • model_param_names (list) – list of strings of model parameters names to extract

  • data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)

  • fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object

  • init_values (float or np.array) – initial sampler value. If None, ‘random’ is used

Returns:

  • posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values

  • training_metrics (dict) – metrics and meta data related to the training process

orbit.estimators.pyro_estimator module

class orbit.estimators.pyro_estimator.PyroEstimator(num_steps=301, learning_rate=0.1, learning_rate_total_decay=1.0, message=100, **kwargs)

Bases: BaseEstimator

Abstract PyroEstimator with shared args for all PyroEstimator child classes

Parameters:
  • num_steps (int) – Number of estimator steps in optimization

  • learning_rate (float) – Estimator learning rate

  • learning_rate_total_decay (float) – A config re-parameterized from lrd in ClippedAdam. For example, 0.1 means a 90% reduction of the final step as of original learning rate where linear decay is implied along the steps. In the case of 1.0, no decay is applied. All steps will have the constant learning rate specified by learning_rate.

  • seed (int) – Seed int

  • message (int) – Print to console every message number of steps

  • kwargs – Additional BaseEstimator args

Notes

See http://docs.pyro.ai/en/stable/_modules/pyro/optim/clipped_adam.html for optimizer details

abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)
Parameters:
  • model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)

  • model_param_names (list) – list of strings of model parameters names to extract

  • data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)

  • fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object

  • init_values (float or np.array) – initial sampler value. If None, ‘random’ is used

Returns:

  • posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values

  • training_metrics (dict) – metrics and meta data related to the training process

class orbit.estimators.pyro_estimator.PyroEstimatorSVI(num_sample=100, num_particles=100, init_scale=0.1, **kwargs)

Bases: PyroEstimator

Pyro Estimator for VI Sampling

Parameters:
  • num_sample (int) – Number of samples ot draw for inference, default 100

  • num_particles (int) – Number of particles used in :class: ~pyro.infer.Trace_ELBO for SVI optimization

  • init_scale (float) – Parameter used in pyro.infer.autoguide; recommend a larger number of small dataset

  • kwargs – Additional PyroEstimator class args

fit(model_name, model_param_names, data_input, sampling_temperature, fitter=None, init_values=None)
Parameters:
  • model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)

  • model_param_names (list) – list of strings of model parameters names to extract

  • data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)

  • fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object

  • init_values (float or np.array) – initial sampler value. If None, ‘random’ is used

Returns:

  • posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values

  • training_metrics (dict) – metrics and meta data related to the training process

orbit.estimators.stan_estimator module

class orbit.estimators.stan_estimator.StanEstimator(num_warmup=900, num_sample=100, chains=4, cores=8, algorithm=None, suppress_stan_log=True, **kwargs)

Bases: BaseEstimator

Abstract StanEstimator with shared args for all StanEstimator child classes

Parameters:
  • num_warmup (int) – Number of samples to warm up and to be discarded, default 900

  • num_sample (int) – Number of samples to return, default 100

  • chains (int) – Number of chains in stan sampler, default 4

  • cores (int) – Number of cores for parallel processing, default max(cores, multiprocessing.cpu_count())

  • algorithm (str) – If None, default to Stan defaults

  • suppress_stan_log (bool) – If False, turn off cmdstanpy logger. Default as False.

  • kwargs – Additional BaseEstimator class args

abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)
Parameters:
  • model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)

  • model_param_names (list) – list of strings of model parameters names to extract

  • data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)

  • fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object

  • init_values (float or np.array) – initial sampler value. If None, ‘random’ is used

Returns:

  • posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values

  • training_metrics (dict) – metrics and meta data related to the training process

class orbit.estimators.stan_estimator.StanEstimatorMAP(stan_map_args=None, **kwargs)

Bases: StanEstimator

Stan Estimator for MAP Posteriors

fit(model_name, model_param_names, data_input, fitter=None, init_values=None)
Parameters:
  • model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)

  • model_param_names (list) – list of strings of model parameters names to extract

  • data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)

  • fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object

  • init_values (float or np.array) – initial sampler value. If None, ‘random’ is used

Returns:

  • posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values

  • training_metrics (dict) – metrics and meta data related to the training process

class orbit.estimators.stan_estimator.StanEstimatorMCMC(stan_mcmc_args=None, **kwargs)

Bases: StanEstimator

Stan Estimator for MCMC Sampling

Parameters:

stan_mcmc_args (dict) – Supplemental stan mcmc args to pass to CmdStandPy.sampling()

fit(model_name, model_param_names, sampling_temperature, data_input, fitter=None, init_values=None)
Parameters:
  • model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)

  • model_param_names (list) – list of strings of model parameters names to extract

  • data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)

  • fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object

  • init_values (float or np.array) – initial sampler value. If None, ‘random’ is used

Returns:

  • posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values

  • training_metrics (dict) – metrics and meta data related to the training process

Module contents