Custom priors in LightweigthMMM

LightweightMMM allows you to pass your own prior to any of the parameters of the model. This notebook explains how this can be done.

The media_prior has a dedicated parameter and is not part of the custom priors as it is a required argument and not an optional one. In this notebook we focus on the optional custom priors for the rest of the parameters in the model.

# Please note that the values given here are just to demonstrate the usage of
# the API provided and they are by no means intended to ressemble good prior`
# values.

Initial setup.

Refer to end to end examples for general information about workflow the model

# Import jax.numpy and any other library we might need.
import jax.numpy as jnp
import numpyro
from lightweight_mmm import lightweight_mmm
from lightweight_mmm import preprocessing
from lightweight_mmm import utils
data_size = 104
media_data, extra_features, target, costs = utils.simulate_dummy_data(
    data_size=data_size + 13,
    n_media_channels=3,
    n_extra_features=1)
# Split and scale data.
split_point = data_size - 13
# Media data
media_data_train = media_data[:split_point, ...]
media_data_test = media_data[split_point:, ...]
# Extra features
extra_features_train = extra_features[:split_point, ...]
extra_features_test = extra_features[split_point:, ...]
# Target
target_train = target[:split_point]
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

media_data_train = media_scaler.fit_transform(media_data_train)
extra_features_train = extra_features_scaler.fit_transform(extra_features_train)
target_train = target_scaler.fit_transform(target_train)
costs = cost_scaler.fit_transform(costs)
mmm = lightweight_mmm.LightweightMMM(model_name="carryover")
# Just a small number as in the following examples we are just seeing
# how different parameters work in the `fit` method.
number_warmup = 10
number_samples = 10

What parameters does the model have? And what are their names?

Although we go over all parameters here, for full detail on the model please refer to the model documentation. The simplified model formulation is the following:

\[ kpi_{t} = \alpha + trend_{t} + seasonality_{t} + media\_channels_{t} + other\_factors_{t} \]

Intercept:

  • \(\alpha \sim HalfNormal(2)\)

  • Prior name: “intercept”

  • Default prior: numpyro.distributions.HalfNormal(scale=2)

  • Final shape:

    • National: ()

    • Geo: (g,) where g is the number of geos

Trend:

  • \(trend_{t} = \mu t^{\kappa}\)

  • Where \(t\) is a linear trend input

  • \(\mu \sim Normal(0,1)\)

    • Prior name: “coef_trend”

    • Default prior: numpyro.distributions.Normal(loc=0., scale=1.)

    • Final shape:

      • National: ()

      • Geo: (g,) where g is the number of geos

  • \(\kappa \sim Uniform(0.5,1.5)\)

    • Prior name: “expo_trend”

    • Default prior: numpyro.distributions.Uniform(low=0.5, high=1.5)

    • Final shape:

      • National: ()

      • Geo: ()

Seasonality:

Seasonality (for models using* weekly observations)

  • \(seasonality_{t} = \displaystyle\sum_{d=1}^{2} (\gamma_{1,d} cos(\frac{2 \pi d}{52}) + \gamma_{2,d} sin(\frac{2 \pi d}{52}))\)

  • \(\gamma_{1,d}, \gamma_{2,d} \sim Normal(0,1)\)

  • Prior name: “gamma_seasonality”

  • Default prior: numpyro.distributions.Normal(loc=0., scale=1.)

  • Final shape:

    • National: (2, d) where d is the number of degrees of seasonality

    • Geo: (2, d) where d is the number of degrees of seasonality

