""" Marketing Mix Model Diffusion (MMM-Diffusion) ============================================== A generative diffusion model for Marketing Mix Modeling, adapted from NVIDIA's Kimodo dual-denoiser architecture (GMD/MDM family). Architecture mapping: Kimodo → MMM-Diffusion ------- --------------- Text prompts → Media spend, non-marketing vars, total sales Motion/position constraints → Sign constraints (β_media ≥ 0) + prior constraints Root denoiser (trajectory) → Campaign/Geo-level denoiser (aggregate patterns) Body denoiser (joint rotations)→ Channel-level denoiser (per-channel coefficients) Skeleton positions/rotations → Time-varying coefficients for sales decomposition References: - GMD (arxiv:2305.12577) — Two-stage trajectory + body diffusion - MDM (arxiv:2209.14916) — Transformer denoiser, x₀-prediction, geometric losses - PhysDiff (arxiv:2212.02500) — Projection during denoising for constraints - PDM (arxiv:2402.03559) — Projected diffusion for hard constraint satisfaction - NNN (arxiv:2504.06212) — Neural network MMM architecture from Google """ import math import json import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from scipy.signal import lfilter import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt # ============================================================================= # 1. SYNTHETIC MMM DATA GENERATOR # ============================================================================= class MMMDataGenerator: """ Generate synthetic Marketing Mix Model data with known ground truth. Model: Sales_t = β_base_t + Σ_m β_m_t * Hill(Adstock(spend_m,t; α_m); ec50_m, k_m) + Σ_c β_c * ctrl_c,t + ε_t Channels: 5 media channels + 3 non-marketing (control) variables Based on NNN paper (arxiv:2504.06212) simulation recipe and Meridian framework. """ MEDIA_CHANNELS = ['TV', 'Digital', 'Social', 'Print', 'Radio'] CONTROL_VARS = ['Seasonality', 'Trend', 'Competitor_Price'] def __init__(self, n_weeks=104, n_geos=1, seed=None): self.n_weeks = n_weeks self.n_geos = n_geos self.n_media = 5 self.n_ctrl = 3 self.rng = np.random.RandomState(seed) def _generate_media_spend(self): """Generate realistic media spend patterns with seasonality and campaigns.""" spend = np.zeros((self.n_weeks, self.n_media)) t = np.arange(self.n_weeks) # Base spend levels (different per channel) base_levels = self.rng.uniform(50, 500, size=self.n_media) for m in range(self.n_media): # Base pattern with weekly variation base = base_levels[m] * (1 + 0.2 * np.sin(2 * np.pi * t / 52)) # Random campaign bursts (3-8 per year) n_campaigns = self.rng.randint(3, 9) for _ in range(n_campaigns): start = self.rng.randint(0, self.n_weeks - 4) duration = self.rng.randint(1, 6) intensity = self.rng.uniform(1.5, 4.0) end = min(start + duration, self.n_weeks) base[start:end] *= intensity # Add noise spend[:, m] = np.maximum(base + self.rng.normal(0, base_levels[m] * 0.1, self.n_weeks), 0) return spend def _adstock(self, x, alpha): """Geometric adstock transformation. α ∈ [0,1] is retention rate.""" result = np.zeros_like(x) result[0] = x[0] for t in range(1, len(x)): result[t] = x[t] + alpha * result[t-1] return result def _hill(self, x, ec50, slope): """Hill saturation function. ec50>0, slope>0.""" x_safe = np.maximum(x, 0) return x_safe**slope / (x_safe**slope + ec50**slope + 1e-10) def _generate_controls(self): """Generate control (non-marketing) variables.""" t = np.arange(self.n_weeks) controls = np.zeros((self.n_weeks, self.n_ctrl)) # Seasonality: annual cycle with harmonics controls[:, 0] = (np.sin(2 * np.pi * t / 52) + 0.5 * np.sin(4 * np.pi * t / 52) + 0.3 * np.cos(2 * np.pi * t / 52)) # Trend: slow linear + mild quadratic trend = t / self.n_weeks controls[:, 1] = trend + 0.5 * trend**2 # Competitor price: random walk with mean reversion price = np.zeros(self.n_weeks) price[0] = 1.0 for i in range(1, self.n_weeks): price[i] = 0.95 * price[i-1] + 0.05 * 1.0 + self.rng.normal(0, 0.05) controls[:, 2] = price return controls def _sample_true_params(self): """Sample ground truth MMM parameters from realistic priors.""" params = {} # Media coefficients (MUST BE POSITIVE — this is the key constraint) # β_m ~ HalfNormal(0.5) — always positive params['beta_media'] = np.abs(self.rng.normal(0, 0.5, self.n_media)) + 0.05 # Adstock retention rates α ∈ [0, 1] # α ~ Beta(2, 2) — most mass in [0.2, 0.8] params['adstock_alpha'] = self.rng.beta(2, 2, self.n_media) params['adstock_alpha'] = np.clip(params['adstock_alpha'], 0.1, 0.95) # Hill EC50 (half-saturation) — must be positive # Sample relative to median spend params['hill_ec50'] = np.abs(self.rng.lognormal(0, 0.5, self.n_media)) + 0.1 # Hill slope k ∈ [0.5, 3] — controls steepness params['hill_slope'] = self.rng.uniform(0.5, 3.0, self.n_media) # Base sales (intercept) — positive params['beta_base'] = self.rng.uniform(500, 2000) # Control coefficients (can be positive or negative) params['beta_ctrl'] = self.rng.normal(0, 50, self.n_ctrl) # Noise level params['noise_std'] = self.rng.uniform(20, 100) return params def _make_time_varying(self, base_coeff, n_weeks, volatility=0.1): """ Make a coefficient time-varying via a random walk with mean reversion. β_t = β * exp(z_t) where z_t follows an OU process. """ z = np.zeros(n_weeks) for t in range(1, n_weeks): z[t] = 0.9 * z[t-1] + self.rng.normal(0, volatility) return base_coeff * np.exp(z) def generate_single(self): """ Generate a single MMM dataset with known ground truth. Returns: dict with keys: - media_spend: (T, 5) raw media spend - controls: (T, 3) control variables - total_sales: (T,) total sales - true_coefficients: (T, 8) time-varying coefficients [5 media + 3 ctrl] - true_contributions: (T, 8) sales contribution per variable - true_params: dict of ground truth parameters """ spend = self._generate_media_spend() controls = self._generate_controls() params = self._sample_true_params() # Apply adstock and Hill transformation to media transformed_media = np.zeros_like(spend) for m in range(self.n_media): adstocked = self._adstock(spend[:, m], params['adstock_alpha'][m]) # Normalize before Hill adstocked_norm = adstocked / (np.percentile(adstocked, 90) + 1e-10) transformed_media[:, m] = self._hill( adstocked_norm, params['hill_ec50'][m], params['hill_slope'][m] ) # Generate time-varying coefficients tv_coeffs = np.zeros((self.n_weeks, self.n_media + self.n_ctrl)) # Media coefficients — time-varying but ALWAYS POSITIVE for m in range(self.n_media): tv_coeffs[:, m] = self._make_time_varying( params['beta_media'][m], self.n_weeks, volatility=0.05 ) tv_coeffs[:, m] = np.maximum(tv_coeffs[:, m], 0.01) # enforce positivity # Control coefficients — mild time variation, can be negative for c in range(self.n_ctrl): tv_coeffs[:, self.n_media + c] = self._make_time_varying( params['beta_ctrl'][c], self.n_weeks, volatility=0.03 ) # Compute contributions contributions = np.zeros((self.n_weeks, self.n_media + self.n_ctrl)) for m in range(self.n_media): contributions[:, m] = tv_coeffs[:, m] * transformed_media[:, m] for c in range(self.n_ctrl): contributions[:, self.n_media + c] = tv_coeffs[:, self.n_media + c] * controls[:, c] # Total sales base = params['beta_base'] noise = self.rng.normal(0, params['noise_std'], self.n_weeks) total_sales = base + contributions.sum(axis=1) + noise total_sales = np.maximum(total_sales, 0) # sales can't be negative return { 'media_spend': spend, 'controls': controls, 'total_sales': total_sales, 'true_coefficients': tv_coeffs, 'true_contributions': contributions, 'base_sales': np.full(self.n_weeks, base), 'true_params': params } def generate_dataset(self, n_samples): """Generate n_samples MMM instances.""" samples = [] for i in range(n_samples): self.rng = np.random.RandomState(self.rng.randint(0, 2**31)) samples.append(self.generate_single()) return samples # ============================================================================= # 2. DATASET CLASS FOR TRAINING # ============================================================================= class MMMDiffusionDataset(Dataset): """ Wraps generated MMM data for diffusion training. Each sample provides: - conditioning: (T, n_media + n_ctrl + 1) — [media_spend, controls, total_sales] - target: (T, n_media + n_ctrl) — time-varying coefficients to denoise - media_mask: boolean mask for media channels (positivity constraint) """ def __init__(self, samples, normalize=True): self.samples = samples self.normalize = normalize self.n_media = 5 self.n_ctrl = 3 self.n_channels = self.n_media + self.n_ctrl # 8 total # Compute normalization statistics if normalize: all_cond = np.stack([ np.concatenate([s['media_spend'], s['controls'], s['total_sales'][:, None]], axis=1) for s in samples ]) all_coeff = np.stack([s['true_coefficients'] for s in samples]) self.cond_mean = all_cond.mean(axis=(0, 1)) self.cond_std = all_cond.std(axis=(0, 1)) + 1e-8 self.coeff_mean = all_coeff.mean(axis=(0, 1)) self.coeff_std = all_coeff.std(axis=(0, 1)) + 1e-8 # For media coefficients, use log-space normalization (ensures positivity) # For ctrl coefficients, use standard z-score media_coeffs = all_coeff[:, :, :self.n_media] self.media_log_mean = np.log(media_coeffs + 1e-8).mean(axis=(0, 1)) self.media_log_std = np.log(media_coeffs + 1e-8).std(axis=(0, 1)) + 1e-8 def __len__(self): return len(self.samples) def __getitem__(self, idx): s = self.samples[idx] # Conditioning: [media_spend | controls | total_sales] cond = np.concatenate([ s['media_spend'], s['controls'], s['total_sales'][:, None] ], axis=1).astype(np.float32) # (T, 9) # Target: time-varying coefficients coeffs = s['true_coefficients'].astype(np.float32) # (T, 8) if self.normalize: cond = (cond - self.cond_mean) / self.cond_std # Log-space for media (positive) coefficients log_media = np.log(coeffs[:, :self.n_media] + 1e-8) log_media = (log_media - self.media_log_mean) / self.media_log_std # Z-score for control coefficients ctrl = (coeffs[:, self.n_media:] - self.coeff_mean[self.n_media:]) / self.coeff_std[self.n_media:] coeffs = np.concatenate([log_media, ctrl], axis=1) return { 'conditioning': torch.tensor(cond, dtype=torch.float32), 'coefficients': torch.tensor(coeffs, dtype=torch.float32), } def decode_coefficients(self, coeffs_normalized): """ Inverse-transform normalized coefficients back to original scale. Applies exp() to media channels to enforce positivity. Args: coeffs_normalized: (batch, T, 8) normalized coefficients Returns: coeffs: (batch, T, 8) original-scale coefficients (media ≥ 0) """ if not self.normalize: return coeffs_normalized coeffs = coeffs_normalized.clone() # Media channels: inverse log-space → guaranteed positive via exp() media_log_mean = torch.tensor(self.media_log_mean, device=coeffs.device, dtype=coeffs.dtype) media_log_std = torch.tensor(self.media_log_std, device=coeffs.device, dtype=coeffs.dtype) coeffs[:, :, :self.n_media] = torch.exp( coeffs[:, :, :self.n_media] * media_log_std + media_log_mean ) # Control channels: inverse z-score coeff_mean = torch.tensor(self.coeff_mean[self.n_media:], device=coeffs.device, dtype=coeffs.dtype) coeff_std = torch.tensor(self.coeff_std[self.n_media:], device=coeffs.device, dtype=coeffs.dtype) coeffs[:, :, self.n_media:] = coeffs[:, :, self.n_media:] * coeff_std + coeff_mean return coeffs # ============================================================================= # 3. DIFFUSION NOISE SCHEDULE # ============================================================================= def cosine_beta_schedule(T, s=0.008): """Cosine noise schedule from 'Improved DDPM' (Nichol & Dhariwal, 2021).""" t = torch.arange(T + 1, dtype=torch.float64) f = torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2 alphas_cumprod = f / f[0] betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1] return torch.clamp(betas, 0, 0.999).float() class DiffusionSchedule: """DDPM diffusion schedule with cosine noise.""" def __init__(self, T=1000): self.T = T self.betas = cosine_beta_schedule(T) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) # Posterior variance for q(x_{t-1} | x_t, x_0) self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) self.posterior_variance = ( self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_log_variance_clipped = torch.log( torch.clamp(self.posterior_variance, min=1e-20) ) self.posterior_mean_coef1 = ( self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod) ) def to(self, device): """Move all tensors to device.""" for attr in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod', 'sqrt_recip_alphas', 'alphas_cumprod_prev', 'posterior_variance', 'posterior_log_variance_clipped', 'posterior_mean_coef1', 'posterior_mean_coef2']: setattr(self, attr, getattr(self, attr).to(device)) return self def q_sample(self, x_0, t, noise=None): """Forward diffusion: q(x_t | x_0).""" if noise is None: noise = torch.randn_like(x_0) sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise def posterior_mean(self, x_0_pred, x_t, t): """Compute posterior mean q(x_{t-1} | x_t, x_0_pred).""" coef1 = self.posterior_mean_coef1[t].view(-1, 1, 1) coef2 = self.posterior_mean_coef2[t].view(-1, 1, 1) return coef1 * x_0_pred + coef2 * x_t # ============================================================================= # 4. DENOISER NETWORKS (Kimodo-adapted dual denoiser) # ============================================================================= class SinusoidalPositionEmbeddings(nn.Module): """Sinusoidal embeddings for diffusion timestep t.""" def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): device = t.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = t[:, None].float() * emb[None, :] return torch.cat([emb.sin(), emb.cos()], dim=-1) class TemporalTransformerBlock(nn.Module): """Transformer block for temporal attention over time steps.""" def __init__(self, d_model, nhead, dropout=0.1): super().__init__() self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) self.ff = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 4, d_model), nn.Dropout(dropout), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x): # Self-attention over temporal dimension h = self.norm1(x) h = x + self.attn(h, h, h)[0] h = h + self.ff(self.norm2(h)) return h class CampaignDenoiser(nn.Module): """ Stage 1: Campaign/Geo-level Denoiser Analogous to GMD's trajectory DPM / Kimodo's root denoiser. Denoises aggregate-level patterns conditioned on non-marketing vars + total sales. Predicts x_0 directly (not noise ε), enabling constraint projection at each step. Input: x_t (B, T, n_agg) — noisy aggregate coefficients Cond: (B, T, cond_dim) — non-marketing vars + total sales Output: x_0_hat (B, T, n_agg) — predicted clean aggregate coefficients """ def __init__(self, n_agg_channels=3, cond_dim=4, d_model=256, nhead=4, n_layers=4, T_diff=1000): super().__init__() self.d_model = d_model # Timestep embedding self.time_embed = nn.Sequential( SinusoidalPositionEmbeddings(d_model), nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model), ) # Conditioning projection: non-marketing vars + total sales self.cond_proj = nn.Linear(cond_dim, d_model) # Input projection: noisy aggregate coefficients self.input_proj = nn.Linear(n_agg_channels, d_model) # Learnable temporal positional encoding self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02) # Transformer encoder for temporal attention self.blocks = nn.ModuleList([ TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers) ]) # Output projection self.output_proj = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, n_agg_channels), ) def forward(self, x_t, t, cond): B, T_seq, _ = x_t.shape # Embed diffusion timestep t_emb = self.time_embed(t) # (B, d_model) # Project inputs h_x = self.input_proj(x_t) # (B, T, d_model) h_c = self.cond_proj(cond) # (B, T, d_model) # Combine: input + conditioning + time h = h_x + h_c + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :] # Temporal transformer for block in self.blocks: h = block(h) return self.output_proj(h) # (B, T, n_agg) class ChannelDenoiser(nn.Module): """ Stage 2: Channel-level Denoiser Analogous to GMD's full-body DPM / Kimodo's body denoiser. Denoises per-channel time-varying coefficients, conditioned on: - Stage 1 output (aggregate patterns) - Media spend data - Total sales Predicts x_0 directly for constraint projection. Input: x_t (B, T, n_channels) — noisy channel coefficients Cond: campaign_ctx (B, T, n_agg) — from Stage 1 media_spend (B, T, n_media) — raw media spend total_sales (B, T, 1) — total sales Output: x_0_hat (B, T, n_channels) — predicted clean coefficients """ def __init__(self, n_channels=8, n_media=5, n_agg=3, d_model=384, nhead=8, n_layers=6, T_diff=1000): super().__init__() self.d_model = d_model self.n_media = n_media self.n_channels = n_channels # Timestep embedding self.time_embed = nn.Sequential( SinusoidalPositionEmbeddings(d_model), nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model), ) # Input projection self.input_proj = nn.Linear(n_channels, d_model) # Multi-source conditioning self.campaign_proj = nn.Linear(n_agg, d_model) # Stage 1 output self.spend_proj = nn.Linear(n_media, d_model) # Media spend self.sales_proj = nn.Linear(1, d_model) # Total sales # Cross-attention: channel features attend to conditioning self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) self.cross_norm = nn.LayerNorm(d_model) # Temporal position encoding self.pos_embed = nn.Parameter(torch.randn(1, 256, d_model) * 0.02) # Transformer blocks self.blocks = nn.ModuleList([ TemporalTransformerBlock(d_model, nhead) for _ in range(n_layers) ]) # Channel-specific output heads (allows per-channel specialization) self.output_proj = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Linear(d_model // 2, n_channels), ) def forward(self, x_t, t, campaign_ctx, media_spend, total_sales): B, T_seq, _ = x_t.shape # Embed timestep t_emb = self.time_embed(t) # (B, d_model) # Project all inputs h_x = self.input_proj(x_t) # (B, T, d_model) h_camp = self.campaign_proj(campaign_ctx) # (B, T, d_model) h_spend = self.spend_proj(media_spend) # (B, T, d_model) h_sales = self.sales_proj(total_sales) # (B, T, d_model) # Conditioning context: concatenate along sequence dim for cross-attention cond_ctx = h_camp + h_spend + h_sales # (B, T, d_model) — additive fusion # Add position + time embeddings h = h_x + t_emb.unsqueeze(1) + self.pos_embed[:, :T_seq, :] # Cross-attention: channel features attend to conditioning h_normed = self.cross_norm(h) h = h + self.cross_attn(h_normed, cond_ctx, cond_ctx)[0] # Self-attention transformer blocks for block in self.blocks: h = block(h) return self.output_proj(h) # (B, T, n_channels) # ============================================================================= # 5. MMM DIFFUSION MODEL (Full Pipeline) # ============================================================================= class MMMDiffusionModel(nn.Module): """ Full MMM-Diffusion model with dual denoiser architecture. Stage 1 (Campaign Denoiser): Denoises aggregate patterns from non-mktg + sales Stage 2 (Channel Denoiser): Denoises per-channel coefficients conditioned on Stage 1 Constraint enforcement: 1. Log-space reparametrization: media coefficients in log-space during training 2. PhysDiff-style projection: clamp during denoising every K steps 3. Soft loss penalties: L_sign, L_sales, L_smooth """ def __init__(self, n_media=5, n_ctrl=3, d_model_campaign=256, d_model_channel=384, n_layers_campaign=4, n_layers_channel=6, T_diff=1000): super().__init__() self.n_media = n_media self.n_ctrl = n_ctrl self.n_channels = n_media + n_ctrl self.T_diff = T_diff # Aggregate channels for Stage 1 (we use 3: total_media_effect, seasonality_trend, base_level) self.n_agg = 3 # Stage 1: Campaign/Geo Denoiser self.campaign_denoiser = CampaignDenoiser( n_agg_channels=self.n_agg, cond_dim=n_ctrl + 1, # non-marketing vars + total sales d_model=d_model_campaign, nhead=4, n_layers=n_layers_campaign, T_diff=T_diff, ) # Stage 2: Channel Denoiser self.channel_denoiser = ChannelDenoiser( n_channels=self.n_channels, n_media=n_media, n_agg=self.n_agg, d_model=d_model_channel, nhead=8, n_layers=n_layers_channel, T_diff=T_diff, ) # Projection from full coefficients to aggregate representation self.coeff_to_agg = nn.Linear(self.n_channels, self.n_agg) self.agg_to_coeff_init = nn.Linear(self.n_agg, self.n_channels) # Diffusion schedule self.schedule = DiffusionSchedule(T_diff) def compute_aggregate(self, coefficients): """Project full coefficients to aggregate representation for Stage 1.""" return self.coeff_to_agg(coefficients) def forward_train(self, batch): """ Training forward pass. Uses x_0-prediction (predicts clean data, not noise). Returns dict of losses. """ cond = batch['conditioning'] # (B, T, 9) coeffs = batch['coefficients'] # (B, T, 8) — normalized, media in log-space B, T_seq, _ = coeffs.shape device = coeffs.device # Separate conditioning components media_spend = cond[:, :, :self.n_media] # (B, T, 5) controls = cond[:, :, self.n_media:self.n_media + self.n_ctrl] # (B, T, 3) total_sales = cond[:, :, -1:] # (B, T, 1) stage1_cond = torch.cat([controls, total_sales], dim=-1) # (B, T, 4) # Compute aggregate targets for Stage 1 with torch.no_grad(): agg_target = self.coeff_to_agg(coeffs) # (B, T, 3) # ---- Stage 1: Campaign Denoiser ---- t1 = torch.randint(0, self.T_diff, (B,), device=device) noise1 = torch.randn_like(agg_target) agg_noisy = self.schedule.q_sample(agg_target, t1, noise1) agg_pred = self.campaign_denoiser(agg_noisy, t1, stage1_cond) # Stage 1 loss: x_0 prediction loss_campaign = F.mse_loss(agg_pred, agg_target) # ---- Stage 2: Channel Denoiser ---- t2 = torch.randint(0, self.T_diff, (B,), device=device) noise2 = torch.randn_like(coeffs) coeffs_noisy = self.schedule.q_sample(coeffs, t2, noise2) # Use ground truth aggregate as conditioning (teacher forcing during training) campaign_ctx = agg_target.detach() coeffs_pred = self.channel_denoiser( coeffs_noisy, t2, campaign_ctx, media_spend, total_sales ) # Stage 2 loss: x_0 prediction loss_channel = F.mse_loss(coeffs_pred, coeffs) # ---- Auxiliary losses (geometric losses from MDM) ---- # L_smooth: temporal smoothness of predicted coefficients (analog of velocity loss) delta_pred = coeffs_pred[:, 1:, :] - coeffs_pred[:, :-1, :] delta_true = coeffs[:, 1:, :] - coeffs[:, :-1, :] loss_smooth = F.mse_loss(delta_pred, delta_true) # L_sign: soft positivity penalty for media coefficients (in log-space they should be finite) # In log-space, very negative values → near-zero coefficients (OK but warn) # We add a mild penalty for extremely negative log-values media_pred_log = coeffs_pred[:, :, :self.n_media] loss_sign = F.relu(-media_pred_log - 5.0).mean() # penalize if log(β) < -5 # L_sales: reconstruction consistency — predicted decomposition should match sales # This is a soft constraint; exact matching comes from the conditioning loss_sales = 0.0 # Computed during inference validation # Total loss with weights loss = ( 1.0 * loss_campaign + 1.0 * loss_channel + 0.1 * loss_smooth + 0.01 * loss_sign ) return { 'loss': loss, 'loss_campaign': loss_campaign.item(), 'loss_channel': loss_channel.item(), 'loss_smooth': loss_smooth.item(), 'loss_sign': loss_sign.item(), } @torch.no_grad() def sample(self, conditioning, n_steps=None, constraint_every_k=10, guidance_scale=1.0): """ Generate time-varying coefficients via dual-denoiser reverse diffusion. Uses PhysDiff-style projection every K steps for constraint enforcement. Args: conditioning: (B, T, 9) — [media_spend, controls, total_sales] n_steps: number of denoising steps (None = full T) constraint_every_k: apply hard constraints every K steps guidance_scale: classifier-free guidance strength Returns: coefficients: (B, T, 8) — predicted time-varying coefficients (normalized) """ B, T_seq, _ = conditioning.shape device = conditioning.device # Separate conditioning media_spend = conditioning[:, :, :self.n_media] controls = conditioning[:, :, self.n_media:self.n_media + self.n_ctrl] total_sales = conditioning[:, :, -1:] stage1_cond = torch.cat([controls, total_sales], dim=-1) T_diff = n_steps or self.T_diff # ==== Stage 1: Denoise aggregate patterns ==== z_t = torch.randn(B, T_seq, self.n_agg, device=device) for t in reversed(range(T_diff)): t_batch = torch.full((B,), t, device=device, dtype=torch.long) # Predict x_0 z_0_pred = self.campaign_denoiser(z_t, t_batch, stage1_cond) if t > 0: # Posterior sampling mean = self.schedule.posterior_mean(z_0_pred, z_t, t_batch) var = self.schedule.posterior_variance[t] noise = torch.randn_like(z_t) z_t = mean + torch.sqrt(var) * noise else: z_t = z_0_pred campaign_ctx = z_t # (B, T, n_agg) — denoised aggregate # ==== Stage 2: Denoise channel coefficients ==== x_t = torch.randn(B, T_seq, self.n_channels, device=device) for t in reversed(range(T_diff)): t_batch = torch.full((B,), t, device=device, dtype=torch.long) # Predict x_0 (channel coefficients) x_0_pred = self.channel_denoiser( x_t, t_batch, campaign_ctx, media_spend, total_sales ) # PhysDiff-style constraint projection every K steps if t % constraint_every_k == 0: # Soft-clamp media coefficients in log-space # (very negative log = near zero, not necessarily bad, but prevent extreme) x_0_pred[:, :, :self.n_media] = torch.clamp( x_0_pred[:, :, :self.n_media], min=-8.0, max=8.0 ) if t > 0: mean = self.schedule.posterior_mean(x_0_pred, x_t, t_batch) var = self.schedule.posterior_variance[t] noise = torch.randn_like(x_t) x_t = mean + torch.sqrt(var) * noise else: x_t = x_0_pred return x_t # (B, T, n_channels) — normalized coefficients # ============================================================================= # 6. TRAINING LOOP # ============================================================================= def train_mmm_diffusion( model, dataset, n_epochs=50, batch_size=16, lr=1e-4, device='cpu', log_every=50, save_path='mmm_diffusion_model.pt' ): """Train the MMM diffusion model.""" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) model = model.to(device) model.schedule = model.schedule.to(device) history = {'loss': [], 'loss_campaign': [], 'loss_channel': [], 'loss_smooth': [], 'loss_sign': []} print(f"\nTraining MMM-Diffusion Model") print(f" Device: {device}") print(f" Samples: {len(dataset)}, Batch size: {batch_size}") print(f" Epochs: {n_epochs}, LR: {lr}") print(f" Model params: {sum(p.numel() for p in model.parameters()):,}") print(f" Diffusion steps: {model.T_diff}") print("-" * 60) step = 0 for epoch in range(n_epochs): model.train() epoch_losses = {k: [] for k in history} for batch in dataloader: batch = {k: v.to(device) for k, v in batch.items()} losses = model.forward_train(batch) optimizer.zero_grad() losses['loss'].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() for k in history: val = losses[k].item() if isinstance(losses[k], torch.Tensor) else losses[k] epoch_losses[k].append(val) step += 1 if step % log_every == 0: avg = {k: np.mean(v[-log_every:]) for k, v in epoch_losses.items() if v} print(f" Step {step:5d} | loss={avg['loss']:.4f} " f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} " f"smooth={avg['loss_smooth']:.4f} sign={avg['loss_sign']:.6f}") scheduler.step() # Epoch summary avg = {k: np.mean(v) for k, v in epoch_losses.items() if v} for k, v in avg.items(): history[k].append(v) print(f"Epoch {epoch+1:3d}/{n_epochs} | loss={avg['loss']:.4f} " f"camp={avg['loss_campaign']:.4f} chan={avg['loss_channel']:.4f} " f"smooth={avg['loss_smooth']:.4f} sign={avg['loss_sign']:.6f} " f"lr={scheduler.get_last_lr()[0]:.6f}") # Save torch.save({ 'model_state_dict': model.state_dict(), 'history': history, 'config': { 'n_media': model.n_media, 'n_ctrl': model.n_ctrl, 'T_diff': model.T_diff, } }, save_path) print(f"\nModel saved to {save_path}") return history # ============================================================================= # 7. SALES DECOMPOSITION FROM PREDICTED COEFFICIENTS # ============================================================================= def decompose_sales(coefficients, media_spend, controls, adstock_alphas=None): """ Given predicted time-varying coefficients, decompose sales into contributions. Args: coefficients: (T, 8) — decoded coefficients [5 media + 3 ctrl] media_spend: (T, 5) — raw media spend controls: (T, 3) — control variables adstock_alphas: optional (5,) — adstock retention rates Returns: contributions: dict with per-channel contributions """ T, n_total = coefficients.shape n_media = 5 contributions = {} total_media = np.zeros(T) for m in range(n_media): name = MMMDataGenerator.MEDIA_CHANNELS[m] spend = media_spend[:, m] # Apply default adstock if alphas provided if adstock_alphas is not None: adstocked = np.zeros(T) adstocked[0] = spend[0] for t in range(1, T): adstocked[t] = spend[t] + adstock_alphas[m] * adstocked[t-1] spend = adstocked # Contribution = coefficient × transformed spend # (simplified: using raw spend here; full model would apply Hill too) contrib = coefficients[:, m] * (spend / (np.percentile(spend, 90) + 1e-10)) contributions[name] = contrib total_media += contrib total_ctrl = np.zeros(T) for c in range(3): name = MMMDataGenerator.CONTROL_VARS[c] contrib = coefficients[:, n_media + c] * controls[:, c] contributions[name] = contrib total_ctrl += contrib contributions['Total_Media'] = total_media contributions['Total_Controls'] = total_ctrl contributions['Predicted_Sales'] = total_media + total_ctrl return contributions # ============================================================================= # 8. VISUALIZATION # ============================================================================= def plot_training_history(history, save_path='training_history.png'): """Plot training loss curves.""" fig, axes = plt.subplots(2, 2, figsize=(14, 10)) for ax, (key, values) in zip(axes.flatten(), history.items()): ax.plot(values, linewidth=1.5) ax.set_title(f'{key}', fontsize=12) ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.grid(True, alpha=0.3) ax.set_yscale('log' if min(values) > 0 else 'linear') plt.suptitle('MMM-Diffusion Training History', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"Training history plot saved to {save_path}") def plot_coefficient_comparison(true_coeffs, pred_coeffs, channel_names, save_path='coeff_comparison.png'): """Compare true vs predicted time-varying coefficients.""" n_channels = true_coeffs.shape[1] fig, axes = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels)) for i, (ax, name) in enumerate(zip(axes, channel_names)): ax.plot(true_coeffs[:, i], 'b-', label='Ground Truth', linewidth=1.5) ax.plot(pred_coeffs[:, i], 'r--', label='Predicted', linewidth=1.5, alpha=0.8) ax.set_title(f'{name} — Time-Varying Coefficient', fontsize=11) ax.legend(fontsize=9) ax.grid(True, alpha=0.3) if i < 5: # Media channels ax.axhline(y=0, color='gray', linestyle=':', alpha=0.5) ax.set_ylabel('β (≥0)') else: ax.set_ylabel('β') axes[-1].set_xlabel('Week') plt.suptitle('MMM-Diffusion: Coefficient Prediction Quality', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"Coefficient comparison plot saved to {save_path}") def plot_sales_decomposition(contributions, total_sales, save_path='sales_decomposition.png'): """Plot stacked area chart of sales decomposition.""" fig, axes = plt.subplots(2, 1, figsize=(14, 10)) # Top: Stacked area of contributions ax = axes[0] weeks = np.arange(len(total_sales)) media_names = MMMDataGenerator.MEDIA_CHANNELS colors = plt.cm.Set2(np.linspace(0, 1, len(media_names))) bottom = np.zeros(len(total_sales)) for name, color in zip(media_names, colors): vals = np.maximum(contributions[name], 0) ax.fill_between(weeks, bottom, bottom + vals, alpha=0.7, label=name, color=color) bottom += vals ax.plot(weeks, total_sales, 'k-', linewidth=2, label='Total Sales', alpha=0.8) ax.set_title('Sales Decomposition: Media Channel Contributions', fontsize=12) ax.legend(loc='upper left', fontsize=9) ax.set_xlabel('Week') ax.set_ylabel('Sales Contribution') ax.grid(True, alpha=0.3) # Bottom: Total predicted vs actual ax = axes[1] ax.plot(weeks, total_sales, 'b-', linewidth=2, label='Actual Sales') ax.plot(weeks, contributions['Predicted_Sales'], 'r--', linewidth=2, label='Predicted (Media + Controls)', alpha=0.8) ax.set_title('Total Sales: Actual vs Predicted Decomposition', fontsize=12) ax.legend(fontsize=10) ax.set_xlabel('Week') ax.set_ylabel('Sales') ax.grid(True, alpha=0.3) plt.suptitle('MMM-Diffusion: Sales Decomposition', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"Sales decomposition plot saved to {save_path}") # ============================================================================= # 9. MAIN — FULL POC PIPELINE # ============================================================================= def main(): import time device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"=" * 60) print(f"MMM-DIFFUSION: Marketing Mix Model via Diffusion") print(f"Adapted from Kimodo/GMD dual-denoiser architecture") print(f"Device: {device}") print(f"=" * 60) # ---- Step 1: Generate synthetic data ---- print("\n[1/5] Generating synthetic MMM data...") t0 = time.time() gen = MMMDataGenerator(n_weeks=104, seed=42) n_train = 500 # 500 training scenarios n_val = 50 # 50 validation scenarios train_samples = gen.generate_dataset(n_train) val_samples = gen.generate_dataset(n_val) print(f" Generated {n_train} train + {n_val} val scenarios") print(f" Each: {gen.n_weeks} weeks, {gen.n_media} media channels, {gen.n_ctrl} control vars") print(f" Time: {time.time()-t0:.1f}s") # Quick data audit sample = train_samples[0] print(f"\n Data audit (sample 0):") print(f" Media spend shape: {sample['media_spend'].shape}") print(f" Media spend range: [{sample['media_spend'].min():.1f}, {sample['media_spend'].max():.1f}]") print(f" Sales range: [{sample['total_sales'].min():.1f}, {sample['total_sales'].max():.1f}]") print(f" Media coeff range: [{sample['true_coefficients'][:,:5].min():.4f}, {sample['true_coefficients'][:,:5].max():.4f}]") print(f" All media coeffs positive: {(sample['true_coefficients'][:,:5] > 0).all()}") # ---- Step 2: Create datasets ---- print("\n[2/5] Creating training datasets...") train_dataset = MMMDiffusionDataset(train_samples, normalize=True) val_dataset = MMMDiffusionDataset(val_samples, normalize=True) item = train_dataset[0] print(f" Conditioning shape: {item['conditioning'].shape}") print(f" Coefficients shape: {item['coefficients'].shape}") # ---- Step 3: Build model ---- print("\n[3/5] Building MMM-Diffusion model...") # For PoC: use smaller model and fewer diffusion steps for faster training on CPU T_DIFF = 200 if device == 'cpu' else 500 model = MMMDiffusionModel( n_media=5, n_ctrl=3, d_model_campaign=128, # Smaller for PoC d_model_channel=192, # Smaller for PoC n_layers_campaign=3, n_layers_channel=4, T_diff=T_DIFF, ) total_params = sum(p.numel() for p in model.parameters()) print(f" Total parameters: {total_params:,}") print(f" Campaign denoiser: {sum(p.numel() for p in model.campaign_denoiser.parameters()):,}") print(f" Channel denoiser: {sum(p.numel() for p in model.channel_denoiser.parameters()):,}") print(f" Diffusion steps: {T_DIFF}") # ---- Step 4: Train ---- print("\n[4/5] Training...") N_EPOCHS = 30 if device == 'cpu' else 50 BATCH_SIZE = 8 if device == 'cpu' else 16 history = train_mmm_diffusion( model, train_dataset, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, lr=3e-4, device=device, log_every=25, save_path='/app/mmm_diffusion_model.pt', ) # Plot training history plot_training_history(history, save_path='/app/training_history.png') # ---- Step 5: Validate — Generate coefficients and decompose ---- print("\n[5/5] Validation: generating coefficients for held-out sample...") model.eval() model = model.to(device) # Take a validation sample val_item = val_dataset[0] cond = val_item['conditioning'].unsqueeze(0).to(device) # (1, T, 9) true_coeffs_norm = val_item['coefficients'].unsqueeze(0) # (1, T, 8) # Generate coefficients via reverse diffusion t0 = time.time() pred_coeffs_norm = model.sample( cond, n_steps=T_DIFF, constraint_every_k=10, ) gen_time = time.time() - t0 print(f" Generation time: {gen_time:.1f}s") # Decode to original scale pred_coeffs = val_dataset.decode_coefficients(pred_coeffs_norm.cpu()) true_coeffs = val_dataset.decode_coefficients(true_coeffs_norm) pred_np = pred_coeffs[0].numpy() true_np = true_coeffs[0].numpy() # Check constraints media_pred = pred_np[:, :5] print(f"\n Constraint check:") print(f" Media coefficients all positive: {(media_pred > 0).all()}") print(f" Media coeff min: {media_pred.min():.6f}") print(f" Media coeff max: {media_pred.max():.6f}") # Correlation between true and predicted coefficients print(f"\n Per-channel correlation (true vs predicted):") channel_names = MMMDataGenerator.MEDIA_CHANNELS + MMMDataGenerator.CONTROL_VARS for i, name in enumerate(channel_names): corr = np.corrcoef(true_np[:, i], pred_np[:, i])[0, 1] rmse = np.sqrt(np.mean((true_np[:, i] - pred_np[:, i])**2)) print(f" {name:20s}: corr={corr:.3f}, RMSE={rmse:.4f}") # Plot coefficient comparison plot_coefficient_comparison( true_np, pred_np, channel_names, save_path='/app/coeff_comparison.png' ) # Sales decomposition val_raw = val_samples[0] contributions = decompose_sales( pred_np, val_raw['media_spend'], val_raw['controls'] ) plot_sales_decomposition( contributions, val_raw['total_sales'], save_path='/app/sales_decomposition.png' ) # ---- Summary ---- print(f"\n{'='*60}") print(f"MMM-DIFFUSION POC COMPLETE") print(f"{'='*60}") print(f" Architecture: Dual-denoiser (Campaign + Channel) diffusion") print(f" Based on: Kimodo/GMD pattern with PhysDiff constraint projection") print(f" Data: {n_train} synthetic MMM scenarios, {gen.n_weeks} weeks each") print(f" Channels: {gen.n_media} media + {gen.n_ctrl} non-marketing") print(f" Constraints: Log-space media (guaranteed positive) + soft sign loss + PhysDiff projection") print(f" Model size: {total_params:,} parameters") print(f" Final training loss: {history['loss'][-1]:.4f}") print(f"\nOutputs:") print(f" Model checkpoint: /app/mmm_diffusion_model.pt") print(f" Training history: /app/training_history.png") print(f" Coefficient comparison: /app/coeff_comparison.png") print(f" Sales decomposition: /app/sales_decomposition.png") return model, history, train_dataset, val_dataset if __name__ == '__main__': model, history, train_dataset, val_dataset = main()