Source code for timesmith.forecasters.lstm

"""LSTM forecaster implementation using Darts (PyTorch RNN only; no Prophet)."""

import logging
import os
from typing import TYPE_CHECKING, Any, Optional, Union

import numpy as np
import pandas as pd

from timesmith.core.base import BaseForecaster
from timesmith.core.tags import set_tags
from timesmith.results.forecast import Forecast
from timesmith.utils.ts_utils import ensure_datetime_index

if TYPE_CHECKING:
    from timesmith.typing import SeriesLike, TableLike

logger = logging.getLogger(__name__)

# Lazy import: importing darts at module load can pull optional models and emit
# noisy logs for unrelated optional deps. We load only RNN + TimeSeries + Scaler.
DartsTimeSeries: Any = None
RNNModel: Any = None
Scaler: Any = None
HAS_DARTS: Optional[bool] = None


def _ensure_darts() -> bool:
    """Load darts on first use; return True if available."""
    global DartsTimeSeries, RNNModel, Scaler, HAS_DARTS
    if HAS_DARTS is not None:
        return HAS_DARTS
    try:
        # Avoid darts.models package __getattr__ noise for unused optional models.
        os.environ.setdefault("DISABLE_DARTS_LOGGING", "1")
        from darts import TimeSeries as _DartsTimeSeries
        from darts.dataprocessing.transformers import Scaler as _Scaler

        try:
            from darts.models.forecasting.rnn_model import RNNModel as _RNNModel
        except ImportError:
            from darts.models import RNNModel as _RNNModel

        DartsTimeSeries, Scaler, RNNModel = _DartsTimeSeries, _Scaler, _RNNModel
        HAS_DARTS = True
    except ImportError:
        HAS_DARTS = False
    return HAS_DARTS


