orbit.estimators package¶
Submodules¶
orbit.estimators.base_estimator module¶
- class orbit.estimators.base_estimator.BaseEstimator(seed=8888, verbose=True)¶
Bases:
object
Base Estimator class for both Stan and Pyro Estimator
- Parameters
seed (int) – seed number for initial random values
verbose (bool) – If True (default), output all diagnostics messages from estimators
- abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
orbit.estimators.pyro_estimator module¶
- class orbit.estimators.pyro_estimator.PyroEstimator(num_steps=301, learning_rate=0.1, learning_rate_total_decay=1.0, message=100, **kwargs)¶
Bases:
orbit.estimators.base_estimator.BaseEstimator
Abstract PyroEstimator with shared args for all PyroEstimator child classes
- Parameters
num_steps (int) – Number of estimator steps in optimization
learning_rate (float) – Estimator learning rate
learning_rate_total_decay (float) – A config re-parameterized from
lrd
inClippedAdam
. For example, 0.1 means a 90% reduction of the final step as of original learning rate where linear decay is implied along the steps. In the case of 1.0, no decay is applied. All steps will have the constant learning rate specified by learning_rate.seed (int) – Seed int
message (int) – Print to console every message number of steps
kwargs – Additional BaseEstimator args
Notes
See http://docs.pyro.ai/en/stable/_modules/pyro/optim/clipped_adam.html for optimizer details
- abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
- class orbit.estimators.pyro_estimator.PyroEstimatorSVI(num_sample=100, num_particles=100, init_scale=0.1, **kwargs)¶
Bases:
orbit.estimators.pyro_estimator.PyroEstimator
Pyro Estimator for VI Sampling
- Parameters
num_sample (int) – Number of samples ot draw for inference, default 100
num_particles (int) – Number of particles used in :class: ~pyro.infer.Trace_ELBO for SVI optimization
init_scale (float) – Parameter used in pyro.infer.autoguide; recommend a larger number of small dataset
kwargs – Additional PyroEstimator class args
- fit(model_name, model_param_names, data_input, sampling_temperature, fitter=None, init_values=None)¶
- Parameters
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
orbit.estimators.stan_estimator module¶
- class orbit.estimators.stan_estimator.StanEstimator(num_warmup=900, num_sample=100, chains=4, cores=8, algorithm=None, suppress_stan_log=True, **kwargs)¶
Bases:
orbit.estimators.base_estimator.BaseEstimator
Abstract StanEstimator with shared args for all StanEstimator child classes
- Parameters
num_warmup (int) – Number of samples to warm up and to be discarded, default 900
num_sample (int) – Number of samples to return, default 100
chains (int) – Number of chains in stan sampler, default 4
cores (int) – Number of cores for parallel processing, default max(cores, multiprocessing.cpu_count())
algorithm (str) – If None, default to Stan defaults
suppress_stan_log (bool) – If False, turn off cmdstanpy logger. Default as False.
kwargs – Additional BaseEstimator class args
- abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
- class orbit.estimators.stan_estimator.StanEstimatorMAP(stan_map_args=None, **kwargs)¶
Bases:
orbit.estimators.stan_estimator.StanEstimator
Stan Estimator for MAP Posteriors
- Parameters
stan_map_args (dict) – Supplemental stan vi args to pass to PyStan.optimizing()
- fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
- class orbit.estimators.stan_estimator.StanEstimatorMCMC(stan_mcmc_args=None, **kwargs)¶
Bases:
orbit.estimators.stan_estimator.StanEstimator
Stan Estimator for MCMC Sampling
- Parameters
stan_mcmc_args (dict) – Supplemental stan mcmc args to pass to CmdStandPy.sampling()
- fit(model_name, model_param_names, sampling_temperature, data_input, fitter=None, init_values=None)¶
- Parameters
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process