"""Synthetic control forecaster for causal inference and counterfactual analysis."""
import logging
from typing import Any, Dict, Optional
import numpy as np
import pandas as pd
from scipy.optimize import minimize
from timesmith.core.base import BaseForecaster
from timesmith.core.tags import set_tags
from timesmith.results.forecast import Forecast
logger = logging.getLogger(__name__)
[docs]
class SyntheticControlForecaster(BaseForecaster):
"""Synthetic control forecaster for counterfactual analysis.
Creates a synthetic control unit as a weighted combination of control units
to estimate what would have happened in the absence of treatment.
"""
[docs]
def __init__(
self,
treatment_start: Optional[int] = None,
pre_period_min: int = 5,
):
"""Initialize synthetic control forecaster.
Args:
treatment_start: Index where treatment begins (None = end of data).
pre_period_min: Minimum number of pre-treatment periods required.
"""
super().__init__()
self.treatment_start = treatment_start
self.pre_period_min = pre_period_min
self.weights_ = None
self.control_indices_ = None
set_tags(
self,
scitype_input="SeriesLike",
scitype_output="ForecastLike",
handles_missing=False,
requires_sorted_index=True,
)
[docs]
def fit(
self, y: Any, X: Optional[Any] = None, **fit_params: Any
) -> "SyntheticControlForecaster":
"""Fit synthetic control model.
Args:
y: Target time series (treated unit).
X: Control units as DataFrame (each column is a control unit).
**fit_params: Additional fit parameters.
Returns:
Self for method chaining.
"""
if X is None:
raise ValueError(
"X (control units) is required for synthetic control. "
"Provide control units as DataFrame with each column as a control unit."
)
if isinstance(y, pd.Series):
self.y_ = y.values
self.index_ = y.index
elif isinstance(y, pd.DataFrame) and y.shape[1] == 1:
self.y_ = y.iloc[:, 0].values
self.index_ = y.index
else:
self.y_ = np.asarray(y, dtype=float)
self.index_ = np.arange(len(self.y_))
if isinstance(X, pd.DataFrame):
self.X_ = X.values
self.control_names_ = X.columns.tolist()
else:
self.X_ = np.asarray(X)
if self.X_.ndim == 1:
self.X_ = self.X_.reshape(-1, 1)
self.control_names_ = [f"control_{i}" for i in range(self.X_.shape[1])]
if len(self.y_) != len(self.X_):
raise ValueError(
f"y and X must have same length. Got {len(self.y_)} and {len(self.X_)}"
)
# Determine treatment start
if self.treatment_start is None:
self.treatment_start_ = len(self.y_)
else:
self.treatment_start_ = self.treatment_start
if self.treatment_start_ < self.pre_period_min:
raise ValueError(
f"treatment_start ({self.treatment_start_}) must be >= "
f"pre_period_min ({self.pre_period_min})"
)
# Split pre/post treatment
treated_pre = self.y_[: self.treatment_start_]
control_pre = self.X_[: self.treatment_start_, :]
# Find optimal weights
self.weights_ = self._find_weights(treated_pre, control_pre)
# Calculate pre-treatment fit quality
synthetic_pre = control_pre @ self.weights_
self.pre_rmse_ = float(np.sqrt(np.mean((treated_pre - synthetic_pre) ** 2)))
logger.info(
f"Synthetic control fitted: {len(self.control_names_)} control units, "
f"pre-treatment RMSE: {self.pre_rmse_:.6f}"
)
self._is_fitted = True
return self
def _find_weights(
self, treated_pre: np.ndarray, control_pre: np.ndarray
) -> np.ndarray:
"""Find optimal weights for synthetic control.
Args:
treated_pre: Pre-treatment values of treated unit.
control_pre: Pre-treatment values of control units (n_periods, n_controls).
Returns:
Optimal weights (n_controls,).
"""
n_controls = control_pre.shape[1]
def objective(weights: np.ndarray) -> float:
synthetic = control_pre @ weights
return float(np.sum((treated_pre - synthetic) ** 2))
# Constraints: weights sum to 1, each weight between 0 and 1
constraints = {"type": "eq", "fun": lambda w: np.sum(w) - 1}
bounds = [(0, 1) for _ in range(n_controls)]
initial = np.ones(n_controls) / n_controls
result = minimize(
objective,
initial,
method="SLSQP",
bounds=bounds,
constraints=constraints,
options={"maxiter": 1000},
)
if not result.success:
logger.warning(
f"Weight optimization did not converge: {result.message}. "
"Using uniform weights."
)
return initial
return result.x
[docs]
def predict(
self, fh: Any, X: Optional[Any] = None, **predict_params: Any
) -> Forecast:
"""Generate counterfactual forecast (what would have happened without treatment).
Args:
fh: Forecast horizon (ignored, uses post-treatment period).
X: Optional control units for post-treatment (uses fit data if None).
**predict_params: Additional prediction parameters.
Returns:
Forecast results with counterfactual predictions.
"""
self._check_is_fitted()
# Use post-treatment period
treated_post = self.y_[self.treatment_start_ :]
control_post = self.X_[self.treatment_start_ :, :]
# Generate synthetic control (counterfactual)
synthetic_post = control_post @ self.weights_
# Calculate treatment effect
treatment_effect = treated_post - synthetic_post
# Create forecast index
if isinstance(self.index_, pd.DatetimeIndex):
post_index = self.index_[self.treatment_start_ :]
else:
post_index = np.arange(
self.treatment_start_, self.treatment_start_ + len(synthetic_post)
)
y_pred_series = pd.Series(synthetic_post, index=post_index)
return Forecast(
y_pred=y_pred_series,
fh=len(synthetic_post),
metadata={
"method": "synthetic_control",
"treatment_effect_mean": float(np.mean(treatment_effect)),
"treatment_effect_std": float(np.std(treatment_effect)),
"pre_rmse": self.pre_rmse_,
"n_controls": len(self.control_names_),
"top_controls": [
(name, float(weight))
for name, weight in zip(self.control_names_, self.weights_)
if weight > 0.01
][:5],
},
)
[docs]
def get_weights(self) -> Dict[str, float]:
"""Get synthetic control weights.
Returns:
Dictionary mapping control unit names to weights.
"""
self._check_is_fitted()
return {
name: float(weight)
for name, weight in zip(self.control_names_, self.weights_)
if weight > 0.01 # Only return significant contributors
}