Source code for lightweight_mmm.models
# Copyright 2023 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing the different models available in the lightweightMMM lib.
Currently this file contains a main model with three possible options for
processing the media data. Which essentially grants the possibility of building
three different models.
- Adstock
- Hill-Adstock
- Carryover
"""
import sys
# pylint: disable=g-import-not-at-top
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
from typing import Any, Dict, Mapping, MutableMapping, Optional, Sequence, Union
import immutabledict
import jax.numpy as jnp
import numpyro
from numpyro import distributions as dist
from lightweight_mmm import media_transforms
Prior = Union[
dist.Distribution,
Dict[str, float],
Sequence[float],
float
]
class TransformFunction(Protocol):
def __call__(
self,
media_data: jnp.ndarray,
custom_priors: MutableMapping[str, Prior],
**kwargs: Any) -> jnp.ndarray:
...
_INTERCEPT = "intercept"
_COEF_TREND = "coef_trend"
_EXPO_TREND = "expo_trend"
_SIGMA = "sigma"
_GAMMA_SEASONALITY = "gamma_seasonality"
_WEEKDAY = "weekday"
_COEF_EXTRA_FEATURES = "coef_extra_features"
_COEF_SEASONALITY = "coef_seasonality"
MODEL_PRIORS_NAMES = frozenset((
_INTERCEPT,
_COEF_TREND,
_EXPO_TREND,
_SIGMA,
_GAMMA_SEASONALITY,
_WEEKDAY,
_COEF_EXTRA_FEATURES,
_COEF_SEASONALITY))
_EXPONENT = "exponent"
_LAG_WEIGHT = "lag_weight"
_HALF_MAX_EFFECTIVE_CONCENTRATION = "half_max_effective_concentration"
_SLOPE = "slope"
_AD_EFFECT_RETENTION_RATE = "ad_effect_retention_rate"
_PEAK_EFFECT_DELAY = "peak_effect_delay"
TRANSFORM_PRIORS_NAMES = immutabledict.immutabledict({
"carryover":
frozenset((_AD_EFFECT_RETENTION_RATE, _PEAK_EFFECT_DELAY, _EXPONENT)),
"adstock":
frozenset((_EXPONENT, _LAG_WEIGHT)),
"hill_adstock":
frozenset((_LAG_WEIGHT, _HALF_MAX_EFFECTIVE_CONCENTRATION, _SLOPE))
})
GEO_ONLY_PRIORS = frozenset((_COEF_SEASONALITY,))
def _get_default_priors() -> Mapping[str, Prior]:
# Since JAX cannot be called before absl.app.run in tests we get default
# priors from a function.
return immutabledict.immutabledict({
_INTERCEPT: dist.HalfNormal(scale=2.),
_COEF_TREND: dist.Normal(loc=0., scale=1.),
_EXPO_TREND: dist.Uniform(low=0.5, high=1.5),
_SIGMA: dist.Gamma(concentration=1., rate=1.),
_GAMMA_SEASONALITY: dist.Normal(loc=0., scale=1.),
_WEEKDAY: dist.Normal(loc=0., scale=.5),
_COEF_EXTRA_FEATURES: dist.Normal(loc=0., scale=1.),
_COEF_SEASONALITY: dist.HalfNormal(scale=.5)
})
def _get_transform_default_priors() -> Mapping[str, Prior]:
# Since JAX cannot be called before absl.app.run in tests we get default
# priors from a function.
return immutabledict.immutabledict({
"carryover":
immutabledict.immutabledict({
_AD_EFFECT_RETENTION_RATE:
dist.Beta(concentration1=1., concentration0=1.),
_PEAK_EFFECT_DELAY:
dist.HalfNormal(scale=2.),
_EXPONENT:
dist.Beta(concentration1=9., concentration0=1.)
}),
"adstock":
immutabledict.immutabledict({
_EXPONENT: dist.Beta(concentration1=9., concentration0=1.),
_LAG_WEIGHT: dist.Beta(concentration1=2., concentration0=1.)
}),
"hill_adstock":
immutabledict.immutabledict({
_LAG_WEIGHT:
dist.Beta(concentration1=2., concentration0=1.),
_HALF_MAX_EFFECTIVE_CONCENTRATION:
dist.Gamma(concentration=1., rate=1.),
_SLOPE:
dist.Gamma(concentration=1., rate=1.)
})
})
[docs]def transform_adstock(media_data: jnp.ndarray,
custom_priors: MutableMapping[str, Prior],
normalise: bool = True) -> jnp.ndarray:
"""Transforms the input data with the adstock function and exponent.
Args:
media_data: Media data to be transformed. It is expected to have 2 dims for
national models and 3 for geo models.
custom_priors: The custom priors we want the model to take instead of the
default ones. The possible names of parameters for adstock and exponent
are "lag_weight" and "exponent".
normalise: Whether to normalise the output values.
Returns:
The transformed media data.
"""
transform_default_priors = _get_transform_default_priors()["adstock"]
with numpyro.plate(name=f"{_LAG_WEIGHT}_plate",
size=media_data.shape[1]):
lag_weight = numpyro.sample(
name=_LAG_WEIGHT,
fn=custom_priors.get(_LAG_WEIGHT,
transform_default_priors[_LAG_WEIGHT]))
with numpyro.plate(name=f"{_EXPONENT}_plate",
size=media_data.shape[1]):
exponent = numpyro.sample(
name=_EXPONENT,
fn=custom_priors.get(_EXPONENT,
transform_default_priors[_EXPONENT]))
if media_data.ndim == 3:
lag_weight = jnp.expand_dims(lag_weight, axis=-1)
exponent = jnp.expand_dims(exponent, axis=-1)
adstock = media_transforms.adstock(
data=media_data, lag_weight=lag_weight, normalise=normalise)
return media_transforms.apply_exponent_safe(data=adstock, exponent=exponent)
[docs]def transform_hill_adstock(media_data: jnp.ndarray,
custom_priors: MutableMapping[str, Prior],
normalise: bool = True) -> jnp.ndarray:
"""Transforms the input data with the adstock and hill functions.
Args:
media_data: Media data to be transformed. It is expected to have 2 dims for
national models and 3 for geo models.
custom_priors: The custom priors we want the model to take instead of the
default ones. The possible names of parameters for hill_adstock and
exponent are "lag_weight", "half_max_effective_concentration" and "slope".
normalise: Whether to normalise the output values.
Returns:
The transformed media data.
"""
transform_default_priors = _get_transform_default_priors()["hill_adstock"]
with numpyro.plate(name=f"{_LAG_WEIGHT}_plate",
size=media_data.shape[1]):
lag_weight = numpyro.sample(
name=_LAG_WEIGHT,
fn=custom_priors.get(_LAG_WEIGHT,
transform_default_priors[_LAG_WEIGHT]))
with numpyro.plate(name=f"{_HALF_MAX_EFFECTIVE_CONCENTRATION}_plate",
size=media_data.shape[1]):
half_max_effective_concentration = numpyro.sample(
name=_HALF_MAX_EFFECTIVE_CONCENTRATION,
fn=custom_priors.get(
_HALF_MAX_EFFECTIVE_CONCENTRATION,
transform_default_priors[_HALF_MAX_EFFECTIVE_CONCENTRATION]))
with numpyro.plate(name=f"{_SLOPE}_plate",
size=media_data.shape[1]):
slope = numpyro.sample(
name=_SLOPE,
fn=custom_priors.get(_SLOPE, transform_default_priors[_SLOPE]))
if media_data.ndim == 3:
lag_weight = jnp.expand_dims(lag_weight, axis=-1)
half_max_effective_concentration = jnp.expand_dims(
half_max_effective_concentration, axis=-1)
slope = jnp.expand_dims(slope, axis=-1)
return media_transforms.hill(
data=media_transforms.adstock(
data=media_data, lag_weight=lag_weight, normalise=normalise),
half_max_effective_concentration=half_max_effective_concentration,
slope=slope)
[docs]def transform_carryover(media_data: jnp.ndarray,
custom_priors: MutableMapping[str, Prior],
number_lags: int = 13) -> jnp.ndarray:
"""Transforms the input data with the carryover function and exponent.
Args:
media_data: Media data to be transformed. It is expected to have 2 dims for
national models and 3 for geo models.
custom_priors: The custom priors we want the model to take instead of the
default ones. The possible names of parameters for carryover and exponent
are "ad_effect_retention_rate_plate", "peak_effect_delay_plate" and
"exponent".
number_lags: Number of lags for the carryover function.
Returns:
The transformed media data.
"""
transform_default_priors = _get_transform_default_priors()["carryover"]
with numpyro.plate(name=f"{_AD_EFFECT_RETENTION_RATE}_plate",
size=media_data.shape[1]):
ad_effect_retention_rate = numpyro.sample(
name=_AD_EFFECT_RETENTION_RATE,
fn=custom_priors.get(
_AD_EFFECT_RETENTION_RATE,
transform_default_priors[_AD_EFFECT_RETENTION_RATE]))
with numpyro.plate(name=f"{_PEAK_EFFECT_DELAY}_plate",
size=media_data.shape[1]):
peak_effect_delay = numpyro.sample(
name=_PEAK_EFFECT_DELAY,
fn=custom_priors.get(
_PEAK_EFFECT_DELAY, transform_default_priors[_PEAK_EFFECT_DELAY]))
with numpyro.plate(name=f"{_EXPONENT}_plate",
size=media_data.shape[1]):
exponent = numpyro.sample(
name=_EXPONENT,
fn=custom_priors.get(_EXPONENT,
transform_default_priors[_EXPONENT]))
carryover = media_transforms.carryover(
data=media_data,
ad_effect_retention_rate=ad_effect_retention_rate,
peak_effect_delay=peak_effect_delay,
number_lags=number_lags)
if media_data.ndim == 3:
exponent = jnp.expand_dims(exponent, axis=-1)
return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent)
[docs]def media_mix_model(
media_data: jnp.ndarray,
target_data: jnp.ndarray,
media_prior: jnp.ndarray,
degrees_seasonality: int,
frequency: int,
transform_function: TransformFunction,
custom_priors: MutableMapping[str, Prior],
transform_kwargs: Optional[MutableMapping[str, Any]] = None,
weekday_seasonality: bool = False,
extra_features: Optional[jnp.ndarray] = None,
) -> None:
"""Media mix model.
Args:
media_data: Media data to be be used in the model.
target_data: Target data for the model.
media_prior: Cost prior for each of the media channels.
degrees_seasonality: Number of degrees of seasonality to use.
frequency: Frequency of the time span which was used to aggregate the data.
Eg. if weekly data then frequency is 52.
transform_function: Function to use to transform the media data in the
model. Currently the following are supported: 'transform_adstock',
'transform_carryover' and 'transform_hill_adstock'.
custom_priors: The custom priors we want the model to take instead of the
default ones. See our custom_priors documentation for details about the
API and possible options.
transform_kwargs: Any extra keyword arguments to pass to the transform
function. For example the adstock function can take a boolean to noramlise
output or not.
weekday_seasonality: In case of daily data you can estimate a weekday (7)
parameter.
extra_features: Extra features data to include in the model.
"""
default_priors = _get_default_priors()
data_size = media_data.shape[0]
n_channels = media_data.shape[1]
geo_shape = (media_data.shape[2],) if media_data.ndim == 3 else ()
n_geos = media_data.shape[2] if media_data.ndim == 3 else 1
with numpyro.plate(name=f"{_INTERCEPT}_plate", size=n_geos):
intercept = numpyro.sample(
name=_INTERCEPT,
fn=custom_priors.get(_INTERCEPT, default_priors[_INTERCEPT]))
with numpyro.plate(name=f"{_SIGMA}_plate", size=n_geos):
sigma = numpyro.sample(
name=_SIGMA,
fn=custom_priors.get(_SIGMA, default_priors[_SIGMA]))
# TODO(): Force all geos to have the same trend sign.
with numpyro.plate(name=f"{_COEF_TREND}_plate", size=n_geos):
coef_trend = numpyro.sample(
name=_COEF_TREND,
fn=custom_priors.get(_COEF_TREND, default_priors[_COEF_TREND]))
expo_trend = numpyro.sample(
name=_EXPO_TREND,
fn=custom_priors.get(
_EXPO_TREND, default_priors[_EXPO_TREND]))
with numpyro.plate(
name="channel_media_plate",
size=n_channels,
dim=-2 if media_data.ndim == 3 else -1):
coef_media = numpyro.sample(
name="channel_coef_media" if media_data.ndim == 3 else "coef_media",
fn=dist.HalfNormal(scale=media_prior))
if media_data.ndim == 3:
with numpyro.plate(
name="geo_media_plate",
size=n_geos,
dim=-1):
# Corrects the mean to be the same as in the channel only case.
normalisation_factor = jnp.sqrt(2.0 / jnp.pi)
coef_media = numpyro.sample(
name="coef_media",
fn=dist.HalfNormal(scale=coef_media * normalisation_factor)
)
with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2):
with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate",
size=degrees_seasonality):
gamma_seasonality = numpyro.sample(
name=_GAMMA_SEASONALITY,
fn=custom_priors.get(
_GAMMA_SEASONALITY, default_priors[_GAMMA_SEASONALITY]))
if weekday_seasonality:
with numpyro.plate(name=f"{_WEEKDAY}_plate", size=7):
weekday = numpyro.sample(
name=_WEEKDAY,
fn=custom_priors.get(_WEEKDAY, default_priors[_WEEKDAY]))
weekday_series = weekday[jnp.arange(data_size) % 7]
# In case of daily data, number of lags should be 13*7.
if transform_function == "carryover" and transform_kwargs and "number_lags" not in transform_kwargs:
transform_kwargs["number_lags"] = 13 * 7
elif transform_function == "carryover" and not transform_kwargs:
transform_kwargs = {"number_lags": 13 * 7}
media_transformed = numpyro.deterministic(
name="media_transformed",
value=transform_function(media_data,
custom_priors=custom_priors,
**transform_kwargs if transform_kwargs else {}))
seasonality = media_transforms.calculate_seasonality(
number_periods=data_size,
degrees=degrees_seasonality,
frequency=frequency,
gamma_seasonality=gamma_seasonality)
# For national model's case
trend = jnp.arange(data_size)
media_einsum = "tc, c -> t" # t = time, c = channel
coef_seasonality = 1
# TODO(): Add conversion of prior for HalfNormal distribution.
if media_data.ndim == 3: # For geo model's case
trend = jnp.expand_dims(trend, axis=-1)
seasonality = jnp.expand_dims(seasonality, axis=-1)
media_einsum = "tcg, cg -> tg" # t = time, c = channel, g = geo
if weekday_seasonality:
weekday_series = jnp.expand_dims(weekday_series, axis=-1)
with numpyro.plate(name="seasonality_plate", size=n_geos):
coef_seasonality = numpyro.sample(
name=_COEF_SEASONALITY,
fn=custom_priors.get(
_COEF_SEASONALITY, default_priors[_COEF_SEASONALITY]))
# expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
prediction = (
intercept + coef_trend * trend ** expo_trend +
seasonality * coef_seasonality +
jnp.einsum(media_einsum, media_transformed, coef_media))
if extra_features is not None:
plate_prefixes = ("extra_feature",)
extra_features_einsum = "tf, f -> t" # t = time, f = feature
extra_features_plates_shape = (extra_features.shape[1],)
if extra_features.ndim == 3:
plate_prefixes = ("extra_feature", "geo")
extra_features_einsum = "tfg, fg -> tg" # t = time, f = feature, g = geo
extra_features_plates_shape = (extra_features.shape[1], *geo_shape)
with numpyro.plate_stack(plate_prefixes,
sizes=extra_features_plates_shape):
coef_extra_features = numpyro.sample(
name=_COEF_EXTRA_FEATURES,
fn=custom_priors.get(
_COEF_EXTRA_FEATURES, default_priors[_COEF_EXTRA_FEATURES]))
extra_features_effect = jnp.einsum(extra_features_einsum,
extra_features,
coef_extra_features)
prediction += extra_features_effect
if weekday_seasonality:
prediction += weekday_series
mu = numpyro.deterministic(name="mu", value=prediction)
numpyro.sample(
name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)