Model Diagnostics

In this section, we introduce to a few recommended diagnostic plots to diagnostic Orbit models. The posterior samples in SVI and Full Bayesian i.e. FullBayesianForecaster and SVIForecaster.

The plots are created by arviz for the plots. ArviZ is a Python package for exploratory analysis of Bayesian models, includes functions for posterior analysis, data storage, model checking, comparison and diagnostics.

  • Trace plot

  • Pair plot

  • Density plot

[1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
import seaborn as sns


%matplotlib inline

import orbit
from orbit.models import LGT, DLT
from orbit.utils.dataset import load_iclaims

import warnings
warnings.filterwarnings('ignore')

from orbit.diagnostics.plot import params_comparison_boxplot
from orbit.constants import palette
[2]:
print(orbit.__version__)
1.1.4.6

Load data

[3]:
df = load_iclaims()
df.dtypes
[3]:
week              datetime64[ns]
claims                   float64
trend.unemploy           float64
trend.filling            float64
trend.job                float64
sp500                    float64
vix                      float64
dtype: object
[4]:
df.head(5)
[4]:
week claims trend.unemploy trend.filling trend.job sp500 vix
0 2010-01-03 13.386595 0.219882 -0.318452 0.117500 -0.417633 0.122654
1 2010-01-10 13.624218 0.219882 -0.194838 0.168794 -0.425480 0.110445
2 2010-01-17 13.398741 0.236143 -0.292477 0.117500 -0.465229 0.532339
3 2010-01-24 13.137549 0.203353 -0.194838 0.106918 -0.481751 0.428645
4 2010-01-31 13.196760 0.134360 -0.242466 0.074483 -0.488929 0.487404

Fit a Model

[5]:
DATE_COL = 'week'
RESPONSE_COL = 'claims'
REGRESSOR_COL = ['trend.unemploy', 'trend.filling', 'trend.job']
[6]:
dlt = DLT(
    response_col=RESPONSE_COL,
    date_col=DATE_COL,
    regressor_col=REGRESSOR_COL,
    seasonality=52,
    num_warmup=2000,
    num_sample=2000,
    chains=4,
    stan_mcmc_args={'show_progress': False},
)
[7]:
dlt.fit(df=df)
2024-03-19 23:39:45 - orbit - INFO - Sampling (CmdStanPy) with chains: 4, cores: 8, temperature: 1.000, warmups (per chain): 500 and samples(per chain): 500.
[7]:
<orbit.forecaster.full_bayes.FullBayesianForecaster at 0x294cf72d0>

We can use .get_posterior_samples() to extract posteriors. Note that we need permute=False to retrieve additional information such as chains when we extract posterior samples for posteriors plotting. For regression, in order to collapse and relabel regression from parameters (usually called as beta), we use relabel=True.

[8]:
ps = dlt.get_posterior_samples(relabel=True, permute=False)
ps.keys()
[8]:
dict_keys(['l', 'b', 'lev_sm', 'slp_sm', 'obs_sigma', 'nu', 'lt_sum', 's', 'sea_sm', 'gt_sum', 'gb', 'gl', 'loglk', 'trend.unemploy', 'trend.filling', 'trend.job'])

Diagnostics Visualization

In the following section, we are going to use the regression coefficients as an example. In practice, you could check other parameters extracted from the model. For now, it only supports 1-D parameter which in generally capture the most important parameters of the model (e.g. obs_sigma, lev_sm etc.)

Convergence Status

Trace plots help us verify the convergence of model. In general, a largely overlapped distribution across samples from different chains indicates the convergence.

[9]:
az.style.use('arviz-darkgrid')
az.plot_trace(
    ps,
    var_names=['trend.unemploy', 'trend.filling', 'trend.job'],
    chain_prop={"color": ['r', 'b', 'g', 'y']},
    figsize=(10, 8),
);
../_images/tutorials_model_diagnostics_15_0.png

Note that this is only applicable for FullBayesianForecaster using sampling method such as MCMC.

Samples Density

We can also check the density of samples by pair plot.

[10]:
az.plot_pair(
    ps,
    var_names=['trend.unemploy', 'trend.filling', 'trend.job'],
    kind=["scatter", "kde"],
    marginals=True,
    point_estimate="median",
    textsize=18.5,
);
../_images/tutorials_model_diagnostics_19_0.png

Compare Models

You can also compare posteriors across different models with the same parameters. You can use plots such as density plot and forest plot to do so.

[11]:
dlt_smaller_prior = DLT(
    response_col=RESPONSE_COL,
    date_col=DATE_COL,
    regressor_col=REGRESSOR_COL,
    regressor_sigma_prior=[0.05, 0.05, 0.05],
    seasonality=52,
    num_warmup=2000,
    num_sample=2000,
    chains=4,
    stan_mcmc_args={'show_progress': False},
)
dlt_smaller_prior.fit(df=df)
ps_smaller_prior = dlt_smaller_prior.get_posterior_samples(relabel=True, permute=False)
2024-03-19 23:39:55 - orbit - INFO - Sampling (CmdStanPy) with chains: 4, cores: 8, temperature: 1.000, warmups (per chain): 500 and samples(per chain): 500.
[12]:
az.plot_density(
    [ps, ps_smaller_prior],
    var_names=['trend.unemploy', 'trend.filling', 'trend.job'],
    data_labels=["Default", "Smaller Sigma"],
    shade=0.1,
    textsize=18.5,
);
../_images/tutorials_model_diagnostics_22_0.png
[13]:
az.plot_forest(
    [ps, ps_smaller_prior],
    var_names=['trend.unemploy'],
    model_names=["Default", "Smaller Sigma"],
);
../_images/tutorials_model_diagnostics_23_0.png
[14]:
params_comparison_boxplot(
        [ps, ps_smaller_prior],
        var_names=['trend.unemploy', 'trend.filling', 'trend.job'],
        model_names=["Default", "Smaller Sigma"],
        box_width = .1, box_distance=0.1,
        showfliers=True
);
../_images/tutorials_model_diagnostics_24_0.png

Conclusion

Orbit models allow multiple visualization to diagnostics models and compare different models. We briefly introduce some basic syntax and usage of arviz. There is an example gallery built by the original team. Users can learn the details and more advance usage there. Meanwhile, the Orbit team aims to continue expand the scope to leverage more work done from the arviz project.