[docs] class LSTMForecaster(BaseForecaster): """LSTM forecaster using Darts ``RNNModel`` (PyTorch only). TimeSmith does not provide Facebook Prophet or other Darts Prophet bindings; this class only uses Darts' recurrent neural network stack. """
[docs] def __init__( self, input_chunk_length: int = 12, output_chunk_length: int = 1, n_rnn_layers: int = 2, hidden_dim: int = 64, n_epochs: int = 100, random_state: Optional[int] = None, scale: bool = True, **darts_params: Any, ): """Initialize LSTM forecaster. Args: input_chunk_length: Number of time steps to use as input. output_chunk_length: Number of time steps to predict at once. n_rnn_layers: Number of RNN layers (default: 2). hidden_dim: Hidden dimension size (default: 64). n_epochs: Number of training epochs (default: 100). random_state: Random seed for reproducibility. scale: Whether to scale the data (default: True). **darts_params: Additional Darts RNNModel parameters. """ if not _ensure_darts(): raise ImportError( "darts is required for LSTMForecaster. Install with: pip install darts" ) super().__init__() self.input_chunk_length = input_chunk_length self.output_chunk_length = output_chunk_length self.n_rnn_layers = n_rnn_layers self.hidden_dim = hidden_dim self.n_epochs = n_epochs self.random_state = random_state self.scale = scale self.darts_params = darts_params set_tags( self, scitype_input="SeriesLike", scitype_output="ForecastLike", handles_missing=False, requires_sorted_index=True, supports_panel=False, requires_fh=True, )
[docs] def fit( self, y: Union["SeriesLike", Any], X: Optional[Union["TableLike", Any]] = None, **fit_params: Any, ) -> "LSTMForecaster": """Fit LSTM model. Args: y: Target time series. X: Optional exogenous data (not yet supported). **fit_params: Additional fit parameters. Returns: Self for method chaining. """ if X is not None: logger.warning( "Exogenous variables (X) not yet supported for LSTMForecaster" ) if isinstance(y, pd.Series): series = y elif isinstance(y, pd.DataFrame) and y.shape[1] == 1: series = y.iloc[:, 0] else: raise ValueError("y must be SeriesLike (Series or single-column DataFrame)") series = ensure_datetime_index(series) self.train_index_ = series.index # Convert to Darts TimeSeries darts_series = DartsTimeSeries.from_series(series) # Scale if requested if self.scale: self.scaler_ = Scaler() darts_series = self.scaler_.fit_transform(darts_series) else: self.scaler_ = None # Create and fit model pl_trainer_kwargs = { "enable_progress_bar": False, "accelerator": "cpu", "devices": 1, "logger": False, } pl_trainer_kwargs.update(self.darts_params.get("pl_trainer_kwargs", {})) self.model_ = RNNModel( model="LSTM", input_chunk_length=self.input_chunk_length, output_chunk_length=self.output_chunk_length, training_length=max(self.input_chunk_length, 24), n_rnn_layers=self.n_rnn_layers, hidden_dim=self.hidden_dim, n_epochs=self.n_epochs, random_state=self.random_state, pl_trainer_kwargs=pl_trainer_kwargs, **{k: v for k, v in self.darts_params.items() if k != "pl_trainer_kwargs"}, ) self.model_.fit(darts_series) self._is_fitted = True return self
[docs] def predict( self, fh: Union[int, list, Any], X: Optional[Union["TableLike", Any]] = None, **predict_params: Any, ) -> Forecast: """Generate forecast. Args: fh: Forecast horizon (integer or array). X: Optional exogenous data (ignored). **predict_params: Additional prediction parameters. Returns: Forecast object with predictions. """ self._check_is_fitted() if X is not None: logger.warning( "Exogenous variables (X) not yet supported for LSTMForecaster" ) # Convert fh to integer if isinstance(fh, (list, np.ndarray)): n_periods = len(fh) elif isinstance(fh, int): n_periods = fh else: n_periods = int(fh) # Generate forecast forecast_darts = self.model_.predict(n_periods) # Inverse transform if scaled if self.scaler_ is not None: forecast_darts = self.scaler_.inverse_transform(forecast_darts) # Convert back to pandas Series forecast_series = forecast_darts.to_series() # Ensure index is correct freq = pd.infer_freq(self.train_index_) or "D" last_date = self.train_index_[-1] expected_index = pd.date_range( start=last_date + pd.Timedelta(days=1), periods=n_periods, freq=freq, ) # Align index if needed if len(forecast_series) == n_periods: forecast_series.index = expected_index[: len(forecast_series)] return Forecast(y_pred=forecast_series, fh=fh)
[docs] def predict_interval( self, fh: Any, X: Optional[Any] = None, coverage: float = 0.9, **predict_params: Any, ) -> Forecast: """Generate forecast with prediction intervals. Args: fh: Forecast horizon. X: Optional exogenous data. coverage: Coverage level (e.g., 0.9 for 90%). **predict_params: Additional prediction parameters. Returns: Forecast with intervals. """ self._check_is_fitted() # Get point forecast forecast = self.predict(fh, X, **predict_params) # Darts doesn't provide built-in prediction intervals for RNNModel # We'll estimate them from training residuals # This is a simplified approach - in practice, you might want to use # quantile regression or other methods # Get training predictions for residual calculation # Note: This is approximate as we'd need to refit or use a different approach # For now, we'll use a simple approach based on forecast uncertainty # Estimate uncertainty from point forecast variance # This is a heuristic - in practice, you'd want proper uncertainty quantification forecast_std = forecast.y_pred.std() * 0.1 # Rough estimate from scipy import stats z_score = stats.norm.ppf((1 + coverage) / 2) margin = z_score * forecast_std y_int = pd.DataFrame( { "lower": forecast.y_pred - margin, "upper": forecast.y_pred + margin, }, index=forecast.y_pred.index, ) return Forecast(y_pred=forecast.y_pred, fh=fh, y_int=y_int)