LightweightMMM

LightweightMMM object

LightweightMMM([model_name])

Lightweight Media Mix Modelling wrapper for bayesian models.

class lightweight_mmm.lightweight_mmm.LightweightMMM(model_name: str = 'hill_adstock')[source]

Lightweight Media Mix Modelling wrapper for bayesian models.

The currently available models are the following:
  • hill_adstock

  • adstock

  • carryover

It also offers the necessary utilities for calculating media contribution and media ROI based on models’ results.

trace

Sampling trace of the bayesian model once fitted.

Type

Dict[str, jax.Array]

n_media_channels

Number of media channels the model was trained with.

Type

int

n_geos

Number of geos for geo models or 1 for national models.

Type

int

model_name

Name of the model.

Type

str

media

The media data the model is trained on. Usefull for a variety of insights post model fitting.

Type

jax.Array

media_names

Names of the media channels passed at fitting time.

Type

Sequence[str]

custom_priors

The set of custom priors the model was trained with. An empty dictionary if none were passed.

Type

MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]]

fit(media: jax.Array, media_prior: jax.Array, target: jax.Array, extra_features: typing.Optional[jax.Array] = None, degrees_seasonality: int = 2, seasonality_frequency: int = 52, weekday_seasonality: bool = False, media_names: typing.Optional[typing.Sequence[str]] = None, number_warmup: int = 1000, number_samples: int = 1000, number_chains: int = 2, target_accept_prob: float = 0.85, init_strategy: typing.Callable[[typing.Mapping[typing.Any, typing.Any], typing.Any], jax.Array] = <function init_to_median>, custom_priors: typing.Optional[typing.Dict[str, typing.Union[numpyro.distributions.distribution.Distribution, typing.Dict[str, float], typing.Sequence[float], float]]] = None, seed: typing.Optional[int] = None) None[source]

Fits MMM given the media data, extra features, costs and sales/KPI.

For detailed information on the selected model please refer to its respective function in the models.py file.

Parameters
  • media – Media input data. Media data must have either 2 dims for national model or 3 for geo models.

  • media_prior – Costs of each media channel. The number of cost values must be equal to the number of media channels.

  • target – Target KPI to use, like for example sales.

  • extra_features – Other variables to add to the model.

  • degrees_seasonality – Number of degrees to use for seasonality. Default is 2.

  • seasonality_frequency – Frequency of the time period used. Default is 52 as in 52 weeks per year.

  • weekday_seasonality – In case of daily data, also estimate seven weekday parameters.

  • media_names – Names of the media channels passed.

  • number_warmup – Number of warm up samples. Default is 1000.

  • number_samples – Number of samples during sampling. Default is 1000.

  • number_chains – Number of chains to sample. Default is 2.

  • target_accept_prob – Target acceptance probability for step size in the NUTS sampler. Default is .85.

  • init_strategy – Initialization function for numpyro NUTS. The available options can be found in https://num.pyro.ai/en/stable/utilities.html#initialization-strategies. Default is numpyro.infer.init_to_median.

  • custom_priors – The custom priors we want the model to take instead of the default ones. Refer to the full documentation on custom priors for details.

  • seed – Seed to use for PRNGKey during training. For better replicability run all different trainings with the same seed.

get_posterior_metrics(unscaled_costs: Optional[jax.Array] = None, cost_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None) Tuple[jax.Array, jax.Array][source]

It estimates the media contribution percentage and ROI of each channel.

If data was scaled prior to training then the target and costs scalers need to be passed to this function to correctly calculate media contribution percentage and ROI in the unscaled space.

Parameters
  • unscaled_costs – Optionally you can pass new costs to get these set of metrics. If None, the costs used for training will be used for calculating ROI.

  • cost_scaler – Scaler that was used to scale the cost data before training. It is ignored if ‘unscaled_costs’ is provided.

  • target_scaler – Scaler that was used to scale the target before training.

Returns

The average media contribution percentage for each channel. roi_hat: The return on investment of each channel calculated as its contribution divided by the cost.

Return type

media_contribution_hat_pct

Raises

