diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeTotoIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeTotoIT.java new file mode 100644 index 0000000000000..eb29dd945d95f --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeTotoIT.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeTotoIT { + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(" + + "model_id=>'toto', " + + "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time DESC LIMIT %d) ORDER BY time, " + + "output_start_time=>%d, " + + "output_length=>%d, " + + "output_interval=>%d, " + + "timecol=>'%s'" + + ")"; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().initClusterEnvironment(1, 1); + AINodeTestUtils.prepareDataInTable(); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void forecastTableFunctionWithTotoTest() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 4; i++) { + final String sql = + String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, i, 5760, 2880, 5760, 96, 1, "time"); + try (ResultSet resultSet = statement.executeQuery(sql)) { + int count = 0; + while (resultSet.next()) { + count++; + } + Assert.assertTrue(count > 0); + } + } + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 5a4dce53666d3..f4de1d78f2744 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -58,7 +58,9 @@ public class AINodeTestUtils { new AbstractMap.SimpleEntry<>( "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active"))) + "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "toto", new FakeModelInfo("toto", "toto", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index f253fb1e56f60..0cb6105a52c11 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -158,4 +158,14 @@ def __repr__(self): }, transformers_registered=True, ), + "toto": ModelInfo( + model_id="toto", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="toto", + pipeline_cls="pipeline_toto.TotoPipeline", + repo_id="Datadog/Toto-Open-Base-1.0", + auto_map=None, + transformers_registered=False, + ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index 1da07cb9fef9e..be7fd27347fb0 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -39,26 +39,33 @@ from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path +from iotdb.ainode.core.model.toto.inference.forecaster import TotoForecaster +from iotdb.ainode.core.model.toto.model.toto import Toto + logger = Logger() BACKEND = DeviceManager() def load_model(model_info: ModelInfo, **model_kwargs) -> Any: - if model_info.auto_map is not None: + if model_info.model_type == "toto": + model = load_toto_model(model_info, **model_kwargs) + elif model_info.auto_map is not None: model = load_model_from_transformers(model_info, **model_kwargs) else: if model_info.model_type == "sktime": model = create_sktime_model(model_info.model_id) else: model = load_model_from_pt(model_info, **model_kwargs) - + + model_device = getattr(model, "device", "cpu") logger.info( - f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully." + f"Model {model_info.model_id} loaded to device {model_device if model_info.model_type != 'sktime' else 'cpu'} successfully." ) return model + def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): device_map = model_kwargs.get("device_map", "cpu") train_from_scratch = model_kwargs.get("train_from_scratch", False) @@ -135,6 +142,19 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs): logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") return BACKEND.move_model(model, device_map) +def load_toto_model(model_info: ModelInfo, **model_kwargs): + device_map = model_kwargs.get("device_map", "cpu") + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + + model = Toto.from_pretrained(model_path) + model = BACKEND.move_model(model, device_map) + return TotoForecaster(model.model) + def load_model_for_efficient_inference(): # TODO: An efficient model loading method for inference based on model_arguments @@ -146,5 +166,6 @@ def load_model_for_powerful_finetune(): pass + def unload_model(): pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py new file mode 100644 index 0000000000000..a933a0f512219 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py @@ -0,0 +1,543 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +from dataclasses import dataclass +from typing import cast + +import numpy as np +import torch +from einops import rearrange, repeat +from gluonts.torch.distributions import AffineTransformed +from jaxtyping import Bool, Float, Int +from torch.distributions import Distribution + +from ..data.util.dataset import ( + MaskedTimeseries, + pad_array, + pad_id_mask, + replace_extreme_values, +) +from ..model.backbone import TotoBackbone + + +@dataclass(frozen=True) +class Forecast: + mean: Float[torch.Tensor, "batch variate future_time_steps"] + samples: Float[torch.Tensor, "batch variate future_time_steps samples"] | None = None + + def quantile(self, q: float | torch.Tensor) -> Float[torch.Tensor, "batch variate future_time_steps"]: + """ + Compute the quantile of the forecast samples. + """ + assert self.samples is not None, "samples must be provided to compute quantiles" + assert isinstance(q, float) or isinstance(q, torch.Tensor), "q must be a float or a tensor" + if isinstance(q, float): + q = torch.tensor(q, device=self.samples.device, dtype=self.samples.dtype) + return self.samples.quantile(q, dim=-1) + + @property + def median(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: + """ + The median of the forecast samples. + """ + return self.quantile(0.5) + + @property + def std(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: + """ + Compute the standard deviation of the forecast samples. + """ + assert self.samples is not None, "samples must be provided to compute standard deviation" + return self.samples.std(dim=-1) + + +class TotoForecaster: + """ + A forecaster class for the Toto model that handles autoregressive decoding for time series forecasting. + + This class wraps a TotoBackbone model and provides methods to generate forecasts for time series data. + The forecasting process uses an autoregressive decoding algorithm: + + 1. The model first processes the entire input context (historical data) + 2. For each future time step: + - The model generates a distribution over possible values + - Either the mean or random samples are drawn from this distribution + - The generated value(s) are appended to the input sequence + - The process repeats with this extended sequence + + When generating multiple samples (num_samples > 1), the model creates separate trajectories for each sample: + - Each trajectory starts with the same historical context + - As sampling progresses, each trajectory evolves independently + - This results in num_samples different possible future paths + - Samples can be processed in batches (samples_per_batch) to manage memory usage + + The forecaster efficiently reuses computation from the context processing phase using a key-value cache, + which stores intermediate transformer attention states to avoid redundant computation. + + The forecaster handles data preprocessing, including padding to match the model's patch size, + and postprocessing to format the outputs as a Forecast object containing means and optional samples. + + Exogenous Variables: + When using exogenous variables, they MUST be placed at the END of the variates/channels dimension + in the input tensor. For example, if you have 3 target variates and 2 exogenous variables, + the input shape should be (batch, 5, time_steps) where: + - indices 0, 1, 2 correspond to target variates + - indices 3, 4 correspond to exogenous variables + + This convention allows the forecaster to correctly inject known future values of exogenous + variables during autoregressive decoding by replacing the last `num_exogenous_variables` + channels with the provided `future_exogenous_variables`. + """ + + model: TotoBackbone + + def __init__( + self, + model: TotoBackbone, + ): + self.model = model + # set the model to evaluation mode + self.model.eval() + + def forecast( + self, + inputs: MaskedTimeseries, + prediction_length: int, + num_samples: int | None = None, + samples_per_batch: int = 10, + use_kv_cache: bool = True, + future_exogenous_variables: Float[torch.Tensor, "batch exogenous_variables future_time_steps"] | None = None, + ) -> Forecast: + """ + Generate a forecast for a batch of time series. This method works autoregressively, + i.e. it feeds the model's predictions back into itself. The decoding process is as follows: + + 1. The model first processes the entire input context (historical data) + 2. For each future time step: + - The model generates a distribution over possible values + - Either the mean or random samples are drawn from this distribution + - If known future exogenous variables are provided, they are injected into the samples to replace the predicted values + - The generated value(s) are appended to the input sequence + - The process repeats with this extended sequence + + There are two modes of operation: + 1. num_samples is None: generate a single mean prediction + 2. num_samples is not None: generate num_samples random samples + + When num_samples is not None, the model creates num_samples separate trajectories for each sample: + - Each trajectory starts with the same historical context + - As sampling progresses, each trajectory evolves independently + - This results in num_samples different possible future paths + - Samples can be processed in batches (samples_per_batch) to manage memory usage + + When using samples_per_batch, this batch size compounds with the optional batch dimension of the input. + For example, if you have a batch of 10 time series, and you set samples_per_batch to 10, + the effective batch size is 100. For the best performance, set samples_per_batch + as high as possible, subject to memory constraints. + + Args: + inputs: A MaskedTimeseries object containing the input time series. + prediction_length: The number of future time steps to predict. + num_samples: + The number of samples to generate. + If None, a single mean prediction is generated. However, + the mean point forecast tends to be less accurate than the + median or mean of the samples (provided enough samples are generated). + It's recommended to use at least 128 samples for reliable forecasts. + samples_per_batch: + The number of samples to generate per batch. + In most cases, this should be as high as possible, subject to memory constraints. + When the inputs have a batch dimension, the effective batch size is samples_per_batch * batch_size. + use_kv_cache: + Whether to use a key-value cache for the model. In most cases, this should be True, + as it significantly speeds up inference. + future_exogenous_variables: + If known future exogenous variables are provided, they are injected into the samples + to replace the predicted values after each prediction step. + + IMPORTANT: Exogenous variables MUST be placed at the END of the variates/channels + dimension in the input tensor. The forecaster assumes that the last `num_exogenous_variables` + channels in `inputs.series` correspond to exogenous variables. During autoregressive + decoding, predictions for these channels are replaced with the known future values + from `future_exogenous_variables`. + """ + if len(inputs.series.shape) == 2: + # unbatched input, variates x time_steps + batch = cast(MaskedTimeseries, torch.utils.data.default_collate([inputs])) + else: + # input is already batched + batch = inputs + + if future_exogenous_variables is not None and len(future_exogenous_variables.shape) == 2: + future_exogenous_variables = future_exogenous_variables.unsqueeze(0) + + # pad the input to the nearest multiple of the patch size + series = pad_array(batch.series, self.model.patch_embed.stride) + padding_mask = pad_array(batch.padding_mask, self.model.patch_embed.stride) + id_mask = batch.id_mask + if id_mask is not None: + id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride) + timestamp_seconds = pad_array(batch.timestamp_seconds, self.model.patch_embed.stride) + time_interval_seconds: Int[torch.Tensor, "batch variate series_len"] = torch.as_tensor( + batch.time_interval_seconds, device=series.device, dtype=torch.int + ) + + if num_samples is not None: + samples = self.generate_samples( + inputs=series, + prediction_length=prediction_length, + num_samples=num_samples, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + input_padding_mask=padding_mask, + id_mask=id_mask, + sampling_batch_size=samples_per_batch, + use_kv_cache=use_kv_cache, + future_exogenous_variables=future_exogenous_variables, + num_exogenous_variables=batch.num_exogenous_variables, + ) + mean = samples.mean(dim=-1) + else: + mean = self.generate_mean( + inputs=series, + prediction_length=prediction_length, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + input_padding_mask=padding_mask, + id_mask=id_mask, + use_kv_cache=use_kv_cache, + future_exogenous_variables=future_exogenous_variables, + num_exogenous_variables=batch.num_exogenous_variables, + ) + samples = None + + return Forecast(mean=mean, samples=samples) + + def assert_ev_compatibility( + self, + inputs: Float[torch.Tensor, "batch total_variate patch_time_steps"], + future_exogenous_variables: Float[torch.Tensor, "batch exogenous_variables future_time_steps"], + prediction_length: int, + num_exogenous_variables: int, + ) -> None: + """ + Assert the compatibility of the future exogenous variables with the input. + """ + assert ( + future_exogenous_variables.shape[-1] == prediction_length + ), "The future exogenous variables must have the same length as the prediction length" + assert ( + future_exogenous_variables.shape[0] == inputs.shape[0] + ), "The future exogenous variables must have the same batch size as the input" + assert ( + num_exogenous_variables == future_exogenous_variables.shape[-2] + ), "The number of exogenous variables must match the number of exogenous variables in the future_exogenous_variables" + + def round_ft_ev( + self, + future_exogenous_variables: Float[torch.Tensor, "batch exogenous_variables future_time_steps"], + T_rounded: int, + ) -> Float[torch.Tensor, "batch exogenous_variables rounded_steps"]: + # add padding to the future exogenous variables to the nearest multiple of the patch size + B, V_ev, T_future = future_exogenous_variables.shape + dtype = future_exogenous_variables.dtype + device = future_exogenous_variables.device + padding = torch.zeros(B, V_ev, T_rounded - T_future, device=device, dtype=dtype) + padded_future_exogenous_variables = torch.cat([future_exogenous_variables, padding], dim=-1) + return padded_future_exogenous_variables + + @torch.no_grad() + def generate_mean( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + prediction_length: int, + timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], + time_interval_seconds: Int[torch.Tensor, "batch variate"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"] | None = None, + id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, + use_kv_cache: bool = False, + future_exogenous_variables: Float[torch.Tensor, "batch exogenous_variables future_time_steps"] | None = None, + num_exogenous_variables: int = 0, + ) -> Float[torch.Tensor, "batch variate time_steps"]: + """ + Generate a point prediction by taking the mean of the output distribution at each step. + This method works autoregressively, i.e. it feeds the model's predictions back into itself + to generate the next prediction. + + If future exogenous variables are provided, they are injected into the samples to replace + the predicted values after each prediction step. + + Note: + Exogenous variables MUST be placed at the END of the variates/channels dimension. + The last `num_exogenous_variables` channels are assumed to be exogenous variables + and will be replaced with values from `future_exogenous_variables` during decoding. + """ + if input_padding_mask is None: + input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) + if id_mask is None: + id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) + + # Assert the compatibility of the future exogenous variables with the input + if future_exogenous_variables is not None: + self.assert_ev_compatibility(inputs, future_exogenous_variables, prediction_length, num_exogenous_variables) + + ## round up the prediction length to the nearest multiple of the patch size + patch_size = self.model.patch_embed.stride + rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) + if rounded_steps > prediction_length and future_exogenous_variables is not None: + future_exogenous_variables = self.round_ft_ev(future_exogenous_variables, rounded_steps) + start_index = inputs.shape[-1] + end_index = start_index + prediction_length + + # TODO: maybe pass in future masks, rather than making assumptions here? + dummy_padding = torch.ones( + (input_padding_mask.shape[0], input_padding_mask.shape[1], patch_size), + device=inputs.device, + dtype=torch.bool, + ) + dummy_id_mask = repeat( + id_mask[:, :, -1:], + "batch variates 1 -> batch variates patch_size", + patch_size=patch_size, + ) + if use_kv_cache: + kv_cache = self.model.allocate_kv_cache( + batch_size=inputs.shape[0], + num_variates=inputs.shape[1], + max_time_steps=inputs.shape[2] + rounded_steps, + device=inputs.device, + dtype=inputs.dtype, + ) + else: + kv_cache = None + + scaling_prefix_length = inputs.shape[-1] + + for idx in range(rounded_steps // patch_size): + base_distr, loc, scale = self.model( + inputs=inputs, + input_padding_mask=input_padding_mask, + id_mask=id_mask, + kv_cache=kv_cache, + scaling_prefix_length=scaling_prefix_length, + num_exogenous_variables=num_exogenous_variables, + ) + distr = self.create_affine_transformed(base_distr, loc, scale) + + # We remove extreme values that can occur early in training + # and cause validation metrics to be NaN + samples = replace_extreme_values(distr.mean[:, :, -patch_size:]) + + # If future exogenous variables are provided, inject them into the samples to replace + # the predicted values. Note: exogenous variables are assumed to be the LAST channels + # in the variates dimension, hence we use `[:, -num_exogenous_variables:]` indexing. + if future_exogenous_variables is not None: + start, stop = idx * patch_size, (idx + 1) * patch_size + samples[:, -num_exogenous_variables:] = future_exogenous_variables[:, :, start:stop] + + inputs = torch.cat([inputs, samples], dim=-1) + id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1) + input_padding_mask = torch.cat([input_padding_mask, dummy_padding], dim=-1) + for _ in range(patch_size): + next_timestamp: Int[torch.Tensor, "batch variate"] = timestamp_seconds[:, :, -1] + time_interval_seconds + timestamp_seconds = torch.cat([timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1) + + return inputs.detach()[:, :, start_index:end_index] + + @torch.no_grad() + def generate_samples( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + prediction_length: int, + num_samples: int, + timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], + time_interval_seconds: Int[torch.Tensor, "batch variate"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"] | None = None, + id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, + sampling_batch_size: int = 10, + use_kv_cache: bool = False, + future_exogenous_variables: Float[torch.Tensor, "batch exogenous_variables future_time_steps"] | None = None, + num_exogenous_variables: int = 0, + ) -> Float[torch.Tensor, "batch variate time_steps samples"]: + """ + Generate samples from the output distribution. + This method works autorregressively, i.e. it feeds the model's predictions back into itself. + It works by creating num_samples chains. Each chain is a separate sequence of predictions. + At each time step, for each chain we take a single sample from the output distribution and append + it to the end of the sequence. + + If future exogenous variables are provided, they are injected into the samples to replace + the predicted values after each prediction step. + + Note: + Exogenous variables MUST be placed at the END of the variates/channels dimension. + The last `num_exogenous_variables` channels are assumed to be exogenous variables + and will be replaced with values from `future_exogenous_variables` during decoding. + """ + if input_padding_mask is None: + input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) + if id_mask is None: + id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) + + if future_exogenous_variables is not None: + self.assert_ev_compatibility(inputs, future_exogenous_variables, prediction_length, num_exogenous_variables) + + assert num_samples % sampling_batch_size == 0, "num_samples must be divisible by sampling_batch_size" + num_batches = num_samples // sampling_batch_size + + # round up the prediction length to the nearest multiple of the patch size + patch_size = self.model.patch_embed.patch_size + rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) + if rounded_steps > prediction_length and future_exogenous_variables is not None: + future_exogenous_variables = self.round_ft_ev(future_exogenous_variables, rounded_steps) + start_index = inputs.shape[-1] + end_index = start_index + prediction_length + + dummy_padding = torch.ones( + ( + input_padding_mask.shape[0] * sampling_batch_size, + input_padding_mask.shape[1], + patch_size, + ), + dtype=torch.bool, + device=inputs.device, + ) + + dummy_id_mask = repeat( + id_mask[:, :, -1:], + "batch variates 1 -> (sampling_batch_size batch) variates patch_size", + sampling_batch_size=sampling_batch_size, + patch_size=patch_size, + ) + inputs = repeat( + inputs, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + if future_exogenous_variables is not None: + future_exogenous_variables = repeat( + future_exogenous_variables, + "batch exogenous_variables future_time_steps -> (sampling_batch_size batch) exogenous_variables future_time_steps", + sampling_batch_size=sampling_batch_size, + ) + input_padding_mask = repeat( + input_padding_mask, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + id_mask = repeat( + id_mask, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + timestamp_seconds = repeat( + timestamp_seconds, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + time_interval_seconds = repeat( + time_interval_seconds, + "batch variates -> (sampling_batch_size batch) variates", + sampling_batch_size=sampling_batch_size, + ) + + all_samples = [] + if use_kv_cache: + kv_cache = self.model.allocate_kv_cache( + batch_size=inputs.shape[0], + num_variates=inputs.shape[1], + max_time_steps=inputs.shape[2] + rounded_steps, + device=inputs.device, + dtype=inputs.dtype, + ) + else: + kv_cache = None + + scaling_prefix_length = inputs.shape[-1] + + for _ in range(num_batches): + batch_inputs = torch.clone(inputs) + batch_input_padding_mask = torch.clone(input_padding_mask) + batch_id_mask = torch.clone(id_mask) + batch_timestamp_seconds = torch.clone(timestamp_seconds) + + for idx in range(rounded_steps // patch_size): + base_distr, loc, scale = self.model( + inputs=batch_inputs, + input_padding_mask=batch_input_padding_mask, + id_mask=batch_id_mask, + kv_cache=kv_cache, + scaling_prefix_length=scaling_prefix_length, + num_exogenous_variables=num_exogenous_variables, + ) + distr = self.create_affine_transformed(base_distr, loc, scale) + + sample = distr.sample() + assert sample is not None + + # We remove extreme values that can occur early in training + # and cause validation metrics to be NaN + samples = replace_extreme_values(sample[:, :, -patch_size:]) + + # If future exogenous variables are provided, inject them into the samples to replace + # the predicted values. Note: exogenous variables are assumed to be the LAST channels + # in the variates dimension, hence we use `[:, -num_exogenous_variables:]` indexing. + if future_exogenous_variables is not None: + start, stop = idx * patch_size, (idx + 1) * patch_size + samples[:, -num_exogenous_variables:] = future_exogenous_variables[:, :, start:stop] + batch_inputs = torch.cat([batch_inputs, samples], dim=-1) + batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask], dim=-1) + batch_input_padding_mask = torch.cat([batch_input_padding_mask, dummy_padding], dim=-1) + for _ in range(patch_size): + next_timestamp = batch_timestamp_seconds[:, :, -1] + time_interval_seconds + batch_timestamp_seconds = torch.cat([batch_timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1) + all_samples.append(batch_inputs) + if kv_cache is not None: + kv_cache.reset() + + outputs = torch.cat(all_samples, dim=0) + unfolded_outputs = rearrange( + outputs, + "(samples batch) variates seq_len -> batch variates seq_len samples", + samples=num_samples, + ).detach() + + trimmed_predictions = unfolded_outputs[:, :, start_index:end_index, :] + return trimmed_predictions + + @staticmethod + def create_affine_transformed(base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor) -> Distribution: + """ + Creates an AffineTransformed distribution with correctly matched shapes. + + Handles three cases: + 1. When loc/scale are per-timestep (from CausalStdMeanScaler) + 2. When base_distr only contains the distribution for the latest patch + while loc/scale contain values for the entire sequence + 3. When loc/scale have a single time step (from StdMeanScaler/StdMinScaler) + and need to be broadcast to match a multi-step base distribution + + Args: + base_distr: The base distribution to transform + loc: Location parameter + scale: Scale parameter + + Returns: + An AffineTransformed distribution with properly handled shapes + """ + # Get the shape of the base distribution + # We'll use this to match the time dimension of loc/scale + base_shape = base_distr.mean.shape + + base_time_dim = base_shape[-1] # Time dimension of base distribution + loc_time_dim = loc.shape[-1] # Time dimension of loc + + if loc_time_dim == 1: + # Case 1: If loc/scale have time dimension 1 (standard scalers), PyTorch broadcasting will handle it + return AffineTransformed(base_distr, loc=loc, scale=scale) + + # Case 2: If loc/scale have time dimension > 1 (causal scaler with history) + # We need to extract only the suffix that matches the base distribution + return AffineTransformed(base_distr, loc=loc[:, :, -base_time_dim:], scale=scale[:, :, -base_time_dim:]) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py new file mode 100644 index 0000000000000..a9996121149ec --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py @@ -0,0 +1,241 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +import logging +import warnings +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +import torch +from einops import rearrange +from jaxtyping import Bool, Float, Int + +from ..model.rope import TimeAwareRotaryEmbedding + +if TYPE_CHECKING: + from ..model.util import KVCache # Import only for type checking + +log = logging.getLogger(__name__) + +try: + from xformers.ops import LowerTriangularMask, memory_efficient_attention + + XFORMERS_AVAILABLE = True + log.info("xFormers Memory-Efficient Attention available.") +except ImportError: + warnings.warn( + "xFormers Memory-Efficient Attention not available. " + "Falling back to native PyTorch scaled_dot_product_attention.", + ImportWarning, + ) + + XFORMERS_AVAILABLE = False + +from torch.nn.functional import scaled_dot_product_attention + + +class AttentionAxis(Enum): + TIME = 1 + SPACE = 2 + + +class BaseMultiheadAttention(torch.nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float, + rotary_emb: Optional[TimeAwareRotaryEmbedding], + use_memory_efficient_attention: bool, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads." + self.head_dim = embed_dim // num_heads + self.rotary_emb = rotary_emb + + # We allocate a single tensor for the q, k, and v projection matrices, + # multiply them with the inputs, and then split the projected tensors into q, k, and v using unbind. + # This reduces overhead a bit vs. having multiple separate Linear layers, + # which need to be initialized, tracked by the optimizer, etc. + self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3) + self.dropout = dropout + self.use_memory_efficient_attention = use_memory_efficient_attention + self.wO = torch.nn.Linear(embed_dim, embed_dim) + + assert not ( + not XFORMERS_AVAILABLE and self.use_memory_efficient_attention + ), "XFORMERS_AVAILABLE is False, so use_memory_efficient_attention must be False" + + if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE): + raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.") + + def rearrange_inputs( + self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"] + ) -> Float[torch.Tensor, "... embed_dim"]: + + pattern = ( + "batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim" + if self.attention_axis == AttentionAxis.TIME + else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim" + ) + + return rearrange(inputs, pattern) + + def get_qkv( + self, + inputs: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + + if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim" + elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim" + elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim" + elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim" + + qkv = self.wQKV(inputs.contiguous()) + return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0) + + def positional_embedding(self, q, k, v, kv_cache, layer_idx): + + # Apply the rotary embeddings + seq_pos_offset = 0 + if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME: + + if kv_cache is not None: + seq_pos_offset = kv_cache.seq_len(layer_idx) + + # We need to permute because rotary embeddings expect the sequence dimension to be the second-to-last dimension + q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset) + + if kv_cache is not None and self.attention_axis == AttentionAxis.TIME: + # First, we append the current input key and value tensors to the cache. + # This concatenates the current key and value tensors to the existing key and value tensors + kv_cache.append(layer_idx, (k, v)) + # Then, we retrieve the key and value tensors from the cache. + # This includes all the key and value tensors from previous time steps + # as well as the current time step. + k, v = kv_cache[layer_idx] + + q = q.contiguous() + k = k.contiguous().to(q.dtype) # Ensure k is the same dtype as q; this is necessary when using mixed precision + v = v.contiguous().to(q.dtype) # Ensure v is the same dtype as q; this is necessary when using mixed precision + + return q, k, v, seq_pos_offset + + def rearrange_output( + self, output: torch.Tensor, batch: int, variate: int, seq_len: int + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)" + elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)" + elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)" + elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)" + + return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) + + def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate): + # Determine dimension ranges for attention + # Ensure the last query vector index is used from the cache + q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len + kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2] + if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + attention_mask = ( + attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else LowerTriangularMask() if seq_pos_offset == 0 else None + ) + return memory_efficient_attention(q, k, v, attn_bias=attention_mask, p=dropout) + elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + attention_mask = ( + attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=dropout, + is_causal=(attention_mask is None and seq_pos_offset == 0), + ) + elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + # We don't use causal masking for space-wise attention + attention_mask = ( + attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return memory_efficient_attention(q, k, v, attn_bias=attention_mask, p=dropout) + elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + # We don't use causal masking for space-wise attention + attention_mask = ( + attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False) + + def forward( + self, + layer_idx: int, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + attention_mask: Optional[ + Union[ + Bool[torch.Tensor, "batch_X_variate n_heads seq_len seq_len"], # Time-wise mask + Bool[torch.Tensor, "batch_X_seq_len n_heads variate variate"], # Space-wise mask + ] + ] = None, + kv_cache: Optional["KVCache"] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + batch_size, variate, seq_len, _ = inputs.shape + dropout = self.dropout if self.training else 0.0 + + rearranged_inputs = self.rearrange_inputs(inputs) + q, k, v = self.get_qkv(rearranged_inputs) + + q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx) + + output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate) + + output = self.rearrange_output(output, batch_size, variate, seq_len) + return self.wO(output) + + +class TimeWiseMultiheadAttention(BaseMultiheadAttention): + """ + Computes standard multihead causal attention over the time axis. + It does this by flattening out the variates along the batch dimension. + It also applies rotary position embeddings to the query and key matrices + in order to incorporate relative positional information. + """ + + attention_axis = AttentionAxis.TIME + + +class SpaceWiseMultiheadAttention(BaseMultiheadAttention): + """ + Computes bidirectional multihead attention over the space axis (i.e. across variates within + a multi-variate time series). This is done by flattening out the time axis along the batch dimension. + This allows the model to attend to different variates at the same time point. By alternating + between time-wise and space-wise attention, the model can learn both temporal and cross-variate + dependencies in the data. + + Unlike with time-wise attention, don't apply rotary embeddings here + because we want cross-variate attention to be invariant to the order of the variates. + """ + + attention_axis = AttentionAxis.SPACE + + +MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py new file mode 100644 index 0000000000000..ba4ded9a14020 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py @@ -0,0 +1,300 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +from math import ceil +from typing import NamedTuple, Optional, Type, cast + +import torch +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int + +from ..model.distribution import DISTRIBUTION_CLASSES_LOOKUP, DistributionOutput +from ..model.embedding import PatchEmbedding +from ..model.scaler import scaler_types +from ..model.transformer import Transformer +from ..model.util import KVCache +from .fusion import Fusion + + +class TotoOutput(NamedTuple): + """ + Output of the Toto model. Contains the output distribution, the location parameters, + and the scale parameters. + """ + + distribution: torch.distributions.Distribution + loc: Float[torch.Tensor, "batch variate"] + scale: Float[torch.Tensor, "batch variate"] + + +class TotoBackbone(torch.nn.Module): + """ + Toto (Timeseries-Optimized Transformer for Observability) is a transformer-based model for multivariate + time series forecasting. It applies a patch embedding to the input data, followed by a transformer + that alternates between time-wise and space-wise attention. The transformer is followed by a linear projection + that maps the transformer output to the output distribution. + + The output distribution can be a single distribution (e.g. Gaussian) or a mixture of distributions. + If a mixture of distributions is used, the model will learn to predict the mixture weights + as well as the parameters of the individual distributions. + + Parameters + ---------- + patch_size + Size of the patch to use for the patch embedding. + stride + Stride to use for the patch embedding. + embed_dim + Dimension of the model's latent space. + num_layers + Number of transformer layers to use. + num_heads + Number of attention heads to use in each self-attention layer. + mlp_hidden_dim + Dimension of the hidden layer in the feedforward network. + dropout + Dropout rate to use in the model. + spacewise_every_n_layers + How many time-wise transformer layers to apply between each space-wise transformer layer. + spacewise_first + Whether to apply space-wise attention before time-wise attention. + scaler_cls + Class to use for scaling the input data. + output_distribution_classes + List of classes to use for the output distribution. If a single class is provided, the model + will output a single distribution. If multiple classes are provided, the model will output a + learned mixture of distributions. + output_distribution_kwargs + Keyword arguments to pass to the output distribution class. Note: this currently only works + with a single output distribution class. + use_memory_efficient_attention: + Whether to use memory-efficient attention. If True, the model will use the memory-efficient from xFormers. + stabilize_with_global: + Whether to use global statistics to stabilize causal statistics by clamping extreme values. Only applies to causal scalers. + scale_factor_exponent: + Exponent that controls the allowed range of deviation from global scale for causal scalers. + """ + + def __init__( + self, + patch_size: int, + stride: int, + embed_dim: int, + num_layers: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + scaler_cls: str, + output_distribution_classes: list[str], + spacewise_first: bool = True, + output_distribution_kwargs: dict | None = None, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + ): + super().__init__() + self.embed_dim = embed_dim + # Attributes for variate-label fusion (initialized when enable_variate_labels is called) + self.fusion: Optional[Fusion] = None + self.num_prepended_tokens: int = 0 + self.target_variate_label: Optional[torch.nn.Parameter] = None + self.exogenous_variate_label: Optional[torch.nn.Parameter] = None + # strings are used when loading a safetensors checkpoint + # Initialize patch-based scalers with the correct patch_size + if scaler_cls == "": + self.scaler = scaler_types[scaler_cls]( + patch_size=patch_size, + stabilize_with_global=stabilize_with_global, + scale_factor_exponent=scale_factor_exponent, + ) + else: + self.scaler = scaler_types[scaler_cls]() + + self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim) + self.dropout = dropout + self.num_layers = num_layers + self.use_memory_efficient_attention = use_memory_efficient_attention + self.transformer = Transformer( + embed_dim=embed_dim, + num_heads=num_heads, + num_layers=self.num_layers, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + spacewise_every_n_layers=spacewise_every_n_layers, + spacewise_first=spacewise_first, + use_memory_efficient_attention=self.use_memory_efficient_attention, + fusion=self.fusion, + ) + self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size) + + # TODO[BEN] this doesn't need to be a list + output_distribution_classes_ = [DISTRIBUTION_CLASSES_LOOKUP[c] for c in output_distribution_classes] + self.output_distribution = output_distribution_classes_[0](embed_dim, **(output_distribution_kwargs or {})) + + def allocate_kv_cache( + self, + batch_size: int, + num_variates: int, + max_time_steps: int, + device: torch.device, + dtype: torch.dtype, + ) -> KVCache: + return KVCache( + batch_size=batch_size, + num_variates=num_variates, + transformer_layers=list(self.transformer.layers), + num_layers=self.num_layers, + embed_dim=self.embed_dim, + num_heads=cast(int, self.transformer.layers[0].num_heads), + max_seq_len=ceil(max_time_steps / self.patch_embed.stride), + device=device, + dtype=dtype, + use_memory_efficient_attention=self.use_memory_efficient_attention, + ) + + def backbone( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"], + id_mask: Float[torch.Tensor, "batch #variate time_steps"], + kv_cache: Optional[KVCache] = None, + scaling_prefix_length: Optional[int] = None, + num_exogenous_variables: int = 0, + ) -> tuple[ + Float[torch.Tensor, "batch variates time_steps embed_dim"], + Float[torch.Tensor, "batch variates time_steps"], + Float[torch.Tensor, "batch variates time_steps"], + ]: + scaled_inputs: Float[torch.Tensor, "batch variate time_steps"] + loc: Float[torch.Tensor, "batch variate time_steps"] + scale: Float[torch.Tensor, "batch variate time_steps"] + + # Standard scaling operation, same API but without ID mask. + scaled_inputs, loc, scale = self.scaler( + inputs, + weights=torch.ones_like(inputs, device=inputs.device), + padding_mask=input_padding_mask, + prefix_length=scaling_prefix_length, + ) + + if kv_cache is not None: + # Account for prepended condition tokens when using KV cache. + # Cached length counts prepended tokens; do not overcount when computing time-series prefix. + kv_cache_len_tensor = kv_cache.current_len(0) + kv_cache_len = ( + int(kv_cache_len_tensor) if isinstance(kv_cache_len_tensor, torch.Tensor) else kv_cache_len_tensor + ) + prefix_len = max(0, self.patch_embed.stride * (kv_cache_len - self.num_prepended_tokens)) + + # Truncate inputs so that the transformer only processes + # the last patch in the sequence. We'll use the KVCache + # for the earlier patches. + scaled_inputs = scaled_inputs[:, :, prefix_len:] + + # As a simplification, when using kv cache we only allow decoding + # one step at a time after the initial forward pass. + assert (prefix_len == 0) or ( + scaled_inputs.shape[-1] == self.patch_embed.stride + ), "Must decode one step at a time." + + input_padding_mask = input_padding_mask[:, :, prefix_len:] + id_mask = id_mask[:, :, prefix_len:] + + embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"] + reduced_id_mask: Float[torch.Tensor, "batch variate seq_len"] + + embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask) + + # Build variate label embeddings (one per variate) if enabled + variate_label_embeds = self.build_variate_label_embeds(num_exogenous_variables, embeddings) + + # Apply the transformer on the embeddings (fusion handles prepending at layer 0) + original_seq_len = embeddings.shape[2] + transformed: Float[torch.Tensor, "batch variates seq_len embed_dim"] = self.transformer( # type: ignore[assignment] + embeddings, reduced_id_mask, kv_cache, variate_label_embeds=variate_label_embeds + ) + # Crop out the prepended tokens before unembedding + added_tokens = transformed.shape[2] - original_seq_len + if added_tokens > 0: + transformed = transformed[:, :, added_tokens:] + + # Unembed and flatten the sequence + flattened: Float[torch.Tensor, "batch variates new_seq_len embed_dim"] = rearrange( + self.unembed(transformed), + "batch variates seq_len (patch_size embed_dim) -> batch variates (seq_len patch_size) embed_dim", + embed_dim=self.embed_dim, + ) + return flattened, loc, scale + + def forward( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"], + id_mask: Float[torch.Tensor, "batch #variate time_steps"], + kv_cache: Optional[KVCache] = None, + scaling_prefix_length: Optional[int] = None, + num_exogenous_variables: int = 0, + ) -> TotoOutput: + flattened, loc, scale = self.backbone( + inputs, + input_padding_mask, + id_mask, + kv_cache, + scaling_prefix_length, + num_exogenous_variables, + ) + + return TotoOutput(self.output_distribution(flattened), loc, scale) + + @property + def device(self): + return next(self.parameters()).device + + def enable_variate_labels(self) -> None: + """ + Enable variate labels for exogenous feature differentiation. + Called automatically when using exogenous features during finetuning. + - Creates trainable label parameters for target and exogenous variates + - Enables fusion by installing a Fusion module + """ + self.fusion = Fusion() + self.num_prepended_tokens = 1 + self.target_variate_label = torch.nn.Parameter(torch.randn(self.embed_dim)) + self.exogenous_variate_label = torch.nn.Parameter(torch.randn(self.embed_dim)) + # If transformer already exists (e.g., loaded from checkpoint), update it as well + if hasattr(self, "transformer") and self.transformer is not None: + self.transformer.fusion = self.fusion + + def build_variate_label_embeds( + self, + num_exogenous_variables: int, + embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"], + ) -> Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]]: + """ + Build per-variate label embeddings for fusion. + The last num_exogenous_variables variates are treated as exogenous and receive the exogenous label. + Returns None when variate labels are not enabled. + """ + if self.fusion is None: + return None + + assert self.target_variate_label is not None + assert self.exogenous_variate_label is not None + + batch_size, num_variates, _, _ = embeddings.shape + + target_variate_label = repeat(self.target_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates).to( + device=embeddings.device, dtype=embeddings.dtype + ) + exogenous_variate_label = repeat(self.exogenous_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates).to( + device=embeddings.device, dtype=embeddings.dtype + ) + # Build exog_mask from num_exogenous_variables: last num_exogenous_variables variates are exogenous + exog_mask = torch.zeros(1, num_variates, 1, 1, dtype=torch.bool, device=embeddings.device) + if num_exogenous_variables > 0: + exog_mask[:, -num_exogenous_variables:] = True + # Select per-variate label: target label for genuine targets, exogenous label for EV channels + return torch.where(exog_mask, exogenous_variate_label, target_variate_label) # (B, V, 1, D) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py new file mode 100644 index 0000000000000..5237afd8b34bf --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py @@ -0,0 +1,76 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +from abc import ABC + +import torch +import torch.nn.functional as F # noqa: flake8 N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' +from gluonts.torch.distributions import AffineTransformed +from gluonts.torch.distributions.studentT import StudentT + + +class DistributionOutput(ABC, torch.nn.Module): + pass + + +class StudentTOutput(DistributionOutput): + def __init__(self, embed_dim): + super().__init__() + self.embed_dim = embed_dim + self.df = torch.nn.Linear(embed_dim, 1) + self.loc_proj = torch.nn.Linear(embed_dim, 1) + self.scale_proj = torch.nn.Linear(embed_dim, 1) + + def forward(self, inputs, loc=None, scale=None): + eps = torch.finfo(inputs.dtype).eps + df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1) + base_loc = self.loc_proj(inputs).squeeze(-1) + base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1) + + base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) + + if loc is not None and scale is not None: + return AffineTransformed( + base_dist, + loc=loc, + scale=scale, + ) + return base_dist + + +class MixtureOfStudentTsOutput(DistributionOutput): + def __init__( + self, + embed_dim, + k_components, + ): + super().__init__() + self.embed_dim = embed_dim + self.k_components = k_components + + self.df = torch.nn.Linear(embed_dim, k_components) + self.loc_proj = torch.nn.Linear(embed_dim, k_components) + self.scale_proj = torch.nn.Linear(embed_dim, k_components) + self.mixture_weights = torch.nn.Linear(embed_dim, k_components) + + def forward(self, inputs, loc=None, scale=None): + df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps) + loc = self.loc_proj(inputs) + scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps) + logits = self.mixture_weights(inputs) + probs = F.softmax(logits, dim=-1) + components = StudentT(df, loc, scale) + mixture_distribution = torch.distributions.Categorical(probs=probs) + + return torch.distributions.MixtureSameFamily( + mixture_distribution, + components, + ) + + +DISTRIBUTION_CLASSES_LOOKUP = { + "": StudentTOutput, + "": MixtureOfStudentTsOutput, +} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py new file mode 100644 index 0000000000000..3ab063955223d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py @@ -0,0 +1,61 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +from jaxtyping import Float, Int, Num + + +def patchify_id_mask( + id_mask: Int[torch.Tensor, "batch variate time_steps"], patch_size: int +) -> Int[torch.Tensor, "batch variate seq_len patch_size"]: + patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size) + patched_id_mask_min = patched_id_mask.min(-1).values + patched_id_mask_max = patched_id_mask.max(-1).values + assert torch.eq(patched_id_mask_min, patched_id_mask_max).all(), "Patches cannot span multiple datasets" + return patched_id_mask_min + + +class PatchEmbedding(torch.nn.Module): + """ + Multivariate time series patch embedding. + Patchifies each variate separately. + """ + + def __init__(self, patch_size: int, stride: int, embed_dim: int): + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.stride = stride + self.projection = torch.nn.Linear(self.patch_size, self.embed_dim) + + def _patchify( + self, x: Num[torch.Tensor, "batch variate time_steps"] + ) -> Num[torch.Tensor, "batch variate seq_len patch_size"]: + return x.unfold(dimension=-1, size=self.patch_size, step=self.stride) + + def forward( + self, + x: Float[torch.Tensor, "batch #variate time_steps"], + id_mask: Float[torch.Tensor, "batch time_steps"], + ) -> tuple[ + Float[torch.Tensor, "batch variate seq_len embed_dim"], + Int[torch.Tensor, "batch seq_len"], + ]: + assert ( + x.shape[-1] % self.patch_size == 0 + ), f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})" + x_patched: Float[torch.Tensor, "batch variate seq_len patch_size"] = self._patchify(x) + id_mask_patched: Int[torch.Tensor, "batch variate seq_len patch_size"] = self._patchify(id_mask) + + assert torch.eq( + id_mask_patched.min(-1).values, id_mask_patched.max(-1).values + ).all(), "Patches cannot span multiple datasets" + + return ( + self.projection(x_patched), + id_mask_patched.min(-1).values, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py new file mode 100644 index 0000000000000..27ea6da21f975 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py @@ -0,0 +1,19 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +import torch +import torch.nn.functional as F + + +class SwiGLU(torch.nn.Module): + """ + https://arxiv.org/abs/2002.05202 + NOTE: x should be 2x the size you want + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Note this ordering is unusual, but is done so to match xFormers + gate, x = x.chunk(2, dim=-1) + return F.silu(gate) * x diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py new file mode 100644 index 0000000000000..216b5e061f75b --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py @@ -0,0 +1,37 @@ +""" +This module implements the fusion of variate label embeddings with input embeddings in the TOTO model. +It prepends trainable variate label embeddings, allowing the model to distinguish between target and exogenous input features. +""" + +from typing import Optional + +import torch +import torch.nn.functional as F +from jaxtyping import Float + + +class Fusion(torch.nn.Module): + """ + Prepends variate label embeddings to the input embeddings along the sequence dimension. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, + embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"], + variate_label_embeds: Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]] = None, + ) -> Float[torch.Tensor, "batch variate new_seq_len embed_dim"]: + + # Nothing to fuse + if variate_label_embeds is None: + return embeddings + + processed_embeddings = F.normalize(variate_label_embeds, p=2, dim=-1) + + # Prepend along sequence dimension + return torch.cat( + [processed_embeddings.to(dtype=embeddings.dtype, device=embeddings.device, non_blocking=True), embeddings], + dim=2, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py new file mode 100644 index 0000000000000..7b99e7fd8500a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py @@ -0,0 +1,98 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +from einops import rearrange +from jaxtyping import Int +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +from rotary_embedding_torch.rotary_embedding_torch import default + + +def exists(val): + return val is not None + + +class TimeAwareRotaryEmbedding(RotaryEmbedding): + """ + A variant of the rotary position embedding that (optionally) uses the time index + to compute the sinusoidal and cosine embeddings. This is useful for + time series data, where the time index is the most important positional + information. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # If the parent stored `freqs` as a Parameter, remove it and register as a buffer + # Register buffer is needed for sharding with FSDP + if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter): + # Extract the underlying Tensor + freqs_data = self.freqs.data + + # Remove `freqs` from the module's parameters + self._parameters.pop("freqs") + + # Register as non-persistent buffer + self.register_buffer("freqs", freqs_data, persistent=False) + + def rotate_queries_and_keys( + self, + q: torch.Tensor, + k: torch.Tensor, + seq_dim: Optional[int] = None, + seq_pos: Optional[Int[torch.Tensor, "... seq_len"]] = None, + seq_pos_offset: int = 0, + ): + """ + This method is the same as the one on the base class, except it allows you to override + the sequence position tensor with a custom one. It also removes the ability + to cache the position encodings, since we have to compute them dynamically + based on the timesteps in the input data. + """ + if seq_dim is None: + seq_dim = self.default_seq_dim + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device)) + seq = seq + seq_pos_offset + + freqs = self.forward(seq) + + scale = self.get_scale(seq).to(dtype) + + # used for xformers + if seq_dim == -3: + num_heads = q.shape[-2] + freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1) + scale = scale.unsqueeze(1).expand(-1, num_heads, -1) + + rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale( + self, + t: torch.Tensor, + ): + """ + This method is adapted closely from the base class, but it knows how to handle + when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different + sequence position vector for each time series). + """ + assert self.use_xpos + + power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base + + scale = self.scale ** rearrange(power, "... n -> ... n 1") + scale = torch.cat((scale, scale), dim=-1) + + return scale diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py new file mode 100644 index 0000000000000..9740cd7c3b45a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py @@ -0,0 +1,464 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +import warnings +from typing import Tuple, Type + +import torch +from einops import reduce, repeat +from gluonts.core.component import validated +from gluonts.torch.scaler import Scaler + + +class StdMeanScaler(Scaler): + """ + Scales data to have zero mean and unit variance along a given dimension. + + Parameters + ---------- + dim + dimension along which to compute the scale + keepdim + controls whether to retain dimension ``dim`` (of length 1) in the + scale tensor, or suppress it. + minimum_scale + default scale that is used for elements that are constantly zero + along dimension `dim`. + """ + + @validated() + def __init__( + self, + dim: int = -1, + keepdim: bool = True, + minimum_scale: float = 1e-3, + ) -> None: + self.dim = dim + self.keepdim = keepdim + self.minimum_scale = minimum_scale + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + with torch.no_grad(): + if prefix_length is not None: + # Create a prefix mask that is 1 for positions within prefix, 0 elsewhere + prefix_mask = torch.zeros_like(weights) + prefix_mask[..., :prefix_length] = 1.0 + # Apply prefix mask to weights instead of slicing + weights = weights * prefix_mask + + weights = weights * padding_mask + + # We need to calculate the standard deviation using double-precision floats + # to avoid overflow for extreme values. + try: + high_precision_data = data.to(torch.float64) + except TypeError: + # Certain backends (particularly MacOS/MPS) don't support float64. + # In this case, we might be doing inference in float16 + # so it's still worthwhile to cast to float32 to avoid + # some (but not all) overflow issues. + warnings.warn( + f"Float64 is not supported by device {data.device}. " + "Using float32 instead for accumulating denominator in input scaler. " + "This may lead to overflow issues if the data contains extreme values.", + RuntimeWarning, + ) + high_precision_data = data.to(torch.float32) + + denominator = weights.sum(self.dim, keepdim=self.keepdim).clamp_min(1.0).to(high_precision_data.dtype) + means = (high_precision_data * weights).sum(self.dim, keepdim=self.keepdim) / denominator + means = torch.nan_to_num(means) + + variance = (((high_precision_data - means) * weights) ** 2).sum( + self.dim, keepdim=self.keepdim + ) / denominator + scale = torch.sqrt(variance + self.minimum_scale).to(data.dtype) + loc = means.to(data.dtype) + + return (data - loc) / scale, loc, scale + + +def compute_causal_statistics( + data: torch.Tensor, + weights: torch.Tensor, + padding_mask: torch.Tensor, + dim: int, + minimum_scale: float, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + prefix_length: int | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute causal mean and scale statistics along a specified dimension using + a vectorized implementation of Welford's algorithm for numerical stability. + + This implementation avoids explicit loops while maintaining the numerical stability + of Welford's algorithm, achieving better performance with the same robustness + against overflow issues. + + + Can optionally use global statistics to stabilize causal statistics by clamping + extreme values, preventing instability while preserving a relaxed version of the + causal property. This allows a controlled amount of future information leakage, + introducing an explicit tradeoff between causality and stability. + extreme values, preventing instability while preserving the causal property. + + Parameters + ---------- + data + The input data tensor + weights + The weight tensor (same shape as data) + padding_mask + The padding mask tensor (same shape as data) + dim + The dimension along which to compute statistics (must be -1, the time dimension) + minimum_scale + Minimum scale value to use + use_bessel_correction + Whether to use Bessel's correction to get an unbiased estimator + stabilize_with_global + Whether to use global statistics to stabilize the causal statistics by clamping + extreme values + scale_factor_exponent + Exponent that controls the allowed range of deviation from global scale. + For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale. + With exponent=2.0, the range would be 0.01x to 100x. + prefix_length + If specified, the global statistics will be computed using only the prefix length + requested. This is used for multistep decoding, where we only want to use the + initial historical data to compute the global statistics. If stabilize_with_global + is False, this parameter is ignored. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Causal mean and scale tensors, potentially stabilized with global statistics + """ + # Assert that dim is -1 (last dimension) + assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)" + + with torch.no_grad(): + # Apply padding mask to weights + weights = weights * padding_mask + + # Try to use higher precision for numerical stability + try: + high_precision_data = data.to(torch.float64) + high_precision_weights = weights.to(torch.float64) + except TypeError: + # Fallback for devices that don't support float64 + warnings.warn( + f"Float64 is not supported by device {data.device}. " + "Using float32 instead for causal scaler calculations. " + "This may lead to numerical issues if the data contains extreme values.", + RuntimeWarning, + ) + high_precision_data = data.to(torch.float32) + high_precision_weights = weights.to(torch.float32) + + # Check if deterministic algorithms are enabled and we're using CUDA. + # Cumsum operations do not support deterministic mode in CUDA, + # so we need to disable it for just this section. + prev_deterministic = torch.are_deterministic_algorithms_enabled() + if prev_deterministic and data.device.type == "cuda": + # Disable deterministic algorithms for operations + torch.use_deterministic_algorithms(False) + + try: + # Create weighted data + weighted_data = high_precision_weights * high_precision_data + + # Compute cumulative sum of weights and weighted data along time dimension + cum_weights = torch.cumsum(high_precision_weights, dim=dim) + cum_values = torch.cumsum(weighted_data, dim=dim) + + # Avoid division by zero for the first time step or when no valid values + denominator = cum_weights.clamp_min(1.0) + + # Compute causal means at each time step + causal_means = cum_values / denominator + + # For Welford's algorithm, we need to compute the correction term + # using the difference between the current value and the current mean + + # Create shifted version of causal means to compute delta efficiently + # First item in shifted_means will be zero + shifted_means = torch.zeros_like(causal_means) + shifted_means[..., 1:] = causal_means[..., :-1] + + # Compute delta between current data point and previous mean + # For t=0, this is just the first data point + delta = high_precision_data - shifted_means + + # Compute the increment term for Welford's algorithm. + # This is defined as the product of the delta and the difference between the current data point and the causal mean. + # This is where we avoid the traditional E[X²] - E[X]² computation + increment = delta * (high_precision_data - causal_means) * high_precision_weights + + # The Welford algorithm uses the term m_2, which is the cumulative sum of the increment term. + # This is an accumulator that helps us compute the second moment (hence m_2) of the distribution. + # Compute cumulative sum of the increment term + m_2 = torch.cumsum(increment, dim=dim) + + # Compute variance according to Welford's algorithm + if use_bessel_correction: + causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0) + else: + causal_variance = m_2 / denominator + + # Add minimum scale but keep in high precision for now + causal_scale = torch.sqrt(causal_variance + minimum_scale) + + # Apply stabilization with global statistics if requested + if stabilize_with_global: + if prefix_length is not None: + # Create a prefix mask for global statistics computation + prefix_mask = torch.zeros_like(weights) + prefix_mask[..., :prefix_length] = 1.0 + + # Apply prefix mask to restrict computation to prefix + weighted_data = weighted_data * prefix_mask + weights = weights * prefix_mask + padding_mask = padding_mask * prefix_mask + + # Calculate scale factors from the exponent + scale_factor_min = 10.0 ** (-scale_factor_exponent) + scale_factor_max = 10.0**scale_factor_exponent + + global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0) + global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator + global_means = torch.nan_to_num(global_means) + + global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum( + dim, keepdim=True + ) / global_denominator + global_scale = torch.sqrt(global_variance + minimum_scale) + + # Expand global statistics to match the time dimension + expanded_global_scale = global_scale.expand_as(causal_scale) + + # Define bounds using scale factors + min_allowed_scale = expanded_global_scale * scale_factor_min + max_allowed_scale = expanded_global_scale * scale_factor_max + + # Clamp the causal scale between min_allowed_scale and max_allowed_scale + causal_scale = torch.clamp( + causal_scale, + min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale), + max=max_allowed_scale, + ) + + # Now convert means and scale to original dtype after all numerical operations + causal_means = causal_means.to(data.dtype) + causal_scale = causal_scale.to(data.dtype) + + finally: + # Restore original deterministic setting if it was changed + if prev_deterministic and data.device.type == "cuda": + torch.use_deterministic_algorithms(True) + + return causal_means, causal_scale + + +class CausalStdMeanScaler(Scaler): + """ + Causally scales the data along dimension `dim` which is expected to be the + time dimension. For each position t along this dimension, the mean and + standard deviation are computed using only data from positions up to t. + + Can optionally stabilize causal statistics using global statistics to prevent + extreme values, while preserving the causal property. + + Note: This scaler only works with dim=-1 (the last dimension). + + Parameters + ---------- + dim + dimension along which to compute the causal scale (must be -1, the last dimension) + minimum_scale + default scale that is used if the scale is below this threshold + or for the first time step, since standard deviation cannot be + computed with a single observation + use_bessel_correction + whether to use Bessel's correction to get an unbiased estimator + stabilize_with_global + whether to use global statistics to stabilize extreme causal statistics + scale_factor_exponent + exponent that controls the allowed range of deviation from global scale. + For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale. + With exponent=2.0, the range would be 0.01x to 100x. + """ + + @validated() + def __init__( + self, + dim: int = -1, + minimum_scale: float = 0.1, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + ) -> None: + super().__init__() + assert dim == -1, "CausalStdMeanScaler only supports dim=-1 (last dimension)" + self.dim = dim + self.minimum_scale = minimum_scale + self.use_bessel_correction = use_bessel_correction + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]" + + # Compute causal statistics with optional stabilization + causal_means, causal_scale = compute_causal_statistics( + data, + weights, + padding_mask, + self.dim, + self.minimum_scale, + self.use_bessel_correction, + self.stabilize_with_global, + self.scale_factor_exponent, + prefix_length, + ) + + # Apply the normalization + scaled_data = (data - causal_means) / causal_scale + + return scaled_data, causal_means, causal_scale + + +class CausalPatchStdMeanScaler(Scaler): + """ + Causally scales data in patches, where each patch uses statistics computed + from all data up to and including that patch. Within each patch, all timesteps + use the same scaling values. + + This approach provides more stability than per-timestep causal scaling while + still maintaining the causal property (not using future data). + + Can optionally stabilize causal statistics using global statistics to prevent + extreme values, while preserving the causal property. + + The statistics are computed using Welford's algorithm, which provides better + numerical stability compared to the direct computation of variance, especially + when dealing with large values or a large number of data points. + + Note: This scaler only works with the following constraints: + - The input must have shape [batch, variates, time_steps] + - It only operates on the last dimension (-1) + - The time_steps must be divisible by patch_size + + Parameters + ---------- + dim + dimension along which to compute the causal scale. Must be -1 (the last dimension). + patch_size + number of timesteps in each patch + minimum_scale + default scale that is used for elements that are constantly zero + along dimension `dim` or for the first patch. + use_bessel_correction + whether to use Bessel's correction to get an unbiased estimator + stabilize_with_global + whether to use global statistics to stabilize extreme causal statistics + scale_factor_exponent + exponent that controls the allowed range of deviation from global scale. + For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale. + With exponent=2.0, the range would be 0.01x to 100x. + """ + + @validated() + def __init__( + self, + dim: int = -1, + patch_size: int = 32, + minimum_scale: float = 0.1, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + ) -> None: + super().__init__() + assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)" + self.dim = dim + self.patch_size = patch_size + self.minimum_scale = minimum_scale + self.use_bessel_correction = use_bessel_correction + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]" + + with torch.no_grad(): + # Get the number of time steps (last dimension) + time_steps = data.shape[-1] + + # Assert that time_steps is divisible by patch_size + assert ( + time_steps % self.patch_size == 0 + ), f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})" + + # First compute causal statistics with optional stabilization + causal_means, causal_scale = compute_causal_statistics( + data, + weights, + padding_mask, + -1, + self.minimum_scale, + self.use_bessel_correction, + self.stabilize_with_global, + self.scale_factor_exponent, + prefix_length, + ) + + # Unfold the causal means and scales to get the patches + means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size) + scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size) + + # Get the last element of each patch (the most recent statistic) + patch_stats_means = means_unfolded[..., -1] + patch_stats_scales = scales_unfolded[..., -1] + + # Tile the patch statistics across time dimension using einops.repeat + # With our fixed [batch, variates, num_patches] shape this is much simpler + patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size) + patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size) + + # Apply normalization + scaled_data = (data - patch_means) / patch_scales + + return scaled_data, patch_means, patch_scales + + +# for deserialization of SafeTensors checkpoints +scaler_types = { + "": StdMeanScaler, + "": CausalStdMeanScaler, + "": CausalPatchStdMeanScaler, +} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py new file mode 100644 index 0000000000000..94d1a10f92a9a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py @@ -0,0 +1,188 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +import json +import os +import re +from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors.torch as safetorch +import torch +from huggingface_hub import ModelHubMixin, constants, hf_hub_download + +from ..model.attention import XFORMERS_AVAILABLE +from ..model.backbone import TotoBackbone +from ..model.transformer import XFORMERS_SWIGLU_AVAILABLE + + +class Toto(torch.nn.Module, ModelHubMixin): + """ + PyTorch module for Toto (Timeseries-Optimized Transformer for Observability). + + Parameters + ---------- + **model_kwargs + Additional keyword arguments to pass to the TotoModule constructor. + """ + + def __init__( + self, + patch_size: int, + stride: int, + embed_dim: int, + num_layers: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + scaler_cls: str, + output_distribution_classes: list[str], + spacewise_first: bool = True, + output_distribution_kwargs: dict | None = None, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + **model_kwargs, + ): + super().__init__() + self.model = TotoBackbone( + patch_size=patch_size, + stride=stride, + embed_dim=embed_dim, + num_layers=num_layers, + num_heads=num_heads, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + spacewise_every_n_layers=spacewise_every_n_layers, + scaler_cls=scaler_cls, + output_distribution_classes=output_distribution_classes, + spacewise_first=spacewise_first, + output_distribution_kwargs=output_distribution_kwargs, + use_memory_efficient_attention=use_memory_efficient_attention, + stabilize_with_global=stabilize_with_global, + scale_factor_exponent=scale_factor_exponent, + **model_kwargs, + ) + self.model_kwargs = model_kwargs + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path, + map_location: str = "cpu", + strict=True, + **model_kwargs, + ): + """ + Custom checkpoint loading. Used to load a local + safetensors checkpoint with an optional config.json file. + """ + if os.path.isdir(checkpoint_path): + safetensors_file = os.path.join(checkpoint_path, "model.safetensors") + else: + safetensors_file = checkpoint_path + + if os.path.exists(safetensors_file): + model_state = safetorch.load_file(safetensors_file, device=map_location) + else: + raise FileNotFoundError(f"Model checkpoint not found at: {safetensors_file}") + + # Load configuration from config.json if it exists. + config_file = os.path.join(checkpoint_path, "config.json") + config = {} + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + + # Merge any extra kwargs into the configuration. + config.update(model_kwargs) + + remapped_state_dict = cls._map_state_dict_keys( + model_state, XFORMERS_SWIGLU_AVAILABLE and not config.get("pre_xformers_checkpoint", False) + ) + + if not XFORMERS_AVAILABLE and config.get("use_memory_efficient_attention", True): + config["use_memory_efficient_attention"] = False + + instance = cls(**config) + instance.to(map_location) + + # Filter out unexpected keys + filtered_remapped_state_dict = { + k: v + for k, v in remapped_state_dict.items() + if k in instance.state_dict() and not k.endswith("rotary_emb.freqs") + } + + instance.load_state_dict(filtered_remapped_state_dict, strict=strict) + return instance + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE) + return cls.load_from_checkpoint(model_file, map_location, strict, **model_kwargs) + else: + model_file = hf_hub_download( + repo_id=model_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + return cls.load_from_checkpoint(model_file, map_location, strict, **model_kwargs) + + @staticmethod + def _map_state_dict_keys(state_dict, use_fused_swiglu): + """ + Maps the keys of a state_dict to match the current model's state_dict. + Currently this is only used to convert between fused and unfused SwiGLU implementations. + """ + if use_fused_swiglu: + remap_keys = { + "mlp.0.weight": "mlp.0.w12.weight", + "mlp.0.bias": "mlp.0.w12.bias", + "mlp.2.weight": "mlp.0.w3.weight", + "mlp.2.bias": "mlp.0.w3.bias", + } + else: + remap_keys = { + "mlp.0.w12.weight": "mlp.0.weight", + "mlp.0.w12.bias": "mlp.0.bias", + "mlp.0.w3.weight": "mlp.2.weight", + "mlp.0.w3.bias": "mlp.2.bias", + } + + def replace_key(text): + for pattern, replacement in remap_keys.items(): + text = re.sub(pattern, replacement, text) + return text + + return {replace_key(k): v for k, v in state_dict.items()} + + @property + def device(self): + return next(self.model.parameters()).device diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py new file mode 100644 index 0000000000000..a477e65ac5385 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py @@ -0,0 +1,351 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +import warnings +from typing import Literal, Optional, Union, cast + +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Bool, Float, Int +from rotary_embedding_torch import RotaryEmbedding + +from ..model.attention import ( + AttentionAxis, + MultiHeadAttention, + SpaceWiseMultiheadAttention, + TimeWiseMultiheadAttention, +) +from ..model.feed_forward import SwiGLU +from ..model.rope import TimeAwareRotaryEmbedding +from ..model.util import KVCache, RMSNorm, make_batched_block_mask +from .fusion import Fusion + +try: + from xformers.ops.swiglu_op import SwiGLU as SwiGLU_fused + + XFORMERS_SWIGLU_AVAILABLE = True +except ImportError: + warnings.warn( + "xFormers fused SwiGLU kernel not found. " "Using native PyTorch implementation for feed-forward layers.", + ImportWarning, + ) + XFORMERS_SWIGLU_AVAILABLE = False + + +class TransformerLayer(torch.nn.Module): + """ + A transformer block that applies multihead attention followed by a feedforward network. + + The transformer can be configured to apply time-wise attention (i.e. attention over the time axis) + or space-wise attention (i.e. attention over the variate axis). + + The transformer block uses pre-norm, which is a variant of the transformer architecture where + LayerNorm is applied before each sublayer, rather than after. This is the approach taken in + LLaMA and other recent transformer-based models. + + The transformer block also uses SwiGLU, which is a variant of the Gated Linear Unit (GLU) activation + function. SwiGLU is a variant of the GLU activation that uses the Swish activation function. This + activation function has been used extensively in recent transformer-based models and has been shown + to improve performance. + """ + + embed_dim: int + num_heads: int + mlp_hidden_dim: int + dropout: float + attention_axis: AttentionAxis + + def __init__( + self, + embed_dim: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + rotary_emb: RotaryEmbedding = None, + attention_axis: AttentionAxis = AttentionAxis.TIME, + RMS_norm: bool = True, + use_memory_efficient_attention: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.mlp_hidden_dim = mlp_hidden_dim + self.dropout = dropout + self.attention_axis = attention_axis + + if RMS_norm: + self.norm1: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim) + self.norm2: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim) + + else: + self.norm1 = torch.nn.LayerNorm(embed_dim) + self.norm2 = torch.nn.LayerNorm(embed_dim) + + self.attention: MultiHeadAttention + + if attention_axis == AttentionAxis.TIME: + self.attention = TimeWiseMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + rotary_emb=rotary_emb, + use_memory_efficient_attention=use_memory_efficient_attention, + ) + elif attention_axis == AttentionAxis.SPACE: + self.attention = SpaceWiseMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + rotary_emb=None, + use_memory_efficient_attention=use_memory_efficient_attention, + ) + else: + raise ValueError("Invalid attention axis") + + if XFORMERS_SWIGLU_AVAILABLE: + self.mlp = torch.nn.Sequential( + SwiGLU_fused(in_features=embed_dim, hidden_features=mlp_hidden_dim), + torch.nn.Dropout(dropout), + ) + else: + self.mlp = torch.nn.Sequential( + torch.nn.Linear(embed_dim, 2 * mlp_hidden_dim), + SwiGLU(), + torch.nn.Linear(mlp_hidden_dim, embed_dim), + torch.nn.Dropout(dropout), + ) + + def forward( + self, + layer_idx: int, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + attention_mask: Optional[ + Union[ + Bool[torch.Tensor, "batch seq_len variate variate"], + Bool[torch.Tensor, "batch #variate seq_len seq_len"], + ] + ] = None, + kv_cache: Optional[KVCache] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + pre_norm_1 = self.norm1(inputs) + hidden_state = inputs + self.attention(layer_idx, pre_norm_1, attention_mask, kv_cache).contiguous() + + pre_norm_2 = self.norm2(hidden_state) + return hidden_state + self.mlp(pre_norm_2) + + +class Transformer(torch.nn.Module): + """ + A stack of transformer layers. The transformer alternates between time-wise and space-wise attention + to learn both temporal and cross-variate dependencies in the data. + + Based on the intuition that time-wise attention is more important overall than space-wise attention + (because an individual variate is more likely to be correlated with itself across time than with other variates), + the transformer can be configured to apply space-wise attention less frequently than time-wise attention. + This is controlled by the `spacewise_every_n_layers` parameter, which specifies how many time-wise transformer + layers to apply between every space-wise transformer layer. + + Parameters + ---------- + num_layers + Number of transformer layers to use. + num_heads + Number of attention heads to use in each self-attention layer. + mlp_hidden_dim + Dimension of the hidden layer in the feedforward network. + dropout + Dropout rate to use in the model. + spacewise_every_n_layers + How many time-wise transformer layers to apply between each space-wise transformer layer. + spacewise_first + Whether to apply space-wise attention before time-wise attention. + use_memory_efficient_attention + Whether to use memory-efficient attention. If True, the model will use the memory-efficient from xFormers. + """ + + def __init__( + self, + num_layers: int, + embed_dim: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + spacewise_first: bool, + use_memory_efficient_attention: bool = True, + *, + fusion: Optional[Fusion] = None, + ): + super().__init__() + + assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads." + + self.rotary_emb = TimeAwareRotaryEmbedding( + embed_dim // num_heads, + use_xpos=True, + cache_if_possible=True, + seq_before_head_dim=use_memory_efficient_attention, + ) + attention_axes = self._get_layer_types(num_layers, spacewise_every_n_layers, spacewise_first) + + self.use_memory_efficient_attention = use_memory_efficient_attention + self.fusion = fusion + + self.layers = torch.nn.ModuleList( + [ + TransformerLayer( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + rotary_emb=self.rotary_emb, + attention_axis=attention_axes[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + ) + for i in range(num_layers) + ] + ) + + def _get_mask( + self, + num_heads: int, + dtype: torch.dtype, + id_mask: Optional[torch.Tensor] = None, + ) -> Union[ + Bool[torch.Tensor, "batch num_heads seq_len seq_len"], + Float[torch.Tensor, "batch num_heads seq_len seq_len"], + Bool[torch.Tensor, "batch num_heads variate variate"], + Float[torch.Tensor, "batch num_heads variate variate"], + ]: + """ + Unified method to create and process space-wise masks. + + Args: + mask_type: Type of mask to create ('spacewise'). + seq_len: Total sequence length. + num_heads: Number of attention heads. + device: Device where the mask should be created. + dtype: Desired dtype for the bias tensor. + id_mask: Mask for variates (used for spacewise masks). + + Returns: + Processed attention mask tensor with the correct shape for the given mask type. + """ + + if id_mask is None: + raise ValueError("id_mask must be provided for spacewise masks.") + + # Create spacewise mask + mask = make_batched_block_mask(id_mask.transpose(-1, -2)) + + if self.use_memory_efficient_attention: + mask = self._pad_to_multiple(mask) + mask = mask.float().masked_fill(~mask, float("-inf")).masked_fill(mask, 0.0).to(dtype) + + # Rearrange for space-wise attention + mask = rearrange(mask, "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1 variate2") + # Stack along num_heads dimension + return mask.expand(-1, num_heads, -1, -1).contiguous() + + def _pad_to_multiple( + self, + tensor: torch.Tensor, + multiple: int = 8, + causal: bool = False, # New flag to indicate causal mask extension + ) -> torch.Tensor: + """ + Pads the last two dimensions of a tensor to be divisible by `multiple`. + For causal masks, the padded area is filled with the continued lower-triangular pattern, + rather than with zeros. + """ + pad_amount = (multiple - tensor.shape[-1] % multiple) % multiple + if pad_amount > 0: + new_size = tensor.shape[-1] + pad_amount + if causal: + # Create a full causal mask for the new size. + full_mask = torch.tril(torch.ones((new_size, new_size), dtype=tensor.dtype, device=tensor.device)) + # Preserve any modifications from the original mask (e.g., condition tokens in top-left) + full_mask[: tensor.shape[-1], : tensor.shape[-1]] = tensor + tensor = full_mask + else: + tensor = F.pad(tensor, (0, pad_amount, 0, pad_amount)) + return tensor + + def _get_layer_types( + self, + num_layers: int, + spacewise_every_n_layers: int, + spacewise_first: bool, + ) -> list[AttentionAxis]: + if spacewise_every_n_layers == -1: + return [AttentionAxis.TIME] * num_layers + assert num_layers % spacewise_every_n_layers == 0 + + block = [AttentionAxis.TIME] * (spacewise_every_n_layers - 1) + + if spacewise_first: + block = [AttentionAxis.SPACE] + block + else: + block = block + [AttentionAxis.SPACE] + + layer_types = block * (num_layers // spacewise_every_n_layers) + + return layer_types + + def forward( + self, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + id_mask: Float[torch.Tensor, "batch #variate seq_len"], + kv_cache: Optional[KVCache] = None, + variate_label_embeds: Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + + # Apply fusion (prepend variate label embeddings) only once at the beginning + # Skip when KV cache indicates we are in incremental decoding steps + if self.fusion is not None and variate_label_embeds is not None: + should_apply_fusion = True + if kv_cache is not None: + kv_len_tensor = kv_cache.current_len(0) + kv_len = int(kv_len_tensor) if isinstance(kv_len_tensor, torch.Tensor) else kv_len_tensor + should_apply_fusion = kv_len == 0 + if should_apply_fusion: + inputs = self.fusion(inputs, variate_label_embeds=variate_label_embeds) + + batch, _, seq_len, _ = inputs.shape + + # If fusion prepended tokens, pad id_mask along seq_len to match + # by repeating the first timestep mask for the number of added tokens. + if id_mask is not None and id_mask.shape[-1] != seq_len: + added = int(seq_len - id_mask.shape[-1]) + if added > 0: + pad_slice = id_mask[..., :1] # (batch, variate, 1) + id_mask = torch.cat([pad_slice.expand(-1, -1, added), id_mask], dim=-1) + # Get the sequence length by looking up a timewise layer in the kv cache. + # Regardless of whether spacewise is first in the stack, the layer + # at index 1 is always a timewise layer. + seq_len = (kv_cache.seq_len(1) if kv_cache else 0) + seq_len + + num_heads: int = cast(int, self.layers[0].num_heads) + + timewise_attention_mask = None + + # We create a space-wise ID mask by creating a block triangular mask from the ID mask + # in the space-wise direction. This ensures that the model can only attend to + # variates in the same group. + spacewise_attention_mask = self._get_mask( + num_heads=num_heads, + dtype=inputs.dtype, + id_mask=id_mask, + ) + + for layer_idx, layer in enumerate(self.layers): + inputs = layer( + layer_idx, + inputs, + (timewise_attention_mask if layer.attention_axis == AttentionAxis.TIME else spacewise_attention_mask), + kv_cache, + ) + return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py new file mode 100644 index 0000000000000..d891a0b67d3e4 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py @@ -0,0 +1,213 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +import warnings +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, TypeAlias, Union + +import torch +from einops import rearrange +from jaxtyping import Float, Int + +from ..model.attention import TimeWiseMultiheadAttention + +if TYPE_CHECKING: + from ..model.transformer import TransformerLayer # Import only for type checking + +try: + from xformers import _is_triton_available + from xformers.ops.rmsnorm import rms_norm, rms_norm_add + + XFORMERS_RMSNORM_AVAILABLE = True +except ImportError: + + warnings.warn( + "xFormers fused RMSNorm implementation not available. Will not use " "optimized kernel for inference.", + ImportWarning, + ) + + def _is_triton_available(): + return False + + XFORMERS_RMSNORM_AVAILABLE = False + + +class RMSNorm(torch.nn.Module): + """ + Wraps xFormers' rms_norm for eval/frozen mode, and does a Python fallback for train mode. + """ + + def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-8): + super(RMSNorm, self).__init__() + self.eps = eps + if include_weight: + self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter(torch.ones(dim)) + else: + self.scale = None + + def forward(self, x: torch.Tensor): + # If in eval mode OR if weight is frozen (requires_grad=False), + # then use fused xformers kernel + if ( + ((not self.training) or (self.scale is not None and not self.scale.requires_grad)) + and XFORMERS_RMSNORM_AVAILABLE + and _is_triton_available() + ): + return rms_norm(x, self.scale, self.eps) # xFormers fused + + # Fallback: standard RMS Norm in Python + x_normed = x / torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + # Scale the normalized input + return x_normed if self.scale is None else x_normed * self.scale + + def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor): + """ + If you need the fused addition with RMS norm, do the same check here. + """ + if (not self.training) or (self.scale is not None and not self.scale.requires_grad): + return rms_norm_add(x, y, self.scale, self.eps) + + # Fallback: x += y; then do RMS Norm + return self.forward(x + y) + + +def make_batched_block_mask(t: torch.Tensor) -> torch.Tensor: + unsqueezed = rearrange(t, "... d -> ... 1 d") + return unsqueezed == unsqueezed.transpose(-1, -2) + + +K: TypeAlias = Float[torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"] +V: TypeAlias = Float[torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"] +KV: TypeAlias = tuple[K, V] + + +@dataclass +class KVCache: + """ + Key/Value cache for storing intermediate attention values + during multistep inference. Only stores KV cache for timewise layers, skipping spacewise layers. + """ + + batch_size: int + num_variates: int + transformer_layers: List["TransformerLayer"] + num_layers: int + embed_dim: int + num_heads: int + max_seq_len: int + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float32 + use_memory_efficient_attention: bool = True + + _keys: Union[ + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim"], + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim"], + ] = field(init=False) + + _values: Union[ + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim"], + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim"], + ] = field(init=False) + + _current_idx: Int[torch.Tensor, "time_layer_count"] = field(init=False) + _layer_cache_map: Int[torch.Tensor, "num_layers"] = field(init=False) + + def __post_init__(self): + """ + - Determine timewise vs. spacewise layers and allocate KV only for timewise. + - Create a fast tensor-based mapping from global layer_idx -> timewise layer_idx. + """ + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + head_dim = self.embed_dim // self.num_heads + + # Compute which layers are timewise + time_layer_indices = [ + i + for i in range(self.num_layers) + if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention) + ] + + time_layer_count = max(1, len(time_layer_indices)) # handle edge case for no timewise layers + # Allocate for only the timewise layers + if self.use_memory_efficient_attention: + shape = ( + time_layer_count, + self.batch_size * self.num_variates, + self.max_seq_len, + self.num_heads, + head_dim, + ) + else: + shape = ( + time_layer_count, + self.batch_size * self.num_variates, + self.num_heads, + self.max_seq_len, + head_dim, + ) + self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype) + self._values = torch.zeros_like(self._keys) + self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int) + # Build a tensor lookup for global -> timewise layer index (default to 0) + self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device) + for cache_idx, layer_idx in enumerate(time_layer_indices): + self._layer_cache_map[layer_idx] = int(cache_idx) # Assign correct indices + + def __getitem__(self, layer_idx: int) -> KV: + cache_idx = int(self._layer_cache_map[layer_idx].item()) + end_idx = int(self._current_idx[cache_idx].item()) + + if self.use_memory_efficient_attention: + return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :] + else: + return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :] + + def current_len(self, cache_idx: int) -> int: + return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0 + + def seq_len(self, layer_idx: int) -> int: + cache_idx = int(self._layer_cache_map[layer_idx].item()) + return self.current_len(cache_idx) + + def append(self, layer_idx: int, kv: KV): + cache_idx = int(self._layer_cache_map[layer_idx].item()) + keys, values = kv + + # Validate dimensions + assert keys.shape == values.shape, "keys and values must have the same shape" + assert ( + keys.shape[0] == self.batch_size * self.num_variates + ), "keys and values must have batch_size * num_variates as their first dimension" + + if self.use_memory_efficient_attention: + assert keys.shape[2] == self.num_heads, "keys and values must have num_heads as their third dimension" + else: + assert keys.shape[1] == self.num_heads, "keys and values must have num_heads as their second dimension" + assert ( + keys.shape[3] == self.embed_dim // self.num_heads + ), "keys and values must have head_dim as their fourth dimension" + + start_idx = self._current_idx[cache_idx] + if self.use_memory_efficient_attention: + end_idx = start_idx + keys.shape[1] + else: + end_idx = start_idx + keys.shape[2] + assert ( + end_idx <= self.max_seq_len + ), f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}" + + if self.use_memory_efficient_attention: + self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys + self._values[cache_idx, :, start_idx:end_idx, :, :] = values + else: + self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys + self._values[cache_idx, :, :, start_idx:end_idx, :] = values + + self._current_idx[cache_idx] = end_idx + + def reset(self): + self._keys.zero_() + self._values.zero_() + self._current_idx.zero_() diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py new file mode 100644 index 0000000000000..4f27f6c5f4f3b --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -0,0 +1,62 @@ +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries + + +class TotoPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, **model_kwargs) + + def preprocess(self, inputs, **infer_kwargs): + super().preprocess(inputs, **infer_kwargs) + processed_inputs = [] + + for item in inputs: + targets = item["targets"] + if targets.ndim == 1: + targets = targets.unsqueeze(0) + + variate_count, series_len = targets.shape + device = targets.device + + processed_inputs.append( + MaskedTimeseries( + series=targets, + padding_mask=torch.ones( + (variate_count, series_len), dtype=torch.bool, device=device + ), + id_mask=torch.arange( + variate_count, dtype=torch.int64, device=device + ).unsqueeze(-1).expand(variate_count, series_len), + timestamp_seconds=torch.arange( + series_len, dtype=torch.int64, device=device + ).unsqueeze(0).expand(variate_count, series_len), + time_interval_seconds=torch.ones( + variate_count, dtype=torch.int64, device=device + ), + num_exogenous_variables=0, + ) + ) + + return processed_inputs + + def forecast(self, inputs, **infer_kwargs): + output_length = infer_kwargs.get("output_length", 96) + num_samples = infer_kwargs.get("num_samples", None) + + outputs = [] + for item in inputs: + forecast = self.model.forecast( + item, + prediction_length=output_length, + num_samples=num_samples, + ) + mean = forecast.mean + if mean.ndim == 3 and mean.shape[0] == 1: + mean = mean.squeeze(0) + outputs.append(mean) + return outputs + + def postprocess(self, outputs, **infer_kwargs): + return super().postprocess(outputs, **infer_kwargs) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 6a29dbbdb628b..1a428a091103a 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -86,6 +86,8 @@ tokenizers = ">=0.22.0,<=0.23.0" huggingface_hub = "^0.34.4" safetensors = "^0.6.2" einops = "^0.8.1" +gluonts = {version = "^0.16.2", extras = ["torch"]} +rotary-embedding-torch = "^0.8.6" # ---- Core scientific stack ---- numpy = ">=2.0,<2.4.0" diff --git a/iotdb-core/ainode/tests/test_toto_pipeline.py b/iotdb-core/ainode/tests/test_toto_pipeline.py new file mode 100644 index 0000000000000..c04633424d674 --- /dev/null +++ b/iotdb-core/ainode/tests/test_toto_pipeline.py @@ -0,0 +1,50 @@ +import unittest +from types import SimpleNamespace + +import torch + +from iotdb.ainode.core.inference.pipeline import basic_pipeline +from iotdb.ainode.core.model.model_info import BUILTIN_HF_TRANSFORMERS_MODEL_MAP +from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries +from iotdb.ainode.core.model.toto.pipeline_toto import TotoPipeline + + +class _FakeForecaster: + def forecast(self, item, prediction_length, num_samples=None): + variate_count = item.series.shape[0] + mean = torch.ones((1, variate_count, prediction_length), dtype=torch.float32) + return SimpleNamespace(mean=mean) + + +class TotoPipelineTest(unittest.TestCase): + def setUp(self): + self._original_load_model = basic_pipeline.load_model + basic_pipeline.load_model = lambda *args, **kwargs: _FakeForecaster() + + def tearDown(self): + basic_pipeline.load_model = self._original_load_model + + def test_preprocess_builds_masked_timeseries(self): + pipeline = TotoPipeline(BUILTIN_HF_TRANSFORMERS_MODEL_MAP["toto"]) + + processed = pipeline.preprocess([{"targets": torch.randn(2, 16)}]) + + self.assertEqual(1, len(processed)) + self.assertIsInstance(processed[0], MaskedTimeseries) + self.assertEqual((2, 16), tuple(processed[0].series.shape)) + self.assertEqual((2, 16), tuple(processed[0].padding_mask.shape)) + self.assertEqual((2, 16), tuple(processed[0].id_mask.shape)) + + def test_forecast_returns_2d_outputs_after_postprocess(self): + pipeline = TotoPipeline(BUILTIN_HF_TRANSFORMERS_MODEL_MAP["toto"]) + processed = pipeline.preprocess([{"targets": torch.randn(2, 16)}], output_length=8) + + outputs = pipeline.forecast(processed, output_length=8, num_samples=None) + outputs = pipeline.postprocess(outputs) + + self.assertEqual(1, len(outputs)) + self.assertEqual((2, 8), tuple(outputs[0].shape)) + + +if __name__ == "__main__": + unittest.main() diff --git a/iotdb-core/ainode/verify_toto.py b/iotdb-core/ainode/verify_toto.py new file mode 100644 index 0000000000000..927f468f498c5 --- /dev/null +++ b/iotdb-core/ainode/verify_toto.py @@ -0,0 +1,22 @@ +import torch + +from iotdb.ainode.core.model.model_info import BUILTIN_HF_TRANSFORMERS_MODEL_MAP +from iotdb.ainode.core.model.toto.pipeline_toto import TotoPipeline + + +def main(): + model_info = BUILTIN_HF_TRANSFORMERS_MODEL_MAP["toto"] + pipeline = TotoPipeline(model_info) + + inputs = [{"targets": torch.randn(1, 128)}] + processed = pipeline.preprocess(inputs, output_length=8) + outputs = pipeline.forecast(processed, output_length=8, num_samples=None) + outputs = pipeline.postprocess(outputs) + + print(f"loaded model: {model_info.model_id}") + print(f"batch size: {len(outputs)}") + print(f"output shape: {tuple(outputs[0].shape)}") + + +if __name__ == "__main__": + main()