Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,16 @@ def __repr__(self):
},
transformers_registered=True,
),
"patchtst_fm": ModelInfo(
model_id = "patchtst_fm",
category=ModelCategory.BUILTIN,
state=ModelStates.INACTIVE,
model_type="patchtst_fm",
pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline",
repo_id="ibm-research/patchtst-fm-r1",
auto_map={
"AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig",
"AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction",
},
),
}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright contributors to the TSFM project
#
"""PatchTST-FM model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)

PATCHTSTFM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}


class PatchTSTFMConfig(PretrainedConfig):
model_type = "patchtst_fm"
attribute_map = {
"hidden_size": "d_model",
"num_hidden_layers": "n_layer",
}

# has_no_defaults_at_init = True
def __init__(
self,
context_length: int = 8192,
prediction_length: int = 64,
d_patch: int = 16,
d_model: int = 384,
n_head: int = 6,
n_layer: int = 6,
norm_first: bool = True,
pretrain_mask_ratio: float = 0.4,
pretrain_mask_cont: int = 8,
num_quantile: int = 99,
**kwargs,
):
self.context_length = context_length
self.prediction_length = prediction_length
self.d_patch = d_patch
self.n_patch = int(context_length // d_patch)
self.d_model = d_model
self.n_head = n_head
self.n_layer = n_layer
self.norm_first = norm_first
self.pretrain_mask_ratio = pretrain_mask_ratio
self.pretrain_mask_cont = pretrain_mask_cont
self.num_quantile = num_quantile

if num_quantile % 9 == 0:
quantiles = [i / (self.num_quantile + 1) for i in range(1, self.num_quantile + 1)]
else:
quantiles = [i / (self.num_quantile - 1) for i in range(1, self.num_quantile - 1)]
quantiles = [0.01] + quantiles + [0.99]
self.quantile_levels = quantiles
super().__init__(**kwargs)
Loading
Loading