NotFittedModelError – When the this method is called without the model being trained previously.

predict(media: jax.Array, extra_features: Optional[jax.Array] = None, media_gap: Optional[jax.Array] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, seed: Optional[int] = None) jax.Array[source]

Runs the model to obtain predictions for the given input data.

Predictions returned are distributions, if point estimates are desired one can calculate those based on the given distribution.

Parameters
  • media – Media array for needed for the model to run predictions.

  • extra_features – Extra features for needed for the model to run.

  • media_gap – Media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 2 months after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, …) can take place correctly.

  • target_scaler – Scaler that was used to scale the target before training.

  • seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that utilises predictions with the same seed.

Returns

Predictions for the given media and extra features at a given date index.

Raises

NotFittedModelError – When the model has not been fitted before running predict.

print_summary() None[source]

Calls print_summary function from numpyro to print parameters summary.

reduce_trace(nsample: int = 100, seed: int = 0) None[source]

Reduces the samples in trace to speed up predict and optimize.

Please note this step is not reversible. Only do this after you have investigated convergence of the model.

Parameters
  • nsample – Target number of samples.

  • seed – Random seed for down sampling.

Raises

ValueError – if nsample is too big.

Preprocessing / Scaling

CustomScaler([divide_operation, divide_by, ...])

Class to scale your data based on multiplications and divisions.

class lightweight_mmm.preprocessing.CustomScaler(divide_operation: Optional[Callable[[jax.Array], jax.numpy.float32]] = None, divide_by: Optional[Union[float, int, jax.Array]] = 1, multiply_operation: Optional[Callable[[jax.Array], jax.numpy.float32]] = None, multiply_by: Optional[Union[float, int, jax.Array]] = 1.0)[source]

Class to scale your data based on multiplications and divisions.

This scaler can be used in two fashions for both the multiplication and division operation. - By specifying a value to use for the scaling operation. - By specifying an operation used at column level to calculate the value

for the actual scaling operation.

Eg. if one wants to scale the dataset by multiply by 100 you can directly pass multiply_by=100. Value can also be an array with as many values as column has the data being scaled. But if you want to multiply by the mean value of each column, then you can pass multiply_operation=jnp.mean (or any other operation desired).

Operation parameters have the upper hand in the cases where both values and operations are passed, values will be ignored in this case.

Scaler must be fit first in order to call the transform method.

Attributes.
divide_operation: Operation to apply over axis 0 of the fitting data to

obtain the value that will be used for division during scaling.

divide_by: Numbers(s) by which to divide data in the scaling process. Since

the scaler is applied to axis 0 of the data, the shape of divide_by must be consistent with division into the data. For example, if data.shape = (100, 3, 5) then divide_by.shape can be (3, 5) or (5,) or a number. If divide_operation is given, this divide_by value will be ignored.

multiply_operation: Operation to apply over axis 0 of the fitting data to

obtain the value that will be used for multiplication during scaling.

multiply_by: Numbers(s) by which to multiply data in the scaling process.

Since the scaler is applied to axis 0 of the data, the shape of multiply_by must be consistent with multiplication into the data. For example, if data.shape = (100, 3, 5) then multiply_by.shape can be (3, 5) or (5,) or a number. If multiply_operation is given, this multiply_by value will be ignored.

fit(data: jax.Array) None[source]

Figures out values for transformations based on the specified operations.

Parameters

data – Input dataset to use for fitting.

fit_transform(data: jax.Array) jax.Array[source]

Fits the values and applies transformation to the input data.

Parameters

data – Input dataset.

Returns

Transformed array.

inverse_transform(data: jax.Array) jax.Array[source]

Runs inverse transformation to get original values.

Parameters

data – Input dataset.

Returns

Dataset with the inverse transformation applied.

transform(data: jax.Array) jax.Array[source]

Applies transformation based on fitted values.

It can only be called if scaler was fit first.

Parameters

data – Input dataset to transform.

Returns

Transformed array.

Optimize Media

find_optimal_budgets(n_time_periods, ...[, ...])

Finds the best media allocation based on MMM model, prices and a budget.