Seasonality (for models using* daily observations)

  • \(seasonality_{t} = \displaystyle\sum_{d=1}^{2} (\gamma_{1,d} cos(\frac{2 \pi d}{365}) + \gamma_{2,d} sin(\frac{2 \pi d}{365})) + \displaystyle\sum_{i=1}^{7} \delta_{i}\)

  • \(\gamma_{1,d}, \gamma_{2,d} \sim Normal(0,1)\)

    • Prior name: “gamma_seasonality”

    • Default prior: numpyro.distributions.Normal(loc=0., scale=1.)

    • Final shape:

      • National: (2, d) where d is the number of degrees of seasonality

      • Geo: (2, d) where d is the number of degrees of seasonality

  • \(\delta_{i} \sim Normal(0,0.5)\)

    • Prior name: “weekday”

    • Default prior: numpyro.distributions.Normal(loc=0., scale=0.5)

    • Final shape:

      • National: (7,)

      • Geo: (7,)

Other Factors (extra features)

  • \(other\_factors_{t} = \displaystyle\sum_{i=1}^{N} \lambda_{i}Z_{i}\)

  • \(\lambda_{i} \sim Normal(0,1)\)

  • Where \(Z_{i}\) are other factors and \(N\) is the number of other factors.

    • Prior name: “coef_extra_features”

    • Default prior: numpyro.distributions.Normal(loc=0., scale=1.)

    • Final shape:

      • National: (f,) where f is the number of extra features

      • Geo: (f,) where f is the number of extra features

Geo model only priors:

The geo model (hierarchical model) has the following other parameters that the national one does not have:

\( kpi = ... + \ \tau \ seasonality_{t} \ + ... \)

  • \(\tau \sim HalfNormal(0.5)\)

  • Prior name: “coef_seasonality”

  • Default prior: numpyro.distributions.HalfNormal(scale=.5)

  • Final shape:

    • Geo: (g,) where g is the number of geos

Other priors

The target is

  • \(target \sim N(\mu, \sigma) \)

  • \(\sigma \sim Gamma(1, 1)\)

  • Prior name: “sigma”

  • Default prior: numpyro.distributions.Gamma(concentration=1., rate=1.)

  • Final shape:

    • National: ()

    • Geo: (g,) where g is the number of geos

Media transformation priors:

Saturation

Hill:

  • \(media\ channels_{t} = \frac{1}{1+(x_{t,m}^{*} / K_{m})^{-S_{m}}}\)

  • \(K_{m} \sim Gamma(1,1)\)

    • It should be stricly positive

    • Prior name: “half_max_effective_concentration”

    • Default prior: numpyro.distributions.Gamma(concentration=1., rate=1.)

    • Final shape:

      • National: (c,) where c is the number of media channels

      • Geo: (c,) where c is the number of media channels

  • \(S_{m} \sim Gamma(1,1)\)

    • Prior name: “slope”

    • Default prior: numpyro.distributions.Gamma(concentration=1., rate=1.)

    • Final shape:

      • National: (c,) where c is the number of media channels

      • Geo: (c,) where c is the number of media channels

Exponent:

  • \(media\_channels_{t} = x_{t,m,s}^{*\rho_{m}}\)

  • \(\rho_{m} \sim Beta(9,1)\)

  • Prior name: “exponent”

  • Default prior: numpyro.distributions.Beta(concentration1=9., concentration0=1.)

  • Final shape:

    • National: (c,) where c is the number of media channels

    • Geo: (c,) where c is the number of media channels

Lagging

Adstock

  • \(media\ channels_{t} = x_{t,m} + \lambda_{m} x_{t-1,m}^{*}\) where \(t=2,..,N\)

  • \(\lambda_{m} \sim Beta(2,1)\)

  • Prior name: “lag_weight”

  • Default prior: numpyro.distributions.Beta(concentration1=2., concentration0=1.)

  • Final shape:

    • National: (c,) where c is the number of media channels

    • Geo: (c,) where c is the number of media channels

