LightweightMMM
LightweightMMM object
|
Lightweight Media Mix Modelling wrapper for bayesian models. |
- class lightweight_mmm.lightweight_mmm.LightweightMMM(model_name: str = 'hill_adstock')[source]
Lightweight Media Mix Modelling wrapper for bayesian models.
- The currently available models are the following:
hill_adstock
adstock
carryover
It also offers the necessary utilities for calculating media contribution and media ROI based on models’ results.
- trace
Sampling trace of the bayesian model once fitted.
- Type
Dict[str, jax.Array]
- n_media_channels
Number of media channels the model was trained with.
- Type
int
- n_geos
Number of geos for geo models or 1 for national models.
- Type
int
- model_name
Name of the model.
- Type
str
- media
The media data the model is trained on. Usefull for a variety of insights post model fitting.
- Type
jax.Array
- media_names
Names of the media channels passed at fitting time.
- Type
Sequence[str]
- custom_priors
The set of custom priors the model was trained with. An empty dictionary if none were passed.
- Type
MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]]
- fit(media: jax.Array, media_prior: jax.Array, target: jax.Array, extra_features: typing.Optional[jax.Array] = None, degrees_seasonality: int = 2, seasonality_frequency: int = 52, weekday_seasonality: bool = False, media_names: typing.Optional[typing.Sequence[str]] = None, number_warmup: int = 1000, number_samples: int = 1000, number_chains: int = 2, target_accept_prob: float = 0.85, init_strategy: typing.Callable[[typing.Mapping[typing.Any, typing.Any], typing.Any], jax.Array] = <function init_to_median>, custom_priors: typing.Optional[typing.Dict[str, typing.Union[numpyro.distributions.distribution.Distribution, typing.Dict[str, float], typing.Sequence[float], float]]] = None, seed: typing.Optional[int] = None) None [source]
Fits MMM given the media data, extra features, costs and sales/KPI.
For detailed information on the selected model please refer to its respective function in the models.py file.
- Parameters
media – Media input data. Media data must have either 2 dims for national model or 3 for geo models.
media_prior – Costs of each media channel. The number of cost values must be equal to the number of media channels.
target – Target KPI to use, like for example sales.
extra_features – Other variables to add to the model.
degrees_seasonality – Number of degrees to use for seasonality. Default is 2.
seasonality_frequency – Frequency of the time period used. Default is 52 as in 52 weeks per year.
weekday_seasonality – In case of daily data, also estimate seven weekday parameters.
media_names – Names of the media channels passed.
number_warmup – Number of warm up samples. Default is 1000.
number_samples – Number of samples during sampling. Default is 1000.
number_chains – Number of chains to sample. Default is 2.
target_accept_prob – Target acceptance probability for step size in the NUTS sampler. Default is .85.
init_strategy – Initialization function for numpyro NUTS. The available options can be found in https://num.pyro.ai/en/stable/utilities.html#initialization-strategies. Default is numpyro.infer.init_to_median.
custom_priors – The custom priors we want the model to take instead of the default ones. Refer to the full documentation on custom priors for details.
seed – Seed to use for PRNGKey during training. For better replicability run all different trainings with the same seed.
- get_posterior_metrics(unscaled_costs: Optional[jax.Array] = None, cost_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None) Tuple[jax.Array, jax.Array] [source]
It estimates the media contribution percentage and ROI of each channel.
If data was scaled prior to training then the target and costs scalers need to be passed to this function to correctly calculate media contribution percentage and ROI in the unscaled space.
- Parameters
unscaled_costs – Optionally you can pass new costs to get these set of metrics. If None, the costs used for training will be used for calculating ROI.
cost_scaler – Scaler that was used to scale the cost data before training. It is ignored if ‘unscaled_costs’ is provided.
target_scaler – Scaler that was used to scale the target before training.
- Returns
The average media contribution percentage for each channel. roi_hat: The return on investment of each channel calculated as its contribution divided by the cost.
- Return type
media_contribution_hat_pct
- Raises
NotFittedModelError – When the this method is called without the model being trained previously.
- predict(media: jax.Array, extra_features: Optional[jax.Array] = None, media_gap: Optional[jax.Array] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, seed: Optional[int] = None) jax.Array [source]
Runs the model to obtain predictions for the given input data.
Predictions returned are distributions, if point estimates are desired one can calculate those based on the given distribution.
- Parameters
media – Media array for needed for the model to run predictions.
extra_features – Extra features for needed for the model to run.
media_gap – Media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 2 months after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, …) can take place correctly.
target_scaler – Scaler that was used to scale the target before training.
seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that utilises predictions with the same seed.
- Returns
Predictions for the given media and extra features at a given date index.
- Raises
NotFittedModelError – When the model has not been fitted before running predict.
- print_summary() None [source]
Calls print_summary function from numpyro to print parameters summary.
- reduce_trace(nsample: int = 100, seed: int = 0) None [source]
Reduces the samples in trace to speed up predict and optimize.
Please note this step is not reversible. Only do this after you have investigated convergence of the model.
- Parameters
nsample – Target number of samples.
seed – Random seed for down sampling.
- Raises
ValueError – if nsample is too big.
Preprocessing / Scaling
|
Class to scale your data based on multiplications and divisions. |
- class lightweight_mmm.preprocessing.CustomScaler(divide_operation: Optional[Callable[[jax.Array], jax.numpy.float32]] = None, divide_by: Optional[Union[float, int, jax.Array]] = 1, multiply_operation: Optional[Callable[[jax.Array], jax.numpy.float32]] = None, multiply_by: Optional[Union[float, int, jax.Array]] = 1.0)[source]
Class to scale your data based on multiplications and divisions.
This scaler can be used in two fashions for both the multiplication and division operation. - By specifying a value to use for the scaling operation. - By specifying an operation used at column level to calculate the value
for the actual scaling operation.
Eg. if one wants to scale the dataset by multiply by 100 you can directly pass multiply_by=100. Value can also be an array with as many values as column has the data being scaled. But if you want to multiply by the mean value of each column, then you can pass multiply_operation=jnp.mean (or any other operation desired).
Operation parameters have the upper hand in the cases where both values and operations are passed, values will be ignored in this case.
Scaler must be fit first in order to call the transform method.
- Attributes.
- divide_operation: Operation to apply over axis 0 of the fitting data to
obtain the value that will be used for division during scaling.
- divide_by: Numbers(s) by which to divide data in the scaling process. Since
the scaler is applied to axis 0 of the data, the shape of divide_by must be consistent with division into the data. For example, if data.shape = (100, 3, 5) then divide_by.shape can be (3, 5) or (5,) or a number. If divide_operation is given, this divide_by value will be ignored.
- multiply_operation: Operation to apply over axis 0 of the fitting data to
obtain the value that will be used for multiplication during scaling.
- multiply_by: Numbers(s) by which to multiply data in the scaling process.
Since the scaler is applied to axis 0 of the data, the shape of multiply_by must be consistent with multiplication into the data. For example, if data.shape = (100, 3, 5) then multiply_by.shape can be (3, 5) or (5,) or a number. If multiply_operation is given, this multiply_by value will be ignored.
- fit(data: jax.Array) None [source]
Figures out values for transformations based on the specified operations.
- Parameters
data – Input dataset to use for fitting.
- fit_transform(data: jax.Array) jax.Array [source]
Fits the values and applies transformation to the input data.
- Parameters
data – Input dataset.
- Returns
Transformed array.
Optimize Media
|
Finds the best media allocation based on MMM model, prices and a budget. |
- lightweight_mmm.optimize_media.find_optimal_budgets(n_time_periods: int, media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, budget: Union[float, int], prices: jax.Array, extra_features: Optional[jax.Array] = None, media_gap: Optional[jax.Array] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, media_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, bounds_lower_pct: Union[float, jax.Array] = 0.2, bounds_upper_pct: Union[float, jax.Array] = 0.2, max_iterations: int = 200, solver_func_tolerance: float = 1e-06, solver_step_size: float = 1.4901161193847656e-08, seed: Optional[int] = None) scipy.optimize._optimize.OptimizeResult [source]
Finds the best media allocation based on MMM model, prices and a budget.
- Parameters
n_time_periods – Number of time periods to optimize for. If model is built on weekly data, this would be the number of weeks ahead to optimize.
media_mix_model – Media mix model to use for the optimization.
budget – Total budget to allocate during the optimization time.
prices – An array with shape (n_media_channels,) for the cost of each media channel unit.
extra_features – Extra features needed for the model to predict.
media_gap – Media data gap between the end of training data and the start of the out of sample media given. Eg. if 100 weeks of data were used for training and prediction starts 8 weeks after training data finished we need to provide the 8 weeks missing between the training data and the prediction data so data transformations (adstock, carryover, …) can take place correctly.
target_scaler – Scaler that was used to scale the target before training.
media_scaler – Scaler that was used to scale the media data before training.
bounds_lower_pct – Relative percentage decrease from the mean value to consider as new lower bound.
bounds_upper_pct – Relative percentage increase from the mean value to consider as new upper bound.
max_iterations – Number of max iterations to use for the SLSQP scipy optimizer. Default is 200.
solver_func_tolerance – Precision goal for the value of the prediction in the stopping criterion. Maps directly to scipy’s ftol. Intended only for advanced users. For more details see: https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.
solver_step_size – Step size used for numerical approximation of the Jacobian. Maps directly to scipy’s eps. Intended only for advanced users. For more details see: https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.
seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that gets predictions with the same seed.
- Returns
OptimizeResult object containing the results of the optimization. kpi_without_optim: Predicted target based on original allocation proportion among channels from the historical data. starting_values: Budget Allocation based on original allocation proportion and the given total budget.
- Return type
solution
Plot
|
Plots the response curves of each media channel based on the model. |
|
Plots the cross correlation coefficients between 2 vectors. |
|
Plots a a chart between the coefficient of variation and cost. |
|
Plots the ground truth, predicted value and interval for the training data. |
|
Plots the ground truth, predicted value and interval for the test data. |
|
Plots the posterior distributions of estimated media channel effect. |
|
Plots prior and posterior distributions for parameters in media_mix_model. |
|
Plots a barchart of estimated media effects with their percentile interval. |
Plots a barcharts to compare pre & post budget allocation. |
|
Plots an area chart to visualize weekly media & baseline contribution. |
|
|
Creates a dataframe for weekly media channels & basline contribution. |
- lightweight_mmm.plot.plot_response_curves(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, media_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, prices: Optional[jax.Array] = None, optimal_allocation_per_timeunit: Optional[jax.Array] = None, steps: int = 50, percentage_add: float = 0.2, apply_log_scale: bool = False, figure_size: Tuple[int, int] = (8, 10), n_columns: int = 3, marker_size: int = 8, legend_fontsize: int = 8, seed: Optional[int] = None) matplotlib.figure.Figure [source]
Plots the response curves of each media channel based on the model.
It plots an individual subplot for each media channel. If ‘ optimal_allocation_per_timeunit is given it uses it to add markers based on historic average spend and the given optimal one on each of the individual subplots.
It then plots a combined plot with all the response curves which can be changed to log scale if apply_log_scale is True.
- Parameters
media_mix_model – Media mix model to use for plotting the response curves.
media_scaler – Scaler that was used to scale the media data before training.
target_scaler – Scaler used for scaling the target, to unscaled values and plot in the original scale.
prices – Prices to translate the media units to spend. If all your data is already in spend numbers you can leave this as None. If some of your data is media spend and others is media unit, leave the media spend with price 1 and add the price to the media unit channels.
optimal_allocation_per_timeunit – Optimal allocation per time unit per media channel. This can be obtained by running the optimization provided by LightweightMMM.
steps – Number of steps to simulate.
percentage_add – Percentage too exceed the maximum historic spend for the simulation of the response curve.
apply_log_scale – Whether to apply the log scale to the predictions (Y axis). When some media channels have very large scale compare to others it might be useful to use apply_log_scale=True. Default is False.
figure_size – Size of the plot figure.
n_columns – Number of columns to display in the subplots grid. Modifying this parameter might require to adjust figure_size accordingly for the plot to still have reasonable structure.
marker_size – Size of the marker for the optimization annotations. Only useful if optimal_allocation_per_timeunit is not None. Default is 8.
legend_fontsize – Legend font size for individual subplots.
seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that gets predictions with the same seed.
- Returns
Plots of response curves.
- lightweight_mmm.plot.plot_cross_correlate(feature: jax.Array, target: jax.Array, maxlags: int = 10) Tuple[int, float] [source]
Plots the cross correlation coefficients between 2 vectors.
In the chart look for positive peaks, this shows how the lags of the feature lead the target.
- Parameters
feature – Vector, the lags of which predict target.
target – Vector, what is predicted.
maxlags – Maximum number of lags.
- Returns
Lag index and corresponding correlation of the peak correlation.
- Raises
ValueError – If inputs don’t have same length.
- lightweight_mmm.plot.plot_var_cost(media: jax.Array, costs: jax.Array, names: List[str]) matplotlib.figure.Figure [source]
Plots a a chart between the coefficient of variation and cost.
- Parameters
media – Media matrix.
costs – Cost vector.
names – List of variable names.
- Returns
Plot of coefficient of variation and cost.
- Raises
ValueError if inputs don't conform to same length. –
- lightweight_mmm.plot.plot_model_fit(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, interval_mid_range: float = 0.9, digits: int = 3) matplotlib.figure.Figure [source]
Plots the ground truth, predicted value and interval for the training data.
Model needs to be fit before calling this function to plot.
- Parameters
media_mix_model – Media mix model.
target_scaler – Scaler used for scaling the target, to unscaled values and plot in the original scale.
interval_mid_range – Mid range interval to take for plotting. Eg. .9 will use .05 and .95 as the lower and upper quantiles. Must be a float number. between 0 and 1.
digits – Number of decimals to display on metrics in the plot.
- Returns
Plot of model fit.
- lightweight_mmm.plot.plot_out_of_sample_model_fit(out_of_sample_predictions: jax.Array, out_of_sample_target: jax.Array, interval_mid_range: float = 0.9, digits: int = 3) matplotlib.figure.Figure [source]
Plots the ground truth, predicted value and interval for the test data.
- Parameters
out_of_sample_predictions – Predictions for the out-of-sample period, as derived from mmm.predict.
out_of_sample_target – Target for the out-of-sample period. Needs to be on the same scale as out_of_sample_predictions.
interval_mid_range – Mid range interval to take for plotting. Eg. .9 will use .05 and .95 as the lower and upper quantiles. Must be a float number. between 0 and 1.
digits – Number of decimals to display on metrics in the plot.
- Returns
Plot of model fit.
- lightweight_mmm.plot.plot_media_channel_posteriors(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, channel_names: Optional[Sequence[Any]] = None, quantiles: Sequence[float] = (0.05, 0.5, 0.95), fig_size: Optional[Tuple[int, int]] = None) matplotlib.figure.Figure [source]
Plots the posterior distributions of estimated media channel effect.
Model needs to be fit before calling this function to plot.
- Parameters
media_mix_model – Media mix model.
channel_names – Names of media channels to be added to plot.
quantiles – Quantiles to draw on the distribution.
fig_size – Size of the figure to plot as used by matplotlib. If not specified it will be determined dynamically based on the number of media channels and geos the model was trained on.
- Returns
Plot of posterior distributions.
- lightweight_mmm.plot.plot_prior_and_posterior(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, fig_size: Optional[Tuple[int, int]] = None, selected_features: Optional[List[str]] = None, number_of_samples_for_prior: int = 5000, kde_bandwidth_adjust_for_posterior: float = 1, seed: Optional[int] = None) matplotlib.figure.Figure [source]
Plots prior and posterior distributions for parameters in media_mix_model.
- Parameters
media_mix_model – Fitted media mix model.
fig_size – Size of the figure to plot as used by matplotlib. Default is a width of 8 and a height of 1.5 for each subplot.
selected_features – Optional list of feature names to select. If not specified (the default), all features are selected.
number_of_samples_for_prior – Controls the level of smoothing for the plotted version of the prior distribution. The default should be fine unless you want to decrease it to speed up runtime.
kde_bandwidth_adjust_for_posterior – Multiplicative factor to adjust the bandwidth of the kernel density estimator, to control the level of smoothing for the posterior distribution. Passed to seaborn.kdeplot as the bw_adjust parameter there.
seed – Seed to use for PRNGKey during sampling. For replicability run this function and any other function that utilises predictions with the same seed.
- Returns
Plot with Kernel density estimate smoothing showing prior and posterior distributions for every parameter in the given media_mix_model.
- Raises
NotFittedModelError – media_mix_model has not yet been fit.
ValueError – A feature has been created without a well-defined prior.
- lightweight_mmm.plot.plot_bars_media_metrics(metric: jax.Array, metric_name: str = 'metric', channel_names: Optional[Tuple[Any]] = None, interval_mid_range: float = 0.9) matplotlib.figure.Figure [source]
Plots a barchart of estimated media effects with their percentile interval.
The lower and upper percentile need to be between 0-1.
- Parameters
metric – Estimated media metric as returned by lightweight_mmm.get_posterior_metrics(). Can be either contribution percentage or ROI.
metric_name – Name of the media metric, e.g. contribution percentage or ROI.
channel_names – Names of media channels to be added to plot.
interval_mid_range – Mid range interval to take for plotting. Eg. .9 will use .05 and .95 as the lower and upper quantiles. Must be a float number.
- Returns
Barplot of estimated media effects with defined percentile-bars.
- lightweight_mmm.plot.plot_pre_post_budget_allocation_comparison(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, kpi_with_optim: jax.Array, kpi_without_optim: jax.Array, optimal_buget_allocation: jax.Array, previous_budget_allocation: jax.Array, channel_names: Optional[Sequence[Any]] = None, figure_size: Tuple[int, int] = (20, 10)) matplotlib.figure.Figure [source]
Plots a barcharts to compare pre & post budget allocation.
- Parameters
media_mix_model – Media mix model to use for the optimization.
kpi_with_optim – Negative predicted target variable with optimized budget allocation.
kpi_without_optim – negative predicted target variable with original budget allocation proportion base on the historical data.
optimal_buget_allocation – Optmized budget allocation.
previous_budget_allocation – Starting budget allocation based on original budget allocation proportion.
channel_names – Names of media channels to be added to plot.
figure_size – size of the plot.
- Returns
Barplots of budget allocation across media channels pre & post optimization.
- lightweight_mmm.plot.plot_media_baseline_contribution_area_plot(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, channel_names: Optional[Sequence[Any]] = None, fig_size: Optional[Tuple[int, int]] = (20, 7), legend_outside: Optional[bool] = False) matplotlib.figure.Figure [source]
Plots an area chart to visualize weekly media & baseline contribution.
- Parameters
media_mix_model – Media mix model.
target_scaler – Scaler used for scaling the target.
channel_names – Names of media channels.
fig_size – Size of the figure to plot as used by matplotlib.
legend_outside – Put the legend outside of the chart, center-right.
- Returns
Stacked area chart of weekly baseline & media contribution.
- lightweight_mmm.plot.create_media_baseline_contribution_df(media_mix_model: lightweight_mmm.lightweight_mmm.LightweightMMM, target_scaler: Optional[lightweight_mmm.preprocessing.CustomScaler] = None, channel_names: Optional[Sequence[str]] = None) pandas.core.frame.DataFrame [source]
Creates a dataframe for weekly media channels & basline contribution.
The output dataframe will be used to create a stacked area plot to visualize the contribution of each media channels & baseline.
- Parameters
media_mix_model – Media mix model.
target_scaler – Scaler used for scaling the target.
channel_names – Names of media channels.
- Returns
DataFrame of weekly channels & baseline contribution percentage & volume.
- Return type
contribution_df
Models
|
Transforms the input data with the adstock function and exponent. |
|
Transforms the input data with the adstock and hill functions. |
|
Transforms the input data with the carryover function and exponent. |
|
Media mix model. |
- lightweight_mmm.models.transform_adstock(media_data: jax.Array, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], normalise: bool = True) jax.Array [source]
Transforms the input data with the adstock function and exponent.
- Parameters
media_data – Media data to be transformed. It is expected to have 2 dims for national models and 3 for geo models.
custom_priors – The custom priors we want the model to take instead of the default ones. The possible names of parameters for adstock and exponent are “lag_weight” and “exponent”.
normalise – Whether to normalise the output values.
- Returns
The transformed media data.
- lightweight_mmm.models.transform_hill_adstock(media_data: jax.Array, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], normalise: bool = True) jax.Array [source]
Transforms the input data with the adstock and hill functions.
- Parameters
media_data – Media data to be transformed. It is expected to have 2 dims for national models and 3 for geo models.
custom_priors – The custom priors we want the model to take instead of the default ones. The possible names of parameters for hill_adstock and exponent are “lag_weight”, “half_max_effective_concentration” and “slope”.
normalise – Whether to normalise the output values.
- Returns
The transformed media data.
- lightweight_mmm.models.transform_carryover(media_data: jax.Array, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], number_lags: int = 13) jax.Array [source]
Transforms the input data with the carryover function and exponent.
- Parameters
media_data – Media data to be transformed. It is expected to have 2 dims for national models and 3 for geo models.
custom_priors – The custom priors we want the model to take instead of the default ones. The possible names of parameters for carryover and exponent are “ad_effect_retention_rate_plate”, “peak_effect_delay_plate” and “exponent”.
number_lags – Number of lags for the carryover function.
- Returns
The transformed media data.
- lightweight_mmm.models.media_mix_model(media_data: jax.Array, target_data: jax.Array, media_prior: jax.Array, degrees_seasonality: int, frequency: int, transform_function: lightweight_mmm.models.TransformFunction, custom_priors: MutableMapping[str, Union[numpyro.distributions.distribution.Distribution, Dict[str, float], Sequence[float], float]], transform_kwargs: Optional[MutableMapping[str, Any]] = None, weekday_seasonality: bool = False, extra_features: Optional[jax.Array] = None) None [source]
Media mix model.
- Parameters
media_data – Media data to be be used in the model.
target_data – Target data for the model.
media_prior – Cost prior for each of the media channels.
degrees_seasonality – Number of degrees of seasonality to use.
frequency – Frequency of the time span which was used to aggregate the data. Eg. if weekly data then frequency is 52.
transform_function –
Function to use to transform the media data in the model. Currently the following are supported: ‘transform_adstock’,
’transform_carryover’ and ‘transform_hill_adstock’.
custom_priors – The custom priors we want the model to take instead of the default ones. See our custom_priors documentation for details about the API and possible options.
transform_kwargs – Any extra keyword arguments to pass to the transform function. For example the adstock function can take a boolean to noramlise output or not.
weekday_seasonality – In case of daily data you can estimate a weekday (7) parameter.
extra_features – Extra features data to include in the model.
Media Transforms
|
Calculates cyclic variation seasonality using Fourier terms. |
|
Calculates the adstock value of a given array. |
|
Calculates the hill function for a given array of values. |
|
Calculates media carryover. |
|
Applies an exponent to given data in a gradient safe way. |
- lightweight_mmm.media_transforms.calculate_seasonality(number_periods: int, degrees: int, gamma_seasonality: Union[int, float, jax.Array], frequency: int = 52) jax.Array [source]
Calculates cyclic variation seasonality using Fourier terms.
- For detailed info check:
- Parameters
number_periods – Number of seasonal periods in the data. Eg. for 1 year of seasonal data it will be 52, for 3 years of the same kind 156.
degrees – Number of degrees to use. Must be greater or equal than 1.
gamma_seasonality – Factor to multiply to each degree calculation. Shape must be aligned with the number of degrees.
frequency – Frequency of the seasonality being computed. By default is 52 for weekly data (52 weeks in a year).
- Returns
An array with the seasonality values.
- lightweight_mmm.media_transforms.adstock(data: jax.Array, lag_weight: float = 0.9, normalise: bool = True) jax.Array [source]
Calculates the adstock value of a given array.
To learn more about advertising lag: https://en.wikipedia.org/wiki/Advertising_adstock
- Parameters
data – Input array.
lag_weight – lag_weight effect of the adstock function. Default is 0.9.
normalise – Whether to normalise the output value. This normalization will divide the output values by (1 / (1 - lag_weight)).
- Returns
The adstock output of the input array.
- lightweight_mmm.media_transforms.hill(data: jax.Array, half_max_effective_concentration: jax.Array, slope: jax.Array) jax.Array [source]
Calculates the hill function for a given array of values.
- Refer to the following link for detailed information on this equation:
- Parameters
data – Input data.
half_max_effective_concentration – ec50 value for the hill function.
slope – Slope of the hill function.
- Returns
The hill values for the respective input data.
- lightweight_mmm.media_transforms.carryover(data: jax.Array, ad_effect_retention_rate: jax.Array, peak_effect_delay: jax.Array, number_lags: int = 13) jax.Array [source]
Calculates media carryover.
More details about this function can be found in: https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46001.pdf
- Parameters
data – Input data. It is expected that data has either 2 dimensions for national models and 3 for geo models.
ad_effect_retention_rate – Retention rate of the advertisement effect. Default is 0.5.
peak_effect_delay – Delay of the peak effect in the carryover function. Default is 1.
number_lags – Number of lags to include in the carryover calculation. Default is 13.
- Returns
The carryover values for the given data with the given parameters.
- lightweight_mmm.media_transforms.apply_exponent_safe(data: jax.Array, exponent: jax.Array) jax.Array [source]
Applies an exponent to given data in a gradient safe way.
More info on the double jnp.where can be found: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
- Parameters
data – Input data to use.
exponent – Exponent required for the operations.
- Returns
The result of the exponent operation with the inputs provided.
Utils
|
Saves the given model in the given path. |
|
Loads a model given a string path. |
|
Simulates dummy data needed for media mix modelling. |
Returns the mean of the half-normal distribition. |
|
Returns the scale of the half-normal distribution. |
|
|
Deterministically estimates (a, b) from (mu, sigma) of a beta variable. |
|
Quantifies the distance between two distributions. |
|
Overwrites outliers in x with interpolated values. |
|
Converts pandas dataframe to right data format for media mix model. |
- lightweight_mmm.utils.save_model(media_mix_model: Any, file_path: str) None [source]
Saves the given model in the given path.
- Parameters
media_mix_model – Model to save on disk.
file_path – File path where the model should be placed.
- lightweight_mmm.utils.load_model(file_path: str) Any [source]
Loads a model given a string path.
- Parameters
file_path – Path of the file containing the model.
- Returns
The LightweightMMM object that was stored in the given path.
- lightweight_mmm.utils.simulate_dummy_data(data_size: int, n_media_channels: int, n_extra_features: int, geos: int = 1, seed: int = 5) Tuple[jax.Array, jax.Array, jax.Array, jax.Array] [source]
Simulates dummy data needed for media mix modelling.
This function’s goal is to be super simple and not have many parameters, although it does not generate a fully realistic dataset is only meant to be used for demos/tutorial purposes. Uses carryover for lagging but has no saturation and no trend.
The data simulated includes the media data, extra features, a target/KPI and costs.
- Parameters
data_size – Number of rows to generate.
n_media_channels – Number of media channels to generate.
n_extra_features – Number of extra features to generate.
geos – Number of geos for geo level data (default = 1 for national).
seed – Random seed.
- Returns
The simulated media, extra features, target and costs.
- lightweight_mmm.utils.get_halfnormal_mean_from_scale(scale: float) float [source]
Returns the mean of the half-normal distribition.
- lightweight_mmm.utils.get_halfnormal_scale_from_mean(mean: float) float [source]
Returns the scale of the half-normal distribution.
- lightweight_mmm.utils.get_beta_params_from_mu_sigma(mu: float, sigma: float, bracket: Tuple[float, float] = (0.5, 100.0)) Tuple[float, float] [source]
Deterministically estimates (a, b) from (mu, sigma) of a beta variable.
https://en.wikipedia.org/wiki/Beta_distribution
- Parameters
mu – The sample mean of the beta distributed variable.
sigma – The sample standard deviation of the beta distributed variable.
bracket – Search bracket for b.
- Returns
Tuple of the (a, b) parameters.
- lightweight_mmm.utils.distance_pior_posterior(p: jax.Array, q: jax.Array, method: str = 'KS', discrete: bool = True) float [source]
Quantifies the distance between two distributions.
Note we do not use KL divergence because it’s not defined when a probability is 0.
https://en.wikipedia.org/wiki/Hellinger_distance
- Parameters
p – Samples for distribution 1.
q – Samples for distribution 2.
method – We can have four methods: KS, Hellinger, JS and min.
discrete – Whether input data is discrete or continuous.
- Returns
The distance metric (between 0 and 1).
- lightweight_mmm.utils.interpolate_outliers(x: jax.Array, outlier_idx: jax.Array) jax.Array [source]
Overwrites outliers in x with interpolated values.
- Parameters
x – The original univariate variable with outliers.
outlier_idx – Indices of the outliers in x.
- Returns
A cleaned x with outliers overwritten.
- lightweight_mmm.utils.dataframe_to_jax(dataframe: pandas.core.frame.DataFrame, media_features: List[str], extra_features: List[str], date_feature: str, target: str, geo_feature: Optional[str] = None, cost_features: Optional[List[str]] = None) Tuple[jax.Array, jax.Array, jax.Array, jax.Array] [source]
Converts pandas dataframe to right data format for media mix model.
This function’s goal is to convert dataframe which is most familar with data scientists to jax arrays to help the users who are not familar with array to use the lightweight MMM library easier.
- Parameters
dataframe – Dataframe with geo, KPI, media and non-media features.
media_features – List of media feature names.
extra_features – List of non media feature names.
date_feature – Date feature name.
target – Target variables name.
geo_feature – Geo feature name and it is optional if the data is at national level.
cost_features – List of media cost variables and it is optional if user use actual media cost as their media features in the model.
- Returns
Media, extra features, target and costs arrays.
- Raises
ValueError – If each geo has unequal number of weeks or there is only one
value in the geo feature. –