lightweight_mmm.optimize_media.find_optimal_budgets(n_time_periods: int, media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, budget: Union[float, int], prices: jax.Array, extra_features: Optional[jax.Array] = None, media_gap: Optional[jax.Array] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, media_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, bounds_lower_pct: Union[float, jax.Array] = 0.2, bounds_upper_pct: Union[float, jax.Array] = 0.2, max_iterations: int = 200, solver_func_tolerance: float = 1e-06, solver_step_size: float = 1.4901161193847656e-08, seed: Optional[int] = None) scipy.optimize._optimize.OptimizeResult[source]

Finds the best media allocation based on MMM model, prices and a budget.

Parameters
  • n_time_periods – Number of time periods to optimize for. If model is built on weekly data, this would be the number of weeks ahead to optimize.

  • media_mix_model – Media mix model to use for the optimization.

  • budget – Total budget to allocate during the optimization time.

  • prices – An array with shape (n_media_channels,) for the cost of each media channel unit.

  • extra_features – Extra features needed for the model to predict.

  • media_gap – Media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 8 weeks after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, …) can take place correctly.

  • target_scaler – Scaler that was used to scale the target before training.

  • media_scaler – Scaler that was used to scale the media data before training.

  • bounds_lower_pct – Relative percentage decrease from the mean value to consider as new lower bound.

  • bounds_upper_pct – Relative percentage increase from the mean value to consider as new upper bound.

  • max_iterations – Number of max iterations to use for the SLSQP scipy optimizer. Default is 200.

  • solver_func_tolerance – Precision goal for the value of the prediction in the stopping criterion. Maps directly to scipy’s ftol. Intended only for advanced users. For more details see: https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.

  • solver_step_size – Step size used for numerical approximation of the Jacobian. Maps directly to scipy’s eps. Intended only for advanced users. For more details see: https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.

  • seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that gets predictions with the same seed.

Returns

OptimizeResult object containing the results of the optimization. kpi_without_optim: Predicted target based on original allocation proportion among channels from the historical data. starting_values: Budget Allocation based on original allocation proportion and the given total budget.

Return type

solution

Plot

plot_response_curves(media_mix_model[, ...])

Plots the response curves of each media channel based on the model.

plot_cross_correlate(feature, target[, maxlags])

Plots the cross correlation coefficients between 2 vectors.

plot_var_cost(media, costs, names)

Plots a a chart between the coefficient of variation and cost.

plot_model_fit(media_mix_model[, ...])

Plots the ground truth, predicted value and interval for the training data.

plot_out_of_sample_model_fit(...[, ...])

Plots the ground truth, predicted value and interval for the test data.

plot_media_channel_posteriors(media_mix_model)

Plots the posterior distributions of estimated media channel effect.

plot_prior_and_posterior(media_mix_model[, ...])

Plots prior and posterior distributions for parameters in media_mix_model.

plot_bars_media_metrics(metric[, ...])

Plots a barchart of estimated media effects with their percentile interval.

plot_pre_post_budget_allocation_comparison(...)

Plots a barcharts to compare pre & post budget allocation.

plot_media_baseline_contribution_area_plot(...)

Plots an area chart to visualize weekly media & baseline contribution.

create_media_baseline_contribution_df(...[, ...])

Creates a dataframe for weekly media channels & basline contribution.

lightweight_mmm.plot.plot_response_curves(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, media_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, prices: Optional[jax.Array] = None, optimal_allocation_per_timeunit: Optional[jax.Array] = None, steps: int = 50, percentage_add: float = 0.2, apply_log_scale: bool = False, figure_size: Tuple[int, int] = (8, 10), n_columns: int = 3, marker_size: int = 8, legend_fontsize: int = 8, seed: Optional[int] = None) matplotlib.figure.Figure[source]

Plots the response curves of each media channel based on the model.

It plots an individual subplot for each media channel. If ‘ optimal_allocation_per_timeunit is given it uses it to add markers based on historic average spend and the given optimal one on each of the individual subplots.

It then plots a combined plot with all the response curves which can be changed to log scale if apply_log_scale is True.

