Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ public class AINodeTestUtils {
new AbstractMap.SimpleEntry<>(
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")))
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"toto", new FakeModelInfo("toto", "toto", "builtin", "active")))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;
Expand Down
10 changes: 10 additions & 0 deletions iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,14 @@ def __repr__(self):
},
transformers_registered=True,
),
"toto": ModelInfo(
model_id="toto",
category=ModelCategory.BUILTIN,
state=ModelStates.INACTIVE,
model_type="toto",
pipeline_cls="pipeline_toto.TotoPipeline",
repo_id="Datadog/Toto-Open-Base-1.0",
auto_map=None,
transformers_registered=False,
),
}
27 changes: 24 additions & 3 deletions iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,33 @@
from iotdb.ainode.core.model.model_info import ModelInfo
from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model
from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path
from iotdb.ainode.core.model.toto.inference.forecaster import TotoForecaster
from iotdb.ainode.core.model.toto.model.toto import Toto


logger = Logger()
BACKEND = DeviceManager()


def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
if model_info.auto_map is not None:
if model_info.model_type == "toto":
model = load_toto_model(model_info, **model_kwargs)
elif model_info.auto_map is not None:
model = load_model_from_transformers(model_info, **model_kwargs)
else:
if model_info.model_type == "sktime":
model = create_sktime_model(model_info.model_id)
else:
model = load_model_from_pt(model_info, **model_kwargs)


model_device = getattr(model, "device", "cpu")
logger.info(
f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully."
f"Model {model_info.model_id} loaded to device {model_device if model_info.model_type != 'sktime' else 'cpu'} successfully."
)
return model



def load_model_from_transformers(model_info: ModelInfo, **model_kwargs):
device_map = model_kwargs.get("device_map", "cpu")
train_from_scratch = model_kwargs.get("train_from_scratch", False)
Expand Down Expand Up @@ -135,6 +142,19 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs):
logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}")
return BACKEND.move_model(model, device_map)

def load_toto_model(model_info: ModelInfo, **model_kwargs):
device_map = model_kwargs.get("device_map", "cpu")
model_path = os.path.join(
os.getcwd(),
AINodeDescriptor().get_config().get_ain_models_dir(),
model_info.category.value,
model_info.model_id,
)

model = Toto.from_pretrained(model_path)
model = BACKEND.move_model(model, device_map)
return TotoForecaster(model.model)


def load_model_for_efficient_inference():
# TODO: An efficient model loading method for inference based on model_arguments
Expand All @@ -146,5 +166,6 @@ def load_model_for_powerful_finetune():
pass



def unload_model():
pass
Empty file.
Empty file.
Loading