Source code for timesmith.forecasters.synthetic_control

"""Synthetic control forecaster for causal inference and counterfactual analysis."""

import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import numpy as np
import pandas as pd

from timesmith.core.base import BaseForecaster
from timesmith.core.tags import set_tags
from timesmith.results.forecast import Forecast

if TYPE_CHECKING:
    from timesmith.typing import SeriesLike, TableLike

logger = logging.getLogger(__name__)

# Optional scipy for optimization
try:
    from scipy.optimize import minimize

    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False
    minimize = None
    logger.warning(
        "scipy not installed. SyntheticControlForecaster requires scipy. "
        "Install with: pip install scipy or pip install timesmith[scipy]"
    )


[docs] class SyntheticControlForecaster(BaseForecaster): """Synthetic control forecaster for counterfactual analysis. Creates a synthetic control unit as a weighted combination of control units to estimate what would have happened in the absence of treatment. """
[docs] def __init__( self, treatment_start: Optional[int] = None, pre_period_min: int = 5, ): """Initialize synthetic control forecaster. Args: treatment_start: Index where treatment begins (None = end of data). pre_period_min: Minimum number of pre-treatment periods required. """ super().__init__() self.treatment_start = treatment_start self.pre_period_min = pre_period_min self.weights_ = None self.control_indices_ = None set_tags( self, scitype_input="SeriesLike", scitype_output="ForecastLike", handles_missing=False, requires_sorted_index=True, )
[docs] def fit( self, y: Union["SeriesLike", Any], X: Optional[Union["TableLike", Any]] = None, **fit_params: Any, ) -> "SyntheticControlForecaster": """Fit synthetic control model. Args: y: Target time series (treated unit). X: Control units as DataFrame (each column is a control unit). **fit_params: Additional fit parameters. Returns: Self for method chaining. """ if X is None: raise ValueError( "X (control units) is required for synthetic control. " "Provide control units as DataFrame with each column as a control unit." ) if isinstance(y, pd.Series): self.y_ = y.values self.index_ = y.index elif isinstance(y, pd.DataFrame) and y.shape[1] == 1: self.y_ = y.iloc[:, 0].values self.index_ = y.index else: self.y_ = np.asarray(y, dtype=float) self.index_ = np.arange(len(self.y_)) if isinstance(X, pd.DataFrame): self.X_ = X.values self.control_names_ = X.columns.tolist() else: self.X_ = np.asarray(X) if self.X_.ndim == 1: self.X_ = self.X_.reshape(-1, 1) self.control_names_ = [f"control_{i}" for i in range(self.X_.shape[1])] if len(self.y_) != len(self.X_): raise ValueError( f"y and X must have same length. Got {len(self.y_)} and {len(self.X_)}" ) # Determine treatment start if self.treatment_start is None: self.treatment_start_ = len(self.y_) else: self.treatment_start_ = self.treatment_start if self.treatment_start_ < self.pre_period_min: raise ValueError( f"treatment_start ({self.treatment_start_}) must be >= " f"pre_period_min ({self.pre_period_min})" ) # Split pre/post treatment treated_pre = self.y_[: self.treatment_start_] control_pre = self.X_[: self.treatment_start_, :] # Find optimal weights self.weights_ = self._find_weights(treated_pre, control_pre) # Calculate pre-treatment fit quality synthetic_pre = control_pre @ self.weights_ self.pre_rmse_ = float(np.sqrt(np.mean((treated_pre - synthetic_pre) ** 2))) logger.info( f"Synthetic control fitted: {len(self.control_names_)} control units, " f"pre-treatment RMSE: {self.pre_rmse_:.6f}" ) self._is_fitted = True return self
def _find_weights( self, treated_pre: np.ndarray, control_pre: np.ndarray ) -> np.ndarray: """Find optimal weights for synthetic control. Args: treated_pre: Pre-treatment values of treated unit. control_pre: Pre-treatment values of control units (n_periods, n_controls). Returns: Optimal weights (n_controls,). """ n_controls = control_pre.shape[1] def objective(weights: np.ndarray) -> float: synthetic = control_pre @ weights return float(np.sum((treated_pre - synthetic) ** 2)) if not HAS_SCIPY or minimize is None: raise ImportError( "SyntheticControlForecaster requires scipy for optimization. " "Install with: pip install scipy or pip install timesmith[scipy]" ) # Constraints: weights sum to 1, each weight between 0 and 1 constraints = {"type": "eq", "fun": lambda w: np.sum(w) - 1} bounds = [(0, 1) for _ in range(n_controls)] initial = np.ones(n_controls) / n_controls result = minimize( objective, initial, method="SLSQP", bounds=bounds, constraints=constraints, options={"maxiter": 1000}, ) if not result.success: logger.warning( f"Weight optimization did not converge: {result.message}. " "Using uniform weights." ) return initial return result.x
[docs] def predict( self, fh: Union[int, list, Any], X: Optional[Union["TableLike", Any]] = None, **predict_params: Any, ) -> Forecast: """Generate counterfactual forecast (what would have happened without treatment). Args: fh: Forecast horizon (ignored, uses post-treatment period). X: Optional control units for post-treatment (uses fit data if None). **predict_params: Additional prediction parameters. Returns: Forecast results with counterfactual predictions. """ self._check_is_fitted() # Use post-treatment period treated_post = self.y_[self.treatment_start_ :] control_post = self.X_[self.treatment_start_ :, :] # Generate synthetic control (counterfactual) synthetic_post = control_post @ self.weights_ # Calculate treatment effect treatment_effect = treated_post - synthetic_post # Create forecast index if isinstance(self.index_, pd.DatetimeIndex): post_index = self.index_[self.treatment_start_ :] else: post_index = np.arange( self.treatment_start_, self.treatment_start_ + len(synthetic_post) ) y_pred_series = pd.Series(synthetic_post, index=post_index) return Forecast( y_pred=y_pred_series, fh=len(synthetic_post), metadata={ "method": "synthetic_control", "treatment_effect_mean": float(np.mean(treatment_effect)), "treatment_effect_std": float(np.std(treatment_effect)), "pre_rmse": self.pre_rmse_, "n_controls": len(self.control_names_), "top_controls": [ (name, float(weight)) for name, weight in zip(self.control_names_, self.weights_) if weight > 0.01 ][:5], }, )
[docs] def get_weights(self) -> Dict[str, float]: """Get synthetic control weights. Returns: Dictionary mapping control unit names to weights. """ self._check_is_fitted() return { name: float(weight) for name, weight in zip(self.control_names_, self.weights_) if weight > 0.01 # Only return significant contributors }