orbit.estimators package¶
Submodules¶
orbit.estimators.base_estimator module¶
- class orbit.estimators.base_estimator.BaseEstimator(seed=8888, verbose=True)¶
Bases:
object
Base Estimator class for both Stan and Pyro Estimator
- Parameters:
seed (int) – seed number for initial random values
verbose (bool) – If True (default), output all diagnostics messages from estimators
- abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters:
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns:
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
orbit.estimators.pyro_estimator module¶
- class orbit.estimators.pyro_estimator.PyroEstimator(num_steps=301, learning_rate=0.1, learning_rate_total_decay=1.0, message=100, **kwargs)¶
Bases:
BaseEstimator
Abstract PyroEstimator with shared args for all PyroEstimator child classes
- Parameters:
num_steps (int) – Number of estimator steps in optimization
learning_rate (float) – Estimator learning rate
learning_rate_total_decay (float) – A config re-parameterized from
lrd
inClippedAdam
. For example, 0.1 means a 90% reduction of the final step as of original learning rate where linear decay is implied along the steps. In the case of 1.0, no decay is applied. All steps will have the constant learning rate specified by learning_rate.seed (int) – Seed int
message (int) – Print to console every message number of steps
kwargs – Additional BaseEstimator args
Notes
See http://docs.pyro.ai/en/stable/_modules/pyro/optim/clipped_adam.html for optimizer details
- abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters:
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns:
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
- class orbit.estimators.pyro_estimator.PyroEstimatorSVI(num_sample=100, num_particles=100, init_scale=0.1, **kwargs)¶
Bases:
PyroEstimator
Pyro Estimator for VI Sampling
- Parameters:
num_sample (int) – Number of samples ot draw for inference, default 100
num_particles (int) – Number of particles used in :class: ~pyro.infer.Trace_ELBO for SVI optimization
init_scale (float) – Parameter used in pyro.infer.autoguide; recommend a larger number of small dataset
kwargs – Additional PyroEstimator class args
- fit(model_name, model_param_names, data_input, sampling_temperature, fitter=None, init_values=None)¶
- Parameters:
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns:
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
orbit.estimators.stan_estimator module¶
- class orbit.estimators.stan_estimator.StanEstimator(num_warmup=900, num_sample=100, chains=4, cores=8, algorithm=None, suppress_stan_log=True, **kwargs)¶
Bases:
BaseEstimator
Abstract StanEstimator with shared args for all StanEstimator child classes
- Parameters:
num_warmup (int) – Number of samples to warm up and to be discarded, default 900
num_sample (int) – Number of samples to return, default 100
chains (int) – Number of chains in stan sampler, default 4
cores (int) – Number of cores for parallel processing, default max(cores, multiprocessing.cpu_count())
algorithm (str) – If None, default to Stan defaults
suppress_stan_log (bool) – If False, turn off cmdstanpy logger. Default as False.
kwargs – Additional BaseEstimator class args
- abstract fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters:
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns:
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
- class orbit.estimators.stan_estimator.StanEstimatorMAP(stan_map_args=None, **kwargs)¶
Bases:
StanEstimator
Stan Estimator for MAP Posteriors
- fit(model_name, model_param_names, data_input, fitter=None, init_values=None)¶
- Parameters:
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns:
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process
- class orbit.estimators.stan_estimator.StanEstimatorMCMC(stan_mcmc_args=None, **kwargs)¶
Bases:
StanEstimator
Stan Estimator for MCMC Sampling
- Parameters:
stan_mcmc_args (dict) – Supplemental stan mcmc args to pass to CmdStandPy.sampling()
- fit(model_name, model_param_names, sampling_temperature, data_input, fitter=None, init_values=None)¶
- Parameters:
model_name (str) – name of model - used in mapping the right sampling file (stan/pyro/…)
model_param_names (list) – list of strings of model parameters names to extract
data_input (dict) – key-value pairs of data input as required by definition in samplers (stan/pyro/…)
fitter – model object used for fitting; this will be used instead of model_name if supplied to search for model object
init_values (float or np.array) – initial sampler value. If None, ‘random’ is used
- Returns:
posteriors (dict) – key value pairs where key is the model parameter name and value is num_sample x posterior values
training_metrics (dict) – metrics and meta data related to the training process