Parameters
  • media_mix_model – Media mix model to use for plotting the response curves.

  • media_scaler – Scaler that was used to scale the media data before training.

  • target_scaler – Scaler used for scaling the target, to unscaled values and plot in the original scale.

  • prices – Prices to translate the media units to spend. If all your data is already in spend numbers you can leave this as None. If some of your data is media spend and others is media unit, leave the media spend with price 1 and add the price to the media unit channels.

  • optimal_allocation_per_timeunit – Optimal allocation per time unit per media channel. This can be obtained by running the optimization provided by LightweightMMM.

  • steps – Number of steps to simulate.

  • percentage_add – Percentage too exceed the maximum historic spend for the simulation of the response curve.

  • apply_log_scale – Whether to apply the log scale to the predictions (Y axis). When some media channels have very large scale compare to others it might be useful to use apply_log_scale=True. Default is False.

  • figure_size – Size of the plot figure.

  • n_columns – Number of columns to display in the subplots grid. Modifying this parameter might require to adjust figure_size accordingly for the plot to still have reasonable structure.

  • marker_size – Size of the marker for the optimization annotations. Only useful if optimal_allocation_per_timeunit is not None. Default is 8.

  • legend_fontsize – Legend font size for individual subplots.

  • seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that gets predictions with the same seed.

Returns

Plots of response curves.

lightweight_mmm.plot.plot_cross_correlate(feature: jax.Array, target: jax.Array, maxlags: int = 10) Tuple[int, float][source]

Plots the cross correlation coefficients between 2 vectors.

In the chart look for positive peaks, this shows how the lags of the feature lead the target.

Parameters
  • feature – Vector, the lags of which predict target.

  • target – Vector, what is predicted.

  • maxlags – Maximum number of lags.

Returns

Lag index and corresponding correlation of the peak correlation.

Raises

ValueError – If inputs don’t have same length.

lightweight_mmm.plot.plot_var_cost(media: jax.Array, costs: jax.Array, names: List[str]) matplotlib.figure.Figure[source]

Plots a a chart between the coefficient of variation and cost.

Parameters
  • media – Media matrix.

  • costs – Cost vector.

  • names – List of variable names.

Returns

Plot of coefficient of variation and cost.

Raises

ValueError if inputs don't conform to same length.

lightweight_mmm.plot.plot_model_fit(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, interval_mid_range: float = 0.9, digits: int = 3) matplotlib.figure.Figure[source]

Plots the ground truth, predicted value and interval for the training data.

Model needs to be fit before calling this function to plot.

Parameters
  • media_mix_model – Media mix model.

  • target_scaler – Scaler used for scaling the target, to unscaled values and plot in the original scale.

  • interval_mid_range – Mid range interval to take for plotting. Eg. .9 will use .05 and .95 as the lower and upper quantiles. Must be a float number. between 0 and 1.

  • digits – Number of decimals to display on metrics in the plot.

Returns

Plot of model fit.

lightweight_mmm.plot.plot_out_of_sample_model_fit(out_of_sample_predictions: jax.Array, out_of_sample_target: jax.Array, interval_mid_range: float = 0.9, digits: int = 3) matplotlib.figure.Figure[source]

Plots the ground truth, predicted value and interval for the test data.

Parameters
  • out_of_sample_predictions – Predictions for the out-of-sample period, as derived from mmm.predict.

  • out_of_sample_target – Target for the out-of-sample period. Needs to be on the same scale as out_of_sample_predictions.

  • interval_mid_range – Mid range interval to take for plotting. Eg. .9 will use .05 and .95 as the lower and upper quantiles. Must be a float number. between 0 and 1.

  • digits – Number of decimals to display on metrics in the plot.

Returns

Plot of model fit.

lightweight_mmm.plot.plot_media_channel_posteriors(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, channel_names: Optional[Sequence[Any]] = None, quantiles: Sequence[float] = (0.05, 0.5, 0.95), fig_size: Optional[Tuple[int, int]] = None) matplotlib.figure.Figure[source]

Plots the posterior distributions of estimated media channel effect.

Model needs to be fit before calling this function to plot.

