Source code for lightweight_mmm.media_transforms

# Copyright 2023 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Media transformations for accounting for lagging or media effects."""

import functools
from typing import Union

import jax
import jax.numpy as jnp


[docs]@functools.partial(jax.jit, static_argnums=[0, 1]) def calculate_seasonality( number_periods: int, degrees: int, gamma_seasonality: Union[int, float, jnp.ndarray], frequency: int = 52, ) -> jnp.ndarray: """Calculates cyclic variation seasonality using Fourier terms. For detailed info check: https://en.wikipedia.org/wiki/Seasonality#Modeling Args: number_periods: Number of seasonal periods in the data. Eg. for 1 year of seasonal data it will be 52, for 3 years of the same kind 156. degrees: Number of degrees to use. Must be greater or equal than 1. gamma_seasonality: Factor to multiply to each degree calculation. Shape must be aligned with the number of degrees. frequency: Frequency of the seasonality being computed. By default is 52 for weekly data (52 weeks in a year). Returns: An array with the seasonality values. """ seasonality_range = jnp.expand_dims(a=jnp.arange(number_periods), axis=-1) degrees_range = jnp.arange(1, degrees+1) inner_value = seasonality_range * 2 * jnp.pi * degrees_range / frequency season_matrix_sin = jnp.sin(inner_value) season_matrix_cos = jnp.cos(inner_value) season_matrix = jnp.concatenate([ jnp.expand_dims(a=season_matrix_sin, axis=-1), jnp.expand_dims(a=season_matrix_cos, axis=-1) ], axis=-1) return (season_matrix * gamma_seasonality).sum(axis=2).sum(axis=1)
[docs]@jax.jit def adstock(data: jnp.ndarray, lag_weight: float = .9, normalise: bool = True) -> jnp.ndarray: """Calculates the adstock value of a given array. To learn more about advertising lag: https://en.wikipedia.org/wiki/Advertising_adstock Args: data: Input array. lag_weight: lag_weight effect of the adstock function. Default is 0.9. normalise: Whether to normalise the output value. This normalization will divide the output values by (1 / (1 - lag_weight)). Returns: The adstock output of the input array. """ def adstock_internal(prev_adstock: jnp.ndarray, data: jnp.ndarray, lag_weight: float = lag_weight) -> jnp.ndarray: adstock_value = prev_adstock * lag_weight + data return adstock_value, adstock_value# jax-ndarray _, adstock_values = jax.lax.scan( f=adstock_internal, init=data[0, ...], xs=data[1:, ...]) adstock_values = jnp.concatenate([jnp.array([data[0, ...]]), adstock_values]) return jax.lax.cond( normalise, lambda adstock_values: adstock_values / (1. / (1 - lag_weight)), lambda adstock_values: adstock_values, operand=adstock_values)
[docs]@jax.jit def hill(data: jnp.ndarray, half_max_effective_concentration: jnp.ndarray, slope: jnp.ndarray) -> jnp.ndarray: """Calculates the hill function for a given array of values. Refer to the following link for detailed information on this equation: https://en.wikipedia.org/wiki/Hill_equation_(biochemistry) Args: data: Input data. half_max_effective_concentration: ec50 value for the hill function. slope: Slope of the hill function. Returns: The hill values for the respective input data. """ save_transform = apply_exponent_safe( data=data / half_max_effective_concentration, exponent=-slope) return jnp.where(save_transform == 0, 0, 1.0 / (1 + save_transform))
@functools.partial(jax.vmap, in_axes=(1, 1, None), out_axes=1) def _carryover_convolve(data: jnp.ndarray, weights: jnp.ndarray, number_lags: int) -> jnp.ndarray: """Applies the convolution between the data and the weights for the carryover. Args: data: Input data. weights: Window weights for the carryover. number_lags: Number of lags the window has. Returns: The result values from convolving the data and the weights with padding. """ window = jnp.concatenate([jnp.zeros(number_lags - 1), weights]) return jax.scipy.signal.convolve(data, window, mode="same") / weights.sum()
[docs]@functools.partial(jax.jit, static_argnames=("number_lags",)) def carryover(data: jnp.ndarray, ad_effect_retention_rate: jnp.ndarray, peak_effect_delay: jnp.ndarray, number_lags: int = 13) -> jnp.ndarray: """Calculates media carryover. More details about this function can be found in: https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46001.pdf Args: data: Input data. It is expected that data has either 2 dimensions for national models and 3 for geo models. ad_effect_retention_rate: Retention rate of the advertisement effect. Default is 0.5. peak_effect_delay: Delay of the peak effect in the carryover function. Default is 1. number_lags: Number of lags to include in the carryover calculation. Default is 13. Returns: The carryover values for the given data with the given parameters. """ lags_arange = jnp.expand_dims(jnp.arange(number_lags, dtype=jnp.float32), axis=-1) convolve_func = _carryover_convolve if data.ndim == 3: # Since _carryover_convolve is already vmaped in the decorator we only need # to vmap it once here to handle the geo level data. We keep the windows bi # dimensional also for three dims data and vmap over only the extra data # dimension. convolve_func = jax.vmap( fun=_carryover_convolve, in_axes=(2, None, None), out_axes=2) weights = ad_effect_retention_rate**((lags_arange - peak_effect_delay)**2) return convolve_func(data, weights, number_lags)
[docs]@jax.jit def apply_exponent_safe( data: jnp.ndarray, exponent: jnp.ndarray, ) -> jnp.ndarray: """Applies an exponent to given data in a gradient safe way. More info on the double jnp.where can be found: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf Args: data: Input data to use. exponent: Exponent required for the operations. Returns: The result of the exponent operation with the inputs provided. """ exponent_safe = jnp.where(data == 0, 1, data) ** exponent return jnp.where(data == 0, 0, exponent_safe)