Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
513 changes: 513 additions & 0 deletions sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md

Large diffs are not rendered by default.

125 changes: 125 additions & 0 deletions sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0
"""SageMaker FeatureStore V3 - powered by sagemaker-core."""

# Resources from core
from sagemaker.core.resources import FeatureGroup, FeatureMetadata

# Shapes from core (Pydantic - no to_dict() needed)
from sagemaker.core.shapes import (
DataCatalogConfig,
FeatureParameter,
FeatureValue,
Filter,
OfflineStoreConfig,
OnlineStoreConfig,
OnlineStoreSecurityConfig,
S3StorageConfig,
SearchExpression,
ThroughputConfig,
TtlDuration,
)

# Enums (local - core uses strings)
from sagemaker.mlops.feature_store.inputs import (
DeletionModeEnum,
ExpirationTimeResponseEnum,
FilterOperatorEnum,
OnlineStoreStorageTypeEnum,
ResourceEnum,
SearchOperatorEnum,
SortOrderEnum,
TableFormatEnum,
TargetStoreEnum,
ThroughputModeEnum,
)

# Feature Definition helpers (local)
from sagemaker.mlops.feature_store.feature_definition import (
FeatureDefinition,
FeatureTypeEnum,
CollectionTypeEnum,
FractionalFeatureDefinition,
IntegralFeatureDefinition,
StringFeatureDefinition,
ListCollectionType,
SetCollectionType,
VectorCollectionType,
)

# Utility functions (local)
from sagemaker.mlops.feature_store.feature_utils import (
as_hive_ddl,
create_athena_query,
get_session_from_role,
ingest_dataframe,
load_feature_definitions_from_dataframe,
)

# Classes (local)
from sagemaker.mlops.feature_store.athena_query import AthenaQuery
from sagemaker.mlops.feature_store.dataset_builder import (
DatasetBuilder,
FeatureGroupToBeMerged,
JoinComparatorEnum,
JoinTypeEnum,
TableType,
)
from sagemaker.mlops.feature_store.ingestion_manager_pandas import (
IngestionError,
IngestionManagerPandas,
)

__all__ = [
# Resources
"FeatureGroup",
"FeatureMetadata",
# Shapes
"DataCatalogConfig",
"FeatureParameter",
"FeatureValue",
"Filter",
"OfflineStoreConfig",
"OnlineStoreConfig",
"OnlineStoreSecurityConfig",
"S3StorageConfig",
"SearchExpression",
"ThroughputConfig",
"TtlDuration",
# Enums
"DeletionModeEnum",
"ExpirationTimeResponseEnum",
"FilterOperatorEnum",
"OnlineStoreStorageTypeEnum",
"ResourceEnum",
"SearchOperatorEnum",
"SortOrderEnum",
"TableFormatEnum",
"TargetStoreEnum",
"ThroughputModeEnum",
# Feature Definitions
"FeatureDefinition",
"FeatureTypeEnum",
"CollectionTypeEnum",
"FractionalFeatureDefinition",
"IntegralFeatureDefinition",
"StringFeatureDefinition",
"ListCollectionType",
"SetCollectionType",
"VectorCollectionType",
# Utility functions
"as_hive_ddl",
"create_athena_query",
"get_session_from_role",
"ingest_dataframe",
"load_feature_definitions_from_dataframe",
# Classes
"AthenaQuery",
"DatasetBuilder",
"FeatureGroupToBeMerged",
"IngestionError",
"IngestionManagerPandas",
"JoinComparatorEnum",
"JoinTypeEnum",
"TableType",
]
112 changes: 112 additions & 0 deletions sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import tempfile
from dataclasses import dataclass, field
from typing import Any, Dict
from urllib.parse import urlparse
import pandas as pd
from pandas import DataFrame

from sagemaker.mlops.feature_store.feature_utils import (
start_query_execution,
get_query_execution,
wait_for_athena_query,
download_athena_query_result,
)

from sagemaker.core.helper.session_helper import Session

@dataclass
class AthenaQuery:
"""Class to manage querying of feature store data with AWS Athena.

This class instantiates a AthenaQuery object that is used to retrieve data from feature store
via standard SQL queries.

Attributes:
catalog (str): name of the data catalog.
database (str): name of the database.
table_name (str): name of the table.
sagemaker_session (Session): instance of the Session class to perform boto calls.
"""

catalog: str
database: str
table_name: str
sagemaker_session: Session
_current_query_execution_id: str = field(default=None, init=False)
_result_bucket: str = field(default=None, init=False)
_result_file_prefix: str = field(default=None, init=False)

def run(
self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None
) -> str:
"""Execute a SQL query given a query string, output location and kms key.

This method executes the SQL query using Athena and outputs the results to output_location
and returns the execution id of the query.

Args:
query_string: SQL query string.
output_location: S3 URI of the query result.
kms_key: KMS key id. If set, will be used to encrypt the query result file.
workgroup (str): The name of the workgroup in which the query is being started.

Returns:
Execution id of the query.
"""
response = start_query_execution(
session=self.sagemaker_session,
catalog=self.catalog,
database=self.database,
query_string=query_string,
output_location=output_location,
kms_key=kms_key,
workgroup=workgroup,
)

self._current_query_execution_id = response["QueryExecutionId"]
parsed_result = urlparse(output_location, allow_fragments=False)
self._result_bucket = parsed_result.netloc
self._result_file_prefix = parsed_result.path.strip("/")
return self._current_query_execution_id

def wait(self):
"""Wait for the current query to finish."""
wait_for_athena_query(self.sagemaker_session, self._current_query_execution_id)

def get_query_execution(self) -> Dict[str, Any]:
"""Get execution status of the current query.

Returns:
Response dict from Athena.
"""
return get_query_execution(self.sagemaker_session, self._current_query_execution_id)

def as_dataframe(self, **kwargs) -> DataFrame:
"""Download the result of the current query and load it into a DataFrame.

Args:
**kwargs (object): key arguments used for the method pandas.read_csv to be able to
have a better tuning on data. For more info read:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html

Returns:
A pandas DataFrame contains the query result.
"""
state = self.get_query_execution()["QueryExecution"]["Status"]["State"]
if state != "SUCCEEDED":
if state in ("QUEUED", "RUNNING"):
raise RuntimeError(f"Query {self._current_query_execution_id} still executing.")
raise RuntimeError(f"Query {self._current_query_execution_id} failed.")

output_file = os.path.join(tempfile.gettempdir(), f"{self._current_query_execution_id}.csv")
download_athena_query_result(
session=self.sagemaker_session,
bucket=self._result_bucket,
prefix=self._result_file_prefix,
query_execution_id=self._current_query_execution_id,
filename=output_file,
)
kwargs.pop("delimiter", None)
return pd.read_csv(output_file, delimiter=",", **kwargs)

Loading
Loading