"""ARIMA forecaster implementation."""
import logging
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import pandas as pd
from timesmith.core.base import BaseForecaster
from timesmith.core.tags import set_tags
from timesmith.results.forecast import Forecast
from timesmith.utils.ts_utils import detect_frequency, ensure_datetime_index
if TYPE_CHECKING:
from timesmith.typing import SeriesLike, TableLike
logger = logging.getLogger(__name__)
try:
from pmdarima import auto_arima
except ImportError:
auto_arima = None
logger.warning(
"pmdarima not installed. ARIMAForecaster will not work. "
"Install with: pip install pmdarima"
)
[docs]
class ARIMAForecaster(BaseForecaster):
"""ARIMA forecaster using auto_arima for automatic order selection.
Wraps pmdarima.auto_arima to provide a BaseForecaster interface.
"""
[docs]
def __init__(
self,
start_p: int = 0,
start_q: int = 0,
max_p: int = 5,
max_q: int = 5,
seasonal: bool = False,
stepwise: bool = True,
suppress_warnings: bool = True,
error_action: str = "ignore",
**kwargs,
):
"""Initialize ARIMA forecaster.
Args:
start_p: Starting value for p parameter.
start_q: Starting value for q parameter.
max_p: Maximum value for p parameter.
max_q: Maximum value for q parameter.
seasonal: Whether to include seasonal component.
stepwise: Whether to use stepwise selection.
suppress_warnings: Whether to suppress warnings.
error_action: Action on error ('ignore', 'warn', 'raise').
**kwargs: Additional arguments passed to auto_arima.
"""
if auto_arima is None:
raise ImportError(
"pmdarima is required for ARIMAForecaster. "
"Install with: pip install pmdarima"
)
super().__init__()
self.start_p = start_p
self.start_q = start_q
self.max_p = max_p
self.max_q = max_q
self.seasonal = seasonal
self.stepwise = stepwise
self.suppress_warnings = suppress_warnings
self.error_action = error_action
self.kwargs = kwargs
set_tags(
self,
scitype_input="SeriesLike",
scitype_output="SeriesLike",
handles_missing=False,
requires_sorted_index=True,
supports_panel=False,
requires_fh=True,
)
[docs]
def fit(
self,
y: Union["SeriesLike", Any],
X: Optional[Union["TableLike", Any]] = None,
**fit_params: Any,
) -> "ARIMAForecaster":
"""Fit ARIMA model to time series.
Args:
y: Target time series.
X: Optional exogenous data (not yet supported).
**fit_params: Additional fit parameters.
Returns:
Self for method chaining.
"""
if X is not None:
logger.warning(
"Exogenous variables (X) not yet supported for ARIMAForecaster"
)
if isinstance(y, pd.Series):
series = y
elif isinstance(y, pd.DataFrame) and y.shape[1] == 1:
series = y.iloc[:, 0]
else:
raise ValueError("y must be SeriesLike (Series or single-column DataFrame)")
series = ensure_datetime_index(series)
self.train_index_ = series.index
# Default auto_arima parameters
fit_kwargs = {
"start_p": self.start_p,
"start_q": self.start_q,
"max_p": self.max_p,
"max_q": self.max_q,
"seasonal": self.seasonal,
"stepwise": self.stepwise,
"suppress_warnings": self.suppress_warnings,
"error_action": self.error_action,
**self.kwargs,
**fit_params,
}
# Fit model (use values only, not index)
self.model_ = auto_arima(series.values, **fit_kwargs)
self.order_ = self.model_.order
self._is_fitted = True
return self
[docs]
def predict(
self,
fh: Union[int, list, Any],
X: Optional[Union["TableLike", Any]] = None,
**predict_params: Any,
) -> Forecast:
"""Generate forecast.
Args:
fh: Forecast horizon (integer or array).
X: Optional exogenous data (not yet supported).
**predict_params: Additional prediction parameters.
Returns:
Forecast object with predictions.
"""
self._check_is_fitted()
if X is not None:
logger.warning(
"Exogenous variables (X) not yet supported for ARIMAForecaster"
)
# Convert fh to integer
if isinstance(fh, (list, np.ndarray)):
n_periods = len(fh)
elif isinstance(fh, int):
n_periods = fh
else:
n_periods = int(fh)
# Generate forecast with confidence intervals
forecast, conf_int = self.model_.predict(
n_periods=n_periods,
return_conf_int=True,
alpha=0.05, # 95% confidence interval
)
# Create forecast index
last_date = pd.Timestamp(self.train_index_[-1])
freq = detect_frequency(pd.Series(index=self.train_index_))
# Create forecast index
if isinstance(freq, str):
next_date = last_date + pd.tseries.frequencies.to_offset(freq)
forecast_index = pd.date_range(
start=next_date, periods=n_periods, freq=freq
)
else:
# Fallback: estimate from spacing
if len(self.train_index_) > 1:
avg_delta = self.train_index_[-1] - self.train_index_[-2]
next_date = last_date + avg_delta
forecast_index = pd.date_range(
start=next_date, periods=n_periods, freq=avg_delta
)
else:
forecast_index = pd.date_range(
start=last_date, periods=n_periods + 1, freq="D"
)[1:]
y_pred = pd.Series(forecast, index=forecast_index)
# Create confidence intervals DataFrame
y_int = pd.DataFrame(
conf_int,
index=forecast_index,
columns=["lower", "upper"],
)
return Forecast(y_pred=y_pred, fh=fh, y_int=y_int)
[docs]
def predict_interval(
self,
fh: Any,
X: Optional[Any] = None,
coverage: float = 0.9,
**predict_params: Any,
) -> Forecast:
"""Generate forecast with prediction intervals.
Args:
fh: Forecast horizon.
X: Optional exogenous data.
coverage: Coverage level (e.g., 0.9 for 90%).
**predict_params: Additional prediction parameters.
Returns:
Forecast with intervals.
"""
# ARIMA already returns intervals, just adjust alpha
alpha = 1.0 - coverage
predict_params["alpha"] = alpha
return self.predict(fh, X, **predict_params)
[docs]
def get_order(self) -> tuple:
"""Get ARIMA order (p, d, q).
Returns:
Tuple of (p, d, q) order.
"""
self._check_is_fitted()
return self.order_