Source code for lightweight_mmm.lightweight_mmm

# Copyright 2023 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A simple and lightweight library for Media Mix Modelling.

Simple usage of this class goes as following:

```
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(media=media_data,
        extra_features=extra_features,
        media_prior=costs,
        target=target,
        number_samples=1000,
        number_chains=2)

# For obtaining media contribution percentage and ROI
predictions, media_contribution_hat_pct, roi_hat = mmm.get_posterior_metrics()

# For running predictions on unseen data
mmm.predict(media=media_data_test, extra_features=extra_features_test)
```
"""

import collections
import dataclasses
import functools
import itertools
import logging
import numbers
from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Sequence, Tuple, Union

from absl import logging
import immutabledict
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from numpyro import distributions as dist
from numpyro import infer

from lightweight_mmm import models
from lightweight_mmm import preprocessing
from lightweight_mmm import utils

Prior = Union[
    dist.Distribution,
    Dict[str, float],
    Sequence[float],
    float
]

_NAMES_TO_MODEL_TRANSFORMS = immutabledict.immutabledict({
    "hill_adstock": models.transform_hill_adstock,
    "adstock": models.transform_adstock,
    "carryover": models.transform_carryover
})
_MODEL_FUNCTION = models.media_mix_model


def _compare_equality_for_lmmm(item_1: Any, item_2: Any) -> bool:
  """Compares two items for equality.

  Helper function for the __eq__ method of LightweightmMM. First checks if items
  are strings or lists of strings (it's okay if empty lists compare True), then
  uses jnp.array_equal if the items are jax.numpy.DeviceArray or other related
  sequences, and uses items' __eq__ otherwise.

  Note: this implementation does not cover every possible data structure, but
  it does cover all the data structures seen in attributes used by
  LightweightMMM. Sometimes the DeviceArray is hidden in the value of a
  MutableMapping, hence the recursion.

  Args:
    item_1: First item to be compared.
    item_2: Second item to be compared.

  Returns:
    Boolean for whether item_1 equals item_2.
  """

  # This is pretty strict but LMMM classes don't need to compare equal unless
  # they are exact copies.
  if type(item_1) != type(item_2):
    is_equal = False
  elif isinstance(item_1, str):
    is_equal = item_1 == item_2
  elif isinstance(item_1, (jax.Array, np.ndarray)) or (
      isinstance(item_1, Sequence)
      and not all(isinstance(x, str) for x in item_1)
  ):
    is_equal = np.array_equal(item_1, item_2, equal_nan=True)
  elif isinstance(item_1, MutableMapping):
    is_equal = all(
        [
            _compare_equality_for_lmmm(item_1[x], item_2[x])
            for x in item_1.keys() | item_2.keys()
        ]
    )
  else:
    is_equal = item_1 == item_2

  return is_equal


class NotFittedModelError(Exception):
  pass


[docs]@dataclasses.dataclass(unsafe_hash=True, eq=False) class LightweightMMM: """Lightweight Media Mix Modelling wrapper for bayesian models. The currently available models are the following: - hill_adstock - adstock - carryover It also offers the necessary utilities for calculating media contribution and media ROI based on models' results. Attributes: trace: Sampling trace of the bayesian model once fitted. n_media_channels: Number of media channels the model was trained with. n_geos: Number of geos for geo models or 1 for national models. model_name: Name of the model. media: The media data the model is trained on. Usefull for a variety of insights post model fitting. media_names: Names of the media channels passed at fitting time. custom_priors: The set of custom priors the model was trained with. An empty dictionary if none were passed. """ model_name: str = "hill_adstock" n_media_channels: int = dataclasses.field(init=False, repr=False) n_geos: int = dataclasses.field(init=False, repr=False) media: jax.Array = dataclasses.field( init=False, repr=False, hash=False, compare=True) media_names: Sequence[str] = dataclasses.field( init=False, repr=False, hash=False, compare=True) trace: Dict[str, jax.Array] = dataclasses.field( init=False, repr=False, hash=False, compare=False) custom_priors: MutableMapping[str, Prior] = dataclasses.field( init=False, repr=False, hash=False, compare=True) _degrees_seasonality: int = dataclasses.field(init=False, repr=False) _weekday_seasonality: bool = dataclasses.field(init=False, repr=False) _media_prior: jax.Array = dataclasses.field( init=False, repr=False, hash=False, compare=True) _extra_features: jax.Array = dataclasses.field( init=False, repr=False, hash=False, compare=True) _target: jax.Array = dataclasses.field( init=False, repr=False, hash=False, compare=True) _train_media_size: int = dataclasses.field( init=False, repr=False, hash=True, compare=False) _mcmc: numpyro.infer.MCMC = dataclasses.field( init=False, repr=False, hash=False, compare=False) def __post_init__(self): if self.model_name not in _NAMES_TO_MODEL_TRANSFORMS: raise ValueError("Model name passed not valid. Please use any of the" "following: 'hill_adstock', 'adstock', 'carryover'.") self._model_function = _MODEL_FUNCTION self._model_transform_function = _NAMES_TO_MODEL_TRANSFORMS[self.model_name] self._prior_names = models.MODEL_PRIORS_NAMES.union( models.TRANSFORM_PRIORS_NAMES[self.model_name]) def __eq__(self, other: Any) -> bool: """Equality method for LightweightMMMM. We need a special method here to handle a couple of issues. First, some of the attributes for LightweightMMM are arrays, which contain multiple values and cannot be evaluated with the default __eq__ method. Second, some attributes are initially undefined and only get values after fitting a model. The latter is dealt with within this function, and the former within the helper function _compare_equality_for_lmmm(). Args: other: Dataclass to compare against. Returns: Boolean for whether self == other; NotImplemented if other is not a LightweightMMM. """ if not isinstance(other, LightweightMMM): return NotImplemented def _create_list_of_attributes_to_compare( mmm_instance: Any) -> Sequence[str]: all_attributes_that_can_be_compared = sorted( [x.name for x in dataclasses.fields(mmm_instance) if x.compare]) attributes_which_have_been_instantiated = [ x for x in all_attributes_that_can_be_compared if hasattr(mmm_instance, x) ] return attributes_which_have_been_instantiated self_attributes = _create_list_of_attributes_to_compare(self) other_attributes = _create_list_of_attributes_to_compare(other) return all( _compare_equality_for_lmmm(getattr(self, a1), getattr(other, a2)) for a1, a2 in itertools.zip_longest(self_attributes, other_attributes)) def _preprocess_custom_priors( self, custom_priors: Dict[str, Prior]) -> MutableMapping[str, Prior]: """Preprocesses the user input custom priors to Numpyro distributions. If numpyro distributions are given they remains untouched, however if any other option is passed, it is passed to the default distribution to alter its constructor values. Args: custom_priors: Mapping of the name of the prior to its custom value. Returns: A mapping of names to numpyro distributions based on user input and default values. """ default_priors = { **models._get_default_priors(), **models._get_transform_default_priors()[self.model_name] } # Checking that the key is contained in custom_priors has already been done # at this point in the fit function. for prior_name in custom_priors: if isinstance(custom_priors[prior_name], numbers.Number): custom_priors[prior_name] = default_priors[prior_name].__class__( custom_priors[prior_name]) elif (isinstance(custom_priors[prior_name], collections.abc.Sequence) and not isinstance(custom_priors[prior_name], str)): custom_priors[prior_name] = default_priors[prior_name].__class__( *custom_priors[prior_name]) elif isinstance(custom_priors[prior_name], dict): custom_priors[prior_name] = default_priors[prior_name].__class__( **custom_priors[prior_name]) elif not isinstance(custom_priors[prior_name], dist.Distribution): raise ValueError( "Priors given must be a Numpyro distribution or one of the " "following to fit in the constructor of our default Numpyro " "distribution. It could be given as args or kwargs as long as it " "is the correct format for such object. Please refer to our " "documentation on custom priors to know more.") return custom_priors
[docs] def fit( self, media: jnp.ndarray, media_prior: jnp.ndarray, target: jnp.ndarray, extra_features: Optional[jnp.ndarray] = None, degrees_seasonality: int = 2, seasonality_frequency: int = 52, weekday_seasonality: bool = False, media_names: Optional[Sequence[str]] = None, number_warmup: int = 1000, number_samples: int = 1000, number_chains: int = 2, target_accept_prob: float = .85, init_strategy: Callable[[Mapping[Any, Any], Any], jnp.ndarray] = numpyro.infer.init_to_median, custom_priors: Optional[Dict[str, Prior]] = None, seed: Optional[int] = None) -> None: """Fits MMM given the media data, extra features, costs and sales/KPI. For detailed information on the selected model please refer to its respective function in the models.py file. Args: media: Media input data. Media data must have either 2 dims for national model or 3 for geo models. media_prior: Costs of each media channel. The number of cost values must be equal to the number of media channels. target: Target KPI to use, like for example sales. extra_features: Other variables to add to the model. degrees_seasonality: Number of degrees to use for seasonality. Default is 2. seasonality_frequency: Frequency of the time period used. Default is 52 as in 52 weeks per year. weekday_seasonality: In case of daily data, also estimate seven weekday parameters. media_names: Names of the media channels passed. number_warmup: Number of warm up samples. Default is 1000. number_samples: Number of samples during sampling. Default is 1000. number_chains: Number of chains to sample. Default is 2. target_accept_prob: Target acceptance probability for step size in the NUTS sampler. Default is .85. init_strategy: Initialization function for numpyro NUTS. The available options can be found in https://num.pyro.ai/en/stable/utilities.html#initialization-strategies. Default is numpyro.infer.init_to_median. custom_priors: The custom priors we want the model to take instead of the default ones. Refer to the full documentation on custom priors for details. seed: Seed to use for PRNGKey during training. For better replicability run all different trainings with the same seed. """ if media.ndim not in (2, 3): raise ValueError( "Media data must have either 2 dims for national model or 3 for geo " "models.") if media.ndim == 3 and media_prior.ndim == 1: media_prior = jnp.expand_dims(media_prior, axis=-1) if media.shape[1] != len(media_prior): raise ValueError("The number of data channels provided must match the " "number of cost values.") if media.min() < 0: raise ValueError("Media values must be greater or equal to zero.") if custom_priors: not_used_custom_priors = set(custom_priors.keys()).difference( self._prior_names) if not_used_custom_priors: raise ValueError( "The following passed custom priors dont have a match in the model." " Please double check the names have been written correctly: %s" % not_used_custom_priors) custom_priors = self._preprocess_custom_priors( custom_priors=custom_priors) geo_custom_priors = set(custom_priors.keys()).intersection( models.GEO_ONLY_PRIORS) if media.ndim == 2 and geo_custom_priors: raise ValueError( "The given data is for national models but custom_prior contains " "priors for the geo version of the model. Please either remove geo " "priors for national model or pass media data with geo dimension.") else: custom_priors = {} if weekday_seasonality and seasonality_frequency == 52: logging.warn("You have chosen daily seasonality and frequency 52 " "(weekly), please check you made the right seasonality " "choices.") if extra_features is not None: extra_features = jnp.array(extra_features) if seed is None: seed = utils.get_time_seed() train_media_size = media.shape[0] kernel = numpyro.infer.NUTS( model=self._model_function, target_accept_prob=target_accept_prob, init_strategy=init_strategy) mcmc = numpyro.infer.MCMC( sampler=kernel, num_warmup=number_warmup, num_samples=number_samples, num_chains=number_chains) mcmc.run( rng_key=jax.random.PRNGKey(seed), media_data=jnp.array(media), extra_features=extra_features, target_data=jnp.array(target), media_prior=jnp.array(media_prior), degrees_seasonality=degrees_seasonality, frequency=seasonality_frequency, transform_function=self._model_transform_function, weekday_seasonality=weekday_seasonality, custom_priors=custom_priors) self.custom_priors = custom_priors if media_names is not None: self.media_names = list(media_names) else: self.media_names = [f"channel_{i}" for i in range(media.shape[1])] self.n_media_channels = media.shape[1] self.n_geos = media.shape[2] if media.ndim == 3 else 1 self._media_prior = media_prior self.trace = mcmc.get_samples() self._number_warmup = number_warmup self._number_samples = number_samples self._number_chains = number_chains self._target = target self._train_media_size = train_media_size self._degrees_seasonality = degrees_seasonality self._seasonality_frequency = seasonality_frequency self._weekday_seasonality = weekday_seasonality self.media = media self._extra_features = extra_features# jax-devicearray self._mcmc = mcmc logging.info("Model has been fitted")
[docs] def print_summary(self) -> None: """Calls print_summary function from numpyro to print parameters summary. """ # TODO(): add name selection for print. self._mcmc.print_summary()
@functools.partial( jax.jit, static_argnums=(0,), static_argnames=("degrees_seasonality", "weekday_seasonality", "transform_function", "model")) def _predict( self, rng_key: jnp.ndarray, media_data: jnp.ndarray, extra_features: Optional[jnp.ndarray], media_prior: jnp.ndarray, degrees_seasonality: int, frequency: int, transform_function: Callable[[Any], jnp.ndarray], weekday_seasonality: bool, model: Callable[[Any], None], posterior_samples: Dict[str, jnp.ndarray], custom_priors: Dict[str, Prior] ) -> Dict[str, jnp.ndarray]: """Encapsulates the numpyro.infer.Predictive function for predict method. It serves as a helper jitted function for running predictions. Args: rng_key: A jax.random.PRNGKey. media_data: Media array for needed for the model to run predictions. extra_features: Extra features for needed for the model to run. media_prior: Cost prior used for training the model. degrees_seasonality: Number of degrees for the seasonality. frequency: Frequency of the seasonality. transform_function: Media transform function to use within the model. weekday_seasonality: Allow daily weekday estimation. model: Numpyro model to use for numpyro.infer.Predictive. posterior_samples: Mapping of the posterior samples. custom_priors: The custom priors we want the model to take instead of the default ones. Refer to the full documentation on custom priors for details. Returns: The predictions for the given data. """ return infer.Predictive( model=model, posterior_samples=posterior_samples)( rng_key=rng_key, media_data=media_data, extra_features=extra_features, media_prior=media_prior, target_data=None, degrees_seasonality=degrees_seasonality, frequency=frequency, transform_function=transform_function, custom_priors=custom_priors, weekday_seasonality=weekday_seasonality)
[docs] def predict( self, media: jnp.ndarray, extra_features: Optional[jnp.ndarray] = None, media_gap: Optional[jnp.ndarray] = None, target_scaler: Optional[preprocessing.CustomScaler] = None, seed: Optional[int] = None ) -> jnp.ndarray: """Runs the model to obtain predictions for the given input data. Predictions returned are distributions, if point estimates are desired one can calculate those based on the given distribution. Args: media: Media array for needed for the model to run predictions. extra_features: Extra features for needed for the model to run. media_gap: Media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 2 months after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, ...) can take place correctly. target_scaler: Scaler that was used to scale the target before training. seed: Seed to use for PRNGKey during sampling. For replicability run this function and any other function that utilises predictions with the same seed. Returns: Predictions for the given media and extra features at a given date index. Raises: NotFittedModelError: When the model has not been fitted before running predict. """ if not hasattr(self, "trace"): raise NotFittedModelError("Need to fit the model before running " "predictions.") if media_gap is not None: if media.ndim != media_gap.ndim: raise ValueError("Original media data and media gap must have the same " "number of dimensions.") if media.ndim > 1 and media.shape[1] != media_gap.shape[1]: raise ValueError("Media gap must have the same numer of media channels" "as the original media data.") previous_media = jnp.concatenate(arrays=[self.media, media_gap], axis=0) if extra_features is not None: previous_extra_features = jnp.concatenate( arrays=[ self._extra_features, jnp.zeros((media_gap.shape[0], *self._extra_features.shape[1:])) ], axis=0) else: previous_media = self.media previous_extra_features = self._extra_features full_media = jnp.concatenate(arrays=[previous_media, media], axis=0) if extra_features is not None: full_extra_features = jnp.concatenate( arrays=[previous_extra_features, extra_features], axis=0) else: full_extra_features = None if seed is None: seed = utils.get_time_seed() prediction = self._predict( rng_key=jax.random.PRNGKey(seed=seed), media_data=full_media, extra_features=full_extra_features, media_prior=jnp.array(self._media_prior), degrees_seasonality=self._degrees_seasonality, frequency=self._seasonality_frequency, weekday_seasonality=self._weekday_seasonality, transform_function=self._model_transform_function, model=self._model_function, custom_priors=self.custom_priors, posterior_samples=self.trace)["mu"][:, previous_media.shape[0]:] if target_scaler: prediction = target_scaler.inverse_transform(prediction) return prediction
[docs] def reduce_trace(self, nsample: int = 100, seed: int = 0) -> None: """Reduces the samples in `trace` to speed up `predict` and optimize. Please note this step is not reversible. Only do this after you have investigated convergence of the model. Args: nsample: Target number of samples. seed: Random seed for down sampling. Raises: ValueError: if `nsample` is too big. """ ntrace = len(self.trace["sigma"]) if ntrace < nsample: raise ValueError("nsample is bigger than the actual posterior samples") key = jax.random.PRNGKey(seed) samples = jax.random.choice(key, ntrace, (nsample,), replace=False) for name in self.trace.keys(): self.trace[name] = self.trace[name][samples] logging.info("Reduction is complete")
[docs] def get_posterior_metrics( self, unscaled_costs: Optional[jnp.ndarray] = None, cost_scaler: Optional[preprocessing.CustomScaler] = None, target_scaler: Optional[preprocessing.CustomScaler] = None ) -> Tuple[jnp.ndarray, jnp.ndarray]: """It estimates the media contribution percentage and ROI of each channel. If data was scaled prior to training then the target and costs scalers need to be passed to this function to correctly calculate media contribution percentage and ROI in the unscaled space. Args: unscaled_costs: Optionally you can pass new costs to get these set of metrics. If None, the costs used for training will be used for calculating ROI. cost_scaler: Scaler that was used to scale the cost data before training. It is ignored if 'unscaled_costs' is provided. target_scaler: Scaler that was used to scale the target before training. Returns: media_contribution_hat_pct: The average media contribution percentage for each channel. roi_hat: The return on investment of each channel calculated as its contribution divided by the cost. Raises: NotFittedModelError: When the this method is called without the model being trained previously. """ if not hasattr(self, "trace"): raise NotFittedModelError( "LightweightMMM has not been fitted and cannot run estimations. " "Please first fit the model.") if unscaled_costs is None and not cost_scaler: logging.warning( "Unscaled cost data or cost scaler were not given and " "therefore unscaling wont be applied to calculcate contribution" " and ROI. If data was not scaled prior to training " "please ignore this warning.") if not target_scaler: logging.warning("Target scaler was not given and unscaling of the target " "will not occur. If your target was not scaled prior to " "training you can ignore this warning.") if unscaled_costs is None: if cost_scaler: unscaled_costs = cost_scaler.inverse_transform(self._media_prior) else: unscaled_costs = self._media_prior if self.media.ndim == 3: # cost shape (channel, geo) -> add a new axis to (channel, geo, sample) unscaled_costs = unscaled_costs = unscaled_costs[:, :, jnp.newaxis] # reshape cost to (sample, channel, geo) unscaled_costs = jnp.einsum("cgs->scg", unscaled_costs) # get the scaled posterior prediction posterior_pred = self.trace["mu"] if target_scaler: unscaled_posterior_pred = target_scaler.inverse_transform(posterior_pred) else: unscaled_posterior_pred = posterior_pred if self.media.ndim == 2: # s for samples, t for time, c for media channels einsum_str = "stc, sc -> sc" elif self.media.ndim == 3: # s for samples, t for time, c for media channels, g for geo einsum_str = "stcg, scg -> scg" media_contribution = jnp.einsum(einsum_str, self.trace["media_transformed"], jnp.squeeze(self.trace["coef_media"])) # aggregate posterior_pred across time: sum_scaled_prediction = jnp.sum(posterior_pred, axis=1) # aggregate unscaled_posterior_pred across time: sum_unscaled_prediction = jnp.sum(unscaled_posterior_pred, axis=1) if self.media.ndim == 2: # add a new axis to represent channel:(sample,) -> (sample,channel) sum_scaled_prediction = sum_scaled_prediction[:, jnp.newaxis] sum_unscaled_prediction = sum_unscaled_prediction[:, jnp.newaxis] elif self.media.ndim == 3: # add a new axis to represent channel:(sample,geo) -> (sample,geo,channel) # note: the total prediction value stays the same for all channels sum_scaled_prediction = sum_scaled_prediction[:, jnp.newaxis, :] # add a new axis to represent channel:(sample,geo) -> (sample,geo,channel) # note: the total prediction value stays the same for all channels sum_unscaled_prediction = sum_unscaled_prediction[:, :, jnp.newaxis] # reshape the array (sample,geo,channel) -> (sample,channel,geo) sum_unscaled_prediction = jnp.einsum("sgc->scg", sum_unscaled_prediction) # media contribution pct = media contribution / prediction # for geo level model: # media_contribution shape (sample, channel, geo) # sum_scaled_prediction shape (sample, channel, geo) # -> media_contribution_hat shape (sample, channel, geo) media_contribution_hat = media_contribution / sum_scaled_prediction # media roi = unscaled prediction * media contribution pct / unscaled costs # for geo leve model: # sum_unscaled_prediction shape (sample, channel, geo) # media_contribution_hat shape (sample, channel, geo) # unscaled_costs shape (sample, channel, geo) # -> roi_hat shape (sample, channel, geo) roi_hat = sum_unscaled_prediction * media_contribution_hat / unscaled_costs return media_contribution_hat, roi_hat