Parameters
  • media_mix_model – Media mix model.

  • channel_names – Names of media channels to be added to plot.

  • quantiles – Quantiles to draw on the distribution.

  • fig_size – Size of the figure to plot as used by matplotlib. If not specified it will be determined dynamically based on the number of media channels and geos the model was trained on.

Returns

Plot of posterior distributions.

lightweight_mmm.plot.plot_prior_and_posterior(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, fig_size: Optional[Tuple[int, int]] = None, selected_features: Optional[List[str]] = None, number_of_samples_for_prior: int = 5000, kde_bandwidth_adjust_for_posterior: float = 1, seed: Optional[int] = None) matplotlib.figure.Figure[source]

Plots prior and posterior distributions for parameters in media_mix_model.

Parameters
  • media_mix_model – Fitted media mix model.

  • fig_size – Size of the figure to plot as used by matplotlib. Default is a width of 8 and a height of 1.5 for each subplot.

  • selected_features – Optional list of feature names to select. If not specified (the default), all features are selected.

  • number_of_samples_for_prior – Controls the level of smoothing for the plotted version of the prior distribution. The default should be fine unless you want to decrease it to speed up runtime.

  • kde_bandwidth_adjust_for_posterior – Multiplicative factor to adjust the bandwidth of the kernel density estimator, to control the level of smoothing for the posterior distribution. Passed to seaborn.kdeplot as the bw_adjust parameter there.

  • seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that utilises predictions with the same seed.

Returns

Plot with Kernel density estimate smoothing showing prior and posterior distributions for every parameter in the given media_mix_model.

Raises
  • NotFittedModelError – media_mix_model has not yet been fit.

  • ValueError – A feature has been created without a well-defined prior.

lightweight_mmm.plot.plot_bars_media_metrics(metric: jax.Array, metric_name: str = 'metric', channel_names: Optional[Tuple[Any]] = None, interval_mid_range: float = 0.9) matplotlib.figure.Figure[source]

Plots a barchart of estimated media effects with their percentile interval.

The lower and upper percentile need to be between 0-1.

Parameters
  • metric – Estimated media metric as returned by lightweight_mmm.get_posterior_metrics(). Can be either contribution percentage or ROI.

  • metric_name – Name of the media metric, e.g. contribution percentage or ROI.

  • channel_names – Names of media channels to be added to plot.

  • interval_mid_range – Mid range interval to take for plotting. Eg. .9 will use .05 and .95 as the lower and upper quantiles. Must be a float number.

Returns

Barplot of estimated media effects with defined percentile-bars.

lightweight_mmm.plot.plot_pre_post_budget_allocation_comparison(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, kpi_with_optim: jax.Array, kpi_without_optim: jax.Array, optimal_buget_allocation: jax.Array, previous_budget_allocation: jax.Array, channel_names: Optional[Sequence[Any]] = None, figure_size: Tuple[int, int] = (20, 10)) matplotlib.figure.Figure[source]

Plots a barcharts to compare pre & post budget allocation.

Parameters
  • media_mix_model – Media mix model to use for the optimization.

  • kpi_with_optim – Negative predicted target variable with optimized budget allocation.

  • kpi_without_optim – negative predicted target variable with original budget allocation proportion base on the historical data.

  • optimal_buget_allocation – Optmized budget allocation.

  • previous_budget_allocation – Starting budget allocation based on original budget allocation proportion.

  • channel_names – Names of media channels to be added to plot.

  • figure_size – size of the plot.

Returns

Barplots of budget allocation across media channels pre & post optimization.

lightweight_mmm.plot.plot_media_baseline_contribution_area_plot(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, channel_names: Optional[Sequence[Any]] = None, fig_size: Optional[Tuple[int, int]] = (20, 7), legend_outside: Optional[bool] = False) matplotlib.figure.Figure[source]

Plots an area chart to visualize weekly media & baseline contribution.

Parameters
  • media_mix_model – Media mix model.

  • target_scaler – Scaler used for scaling the target.

  • channel_names – Names of media channels.

  • fig_size – Size of the figure to plot as used by matplotlib.

  • legend_outside – Put the legend outside of the chart, center-right.