Carryover

  • \(media\ channels_{t} = \frac{\displaystyle\sum_{l=0}^{L} \tau_{m}^{(l-\theta_{m})^2}x_{t-l,m}}{\displaystyle\sum_{l=0}^{L}\tau_{m}^{(l-\theta_{m})^2}}\)

  • where \(L=13\) for weekly data and \(L=13*7\) for daily data

  • \(\tau_{m} \sim Beta(1,1)\)

    • Prior name: “ad_effect_retention_rate”

    • Default prior: numpyro.distributions.Beta(concentration1=1., concentration0=1.)

    • Final shape:

      • National: (c,) where c is the number of media channels

      • Geo: (c,) where c is the number of media channels

  • \(\theta_{m} \sim HalfNormal(2)\)

    • Prior name: “peak_effect_delay”

    • Default prior: numpyro.distributions.HalfNormal(scale=2.)

    • Final shape:

      • National: (c,) where c is the number of media channels

      • Geo: (c,) where c is the number of media channels

What does the API to change them look like?

Media priors are always required to be given and they have a dedicated parameter for it in the fit method. Below we focus on the rest of the parameters of the model which are optional and are given through a dictionary with the prior name and its respective value/distribution.

There are two main ways of passing your own priors to the parameter’s models.

  • Passing a new distribution object.

  • Passing the values of the constructor for the default prior.

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"intercept": numpyro.distributions.HalfNormal(5)})
# Also know that even if you pass the same prior as our default one you can 
# always do it from scratch and provide the whole object.
mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"intercept": numpyro.distributions.HalfNormal(scale=2.)})

You can also find the full list of available distributions in the Numpyro documentation

Our default “intercept” prior is a numpyro.distribution.HalfNormal(scale=2.). See numpyro distribution here.

Since the HalfNormal distribution has one positional argument (not positional only) we have the following ways you can give your desired values:

As keyword arguments:

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"intercept": {"scale": 4.}})

As positional arguments:

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"intercept": (4.,)})
# Another way of giving the first parameter (in case it has multiple  ones).
mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"intercept": 2.})

When Numpyro distribution does not have default values

For example the Beta distribution does not have default parameters so we will have to pass both parameters for altering our default priors.

The default prior distribution of the exponent is a Beta distribution, let’s see how it would work:

# The following will fail:
mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"exponent": 0.65})

We need to specify all arguments either positionaly or by keyword:

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"exponent": (0.5, 1.5)})
mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"exponent": {"concentration1": 0.7, "concentration0": 1.7}})

Final shape of priors

Prior values are in many cases arrays and not single values. We provide information about the target shape of each parameter so you can tailor your prior to that shape if you desired.

Providing a single value will just broadcast the given value to the target shape.

For example our “weekday” prior has a target shape of (7,) let’s see what the options are for passing a custom prior:

# The simple case is where we provide single values as we have seen before.
mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"weekday": numpyro.distributions.Normal(loc=0., scale=.5)})
# But we can provide a shape that is broadcastable to (7,)
weekday_prior = numpyro.distributions.Normal(loc=jnp.arange(7), scale=.5)

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"weekday": weekday_prior})
# For one or all the parameters:
weekday_prior = numpyro.distributions.Normal(
    loc=jnp.arange(7), scale=jnp.array([2, 2, 2, 2, 2, 4, 4]))

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"weekday": weekday_prior})
# And also for just passing the values without the distribution object.

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"weekday": {"loc": jnp.arange(7), "scale": jnp.arange(7)}})
# And also for just passing the values as positional args.

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    custom_priors={"weekday": (jnp.arange(7), jnp.arange(7))})

The only exception will be when you only pass one value which is a Sequence as a positional argument and it is not encapsulated within another Sequence.

# This does not work
mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    # This does not work since we tried to pass the different values of the 
    # sequence to the prior distribution constructor which in this case only 
    # has 2 positional arguments.
    custom_priors={"weekday": jnp.arange(7)})
mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    # This works since the prior is within the tuple
    custom_priors={"weekday": (jnp.arange(7),)})

Final notes

To avoid silent errors or unexpected behaviour we also raise exceptions (in most casdes) when user passes a prior for a key that does not exist in the model.