Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ env/
sagemaker_train/src/**/container_drivers/sm_train.sh
sagemaker_train/src/**/container_drivers/sourcecode.json
sagemaker_train/src/**/container_drivers/distributed.json
.kiro
22 changes: 12 additions & 10 deletions sagemaker-core/src/sagemaker/core/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35788,7 +35788,7 @@ def stop(self) -> None:
ResourceNotFound: Resource being access is not found.
"""

client = SageMakerClient().client
client = SageMakerClient().sagemaker_client

operation_input_args = {
"TrainingJobName": self.training_job_name,
Expand Down Expand Up @@ -35833,15 +35833,17 @@ def wait(
progress.add_task("Waiting for TrainingJob...")
status = Status("Current status:")

instance_count = (
sum(
instance_group.instance_count
for instance_group in self.resource_config.instance_groups
)
if self.resource_config.instance_groups
and not isinstance(self.resource_config.instance_groups, Unassigned)
else self.resource_config.instance_count
)
instance_count = 1 # Default
if not isinstance(self.resource_config, Unassigned):
if (hasattr(self.resource_config, 'instance_groups') and
self.resource_config.instance_groups and
not isinstance(self.resource_config.instance_groups, Unassigned)):
instance_count = sum(
instance_group.instance_count
for instance_group in self.resource_config.instance_groups
)
elif hasattr(self.resource_config, 'instance_count'):
instance_count = self.resource_config.instance_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to update the autogen engine for these changes .


if logs:
multi_stream_logger = MultiLogStreamHandler(
Expand Down
19 changes: 13 additions & 6 deletions sagemaker-core/src/sagemaker/core/tools/resources_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,12 +1664,19 @@ def _get_instance_count_ref(self, resource_name: str) -> str:
"""

if resource_name == "TrainingJob":
return """(
sum(instance_group.instance_count for instance_group in self.resource_config.instance_groups)
if self.resource_config.instance_groups and not isinstance(self.resource_config.instance_groups, Unassigned)
else self.resource_config.instance_count
)
"""
return """1 # Default
if not isinstance(self.resource_config, Unassigned):
if (
hasattr(self.resource_config, "instance_groups")
and self.resource_config.instance_groups
and not isinstance(self.resource_config.instance_groups, Unassigned)
):
instance_count = sum(
instance_group.instance_count
for instance_group in self.resource_config.instance_groups
)
elif hasattr(self.resource_config, "instance_count"):
instance_count = self.resource_config.instance_count"""
elif resource_name == "TransformJob":
return "self.transform_resources.instance_count"
elif resource_name == "ProcessingJob":
Expand Down
2 changes: 1 addition & 1 deletion sagemaker-core/src/sagemaker/core/tools/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def delete(
@Base.add_validate_call
def stop(self) -> None:
{docstring}
client = SageMakerClient().client
client = SageMakerClient().sagemaker_client

operation_input_args = {{
{operation_input_args}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "py3.10.14",
"language": "python",
"name": "python3"
},
Expand All @@ -456,7 +456,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
5 changes: 5 additions & 0 deletions sagemaker-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ test = [
"graphene",
"IPython"
]
notebook = [
"ipywidgets>=8.0.0",
"rich>=13.0.0",
"matplotlib>=3.5.0",
]

[tool.setuptools.packages.find]
where = ["src/"]
Expand Down
12 changes: 12 additions & 0 deletions sagemaker-train/src/sagemaker/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,16 @@ def __getattr__(name):
elif name == "get_builtin_metrics":
from sagemaker.train.evaluate import get_builtin_metrics
return get_builtin_metrics
elif name == "plot_training_metrics":
from sagemaker.train.common_utils.metrics_visualizer import plot_training_metrics
return plot_training_metrics
elif name == "get_available_metrics":
from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
return get_available_metrics
elif name == "get_studio_url":
from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
return get_studio_url
elif name == "get_mlflow_url":
from sagemaker.train.common_utils.trainer_wait import get_mlflow_url
return get_mlflow_url
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class _MLflowConstants:

# Metric names
TOTAL_LOSS_METRIC = 'total_loss'
LOSS_METRIC_KEYWORDS = ('loss',)
EPOCH_KEYWORD = 'epoch'

# MLflow run tags
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni

except Exception as e:
logger.error("Exception getting fine-tuning options: %s", e)
raise


def _create_input_channels(dataset: str, content_type: Optional[str] = None,
Expand Down
Loading
Loading