Returns

Stacked area chart of weekly baseline & media contribution.

lightweight_mmm.plot.create_media_baseline_contribution_df(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, channel_names: Optional[Sequence[str]] = None) pandas.core.frame.DataFrame[source]

Creates a dataframe for weekly media channels & basline contribution.

The output dataframe will be used to create a stacked area plot to visualize the contribution of each media channels & baseline.

Parameters
  • media_mix_model – Media mix model.

  • target_scaler – Scaler used for scaling the target.

  • channel_names – Names of media channels.

Returns

DataFrame of weekly channels & baseline contribution percentage & volume.

Return type

contribution_df

Models

transform_adstock(media_data, custom_priors)

Transforms the input data with the adstock function and exponent.

transform_hill_adstock(media_data, custom_priors)

Transforms the input data with the adstock and hill functions.

transform_carryover(media_data, custom_priors)

Transforms the input data with the carryover function and exponent.

media_mix_model(media_data, target_data, ...)

Media mix model.

lightweight_mmm.models.transform_adstock(media_data: jax.Array, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], normalise: bool = True) jax.Array[source]

Transforms the input data with the adstock function and exponent.

Parameters
  • media_data – Media data to be transformed. It is expected to have 2 dims for national models and 3 for geo models.

  • custom_priors – The custom priors we want the model to take instead of the default ones. The possible names of parameters for adstock and exponent are “lag_weight” and “exponent”.

  • normalise – Whether to normalise the output values.

Returns

The transformed media data.

lightweight_mmm.models.transform_hill_adstock(media_data: jax.Array, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], normalise: bool = True) jax.Array[source]

Transforms the input data with the adstock and hill functions.

Parameters
  • media_data – Media data to be transformed. It is expected to have 2 dims for national models and 3 for geo models.

  • custom_priors – The custom priors we want the model to take instead of the default ones. The possible names of parameters for hill_adstock and exponent are “lag_weight”, “half_max_effective_concentration” and “slope”.

  • normalise – Whether to normalise the output values.

Returns

The transformed media data.

lightweight_mmm.models.transform_carryover(media_data: jax.Array, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], number_lags: int = 13) jax.Array[source]

Transforms the input data with the carryover function and exponent.

Parameters
  • media_data – Media data to be transformed. It is expected to have 2 dims for national models and 3 for geo models.

  • custom_priors – The custom priors we want the model to take instead of the default ones. The possible names of parameters for carryover and exponent are “ad_effect_retention_rate_plate”, “peak_effect_delay_plate” and “exponent”.

  • number_lags – Number of lags for the carryover function.

Returns

The transformed media data.

lightweight_mmm.models.media_mix_model(media_data: jax.Array, target_data: jax.Array, media_prior: jax.Array, degrees_seasonality: int, frequency: int, transform_function: lightweight_mmm.models.TransformFunction, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], transform_kwargs: Optional[MutableMapping[str, Any]] = None, weekday_seasonality: bool = False, extra_features: Optional[jax.Array] = None) None[source]

Media mix model.

Parameters
  • media_data – Media data to be be used in the model.

  • target_data – Target data for the model.

  • media_prior – Cost prior for each of the media channels.

  • degrees_seasonality – Number of degrees of seasonality to use.

  • frequency – Frequency of the time span which was used to aggregate the data. Eg. if weekly data then frequency is 52.

  • transform_function

    Function to use to transform the media data in the model. Currently the following are supported: ‘transform_adstock’,

    ’transform_carryover’ and ‘transform_hill_adstock’.

  • custom_priors – The custom priors we want the model to take instead of the default ones. See our custom_priors documentation for details about the API and possible options.

  • transform_kwargs – Any extra keyword arguments to pass to the transform function. For example the adstock function can take a boolean to noramlise output or not.

  • weekday_seasonality – In case of daily data you can estimate a weekday (7) parameter.

  • extra_features – Extra features data to include in the model.

Media Transforms

calculate_seasonality(number_periods, ...[, ...])

Calculates cyclic variation seasonality using Fourier terms.

