"""Causal inference tools for time series.
Transfer entropy and related causal measures.
"""
import logging
from typing import Any, Optional
import numpy as np
import pandas as pd
from timesmith.core.base import BaseDetector
from timesmith.core.tags import set_tags
logger = logging.getLogger(__name__)
# Try to import numba for JIT compilation (optional)
try:
from numba import njit
HAS_NUMBA = True
except ImportError:
HAS_NUMBA = False
def njit(*args, **kwargs):
def decorator(func):
return func
if args and callable(args[0]):
return args[0]
return decorator
@njit(cache=True)
def _count_joint_2d(
indices1: np.ndarray, indices2: np.ndarray, bins: int
) -> np.ndarray:
"""Count joint 2D histogram using Numba JIT (fast path).
Args:
indices1: First dimension indices (clipped to [0, bins-1]).
indices2: Second dimension indices (clipped to [0, bins-1]).
bins: Number of bins.
Returns:
Joint count matrix (bins, bins).
"""
counts = np.zeros((bins, bins), dtype=np.int32)
n = len(indices1)
for i in range(n):
idx1 = int(indices1[i])
idx2 = int(indices2[i])
if 0 <= idx1 < bins and 0 <= idx2 < bins:
counts[idx1, idx2] += 1
return counts
@njit(cache=True)
def _count_joint_3d(
indices1: np.ndarray, indices2: np.ndarray, indices3: np.ndarray, bins: int
) -> np.ndarray:
"""Count joint 3D histogram using Numba JIT (fast path).
Args:
indices1: First dimension indices (clipped to [0, bins-1]).
indices2: Second dimension indices (clipped to [0, bins-1]).
indices3: Third dimension indices (clipped to [0, bins-1]).
bins: Number of bins.
Returns:
Joint count tensor (bins, bins, bins).
"""
counts = np.zeros((bins, bins, bins), dtype=np.int32)
n = len(indices1)
for i in range(n):
idx1 = int(indices1[i])
idx2 = int(indices2[i])
idx3 = int(indices3[i])
if 0 <= idx1 < bins and 0 <= idx2 < bins and 0 <= idx3 < bins:
counts[idx1, idx2, idx3] += 1
return counts
@njit(cache=True)
def _count_joint_4d(
indices1: np.ndarray,
indices2: np.ndarray,
indices3: np.ndarray,
indices4: np.ndarray,
bins: int,
) -> np.ndarray:
"""Count joint 4D histogram using Numba JIT (fast path).
Args:
indices1: First dimension indices (clipped to [0, bins-1]).
indices2: Second dimension indices (clipped to [0, bins-1]).
indices3: Third dimension indices (clipped to [0, bins-1]).
indices4: Fourth dimension indices (clipped to [0, bins-1]).
bins: Number of bins.
Returns:
Joint count tensor (bins, bins, bins, bins).
"""
counts = np.zeros((bins, bins, bins, bins), dtype=np.int32)
n = len(indices1)
for i in range(n):
idx1 = int(indices1[i])
idx2 = int(indices2[i])
idx3 = int(indices3[i])
idx4 = int(indices4[i])
if (
0 <= idx1 < bins
and 0 <= idx2 < bins
and 0 <= idx3 < bins
and 0 <= idx4 < bins
):
counts[idx1, idx2, idx3, idx4] += 1
return counts
@njit(cache=True, fastmath=True)
def _compute_entropy_numba(probs: np.ndarray) -> float:
"""Compute Shannon entropy from probability distribution (JIT-compiled).
Args:
probs: Probability array (flattened).
Returns:
Entropy in bits.
"""
entropy = 0.0
for i in range(len(probs)):
if probs[i] > 0:
entropy -= probs[i] * np.log2(probs[i])
return entropy
[docs]
def conditional_transfer_entropy(
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
lag: int = 1,
bins: int = 10,
) -> float:
"""Compute conditional transfer entropy from X to Y given Z.
Conditional transfer entropy accounts for confounding variables Z,
measuring the direct causal influence from X to Y.
Args:
x: Source time series.
y: Target time series.
z: Conditioning time series (confounding variable).
lag: Time lag for past values.
bins: Number of bins for discretization.
Returns:
Conditional transfer entropy from X to Y given Z (non-negative, bits).
"""
if not (len(x) == len(y) == len(z)):
raise ValueError("All series must have same length")
if lag < 1:
raise ValueError(f"lag must be >= 1, got {lag}")
if len(x) < lag + 1:
return 0.0
# Discretize series
x_min, x_max = np.nanmin(x), np.nanmax(x)
y_min, y_max = np.nanmin(y), np.nanmax(y)
z_min, z_max = np.nanmin(z), np.nanmax(z)
if x_min == x_max or y_min == y_max or z_min == z_max:
return 0.0
x_edges = np.linspace(x_min, x_max, bins + 1)
y_edges = np.linspace(y_min, y_max, bins + 1)
z_edges = np.linspace(z_min, z_max, bins + 1)
x_edges[-1] += 1e-10
y_edges[-1] += 1e-10
z_edges[-1] += 1e-10
x_disc = np.clip(np.digitize(x, x_edges) - 1, 0, bins - 1)
y_disc = np.clip(np.digitize(y, y_edges) - 1, 0, bins - 1)
z_disc = np.clip(np.digitize(z, z_edges) - 1, 0, bins - 1)
# Compute conditional entropies
y_t = y_disc[lag:]
y_past = y_disc[: len(y) - lag]
x_past = x_disc[: len(x) - lag]
z_past = z_disc[: len(z) - lag]
min_len = min(len(y_t), len(y_past), len(x_past), len(z_past))
y_t = y_t[:min_len]
y_past = y_past[:min_len]
x_past = x_past[:min_len]
z_past = z_past[:min_len]
# H(Y_t | Y_{t-lag}, Z_{t-lag}) - use JIT if available
# Clip indices to valid range
y_t_clipped = np.clip(y_t, 0, bins - 1)
y_past_clipped = np.clip(y_past, 0, bins - 1)
z_past_clipped = np.clip(z_past, 0, bins - 1)
# Use JIT-compiled counting if available
if HAS_NUMBA and len(y_t) > 100:
try:
joint_counts_yz = _count_joint_3d(
y_t_clipped, y_past_clipped, z_past_clipped, bins
)
marginal_counts_yz = _count_joint_2d(y_past_clipped, z_past_clipped, bins)
except Exception:
# Fallback to vectorized counting
joint_counts_yz = np.zeros((bins, bins, bins), dtype=np.int32)
np.add.at(joint_counts_yz, (y_t_clipped, y_past_clipped, z_past_clipped), 1)
marginal_counts_yz = np.zeros((bins, bins), dtype=np.int32)
np.add.at(marginal_counts_yz, (y_past_clipped, z_past_clipped), 1)
else:
# Use advanced indexing for counting (vectorized)
joint_counts_yz = np.zeros((bins, bins, bins), dtype=np.int32)
np.add.at(joint_counts_yz, (y_t_clipped, y_past_clipped, z_past_clipped), 1)
marginal_counts_yz = np.zeros((bins, bins), dtype=np.int32)
np.add.at(marginal_counts_yz, (y_past_clipped, z_past_clipped), 1)
total_yz = np.sum(joint_counts_yz)
if total_yz == 0:
return 0.0
joint_probs_yz = joint_counts_yz / total_yz
if HAS_NUMBA:
joint_entropy_yz = _compute_entropy_numba(joint_probs_yz.flatten())
else:
joint_entropy_yz = -np.sum(
joint_probs_yz[joint_probs_yz > 0]
* np.log2(joint_probs_yz[joint_probs_yz > 0])
)
marginal_probs_yz = marginal_counts_yz / (np.sum(marginal_counts_yz) + 1e-10)
if HAS_NUMBA:
marginal_entropy_yz = _compute_entropy_numba(marginal_probs_yz.flatten())
else:
marginal_entropy_yz = -np.sum(
marginal_probs_yz[marginal_probs_yz > 0]
* np.log2(marginal_probs_yz[marginal_probs_yz > 0])
)
h_y_given_yz = joint_entropy_yz - marginal_entropy_yz
# H(Y_t | Y_{t-lag}, X_{t-lag}, Z_{t-lag}) - use JIT if available
x_past_clipped = np.clip(x_past, 0, bins - 1)
# Use JIT-compiled counting if available
if HAS_NUMBA and len(y_t) > 100:
try:
joint_counts_4d = _count_joint_4d(
y_t_clipped, y_past_clipped, x_past_clipped, z_past_clipped, bins
)
marginal_counts_xyz = _count_joint_3d(
y_past_clipped, x_past_clipped, z_past_clipped, bins
)
except Exception:
# Fallback to vectorized counting
joint_counts_4d = np.zeros((bins, bins, bins, bins), dtype=np.int32)
np.add.at(
joint_counts_4d,
(y_t_clipped, y_past_clipped, x_past_clipped, z_past_clipped),
1,
)
marginal_counts_xyz = np.zeros((bins, bins, bins), dtype=np.int32)
np.add.at(
marginal_counts_xyz, (y_past_clipped, x_past_clipped, z_past_clipped), 1
)
else:
# Use advanced indexing for counting (vectorized)
joint_counts_4d = np.zeros((bins, bins, bins, bins), dtype=np.int32)
np.add.at(
joint_counts_4d,
(y_t_clipped, y_past_clipped, x_past_clipped, z_past_clipped),
1,
)
marginal_counts_xyz = np.zeros((bins, bins, bins), dtype=np.int32)
np.add.at(
marginal_counts_xyz, (y_past_clipped, x_past_clipped, z_past_clipped), 1
)
total_4d = np.sum(joint_counts_4d)
if total_4d == 0:
return 0.0
joint_probs_4d = joint_counts_4d / total_4d
if HAS_NUMBA:
joint_entropy_4d = _compute_entropy_numba(joint_probs_4d.flatten())
else:
joint_entropy_4d = -np.sum(
joint_probs_4d[joint_probs_4d > 0]
* np.log2(joint_probs_4d[joint_probs_4d > 0])
)
marginal_probs_xyz = marginal_counts_xyz / (np.sum(marginal_counts_xyz) + 1e-10)
if HAS_NUMBA:
marginal_entropy_xyz = _compute_entropy_numba(marginal_probs_xyz.flatten())
else:
marginal_entropy_xyz = -np.sum(
marginal_probs_xyz[marginal_probs_xyz > 0]
* np.log2(marginal_probs_xyz[marginal_probs_xyz > 0])
)
h_y_given_xyz = joint_entropy_4d - marginal_entropy_xyz
# Conditional transfer entropy
cte = h_y_given_yz - h_y_given_xyz
return float(max(0.0, cte))
[docs]
def transfer_entropy_network(
X: list,
lag: int = 1,
bins: int = 10,
threshold: Optional[float] = None,
series_names: Optional[list] = None,
):
"""Construct a directed network based on transfer entropy between time series.
Each edge (i, j) represents causal influence from series i to series j,
weighted by the transfer entropy value.
Args:
X: List of time series arrays to analyze.
lag: Time lag for transfer entropy computation.
bins: Number of bins for discretization.
threshold: Minimum transfer entropy threshold for edges (if None, include all edges).
series_names: Names for each series (default: "Series_0", "Series_1", ...).
Returns:
Tuple of (NetworkX DiGraph, transfer entropy matrix, statistics dictionary).
"""
import networkx as nx
n_series = len(X)
series_names = series_names or [f"Series_{i}" for i in range(n_series)]
if len(series_names) != n_series:
raise ValueError(
f"series_names length ({len(series_names)}) must match "
f"number of series ({n_series})"
)
te_matrix = np.zeros((n_series, n_series))
for i in range(n_series):
for j in range(n_series):
if i != j:
te_matrix[i, j] = transfer_entropy(X[i], X[j], lag=lag, bins=bins)
G = nx.DiGraph()
G.add_nodes_from(range(n_series))
for i in range(n_series):
for j in range(n_series):
if i != j:
te_val = te_matrix[i, j]
if threshold is None or te_val >= threshold:
G.add_edge(i, j, weight=te_val)
for i, name in enumerate(series_names):
G.nodes[i]["name"] = name
stats = {
"mean_te": float(np.mean(te_matrix[te_matrix > 0]))
if np.any(te_matrix > 0)
else 0.0,
"max_te": float(np.max(te_matrix)),
"min_te": float(np.min(te_matrix[te_matrix > 0]))
if np.any(te_matrix > 0)
else 0.0,
"std_te": float(np.std(te_matrix[te_matrix > 0]))
if np.any(te_matrix > 0)
else 0.0,
"n_edges": G.number_of_edges(),
"density": G.number_of_edges() / (n_series * (n_series - 1))
if n_series > 1
else 0.0,
}
return G, te_matrix, stats
[docs]
def transfer_entropy(
x: np.ndarray,
y: np.ndarray,
lag: int = 1,
bins: int = 10,
) -> float:
"""Compute transfer entropy from X to Y.
Transfer entropy measures the amount of information transferred from X to Y,
quantifying causal influence.
Args:
x: Source time series.
y: Target time series (must have same length as x).
lag: Time lag for past values.
bins: Number of bins for discretization.
Returns:
Transfer entropy from X to Y (non-negative, bits).
"""
if len(x) != len(y):
raise ValueError(f"Series must have same length: {len(x)} != {len(y)}")
if bins < 2:
raise ValueError(f"bins must be >= 2, got {bins}")
# Discretize series
x_min, x_max = np.nanmin(x), np.nanmax(x)
y_min, y_max = np.nanmin(y), np.nanmax(y)
if x_min == x_max or y_min == y_max:
return 0.0
x_edges = np.linspace(x_min, x_max, bins + 1)
y_edges = np.linspace(y_min, y_max, bins + 1)
x_edges[-1] += 1e-10
y_edges[-1] += 1e-10
x_disc = np.digitize(x, x_edges) - 1
y_disc = np.digitize(y, y_edges) - 1
# Clip to valid range
x_disc = np.clip(x_disc, 0, bins - 1)
y_disc = np.clip(y_disc, 0, bins - 1)
# Compute conditional entropies
# H(Y_t | Y_{t-lag})
n = len(y)
if n <= lag:
return 0.0
y_t = y_disc[lag:]
y_past = y_disc[: n - lag]
# H(Y_t | Y_{t-lag}) - use JIT if available
y_t_clipped = np.clip(y_t, 0, bins - 1)
y_past_clipped = np.clip(y_past, 0, bins - 1)
# Use JIT-compiled counting if available
if HAS_NUMBA and len(y_t) > 100:
try:
joint_counts = _count_joint_2d(y_t_clipped, y_past_clipped, bins)
except Exception:
# Fallback to vectorized counting
joint_counts = np.zeros((bins, bins), dtype=np.int32)
np.add.at(joint_counts, (y_t_clipped, y_past_clipped), 1)
else:
# Use advanced indexing for counting (vectorized)
joint_counts = np.zeros((bins, bins), dtype=np.int32)
np.add.at(joint_counts, (y_t_clipped, y_past_clipped), 1)
total = np.sum(joint_counts)
if total == 0:
return 0.0
joint_probs = joint_counts / total
if HAS_NUMBA:
joint_entropy = _compute_entropy_numba(joint_probs.flatten())
else:
joint_entropy = -np.sum(
joint_probs[joint_probs > 0] * np.log2(joint_probs[joint_probs > 0])
)
# Marginal counts (vectorized using bincount)
y_past_counts = np.bincount(y_past_clipped, minlength=bins)
y_past_probs = y_past_counts / (np.sum(y_past_counts) + 1e-10)
if HAS_NUMBA:
y_past_entropy = _compute_entropy_numba(y_past_probs)
else:
y_past_entropy = -np.sum(
y_past_probs[y_past_probs > 0] * np.log2(y_past_probs[y_past_probs > 0])
)
h_y_given_y_past = joint_entropy - y_past_entropy
# H(Y_t | Y_{t-lag}, X_{t-lag}) - use JIT if available
x_past = x_disc[: n - lag]
x_past_clipped = np.clip(x_past, 0, bins - 1)
# Use JIT-compiled counting if available
if HAS_NUMBA and len(y_t) > 100:
try:
joint_counts_3d = _count_joint_3d(
y_t_clipped, y_past_clipped, x_past_clipped, bins
)
marginal_counts = _count_joint_2d(y_past_clipped, x_past_clipped, bins)
except Exception:
# Fallback to vectorized counting
joint_counts_3d = np.zeros((bins, bins, bins), dtype=np.int32)
np.add.at(joint_counts_3d, (y_t_clipped, y_past_clipped, x_past_clipped), 1)
marginal_counts = np.zeros((bins, bins), dtype=np.int32)
np.add.at(marginal_counts, (y_past_clipped, x_past_clipped), 1)
else:
# Use advanced indexing for counting (vectorized)
joint_counts_3d = np.zeros((bins, bins, bins), dtype=np.int32)
np.add.at(joint_counts_3d, (y_t_clipped, y_past_clipped, x_past_clipped), 1)
marginal_counts = np.zeros((bins, bins), dtype=np.int32)
np.add.at(marginal_counts, (y_past_clipped, x_past_clipped), 1)
total_3d = np.sum(joint_counts_3d)
if total_3d == 0:
return 0.0
joint_probs_3d = joint_counts_3d / total_3d
if HAS_NUMBA:
joint_entropy_3d = _compute_entropy_numba(joint_probs_3d.flatten())
else:
joint_entropy_3d = -np.sum(
joint_probs_3d[joint_probs_3d > 0]
* np.log2(joint_probs_3d[joint_probs_3d > 0])
)
marginal_probs = marginal_counts / (np.sum(marginal_counts) + 1e-10)
if HAS_NUMBA:
marginal_entropy = _compute_entropy_numba(marginal_probs.flatten())
else:
marginal_entropy = -np.sum(
marginal_probs[marginal_probs > 0]
* np.log2(marginal_probs[marginal_probs > 0])
)
h_y_given_y_past_x_past = joint_entropy_3d - marginal_entropy
# Transfer entropy
te = h_y_given_y_past - h_y_given_y_past_x_past
return float(max(0.0, te)) # Ensure non-negative
[docs]
class TransferEntropyDetector(BaseDetector):
"""Detector using transfer entropy for causal inference.
Uses transfer entropy to detect causal relationships and anomalies.
"""
[docs]
def __init__(self, lag: int = 1, bins: int = 10, threshold: Optional[float] = None):
"""Initialize transfer entropy detector.
Args:
lag: Time lag for past values.
bins: Number of bins for discretization.
threshold: Optional threshold for binary classification.
"""
super().__init__()
self.lag = lag
self.bins = bins
self.threshold = threshold
set_tags(
self,
scitype_input="SeriesLike",
scitype_output="SeriesLike",
handles_missing=False,
requires_sorted_index=True,
)
[docs]
def fit(
self, y: Any, X: Optional[Any] = None, **fit_params: Any
) -> "TransferEntropyDetector":
"""Fit the detector (computes transfer entropy if X provided).
Args:
y: Target time series.
X: Optional source time series for causal inference.
**fit_params: Additional fit parameters.
Returns:
Self for method chaining.
"""
if isinstance(y, pd.Series):
self.y_ = y.values
elif isinstance(y, pd.DataFrame) and y.shape[1] == 1:
self.y_ = y.iloc[:, 0].values
else:
self.y_ = np.asarray(y)
if X is not None:
if isinstance(X, pd.Series):
self.X_ = X.values
elif isinstance(X, pd.DataFrame) and X.shape[1] == 1:
self.X_ = X.iloc[:, 0].values
else:
self.X_ = np.asarray(X)
# Compute transfer entropy
self.te_ = transfer_entropy(self.X_, self.y_, lag=self.lag, bins=self.bins)
else:
self.X_ = None
self.te_ = None
self._is_fitted = True
return self
[docs]
def score(self, y: Any, X: Optional[Any] = None) -> Any:
"""Compute transfer entropy scores.
Args:
y: Target time series.
X: Optional source time series.
Returns:
Transfer entropy score(s).
"""
self._check_is_fitted()
if X is None:
if self.X_ is None:
raise ValueError("X must be provided for scoring")
X = self.X_
if isinstance(y, pd.Series):
y_vals = y.values
elif isinstance(y, pd.DataFrame) and y.shape[1] == 1:
y_vals = y.iloc[:, 0].values
else:
y_vals = np.asarray(y)
if isinstance(X, pd.Series):
x_vals = X.values
elif isinstance(X, pd.DataFrame) and X.shape[1] == 1:
x_vals = X.iloc[:, 0].values
else:
x_vals = np.asarray(X)
te = transfer_entropy(x_vals, y_vals, lag=self.lag, bins=self.bins)
return te
[docs]
def predict(
self, y: Any, X: Optional[Any] = None, threshold: Optional[float] = None
) -> Any:
"""Predict causal relationships based on transfer entropy.
Args:
y: Target time series.
X: Optional source time series.
threshold: Optional threshold for binary classification.
Returns:
Boolean array indicating causal relationship (if threshold provided),
or transfer entropy value.
"""
self._check_is_fitted()
threshold = threshold or self.threshold
score = self.score(y, X)
if threshold is not None:
return score >= threshold
else:
return score