adstock(data[, lag_weight, normalise])

Calculates the adstock value of a given array.

hill(data, half_max_effective_concentration, ...)

Calculates the hill function for a given array of values.

carryover(data, ad_effect_retention_rate, ...)

Calculates media carryover.

apply_exponent_safe(data, exponent)

Applies an exponent to given data in a gradient safe way.

lightweight_mmm.media_transforms.calculate_seasonality(number_periods: int, degrees: int, gamma_seasonality: Union[int, float, jax.Array], frequency: int = 52) jax.Array[source]

Calculates cyclic variation seasonality using Fourier terms.

For detailed info check:

https://en.wikipedia.org/wiki/Seasonality#Modeling

Parameters
  • number_periods – Number of seasonal periods in the data. Eg. for 1 year of seasonal data it will be 52, for 3 years of the same kind 156.

  • degrees – Number of degrees to use. Must be greater or equal than 1.

  • gamma_seasonality – Factor to multiply to each degree calculation. Shape must be aligned with the number of degrees.

  • frequency – Frequency of the seasonality being computed. By default is 52 for weekly data (52 weeks in a year).

Returns

An array with the seasonality values.

lightweight_mmm.media_transforms.adstock(data: jax.Array, lag_weight: float = 0.9, normalise: bool = True) jax.Array[source]

Calculates the adstock value of a given array.

To learn more about advertising lag: https://en.wikipedia.org/wiki/Advertising_adstock

Parameters
  • data – Input array.

  • lag_weight – lag_weight effect of the adstock function. Default is 0.9.

  • normalise – Whether to normalise the output value. This normalization will divide the output values by (1 / (1 - lag_weight)).

Returns

The adstock output of the input array.

lightweight_mmm.media_transforms.hill(data: jax.Array, half_max_effective_concentration: jax.Array, slope: jax.Array) jax.Array[source]

Calculates the hill function for a given array of values.

Refer to the following link for detailed information on this equation:

https://en.wikipedia.org/wiki/Hill_equation_(biochemistry)

Parameters
  • data – Input data.

  • half_max_effective_concentration – ec50 value for the hill function.

  • slope – Slope of the hill function.

Returns

The hill values for the respective input data.

lightweight_mmm.media_transforms.carryover(data: jax.Array, ad_effect_retention_rate: jax.Array, peak_effect_delay: jax.Array, number_lags: int = 13) jax.Array[source]

Calculates media carryover.

More details about this function can be found in: https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46001.pdf

Parameters
  • data – Input data. It is expected that data has either 2 dimensions for national models and 3 for geo models.

  • ad_effect_retention_rate – Retention rate of the advertisement effect. Default is 0.5.

  • peak_effect_delay – Delay of the peak effect in the carryover function. Default is 1.

  • number_lags – Number of lags to include in the carryover calculation. Default is 13.

Returns

The carryover values for the given data with the given parameters.

lightweight_mmm.media_transforms.apply_exponent_safe(data: jax.Array, exponent: jax.Array) jax.Array[source]

Applies an exponent to given data in a gradient safe way.

More info on the double jnp.where can be found: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf

Parameters
  • data – Input data to use.

  • exponent – Exponent required for the operations.

Returns

The result of the exponent operation with the inputs provided.

Utils

save_model(media_mix_model, file_path)

Saves the given model in the given path.

load_model(file_path)

Loads a model given a string path.

simulate_dummy_data(data_size, ...[, geos, seed])

Simulates dummy data needed for media mix modelling.

get_halfnormal_mean_from_scale(scale)

Returns the mean of the half-normal distribition.

get_halfnormal_scale_from_mean(mean)

Returns the scale of the half-normal distribution.

get_beta_params_from_mu_sigma(mu, sigma[, ...])

Deterministically estimates (a, b) from (mu, sigma) of a beta variable.

distance_pior_posterior(p, q[, method, discrete])

Quantifies the distance between two distributions.

interpolate_outliers(x, outlier_idx)

Overwrites outliers in x with interpolated values.

dataframe_to_jax(dataframe, media_features, ...)

Converts pandas dataframe to right data format for media mix model.

lightweight_mmm.utils.save_model(media_mix_model: Any, file_path: str) None[source]

Saves the given model in the given path.

Parameters
  • media_mix_model – Model to save on disk.

  • file_path – File path where the model should be placed.

lightweight_mmm.utils.load_model(file_path: str) Any[source]

Loads a model given a string path.

Parameters

file_path – Path of the file containing the model.

Returns

The LightweightMMM object that was stored in the given path.

lightweight_mmm.utils.simulate_dummy_data(data_size: int, n_media_channels: int, n_extra_features: int, geos: int = 1, seed: int = 5) Tuple[jax.Array, jax.Array, jax.Array, jax.Array][source]

Simulates dummy data needed for media mix modelling.

This function’s goal is to be super simple and not have many parameters, although it does not generate a fully realistic dataset is only meant to be used for demos/tutorial purposes. Uses carryover for lagging but has no saturation and no trend.

The data simulated includes the media data, extra features, a target/KPI and costs.

Parameters
  • data_size – Number of rows to generate.

  • n_media_channels – Number of media channels to generate.

  • n_extra_features – Number of extra features to generate.

  • geos – Number of geos for geo level data (default = 1 for national).

  • seed – Random seed.

Returns

The simulated media, extra features, target and costs.

lightweight_mmm.utils.get_halfnormal_mean_from_scale(scale: float) float[source]

Returns the mean of the half-normal distribition.

lightweight_mmm.utils.get_halfnormal_scale_from_mean(mean: float) float[source]

Returns the scale of the half-normal distribution.

lightweight_mmm.utils.get_beta_params_from_mu_sigma(mu: float, sigma: float, bracket: Tuple[float, float] = (0.5, 100.0)) Tuple[float, float][source]

Deterministically estimates (a, b) from (mu, sigma) of a beta variable.

https://en.wikipedia.org/wiki/Beta_distribution

Parameters
  • mu – The sample mean of the beta distributed variable.

  • sigma – The sample standard deviation of the beta distributed variable.

  • bracket – Search bracket for b.

Returns

Tuple of the (a, b) parameters.

lightweight_mmm.utils.distance_pior_posterior(p: jax.Array, q: jax.Array, method: str = 'KS', discrete: bool = True) float[source]

Quantifies the distance between two distributions.

Note we do not use KL divergence because it’s not defined when a probability is 0.

https://en.wikipedia.org/wiki/Hellinger_distance

Parameters
  • p – Samples for distribution 1.

  • q – Samples for distribution 2.

  • method – We can have four methods: KS, Hellinger, JS and min.

  • discrete – Whether input data is discrete or continuous.

Returns

The distance metric (between 0 and 1).

lightweight_mmm.utils.interpolate_outliers(x: jax.Array, outlier_idx: jax.Array) jax.Array[source]

Overwrites outliers in x with interpolated values.

Parameters
  • x – The original univariate variable with outliers.

  • outlier_idx – Indices of the outliers in x.

Returns

A cleaned x with outliers overwritten.

lightweight_mmm.utils.dataframe_to_jax(dataframe: pandas.core.frame.DataFrame, media_features: List[str], extra_features: List[str], date_feature: str, target: str, geo_feature: Optional[str] = None, cost_features: Optional[List[str]] = None) Tuple[jax.Array, jax.Array, jax.Array, jax.Array][source]

Converts pandas dataframe to right data format for media mix model.

This function’s goal is to convert dataframe which is most familar with data scientists to jax arrays to help the users who are not familar with array to use the lightweight MMM library easier.

Parameters
  • dataframe – Dataframe with geo, KPI, media and non-media features.

  • media_features – List of media feature names.

  • extra_features – List of non media feature names.

  • date_feature – Date feature name.

  • target – Target variables name.

  • geo_feature – Geo feature name and it is optional if the data is at national level.

  • cost_features – List of media cost variables and it is optional if user use actual media cost as their media features in the model.

Returns

Media, extra features, target and costs arrays.

Raises
  • ValueError – If each geo has unequal number of weeks or there is only one

  • value in the geo feature.