From 534df91349710aa7e20a858ddd68a71ae6450995 Mon Sep 17 00:00:00 2001 From: adishaa Date: Fri, 16 Jan 2026 07:00:13 -0800 Subject: [PATCH 1/2] feat: Add Feature Store Support to V3 --- .../sagemaker/mlops/feature_store/__init__.py | 129 ++++ .../mlops/feature_store/athena_query.py | 112 +++ .../mlops/feature_store/dataset_builder.py | 725 ++++++++++++++++++ .../mlops/feature_store/feature_definition.py | 107 +++ .../mlops/feature_store/feature_utils.py | 488 ++++++++++++ .../feature_store/ingestion_manager_pandas.py | 321 ++++++++ .../sagemaker/mlops/feature_store/inputs.py | 60 ++ 7 files changed, 1942 insertions(+) create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py new file mode 100644 index 0000000000..ee8cd7d1a3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""SageMaker FeatureStore V3 - powered by sagemaker-core.""" + +# Resources from core +from sagemaker_core.main.resources import FeatureGroup, FeatureMetadata +from sagemaker_core.main.resources import FeatureStore + +# Shapes from core (Pydantic - no to_dict() needed) +from sagemaker_core.main.shapes import ( + DataCatalogConfig, + FeatureParameter, + FeatureValue, + Filter, + OfflineStoreConfig, + OnlineStoreConfig, + OnlineStoreSecurityConfig, + S3StorageConfig, + SearchExpression, + ThroughputConfig, + TtlDuration, +) + +# Enums (local - core uses strings) +from sagemaker.mlops.feature_store.inputs import ( + DeletionModeEnum, + ExpirationTimeResponseEnum, + FilterOperatorEnum, + OnlineStoreStorageTypeEnum, + ResourceEnum, + SearchOperatorEnum, + SortOrderEnum, + TableFormatEnum, + TargetStoreEnum, + ThroughputModeEnum, +) + +# Feature Definition helpers (local) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, + CollectionTypeEnum, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + ListCollectionType, + SetCollectionType, + VectorCollectionType, +) + +# Utility functions (local) +from sagemaker.mlops.feature_store.feature_utils import ( + as_hive_ddl, + create_athena_query, + create_dataset, + get_session_from_role, + ingest_dataframe, + load_feature_definitions_from_dataframe, +) + +# Classes (local) +from sagemaker.mlops.feature_store.athena_query import AthenaQuery +from sagemaker.mlops.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + JoinComparatorEnum, + JoinTypeEnum, + TableType, +) +from sagemaker.mlops.feature_store.ingestion_manager_pandas import ( + IngestionError, + IngestionManagerPandas, +) + +__all__ = [ + # Resources + "FeatureGroup", + "FeatureMetadata", + "FeatureStore", + # Shapes + "DataCatalogConfig", + "FeatureParameter", + "FeatureValue", + "Filter", + "OfflineStoreConfig", + "OnlineStoreConfig", + "OnlineStoreSecurityConfig", + "S3StorageConfig", + "SearchExpression", + "ThroughputConfig", + "TtlDuration", + # Enums + "DeletionModeEnum", + "ExpirationTimeResponseEnum", + "FilterOperatorEnum", + "OnlineStoreStorageTypeEnum", + "ResourceEnum", + "SearchOperatorEnum", + "SortOrderEnum", + "TableFormatEnum", + "TargetStoreEnum", + "ThroughputModeEnum", + # Feature Definitions + "FeatureDefinition", + "FeatureTypeEnum", + "CollectionTypeEnum", + "FractionalFeatureDefinition", + "IntegralFeatureDefinition", + "StringFeatureDefinition", + "ListCollectionType", + "SetCollectionType", + "VectorCollectionType", + # Utility functions + "as_hive_ddl", + "create_athena_query", + "create_dataset", + "get_session_from_role", + "ingest_dataframe", + "load_feature_definitions_from_dataframe", + # Classes + "AthenaQuery", + "DatasetBuilder", + "FeatureGroupToBeMerged", + "IngestionError", + "IngestionManagerPandas", + "JoinComparatorEnum", + "JoinTypeEnum", + "TableType", +] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py new file mode 100644 index 0000000000..123b3c4305 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py @@ -0,0 +1,112 @@ +import os +import tempfile +from dataclasses import dataclass, field +from typing import Any, Dict +from urllib.parse import urlparse +import pandas as pd +from pandas import DataFrame + +from sagemaker.mlops.feature_store.feature_utils import ( + start_query_execution, + get_query_execution, + wait_for_athena_query, + download_athena_query_result, +) + +from sagemaker.core.helper.session_helper import Session + +@dataclass +class AthenaQuery: + """Class to manage querying of feature store data with AWS Athena. + + This class instantiates a AthenaQuery object that is used to retrieve data from feature store + via standard SQL queries. + + Attributes: + catalog (str): name of the data catalog. + database (str): name of the database. + table_name (str): name of the table. + sagemaker_session (Session): instance of the Session class to perform boto calls. + """ + + catalog: str + database: str + table_name: str + sagemaker_session: Session + _current_query_execution_id: str = field(default=None, init=False) + _result_bucket: str = field(default=None, init=False) + _result_file_prefix: str = field(default=None, init=False) + + def run( + self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None + ) -> str: + """Execute a SQL query given a query string, output location and kms key. + + This method executes the SQL query using Athena and outputs the results to output_location + and returns the execution id of the query. + + Args: + query_string: SQL query string. + output_location: S3 URI of the query result. + kms_key: KMS key id. If set, will be used to encrypt the query result file. + workgroup (str): The name of the workgroup in which the query is being started. + + Returns: + Execution id of the query. + """ + response = start_query_execution( + session=self.sagemaker_session, + catalog=self.catalog, + database=self.database, + query_string=query_string, + output_location=output_location, + kms_key=kms_key, + workgroup=workgroup, + ) + + self._current_query_execution_id = response["QueryExecutionId"] + parsed_result = urlparse(output_location, allow_fragments=False) + self._result_bucket = parsed_result.netloc + self._result_file_prefix = parsed_result.path.strip("/") + return self._current_query_execution_id + + def wait(self): + """Wait for the current query to finish.""" + wait_for_athena_query(self.sagemaker_session, self._current_query_execution_id) + + def get_query_execution(self) -> Dict[str, Any]: + """Get execution status of the current query. + + Returns: + Response dict from Athena. + """ + return get_query_execution(self.sagemaker_session, self._current_query_execution_id) + + def as_dataframe(self, **kwargs) -> DataFrame: + """Download the result of the current query and load it into a DataFrame. + + Args: + **kwargs (object): key arguments used for the method pandas.read_csv to be able to + have a better tuning on data. For more info read: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html + + Returns: + A pandas DataFrame contains the query result. + """ + state = self.get_query_execution()["QueryExecution"]["Status"]["State"] + if state != "SUCCEEDED": + if state in ("QUEUED", "RUNNING"): + raise RuntimeError(f"Query {self._current_query_execution_id} still executing.") + raise RuntimeError(f"Query {self._current_query_execution_id} failed.") + + output_file = os.path.join(tempfile.gettempdir(), f"{self._current_query_execution_id}.csv") + download_athena_query_result( + session=self.sagemaker_session, + bucket=self._result_bucket, + prefix=self._result_file_prefix, + query_execution_id=self._current_query_execution_id, + filename=output_file, + ) + kwargs.pop("delimiter", None) + return pd.read_csv(output_file, delimiter=",", **kwargs) + diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py new file mode 100644 index 0000000000..f5450663a6 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py @@ -0,0 +1,725 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Dataset Builder for FeatureStore.""" +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Union +import datetime + +import pandas as pd + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store import FeatureGroup +from sagemaker.mlops.feature_store.feature_definition import FeatureDefinition, FeatureTypeEnum +from sagemaker.mlops.feature_store.feature_utils import ( + upload_dataframe_to_s3, + download_csv_from_s3, + run_athena_query, +) + +_DEFAULT_CATALOG = "AwsDataCatalog" +_DEFAULT_DATABASE = "sagemaker_featurestore" + +_DTYPE_TO_FEATURE_TYPE = { + "object": "String", "string": "String", + "int64": "Integral", "int32": "Integral", + "float64": "Fractional", "float32": "Fractional", +} + +_DTYPE_TO_ATHENA_TYPE = { + "object": "STRING", "int64": "INT", "float64": "DOUBLE", + "bool": "BOOLEAN", "datetime64[ns]": "TIMESTAMP", +} + + +class TableType(Enum): + FEATURE_GROUP = "FeatureGroup" + DATA_FRAME = "DataFrame" + + +class JoinTypeEnum(Enum): + INNER_JOIN = "JOIN" + LEFT_JOIN = "LEFT JOIN" + RIGHT_JOIN = "RIGHT JOIN" + FULL_JOIN = "FULL JOIN" + CROSS_JOIN = "CROSS JOIN" + + +class JoinComparatorEnum(Enum): + EQUALS = "=" + GREATER_THAN = ">" + GREATER_THAN_OR_EQUAL_TO = ">=" + LESS_THAN = "<" + LESS_THAN_OR_EQUAL_TO = "<=" + NOT_EQUAL_TO = "<>" + + +@dataclass +class FeatureGroupToBeMerged: + """FeatureGroup metadata which will be used for SQL join. + + This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names, + a list of feature names which will be included in SQL query, a database, an Athena table name, + a feature name of record identifier, a feature name of event time identifier and a feature name + of base which is the target join key. + + Attributes: + features (List[str]): A list of strings representing feature names of this FeatureGroup. + included_feature_names (List[str]): A list of strings representing features to be + included in the SQL join. + projected_feature_names (List[str]): A list of strings representing features to be + included for final projection in output. + catalog (str): A string representing the catalog. + database (str): A string representing the database. + table_name (str): A string representing the Athena table name of this FeatureGroup. + record_identifier_feature_name (str): A string representing the record identifier feature. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + target_feature_name_in_base (str): A string representing the feature name in base which will + be used as target join key (default: None). + table_type (TableType): A TableType representing the type of table if it is Feature Group or + Panda Data Frame (default: None). + feature_name_in_target (str): A string representing the feature name in the target feature + group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + """ + features: List[str] + included_feature_names: List[str] + projected_feature_names: List[str] + catalog: str + database: str + table_name: str + record_identifier_feature_name: str + event_time_identifier_feature: FeatureDefinition + target_feature_name_in_base: str = None + table_type: TableType = None + feature_name_in_target: str = None + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN + + +def construct_feature_group_to_be_merged( + target_feature_group: FeatureGroup, + included_feature_names: List[str], + target_feature_name_in_base: str = None, + feature_name_in_target: str = None, + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS, + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN, +) -> FeatureGroupToBeMerged: + """Construct a FeatureGroupToBeMerged object by provided parameters. + + Args: + target_feature_group (FeatureGroup): A FeatureGroup object. + included_feature_names (List[str]): A list of strings representing features to be + included in the output. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + feature_name_in_target (str): A string representing the feature name in the target feature + group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + + Returns: + A FeatureGroupToBeMerged object. + + Raises: + RuntimeError: No metastore is configured with the FeatureGroup. + ValueError: Invalid feature name(s) in included_feature_names. + """ + fg = FeatureGroup.get(feature_group_name=target_feature_group.feature_group_name) + + if not fg.offline_store_config or not fg.offline_store_config.data_catalog_config: + raise RuntimeError(f"No metastore configured for FeatureGroup {fg.feature_group_name}.") + + catalog_config = fg.offline_store_config.data_catalog_config + disable_glue = catalog_config.disable_glue_table_creation or False + + features = [fd.feature_name for fd in fg.feature_definitions] + record_id = fg.record_identifier_feature_name + event_time_name = fg.event_time_feature_name + event_time_type = next( + (fd.feature_type for fd in fg.feature_definitions if fd.feature_name == event_time_name), + None + ) + + if feature_name_in_target and feature_name_in_target not in features: + raise ValueError(f"Feature {feature_name_in_target} not found in {fg.feature_group_name}") + + for feat in included_feature_names or []: + if feat not in features: + raise ValueError(f"Feature {feat} not found in {fg.feature_group_name}") + + if not included_feature_names: + included_feature_names = features.copy() + projected_feature_names = features.copy() + else: + projected_feature_names = included_feature_names.copy() + if record_id not in included_feature_names: + included_feature_names.append(record_id) + if event_time_name not in included_feature_names: + included_feature_names.append(event_time_name) + + return FeatureGroupToBeMerged( + features=features, + included_feature_names=included_feature_names, + projected_feature_names=projected_feature_names, + catalog=catalog_config.catalog if disable_glue else _DEFAULT_CATALOG, + database=catalog_config.database, + table_name=catalog_config.table_name, + record_identifier_feature_name=record_id, + event_time_identifier_feature=FeatureDefinition(event_time_name, FeatureTypeEnum(event_time_type)), + target_feature_name_in_base=target_feature_name_in_base, + table_type=TableType.FEATURE_GROUP, + feature_name_in_target=feature_name_in_target, + join_comparator=join_comparator, + join_type=join_type, + ) + + +@dataclass +class DatasetBuilder: + """DatasetBuilder definition. + + This class instantiates a DatasetBuilder object that comprises a base, a list of feature names, + an output path and a KMS key ID. + + Attributes: + _sagemaker_session (Session): Session instance to perform boto calls. + _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset. + _output_path (str): An S3 URI which stores the output .csv file. + _record_identifier_feature_name (str): A string representing the record identifier feature + if base is a DataFrame (default: None). + _event_time_identifier_feature_name (str): A string representing the event time identifier + feature if base is a DataFrame (default: None). + _included_feature_names (List[str]): A list of strings representing features to be + included in the output. If not set, all features will be included in the output. + (default: None). + _kms_key_id (str): A KMS key id. If set, will be used to encrypt the result file + (default: None). + _point_in_time_accurate_join (bool): A boolean representing if point-in-time join + is applied to the resulting dataframe when calling "to_dataframe". + When set to True, users can retrieve data using "row-level time travel" + according to the event times provided to the DatasetBuilder. This requires that the + entity dataframe with event times is submitted as the base in the constructor + (default: False). + _include_duplicated_records (bool): A boolean representing whether the resulting dataframe + when calling "to_dataframe" should include duplicated records (default: False). + _include_deleted_records (bool): A boolean representing whether the resulting + dataframe when calling "to_dataframe" should include deleted records (default: False). + _number_of_recent_records (int): An integer representing how many records will be + returned for each record identifier (default: 1). + _number_of_records (int): An integer representing the number of records that should be + returned in the resulting dataframe when calling "to_dataframe" (default: None). + _write_time_ending_timestamp (datetime.datetime): A datetime that represents the latest + write time for a record to be included in the resulting dataset. Records with a + newer write time will be omitted from the resulting dataset. (default: None). + _event_time_starting_timestamp (datetime.datetime): A datetime that represents the earliest + event time for a record to be included in the resulting dataset. Records + with an older event time will be omitted from the resulting dataset. (default: None). + _event_time_ending_timestamp (datetime.datetime): A datetime that represents the latest + event time for a record to be included in the resulting dataset. Records + with a newer event time will be omitted from the resulting dataset. (default: None). + _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of + FeatureGroupToBeMerged which will be joined to base (default: []). + _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the + type of event time identifier feature (default: None). + """ + + _sagemaker_session: Session + _base: Union[FeatureGroup, pd.DataFrame] + _output_path: str + _record_identifier_feature_name: str = None + _event_time_identifier_feature_name: str = None + _included_feature_names: List[str] = None + _kms_key_id: str = None + _event_time_identifier_feature_type: FeatureTypeEnum = None + + _point_in_time_accurate_join: bool = field(default=False, init=False) + _include_duplicated_records: bool = field(default=False, init=False) + _include_deleted_records: bool = field(default=False, init=False) + _number_of_recent_records: int = field(default=None, init=False) + _number_of_records: int = field(default=None, init=False) + _write_time_ending_timestamp: datetime.datetime = field(default=None, init=False) + _event_time_starting_timestamp: datetime.datetime = field(default=None, init=False) + _event_time_ending_timestamp: datetime.datetime = field(default=None, init=False) + _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = field(default_factory=list, init=False) + + def with_feature_group( + self, + feature_group: FeatureGroup, + target_feature_name_in_base: str = None, + included_feature_names: List[str] = None, + feature_name_in_target: str = None, + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS, + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN, + ) -> "DatasetBuilder": + """Join FeatureGroup with base. + + Args: + feature_group (FeatureGroup): A target FeatureGroup which will be joined to base. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as a join key (default: None). + included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + feature_name_in_target (str): A string representing the feature name in the target + feature group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + + Returns: + This DatasetBuilder object. + """ + self._feature_groups_to_be_merged.append( + construct_feature_group_to_be_merged( + feature_group, included_feature_names, target_feature_name_in_base, + feature_name_in_target, join_comparator, join_type, + ) + ) + return self + + def point_in_time_accurate_join(self) -> "DatasetBuilder": + """Enable point-in-time accurate join. + + Returns: + This DatasetBuilder object. + """ + self._point_in_time_accurate_join = True + return self + + def include_duplicated_records(self) -> "DatasetBuilder": + """Include duplicated records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_duplicated_records = True + return self + + def include_deleted_records(self) -> "DatasetBuilder": + """Include deleted records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_deleted_records = True + return self + + def with_number_of_recent_records_by_record_identifier(self, n: int) -> "DatasetBuilder": + """Set number_of_recent_records field with provided input. + + Args: + n (int): An int that how many recent records will be returned for + each record identifier. + + Returns: + This DatasetBuilder object. + """ + self._number_of_recent_records = n + return self + + def with_number_of_records_from_query_results(self, n: int) -> "DatasetBuilder": + """Set number_of_records field with provided input. + + Args: + n (int): An int that how many records will be returned. + + Returns: + This DatasetBuilder object. + """ + self._number_of_records = n + return self + + def as_of(self, timestamp: datetime.datetime) -> "DatasetBuilder": + """Set write_time_ending_timestamp field with provided input. + + Args: + timestamp (datetime.datetime): A datetime that all records' write time in dataset will + be before it. + + Returns: + This DatasetBuilder object. + """ + self._write_time_ending_timestamp = timestamp + return self + + def with_event_time_range( + self, + starting_timestamp: datetime.datetime = None, + ending_timestamp: datetime.datetime = None, + ) -> "DatasetBuilder": + """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs. + + Args: + starting_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be after it (default: None). + ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset + will be before it (default: None). + + Returns: + This DatasetBuilder object. + """ + self._event_time_starting_timestamp = starting_timestamp + self._event_time_ending_timestamp = ending_timestamp + return self + + def to_csv_file(self) -> tuple[str, str]: + """Get query string and result in .csv format file. + + Returns: + The S3 path of the .csv file. + The query string executed. + """ + if isinstance(self._base, pd.DataFrame): + return self._to_csv_from_dataframe() + if isinstance(self._base, FeatureGroup): + return self._to_csv_from_feature_group() + raise ValueError("Base must be either a FeatureGroup or a DataFrame.") + + def to_dataframe(self) -> tuple[pd.DataFrame, str]: + """Get query string and result in pandas.DataFrame. + + Returns: + The pandas.DataFrame object. + The query string executed. + """ + csv_file, query_string = self.to_csv_file() + df = download_csv_from_s3(csv_file, self._sagemaker_session, self._kms_key_id) + if "row_recent" in df.columns: + df = df.drop("row_recent", axis="columns") + return df, query_string + + + def _to_csv_from_dataframe(self) -> tuple[str, str]: + s3_folder, temp_table_name = upload_dataframe_to_s3( + self._base, self._output_path, self._sagemaker_session, self._kms_key_id + ) + self._create_temp_table(temp_table_name, s3_folder) + + base_features = list(self._base.columns) + event_time_dtype = str(self._base[self._event_time_identifier_feature_name].dtypes) + self._event_time_identifier_feature_type = FeatureTypeEnum( + _DTYPE_TO_FEATURE_TYPE.get(event_time_dtype, "String") + ) + + included = self._included_feature_names or base_features + fg_to_merge = FeatureGroupToBeMerged( + features=base_features, + included_feature_names=included, + projected_feature_names=included, + catalog=_DEFAULT_CATALOG, + database=_DEFAULT_DATABASE, + table_name=temp_table_name, + record_identifier_feature_name=self._record_identifier_feature_name, + event_time_identifier_feature=FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + table_type=TableType.DATA_FRAME, + ) + + query_string = self._construct_query_string(fg_to_merge) + result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + return self._extract_result(result) + + def _to_csv_from_feature_group(self) -> tuple[str, str]: + base_fg = construct_feature_group_to_be_merged(self._base, self._included_feature_names) + self._record_identifier_feature_name = base_fg.record_identifier_feature_name + self._event_time_identifier_feature_name = base_fg.event_time_identifier_feature.feature_name + self._event_time_identifier_feature_type = base_fg.event_time_identifier_feature.feature_type + + query_string = self._construct_query_string(base_fg) + result = self._run_query(query_string, base_fg.catalog, base_fg.database) + return self._extract_result(result) + + def _extract_result(self, query_result: dict) -> tuple[str, str]: + execution = query_result.get("QueryExecution", {}) + return ( + execution.get("ResultConfiguration", {}).get("OutputLocation"), + execution.get("Query"), + ) + + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: + return run_athena_query( + session=self._sagemaker_session, + catalog=catalog, + database=database, + query_string=query_string, + output_location=self._output_path, + kms_key=self._kms_key_id, + ) + + def _create_temp_table(self, temp_table_name: str, s3_folder: str): + columns = ", ".join( + f"{col} {_DTYPE_TO_ATHENA_TYPE.get(str(self._base[col].dtypes), 'STRING')}" + for col in self._base.columns + ) + serde = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"' + query = ( + f"CREATE EXTERNAL TABLE {temp_table_name} ({columns}) " + f"ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + f"WITH SERDEPROPERTIES ({serde}) LOCATION '{s3_folder}';" + ) + self._run_query(query, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + + + def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: + base_query = self._construct_table_query(base, "base") + query = f"WITH fg_base AS ({base_query})" + + for i, fg in enumerate(self._feature_groups_to_be_merged): + fg_query = self._construct_table_query(fg, str(i)) + query += f",\nfg_{i} AS ({fg_query})" + + selected = ", ".join(f"fg_base.{f}" for f in base.projected_feature_names) + selected_final = ", ".join(base.projected_feature_names) + + for i, fg in enumerate(self._feature_groups_to_be_merged): + selected += ", " + ", ".join( + f'fg_{i}."{f}" as "{f}.{i+1}"' for f in fg.projected_feature_names + ) + selected_final += ", " + ", ".join( + f'"{f}.{i+1}"' for f in fg.projected_feature_names + ) + + query += ( + f"\nSELECT {selected_final}\nFROM (\n" + f"SELECT {selected}, row_number() OVER (\n" + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n' + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC' + ) + + join_strings = [] + for i, fg in enumerate(self._feature_groups_to_be_merged): + if not fg.target_feature_name_in_base: + fg.target_feature_name_in_base = self._record_identifier_feature_name + elif fg.target_feature_name_in_base not in base.features: + raise ValueError(f"Feature {fg.target_feature_name_in_base} not found in base") + query += f', fg_{i}."{fg.event_time_identifier_feature.feature_name}" DESC' + join_strings.append(self._construct_join_condition(fg, str(i))) + + recent_where = "" + if self._number_of_recent_records is not None and self._number_of_recent_records >= 0: + recent_where = f"WHERE row_recent <= {self._number_of_recent_records}" + + query += f"\n) AS row_recent\nFROM fg_base{''.join(join_strings)}\n)\n{recent_where}" + + if self._number_of_records is not None and self._number_of_records >= 0: + query += f"\nLIMIT {self._number_of_records}" + + return query + + def _construct_table_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + included = ", ".join(f'table_{suffix}."{f}"' for f in fg.included_feature_names) + included_with_write = included + if fg.table_type is TableType.FEATURE_GROUP: + included_with_write += f', table_{suffix}."write_time"' + + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + + if self._include_duplicated_records and self._include_deleted_records: + return ( + f"SELECT {included}\n" + f'FROM "{fg.database}"."{fg.table_name}" table_{suffix}\n' + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, ["NOT is_deleted"]) + ) + + if fg.table_type is TableType.FEATURE_GROUP and self._include_deleted_records: + rank = f'ORDER BY origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + return ( + f"SELECT {included}\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}", origin_{suffix}."{event_time}"\n' + f"{rank}) AS row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"WHERE NOT is_deleted) AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, [f"row_{suffix} = 1"]) + ) + + if fg.table_type is TableType.FEATURE_GROUP: + dedup = self._construct_dedup_query(fg, suffix) + deleted = self._construct_deleted_query(fg, suffix) + rank_cond = ( + f'OR (table_{suffix}."{event_time}" = deleted_{suffix}."{event_time}" ' + f'AND table_{suffix}."api_invocation_time" > deleted_{suffix}."api_invocation_time")\n' + f'OR (table_{suffix}."{event_time}" = deleted_{suffix}."{event_time}" ' + f'AND table_{suffix}."api_invocation_time" = deleted_{suffix}."api_invocation_time" ' + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n' + ) + + if self._include_duplicated_records: + return ( + f"WITH {deleted}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\n" + f'FROM "{fg.database}"."{fg.table_name}" table_{suffix}\n' + f"LEFT JOIN deleted_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'WHERE deleted_{suffix}."{record_id}" IS NULL\n' + f"UNION ALL\n" + f"SELECT {included_with_write}\nFROM deleted_{suffix}\n" + f'JOIN "{fg.database}"."{fg.table_name}" table_{suffix}\n' + f'ON table_{suffix}."{record_id}" = deleted_{suffix}."{record_id}"\n' + f'AND (table_{suffix}."{event_time}" > deleted_{suffix}."{event_time}"\n{rank_cond})\n' + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + return ( + f"WITH {dedup},\n{deleted}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\nFROM table_{suffix}\n" + f"LEFT JOIN deleted_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'WHERE deleted_{suffix}."{record_id}" IS NULL\n' + f"UNION ALL\n" + f"SELECT {included_with_write}\nFROM deleted_{suffix}\n" + f"JOIN table_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'AND (table_{suffix}."{event_time}" > deleted_{suffix}."{event_time}"\n{rank_cond})\n' + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + dedup = self._construct_dedup_query(fg, suffix) + return ( + f"WITH {dedup}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\nFROM table_{suffix}\n" + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + def _construct_dedup_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + rank = "" + is_fg = fg.table_type is TableType.FEATURE_GROUP + + if is_fg: + rank = f'ORDER BY origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + + where_conds = [] + if is_fg and self._write_time_ending_timestamp: + where_conds.append(self._construct_write_time_condition(f"origin_{suffix}")) + where_conds.extend(self._construct_event_time_conditions(f"origin_{suffix}", fg.event_time_identifier_feature)) + where_str = f"WHERE {' AND '.join(where_conds)}\n" if where_conds else "" + + dedup_where = f"WHERE dedup_row_{suffix} = 1\n" if is_fg else "" + + return ( + f"table_{suffix} AS (\n" + f"SELECT *\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}", origin_{suffix}."{event_time}"\n' + f"{rank}) AS dedup_row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"{where_str})\n{dedup_where})" + ) + + def _construct_deleted_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + rank = f'ORDER BY origin_{suffix}."{event_time}" DESC' + + if fg.table_type is TableType.FEATURE_GROUP: + rank += f', origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + + write_cond = "" + if fg.table_type is TableType.FEATURE_GROUP and self._write_time_ending_timestamp: + write_cond = f" AND {self._construct_write_time_condition(f'origin_{suffix}')}\n" + + event_conds = "" + if self._event_time_starting_timestamp and self._event_time_ending_timestamp: + conds = self._construct_event_time_conditions(f"origin_{suffix}", fg.event_time_identifier_feature) + event_conds = "".join(f"AND {c}\n" for c in conds) + + return ( + f"deleted_{suffix} AS (\n" + f"SELECT *\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}"\n' + f"{rank}) AS deleted_row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"WHERE is_deleted{write_cond}{event_conds})\n" + f"WHERE deleted_row_{suffix} = 1\n)" + ) + + def _construct_where_query_string( + self, suffix: str, event_time_feature: FeatureDefinition, conditions: List[str] + ) -> str: + self._validate_options() + + if isinstance(self._base, FeatureGroup) and self._write_time_ending_timestamp: + conditions.append(self._construct_write_time_condition(f"table_{suffix}")) + + conditions.extend(self._construct_event_time_conditions(f"table_{suffix}", event_time_feature)) + return f"WHERE {' AND '.join(conditions)}" if conditions else "" + + def _validate_options(self): + is_df_base = isinstance(self._base, pd.DataFrame) + no_joins = len(self._feature_groups_to_be_merged) == 0 + + if self._number_of_recent_records is not None and self._number_of_recent_records < 0: + raise ValueError("number_of_recent_records must be non-negative.") + if self._number_of_records is not None and self._number_of_records < 0: + raise ValueError("number_of_records must be non-negative.") + if is_df_base and no_joins: + if self._include_deleted_records: + raise ValueError("include_deleted_records() only works for FeatureGroup if no join.") + if self._include_duplicated_records: + raise ValueError("include_duplicated_records() only works for FeatureGroup if no join.") + if self._write_time_ending_timestamp: + raise ValueError("as_of() only works for FeatureGroup if no join.") + if self._point_in_time_accurate_join and no_joins: + raise ValueError("point_in_time_accurate_join() requires at least one join.") + + def _construct_event_time_conditions(self, table: str, event_time_feature: FeatureDefinition) -> List[str]: + cast_fn = "from_iso8601_timestamp" if event_time_feature.feature_type == FeatureTypeEnum.STRING else "from_unixtime" + conditions = [] + if self._event_time_starting_timestamp: + conditions.append( + f'{cast_fn}({table}."{event_time_feature.feature_name}") >= ' + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})" + ) + if self._event_time_ending_timestamp: + conditions.append( + f'{cast_fn}({table}."{event_time_feature.feature_name}") <= ' + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})" + ) + return conditions + + def _construct_write_time_condition(self, table: str) -> str: + ts = self._write_time_ending_timestamp.replace(microsecond=0) + return f'{table}."write_time" <= to_timestamp(\'{ts}\', \'yyyy-mm-dd hh24:mi:ss\')' + + def _construct_join_condition(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + target_feature = fg.feature_name_in_target or fg.record_identifier_feature_name + join = ( + f"\n{fg.join_type.value} fg_{suffix}\n" + f'ON fg_base."{fg.target_feature_name_in_base}" {fg.join_comparator.value} fg_{suffix}."{target_feature}"' + ) + + if self._point_in_time_accurate_join: + base_cast = "from_iso8601_timestamp" if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING else "from_unixtime" + fg_cast = "from_iso8601_timestamp" if fg.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING else "from_unixtime" + join += ( + f'\nAND {base_cast}(fg_base."{self._event_time_identifier_feature_name}") >= ' + f'{fg_cast}(fg_{suffix}."{fg.event_time_identifier_feature.feature_name}")' + ) + + return join diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py new file mode 100644 index 0000000000..32408e5585 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Definitions for FeatureStore.""" +from __future__ import absolute_import + +from enum import Enum +from typing import Optional, Union + +from sagemaker.core.shapes import ( + FeatureDefinition, + CollectionConfig, + VectorConfig, +) + +class FeatureTypeEnum(Enum): + """Feature data types: Fractional, Integral, or String.""" + + FRACTIONAL = "Fractional" + INTEGRAL = "Integral" + STRING = "String" + +class CollectionTypeEnum(Enum): + """Collection types: List, Set, or Vector.""" + + LIST = "List" + SET = "Set" + VECTOR = "Vector" + +class ListCollectionType: + """List collection type.""" + + collection_type = CollectionTypeEnum.LIST.value + collection_config = None + +class SetCollectionType: + """Set collection type.""" + + collection_type = CollectionTypeEnum.SET.value + collection_config = None + +class VectorCollectionType: + """Vector collection type with dimension.""" + + collection_type = CollectionTypeEnum.VECTOR.value + + def __init__(self, dimension: int): + self.collection_config = CollectionConfig( + vector_config=VectorConfig(dimension=dimension) + ) + +CollectionType = Union[ListCollectionType, SetCollectionType, VectorCollectionType] + +def _create_feature_definition( + feature_name: str, + feature_type: FeatureTypeEnum, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Internal helper to create FeatureDefinition from collection type.""" + return FeatureDefinition( + feature_name=feature_name, + feature_type=feature_type.value, + collection_type=collection_type.collection_type if collection_type else None, + collection_config=collection_type.collection_config if collection_type else None, + ) + +def FractionalFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with Fractional type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.FRACTIONAL, collection_type) + +def IntegralFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with Integral type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.INTEGRAL, collection_type) + +def StringFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with String type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.STRING, collection_type) + +__all__ = [ + "FeatureDefinition", + "FeatureTypeEnum", + "CollectionTypeEnum", + "ListCollectionType", + "SetCollectionType", + "VectorCollectionType", + "FractionalFeatureDefinition", + "IntegralFeatureDefinition", + "StringFeatureDefinition", +] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py new file mode 100644 index 0000000000..f7f6523b8d --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -0,0 +1,488 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Utilities for working with FeatureGroups and FeatureStores.""" +import logging +import os +import time +from typing import Any, Dict, Sequence, Union + +import boto3 +import pandas as pd +from pandas import DataFrame, Series + +from sagemaker.mlops.feature_store import FeatureGroup as CoreFeatureGroup, FeatureGroup +from sagemaker.core.helper.session_helper import Session +from sagemaker.core.s3.client import S3Uploader, S3Downloader +from sagemaker.mlops.feature_store.dataset_builder import DatasetBuilder +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + ListCollectionType, + StringFeatureDefinition, +) +from sagemaker.mlops.feature_store.ingestion_manager_pandas import IngestionManagerPandas + +from sagemaker import utils + + +logger = logging.getLogger(__name__) + +# --- Constants --- + +_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = { + "Integral": "INT", + "Fractional": "FLOAT", + "String": "STRING", +} + +_DTYPE_TO_FEATURE_TYPE_MAP = { + "object": "String", + "string": "String", + "int64": "Integral", + "float64": "Fractional", +} + +_INTEGER_TYPES = {"int_", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"} +_FLOAT_TYPES = {"float_", "float16", "float32", "float64"} + + +def _get_athena_client(session: Session): + """Get Athena client from session.""" + return session.boto_session.client("athena", region_name=session.boto_region_name) + + +def _get_s3_client(session: Session): + """Get S3 client from session.""" + return session.boto_session.client("s3", region_name=session.boto_region_name) + + +def start_query_execution( + session: Session, + catalog: str, + database: str, + query_string: str, + output_location: str, + kms_key: str = None, + workgroup: str = None, +) -> Dict[str, str]: + """Start Athena query execution. + + Args: + session: Session instance for boto calls. + catalog: Name of the data catalog. + database: Name of the database. + query_string: SQL query string. + output_location: S3 URI for query results. + kms_key: KMS key for encryption (default: None). + workgroup: Athena workgroup name (default: None). + + Returns: + Response dict with QueryExecutionId. + """ + kwargs = { + "QueryString": query_string, + "QueryExecutionContext": {"Catalog": catalog, "Database": database}, + "ResultConfiguration": {"OutputLocation": output_location}, + } + if kms_key: + kwargs["ResultConfiguration"]["EncryptionConfiguration"] = { + "EncryptionOption": "SSE_KMS", + "KmsKey": kms_key, + } + if workgroup: + kwargs["WorkGroup"] = workgroup + return _get_athena_client(session).start_query_execution(**kwargs) + + +def get_query_execution(session: Session, query_execution_id: str) -> Dict[str, Any]: + """Get execution status of an Athena query. + + Args: + session: Session instance for boto calls. + query_execution_id: The query execution ID. + + Returns: + Response dict from Athena. + """ + return _get_athena_client(session).get_query_execution(QueryExecutionId=query_execution_id) + + +def wait_for_athena_query(session: Session, query_execution_id: str, poll: int = 5): + """Wait for Athena query to finish. + + Args: + session: Session instance for boto calls. + query_execution_id: The query execution ID. + poll: Polling interval in seconds (default: 5). + """ + while True: + state = get_query_execution(session, query_execution_id)["QueryExecution"]["Status"]["State"] + if state in ("SUCCEEDED", "FAILED"): + logger.info("Query %s %s.", query_execution_id, state.lower()) + break + logger.info("Query %s is being executed.", query_execution_id) + time.sleep(poll) + + +def run_athena_query( + session: Session, + catalog: str, + database: str, + query_string: str, + output_location: str, + kms_key: str = None, +) -> Dict[str, Any]: + """Execute Athena query, wait for completion, and return result. + + Args: + session: Session instance for boto calls. + catalog: Name of the data catalog. + database: Name of the database. + query_string: SQL query string. + output_location: S3 URI for query results. + kms_key: KMS key for encryption (default: None). + + Returns: + Query execution result dict. + + Raises: + RuntimeError: If query fails. + """ + response = start_query_execution( + session=session, + catalog=catalog, + database=database, + query_string=query_string, + output_location=output_location, + kms_key=kms_key, + ) + query_id = response["QueryExecutionId"] + wait_for_athena_query(session, query_id) + + result = get_query_execution(session, query_id) + if result["QueryExecution"]["Status"]["State"] != "SUCCEEDED": + raise RuntimeError(f"Athena query {query_id} failed.") + return result + + +def download_athena_query_result( + session: Session, + bucket: str, + prefix: str, + query_execution_id: str, + filename: str, +): + """Download query result file from S3. + + Args: + session: Session instance for boto calls. + bucket: S3 bucket name. + prefix: S3 key prefix. + query_execution_id: The query execution ID. + filename: Local filename to save to. + """ + _get_s3_client(session).download_file( + Bucket=bucket, + Key=f"{prefix}/{query_execution_id}.csv", + Filename=filename, + ) + + +def upload_dataframe_to_s3( + data_frame: DataFrame, + output_path: str, + session: Session, + kms_key: str = None, +) -> tuple[str, str]: + """Upload DataFrame to S3 as CSV. + + Args: + data_frame: DataFrame to upload. + output_path: S3 URI base path. + session: Session instance for boto calls. + kms_key: KMS key for encryption (default: None). + + Returns: + Tuple of (s3_folder, temp_table_name). + """ + + temp_id = utils.unique_name_from_base("dataframe-base") + local_file = f"{temp_id}.csv" + s3_folder = os.path.join(output_path, temp_id) + + data_frame.to_csv(local_file, index=False, header=False) + S3Uploader.upload( + local_path=local_file, + desired_s3_uri=s3_folder, + sagemaker_session=session, + kms_key=kms_key, + ) + os.remove(local_file) + + table_name = f'dataframe_{temp_id.replace("-", "_")}' + return s3_folder, table_name + + +def download_csv_from_s3( + s3_uri: str, + session: Session, + kms_key: str = None, +) -> DataFrame: + """Download CSV from S3 and return as DataFrame. + + Args: + s3_uri: S3 URI of the CSV file. + session: Session instance for boto calls. + kms_key: KMS key for decryption (default: None). + + Returns: + DataFrame with CSV contents. + """ + + S3Downloader.download( + s3_uri=s3_uri, + local_path="./", + kms_key=kms_key, + sagemaker_session=session, + ) + + local_file = s3_uri.split("/")[-1] + df = pd.read_csv(local_file) + os.remove(local_file) + + metadata_file = f"{local_file}.metadata" + if os.path.exists(metadata_file): + os.remove(metadata_file) + + return df + + +def get_session_from_role(region: str, assume_role: str = None) -> Session: + """Get a Session from a region and optional IAM role. + + Args: + region: AWS region name. + assume_role: IAM role ARN to assume (default: None). + + Returns: + Session instance. + """ + boto_session = boto3.Session(region_name=region) + + if assume_role: + sts = boto_session.client("sts", region_name=region) + credentials = sts.assume_role( + RoleArn=assume_role, + RoleSessionName="SagemakerExecution", + )["Credentials"] + + boto_session = boto3.Session( + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + return Session( + boto_session=boto_session, + sagemaker_client=boto_session.client("sagemaker"), + sagemaker_runtime_client=boto_session.client("sagemaker-runtime"), + sagemaker_featurestore_runtime_client=boto_session.client("sagemaker-featurestore-runtime"), + ) + + +# --- FeatureDefinition Functions --- + +def _is_collection_column(series: Series, sample_size: int = 1000) -> bool: + """Check if column contains list/set values.""" + sample = series.head(sample_size).dropna() + return sample.apply(lambda x: isinstance(x, (list, set))).any() + + +def _generate_feature_definition( + series: Series, + online_storage_type: str = None, +) -> FeatureDefinition: + """Generate a FeatureDefinition from a pandas Series.""" + dtype = str(series.dtype) + collection_type = None + + if online_storage_type == "InMemory" and _is_collection_column(series): + collection_type = ListCollectionType() + + if dtype in _INTEGER_TYPES: + return IntegralFeatureDefinition(series.name, collection_type) + if dtype in _FLOAT_TYPES: + return FractionalFeatureDefinition(series.name, collection_type) + return StringFeatureDefinition(series.name, collection_type) + + +def load_feature_definitions_from_dataframe( + data_frame: DataFrame, + online_storage_type: str = None, +) -> Sequence[FeatureDefinition]: + """Infer FeatureDefinitions from DataFrame dtypes. + + Column name is used as feature name. Feature type is inferred from the dtype + of the column. Integer dtypes are mapped to Integral feature type. Float dtypes + are mapped to Fractional feature type. All other dtypes are mapped to String. + + For IN_MEMORY online_storage_type, collection type columns within DataFrame + will be inferred as List instead of String. + + Args: + data_frame: DataFrame to infer features from. + online_storage_type: "Standard" or "InMemory" (default: None). + + Returns: + List of FeatureDefinition objects. + """ + return [ + _generate_feature_definition(data_frame[col], online_storage_type) + for col in data_frame.columns + ] + + +# --- FeatureGroup Functions --- + +def create_athena_query(feature_group_name: str, session: Session): + """Create an AthenaQuery for a FeatureGroup. + + Args: + feature_group_name: Name of the FeatureGroup. + session: Session instance for Athena boto calls. + + Returns: + AthenaQuery initialized with data catalog config. + + Raises: + RuntimeError: If no metastore is configured. + """ + from sagemaker.mlops.feature_store.athena_query import AthenaQuery + + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + + if not fg.offline_store_config or not fg.offline_store_config.data_catalog_config: + raise RuntimeError("No metastore is configured with this feature group.") + + catalog_config = fg.offline_store_config.data_catalog_config + disable_glue = catalog_config.disable_glue_table_creation or False + + return AthenaQuery( + catalog=catalog_config.catalog if disable_glue else "AwsDataCatalog", + database=catalog_config.database, + table_name=catalog_config.table_name, + sagemaker_session=session, + ) + + +def as_hive_ddl( + feature_group_name: str, + database: str = "sagemaker_featurestore", + table_name: str = None, +) -> str: + """Generate Hive DDL for a FeatureGroup's offline store table. + + Schema of the table is generated based on the feature definitions. Columns are named + after feature name and data-type are inferred based on feature type. Integral feature + type is mapped to INT data-type. Fractional feature type is mapped to FLOAT data-type. + String feature type is mapped to STRING data-type. + + Args: + feature_group_name: Name of the FeatureGroup. + database: Hive database name (default: "sagemaker_featurestore"). + table_name: Hive table name (default: feature_group_name). + + Returns: + CREATE EXTERNAL TABLE DDL string. + """ + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + table_name = table_name or feature_group_name + resolved_output_s3_uri = fg.offline_store_config.s3_storage_config.resolved_output_s3_uri + + ddl = f"CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" + for fd in fg.feature_definitions: + ddl += f" {fd.feature_name} {_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP.get(fd.feature_type)}\n" + ddl += " write_time TIMESTAMP\n" + ddl += " event_time TIMESTAMP\n" + ddl += " is_deleted BOOLEAN\n" + ddl += ")\n" + ddl += ( + "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" + " STORED AS\n" + " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" + " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" + f"LOCATION '{resolved_output_s3_uri}'" + ) + return ddl + + +def ingest_dataframe( + feature_group_name: str, + data_frame: DataFrame, + max_workers: int = 1, + max_processes: int = 1, + wait: bool = True, + timeout: Union[int, float] = None, +): + """Ingest a pandas DataFrame to a FeatureGroup. + + Args: + feature_group_name: Name of the FeatureGroup. + data_frame: DataFrame to ingest. + max_workers: Threads per process (default: 1). + max_processes: Number of processes (default: 1). + wait: Wait for ingestion to complete (default: True). + timeout: Timeout in seconds (default: None). + + Returns: + IngestionManagerPandas instance. + + Raises: + ValueError: If max_workers or max_processes <= 0. + """ + + if max_processes <= 0: + raise ValueError("max_processes must be greater than 0.") + if max_workers <= 0: + raise ValueError("max_workers must be greater than 0.") + + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + feature_definitions = {fd.feature_name: fd.feature_type for fd in fg.feature_definitions} + + manager = IngestionManagerPandas( + feature_group_name=feature_group_name, + feature_definitions=feature_definitions, + max_workers=max_workers, + max_processes=max_processes, + ) + manager.run(data_frame=data_frame, wait=wait, timeout=timeout) + return manager + +def create_dataset( + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + session: Session, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: Sequence[str] = None, + kms_key_id: str = None, +) -> DatasetBuilder: + """Create a DatasetBuilder for generating a Dataset.""" + if isinstance(base, pd.DataFrame): + if not record_identifier_feature_name or not event_time_identifier_feature_name: + raise ValueError( + "record_identifier_feature_name and event_time_identifier_feature_name " + "are required when base is a DataFrame." + ) + return DatasetBuilder( + _sagemaker_session=session, + _base=base, + _output_path=output_path, + _record_identifier_feature_name=record_identifier_feature_name, + _event_time_identifier_feature_name=event_time_identifier_feature_name, + _included_feature_names=included_feature_names, + _kms_key_id=kms_key_id, + ) + diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py new file mode 100644 index 0000000000..60d022dab1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py @@ -0,0 +1,321 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Multi-threaded data ingestion for FeatureStore using SageMaker Core.""" +import logging +import math +import signal +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from multiprocessing.pool import AsyncResult +from typing import Any, Dict, Iterable, List, Sequence, Union + +import pandas as pd +from pandas import DataFrame +from pandas.api.types import is_list_like +from pathos.multiprocessing import ProcessingPool + +from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup +from sagemaker.core.shapes import FeatureValue + +logger = logging.getLogger(__name__) + + +class IngestionError(Exception): + """Exception raised for errors during ingestion. + + Attributes: + failed_rows: List of row indices that failed to ingest. + message: Error message. + """ + + def __init__(self, failed_rows: List[int], message: str): + self.failed_rows = failed_rows + self.message = message + super().__init__(self.message) + + +@dataclass +class IngestionManagerPandas: + """Class to manage the multi-threaded data ingestion process. + + This class will manage the data ingestion process which is multi-threaded. + + Attributes: + feature_group_name (str): name of the Feature Group. + feature_definitions (Dict[str, Dict[Any, Any]]): dictionary of feature definitions + where the key is the feature name and the value is the FeatureDefinition. + The FeatureDefinition contains the data type of the feature. + max_workers (int): number of threads to create. + max_processes (int): number of processes to create. Each process spawns + ``max_workers`` threads. + """ + + feature_group_name: str + feature_definitions: Dict[str, Dict[Any, Any]] + max_workers: int = 1 + max_processes: int = 1 + _async_result: AsyncResult = field(default=None, init=False) + _processing_pool: ProcessingPool = field(default=None, init=False) + _failed_indices: List[int] = field(default_factory=list, init=False) + + @property + def failed_rows(self) -> List[int]: + """Get rows that failed to ingest. + + Returns: + List of row indices that failed to be ingested. + """ + return self._failed_indices + + def run( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + wait: bool = True, + timeout: Union[int, float] = None, + ): + """Start the ingestion process. + + Args: + data_frame (DataFrame): source DataFrame to be ingested. + target_stores (List[str]): list of target stores ("OnlineStore", "OfflineStore"). + If None, the default target store is used. + wait (bool): whether to wait for the ingestion to finish or not. + timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised + if timeout is reached. + """ + if self.max_workers == 1 and self.max_processes == 1: + self._run_single_process_single_thread(data_frame=data_frame, target_stores=target_stores) + else: + self._run_multi_process(data_frame=data_frame, target_stores=target_stores, wait=wait, timeout=timeout) + + def wait(self, timeout: Union[int, float] = None): + """Wait for the ingestion process to finish. + + Args: + timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised + if timeout is reached. + """ + try: + results = self._async_result.get(timeout=timeout) + except KeyboardInterrupt as e: + self._processing_pool.terminate() + self._processing_pool.close() + self._processing_pool.clear() + raise e + else: + self._processing_pool.close() + self._processing_pool.clear() + + self._failed_indices = [idx for failed in results for idx in failed] + + if self._failed_indices: + raise IngestionError( + self._failed_indices, + f"Failed to ingest some data into FeatureGroup {self.feature_group_name}", + ) + + def _run_single_process_single_thread( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + ): + """Ingest utilizing a single process and a single thread.""" + logger.info("Started single-threaded ingestion for %d rows", len(data_frame)) + failed_rows = [] + + fg = CoreFeatureGroup(feature_group_name=self.feature_group_name) + + for row in data_frame.itertuples(): + self._ingest_row( + data_frame=data_frame, + row=row, + feature_group=fg, + feature_definitions=self.feature_definitions, + failed_rows=failed_rows, + target_stores=target_stores, + ) + + self._failed_indices = failed_rows + if self._failed_indices: + raise IngestionError( + self._failed_indices, + f"Failed to ingest some data into FeatureGroup {self.feature_group_name}", + ) + + def _run_multi_process( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + wait: bool = True, + timeout: Union[int, float] = None, + ): + """Start the ingestion process with the specified number of processes.""" + batch_size = math.ceil(data_frame.shape[0] / self.max_processes) + + args = [] + for i in range(self.max_processes): + start_index = min(i * batch_size, data_frame.shape[0]) + end_index = min(i * batch_size + batch_size, data_frame.shape[0]) + args.append(( + self.max_workers, + self.feature_group_name, + self.feature_definitions, + data_frame[start_index:end_index], + target_stores, + start_index, + timeout, + )) + + def init_worker(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + + self._processing_pool = ProcessingPool(self.max_processes, init_worker) + self._processing_pool.restart(force=True) + + self._async_result = self._processing_pool.amap( + lambda x: IngestionManagerPandas._run_multi_threaded(*x), + args, + ) + + if wait: + self.wait(timeout=timeout) + + @staticmethod + def _run_multi_threaded( + max_workers: int, + feature_group_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + data_frame: DataFrame, + target_stores: List[str] = None, + row_offset: int = 0, + timeout: Union[int, float] = None, + ) -> List[int]: + """Start multi-threaded ingestion within a single process.""" + executor = ThreadPoolExecutor(max_workers=max_workers) + batch_size = math.ceil(data_frame.shape[0] / max_workers) + + futures = {} + for i in range(max_workers): + start_index = min(i * batch_size, data_frame.shape[0]) + end_index = min(i * batch_size + batch_size, data_frame.shape[0]) + future = executor.submit( + IngestionManagerPandas._ingest_single_batch, + data_frame=data_frame, + feature_group_name=feature_group_name, + feature_definitions=feature_definitions, + start_index=start_index, + end_index=end_index, + target_stores=target_stores, + ) + futures[future] = (start_index + row_offset, end_index + row_offset) + + failed_indices = [] + for future in as_completed(futures, timeout=timeout): + start, end = futures[future] + failed_rows = future.result() + if not failed_rows: + logger.info("Successfully ingested row %d to %d", start, end) + failed_indices.extend(failed_rows) + + executor.shutdown(wait=False) + return failed_indices + + @staticmethod + def _ingest_single_batch( + data_frame: DataFrame, + feature_group_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + start_index: int, + end_index: int, + target_stores: List[str] = None, + ) -> List[int]: + """Ingest a single batch of DataFrame rows into FeatureStore.""" + logger.info("Started ingesting index %d to %d", start_index, end_index) + failed_rows = [] + + fg = CoreFeatureGroup(feature_group_name=feature_group_name) + + for row in data_frame[start_index:end_index].itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=data_frame, + row=row, + feature_group=fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=target_stores, + ) + + return failed_rows + + @staticmethod + def _ingest_row( + data_frame: DataFrame, + row: Iterable, + feature_group: CoreFeatureGroup, + feature_definitions: Dict[str, Dict[Any, Any]], + failed_rows: List[int], + target_stores: List[str] = None, + ): + """Ingest a single DataFrame row into FeatureStore using SageMaker Core.""" + try: + record = [] + for index in range(1, len(row)): + feature_name = data_frame.columns[index - 1] + feature_value = row[index] + + if not IngestionManagerPandas._feature_value_is_not_none(feature_value): + continue + + if IngestionManagerPandas._is_feature_collection_type(feature_name, feature_definitions): + record.append(FeatureValue( + feature_name=feature_name, + value_as_string_list=IngestionManagerPandas._convert_to_string_list(feature_value), + )) + else: + record.append(FeatureValue( + feature_name=feature_name, + value_as_string=str(feature_value), + )) + + # Use SageMaker Core's put_record directly + feature_group.put_record( + record=record, + target_stores=target_stores, + ) + + except Exception as e: + logger.error("Failed to ingest row %d: %s", row[0], e) + failed_rows.append(row[0]) + + @staticmethod + def _is_feature_collection_type( + feature_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + ) -> bool: + """Check if the feature is a collection type.""" + feature_def = feature_definitions.get(feature_name) + if feature_def: + return feature_def.get("CollectionType") is not None + return False + + @staticmethod + def _feature_value_is_not_none(feature_value: Any) -> bool: + """Check if the feature value is not None. + + For Collection Type features, we check if the value is not None. + For Scalar values, we use pd.notna() to keep the behavior same. + """ + if not is_list_like(feature_value): + return pd.notna(feature_value) + return feature_value is not None + + @staticmethod + def _convert_to_string_list(feature_value: List[Any]) -> List[str]: + """Convert a list of feature values to a list of strings.""" + if not is_list_like(feature_value): + raise ValueError( + f"Invalid feature value: {feature_value} for a collection type feature " + f"must be an Array, but was {type(feature_value)}" + ) + return [str(v) if v is not None else None for v in feature_value] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py new file mode 100644 index 0000000000..f264059eb3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py @@ -0,0 +1,60 @@ +"""Enums for FeatureStore operations.""" +from enum import Enum + +class TargetStoreEnum(Enum): + """Store types for put_record.""" + ONLINE_STORE = "OnlineStore" + OFFLINE_STORE = "OfflineStore" + +class OnlineStoreStorageTypeEnum(Enum): + """Storage types for online store.""" + STANDARD = "Standard" + IN_MEMORY = "InMemory" + +class TableFormatEnum(Enum): + """Offline store table formats.""" + GLUE = "Glue" + ICEBERG = "Iceberg" + +class ResourceEnum(Enum): + """Resource types for search.""" + FEATURE_GROUP = "FeatureGroup" + FEATURE_METADATA = "FeatureMetadata" + +class SearchOperatorEnum(Enum): + """Search operators.""" + AND = "And" + OR = "Or" + +class SortOrderEnum(Enum): + """Sort orders.""" + ASCENDING = "Ascending" + DESCENDING = "Descending" + +class FilterOperatorEnum(Enum): + """Filter operators.""" + EQUALS = "Equals" + NOT_EQUALS = "NotEquals" + GREATER_THAN = "GreaterThan" + GREATER_THAN_OR_EQUAL_TO = "GreaterThanOrEqualTo" + LESS_THAN = "LessThan" + LESS_THAN_OR_EQUAL_TO = "LessThanOrEqualTo" + CONTAINS = "Contains" + EXISTS = "Exists" + NOT_EXISTS = "NotExists" + IN = "In" + +class DeletionModeEnum(Enum): + """Deletion modes for delete_record.""" + SOFT_DELETE = "SoftDelete" + HARD_DELETE = "HardDelete" + +class ExpirationTimeResponseEnum(Enum): + """ExpiresAt response toggle.""" + DISABLED = "Disabled" + ENABLED = "Enabled" + +class ThroughputModeEnum(Enum): + """Throughput modes for feature group.""" + ON_DEMAND = "OnDemand" + PROVISIONED = "Provisioned" \ No newline at end of file From 193d16fa2c61440ed1cd4d906e64e2726229ce6e Mon Sep 17 00:00:00 2001 From: adishaa Date: Fri, 16 Jan 2026 11:34:01 -0800 Subject: [PATCH 2/2] Add feature store tests --- .../mlops/feature_store/MIGRATION_GUIDE.md | 513 ++++++++++++++++++ .../sagemaker/mlops/feature_store/__init__.py | 8 +- .../mlops/feature_store/dataset_builder.py | 45 +- .../mlops/feature_store/feature_utils.py | 31 +- .../feature_store/ingestion_manager_pandas.py | 19 +- sagemaker-mlops/tests/__init__.py | 0 sagemaker-mlops/tests/unit/__init__.py | 0 .../tests/unit/sagemaker/__init__.py | 0 .../tests/unit/sagemaker/mlops/__init__.py | 0 .../sagemaker/mlops/feature_store/__init__.py | 2 + .../sagemaker/mlops/feature_store/conftest.py | 80 +++ .../mlops/feature_store/test_athena_query.py | 113 ++++ .../feature_store/test_dataset_builder.py | 345 ++++++++++++ .../feature_store/test_feature_definition.py | 126 +++++ .../mlops/feature_store/test_feature_utils.py | 202 +++++++ .../test_ingestion_manager_pandas.py | 256 +++++++++ .../mlops/feature_store/test_inputs.py | 109 ++++ 17 files changed, 1802 insertions(+), 47 deletions(-) create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md create mode 100644 sagemaker-mlops/tests/__init__.py create mode 100644 sagemaker-mlops/tests/unit/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md b/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md new file mode 100644 index 0000000000..40942fa6f3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md @@ -0,0 +1,513 @@ +# SageMaker FeatureStore V2 to V3 Migration Guide + +## Overview + +V3 uses **sagemaker-core** as the foundation, which provides: +- Pydantic-based shapes with automatic serialization +- Resource classes that manage boto clients internally +- No need for explicit Session management in most cases + +## File Mapping + +| V2 File | V3 File | Notes | +|---------|---------|-------| +| `feature_group.py` | Re-exported from `sagemaker_core.main.resources` | No wrapper class needed | +| `feature_store.py` | Re-exported from `sagemaker_core.main.resources` | `FeatureStore.search()` available | +| `feature_definition.py` | `feature_definition.py` | Helper factories retained | +| `feature_utils.py` | `feature_utils.py` | Standalone functions | +| `inputs.py` | `inputs.py` | Enums only (shapes from core) | +| `dataset_builder.py` | `dataset_builder.py` | Converted to dataclass | +| N/A | `athena_query.py` | Extracted from feature_group.py | +| N/A | `ingestion_manager_pandas.py` | Extracted from feature_group.py | + +--- + +## FeatureGroup Operations + +### Create FeatureGroup + +**V2:** +```python +from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.session import Session + +session = Session() +fg = FeatureGroup(name="my-fg", sagemaker_session=session) +fg.load_feature_definitions(data_frame=df) +fg.create( + s3_uri="s3://bucket/prefix", + record_identifier_name="id", + event_time_feature_name="ts", + role_arn=role, + enable_online_store=True, +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ( + FeatureGroup, + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + load_feature_definitions_from_dataframe, +) + +feature_defs = load_feature_definitions_from_dataframe(df) + +FeatureGroup.create( + feature_group_name="my-fg", + feature_definitions=feature_defs, + record_identifier_feature_name="id", + event_time_feature_name="ts", + role_arn=role, + online_store_config=OnlineStoreConfig(enable_online_store=True), + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/prefix") + ), +) +``` + +### Get/Describe FeatureGroup + +**V2:** +```python +fg = FeatureGroup(name="my-fg", sagemaker_session=session) +response = fg.describe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup + +fg = FeatureGroup.get(feature_group_name="my-fg") +# fg is now a typed object with attributes: +# fg.feature_group_name, fg.feature_definitions, fg.offline_store_config, etc. +``` + +### Delete FeatureGroup + +**V2:** +```python +fg.delete() +``` + +**V3:** +```python +FeatureGroup(feature_group_name="my-fg").delete() +# or +fg = FeatureGroup.get(feature_group_name="my-fg") +fg.delete() +``` + +### Update FeatureGroup + +**V2:** +```python +fg.update( + feature_additions=[FeatureDefinition("new_col", FeatureTypeEnum.STRING)], + throughput_config=ThroughputConfigUpdate(mode=ThroughputModeEnum.ON_DEMAND), +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup, ThroughputConfig + +fg = FeatureGroup.get(feature_group_name="my-fg") +fg.update( + feature_additions=[{"FeatureName": "new_col", "FeatureType": "String"}], + throughput_config=ThroughputConfig(throughput_mode="OnDemand"), +) +``` + +--- + +## Record Operations + +### Put Record + +**V2:** +```python +from sagemaker.feature_store.inputs import FeatureValue + +fg.put_record( + record=[ + FeatureValue(feature_name="id", value_as_string="123"), + FeatureValue(feature_name="name", value_as_string="John"), + ], + target_stores=[TargetStoreEnum.ONLINE_STORE], +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup, FeatureValue + +FeatureGroup(feature_group_name="my-fg").put_record( + record=[ + FeatureValue(feature_name="id", value_as_string="123"), + FeatureValue(feature_name="name", value_as_string="John"), + ], + target_stores=["OnlineStore"], # strings, not enums +) +``` + +### Get Record + +**V2:** +```python +response = fg.get_record(record_identifier_value_as_string="123") +``` + +**V3:** +```python +response = FeatureGroup(feature_group_name="my-fg").get_record( + record_identifier_value_as_string="123" +) +``` + +### Delete Record + +**V2:** +```python +fg.delete_record( + record_identifier_value_as_string="123", + event_time="2024-01-15T00:00:00Z", + deletion_mode=DeletionModeEnum.SOFT_DELETE, +) +``` + +**V3:** +```python +FeatureGroup(feature_group_name="my-fg").delete_record( + record_identifier_value_as_string="123", + event_time="2024-01-15T00:00:00Z", + deletion_mode="SoftDelete", # string, not enum +) +``` + +### Batch Get Record + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore +from sagemaker.feature_store.inputs import Identifier + +fs = FeatureStore(sagemaker_session=session) +response = fs.batch_get_record( + identifiers=[ + Identifier(feature_group_name="my-fg", record_identifiers_value_as_string=["123", "456"]) + ] +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup + +response = FeatureGroup(feature_group_name="my-fg").batch_get_record( + identifiers=[ + {"FeatureGroupName": "my-fg", "RecordIdentifiersValueAsString": ["123", "456"]} + ] +) +``` + +--- + +## DataFrame Ingestion + +**V2:** +```python +fg.ingest(data_frame=df, max_workers=4, max_processes=2, wait=True) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ingest_dataframe + +manager = ingest_dataframe( + feature_group_name="my-fg", + data_frame=df, + max_workers=4, + max_processes=2, + wait=True, +) +# Access failed rows: manager.failed_rows +``` + +--- + +## Athena Query + +**V2:** +```python +query = fg.athena_query() +query.run(query_string="SELECT * FROM ...", output_location="s3://...") +query.wait() +df = query.as_dataframe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import create_athena_query + +query = create_athena_query("my-fg", session) +query.run(query_string="SELECT * FROM ...", output_location="s3://...") +query.wait() +df = query.as_dataframe() +``` + +--- + +## Hive DDL Generation + +**V2:** +```python +ddl = fg.as_hive_ddl(database="mydb", table_name="mytable") +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import as_hive_ddl + +ddl = as_hive_ddl("my-fg", database="mydb", table_name="mytable") +``` + +--- + +## Feature Definitions + +**V2:** +```python +fg.load_feature_definitions(data_frame=df) +# Modifies fg.feature_definitions in place +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import load_feature_definitions_from_dataframe + +defs = load_feature_definitions_from_dataframe(df) +# Returns list, doesn't modify any object +``` + +### Using Helper Factories + +**V2 & V3 (same):** +```python +from sagemaker.mlops.feature_store import ( + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, +) + +defs = [ + IntegralFeatureDefinition("id"), + StringFeatureDefinition("name"), + FractionalFeatureDefinition("embedding", VectorCollectionType(128)), +] +``` + +--- + +## Search + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore +from sagemaker.feature_store.inputs import Filter, ResourceEnum + +fs = FeatureStore(sagemaker_session=session) +response = fs.search( + resource=ResourceEnum.FEATURE_GROUP, + filters=[Filter(name="FeatureGroupName", value="my-prefix", operator=FilterOperatorEnum.CONTAINS)], +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureStore, Filter, SearchExpression + +response = FeatureStore.search( + resource="FeatureGroup", + search_expression=SearchExpression( + filters=[Filter(name="FeatureGroupName", value="my-prefix", operator="Contains")] + ), +) +``` + +--- + +## Feature Metadata + +**V2:** +```python +fg.describe_feature_metadata(feature_name="my-feature") +fg.update_feature_metadata(feature_name="my-feature", description="Updated desc") +fg.list_parameters_for_feature_metadata(feature_name="my-feature") +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureMetadata + +# Get metadata +metadata = FeatureMetadata.get(feature_group_name="my-fg", feature_name="my-feature") +print(metadata.description) +print(metadata.parameters) + +# Update metadata +metadata.update(description="Updated desc") +``` + +--- + +## Dataset Builder + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore + +fs = FeatureStore(sagemaker_session=session) +builder = fs.create_dataset( + base=fg, + output_path="s3://bucket/output", +) +builder.with_feature_group(other_fg, target_feature_name_in_base="id") +builder.point_in_time_accurate_join() +df, query = builder.to_dataframe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import create_dataset, FeatureGroup + +fg = FeatureGroup.get(feature_group_name="my-fg") +other_fg = FeatureGroup.get(feature_group_name="other-fg") + +builder = create_dataset( + base=fg, + output_path="s3://bucket/output", + session=session, +) +builder.with_feature_group(other_fg, target_feature_name_in_base="id") +builder.point_in_time_accurate_join() +df, query = builder.to_dataframe() +``` + +--- + +## Config Objects (Shapes) + +**V2:** +```python +from sagemaker.feature_store.inputs import ( + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + TtlDuration, +) + +config = OnlineStoreConfig(enable_online_store=True, ttl_duration=TtlDuration(unit="Hours", value=24)) +config.to_dict() # Manual serialization required +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ( + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + TtlDuration, +) + +config = OnlineStoreConfig(enable_online_store=True, ttl_duration=TtlDuration(unit="Hours", value=24)) +# No to_dict() needed - Pydantic handles serialization automatically +``` + +--- + +## Key Differences Summary + +| Aspect | V2 | V3 | +|--------|----|----| +| **Session** | Required for most operations | Optional - core manages clients | +| **FeatureGroup** | Wrapper class with session | Direct core resource class | +| **Shapes** | `@attr.s` with `to_dict()` | Pydantic with auto-serialization | +| **Enums** | `TargetStoreEnum.ONLINE_STORE.value` | Just use strings: `"OnlineStore"` | +| **Methods** | Instance methods on FeatureGroup | Standalone functions + core methods | +| **Ingestion** | `fg.ingest(df)` | `ingest_dataframe(name, df)` | +| **Athena** | `fg.athena_query()` | `create_athena_query(name, session)` | +| **DDL** | `fg.as_hive_ddl()` | `as_hive_ddl(name)` | +| **Feature Defs** | `fg.load_feature_definitions(df)` | `load_feature_definitions_from_dataframe(df)` | +| **Imports** | Multiple modules | Single `__init__.py` re-exports all | + +--- + +## Missing in V3 (Intentionally) + +These V2 features are **not wrapped** because core provides them directly: + +- `FeatureGroup.create()` - use `FeatureGroup.create()` from core +- `FeatureGroup.delete()` - use `FeatureGroup(...).delete()` from core +- `FeatureGroup.describe()` - use `FeatureGroup.get()` from core (returns typed object) +- `FeatureGroup.update()` - use `FeatureGroup(...).update()` from core +- `FeatureGroup.put_record()` - use `FeatureGroup(...).put_record()` from core +- `FeatureGroup.get_record()` - use `FeatureGroup(...).get_record()` from core +- `FeatureGroup.delete_record()` - use `FeatureGroup(...).delete_record()` from core +- `FeatureGroup.batch_get_record()` - use `FeatureGroup(...).batch_get_record()` from core +- `FeatureStore.search()` - use `FeatureStore.search()` from core +- `FeatureStore.list_feature_groups()` - use `FeatureGroup.get_all()` from core +- All config shapes (`OnlineStoreConfig`, etc.) - re-exported from core + +--- + +## Import Cheatsheet + +```python +# V3 - Everything from one place +from sagemaker.mlops.feature_store import ( + # Resources (from core) + FeatureGroup, + FeatureStore, + FeatureMetadata, + + # Shapes (from core) + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + TtlDuration, + FeatureValue, + FeatureParameter, + ThroughputConfig, + Filter, + SearchExpression, + + # Enums (local) + TargetStoreEnum, + OnlineStoreStorageTypeEnum, + TableFormatEnum, + DeletionModeEnum, + ThroughputModeEnum, + + # Feature Definition helpers (local) + FeatureDefinition, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, + + # Utility functions (local) + create_athena_query, + as_hive_ddl, + load_feature_definitions_from_dataframe, + ingest_dataframe, + create_dataset, + + # Classes (local) + DatasetBuilder, +) +``` diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py index ee8cd7d1a3..f15d6d3845 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -3,11 +3,10 @@ """SageMaker FeatureStore V3 - powered by sagemaker-core.""" # Resources from core -from sagemaker_core.main.resources import FeatureGroup, FeatureMetadata -from sagemaker_core.main.resources import FeatureStore +from sagemaker.core.resources import FeatureGroup, FeatureMetadata # Shapes from core (Pydantic - no to_dict() needed) -from sagemaker_core.main.shapes import ( +from sagemaker.core.shapes import ( DataCatalogConfig, FeatureParameter, FeatureValue, @@ -52,7 +51,6 @@ from sagemaker.mlops.feature_store.feature_utils import ( as_hive_ddl, create_athena_query, - create_dataset, get_session_from_role, ingest_dataframe, load_feature_definitions_from_dataframe, @@ -76,7 +74,6 @@ # Resources "FeatureGroup", "FeatureMetadata", - "FeatureStore", # Shapes "DataCatalogConfig", "FeatureParameter", @@ -113,7 +110,6 @@ # Utility functions "as_hive_ddl", "create_athena_query", - "create_dataset", "get_session_from_role", "ingest_dataframe", "load_feature_definitions_from_dataframe", diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py index f5450663a6..72e9535320 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py @@ -178,7 +178,9 @@ def construct_feature_group_to_be_merged( database=catalog_config.database, table_name=catalog_config.table_name, record_identifier_feature_name=record_id, - event_time_identifier_feature=FeatureDefinition(event_time_name, FeatureTypeEnum(event_time_type)), + event_time_identifier_feature=FeatureDefinition( + feature_name=event_time_name, feature_type=FeatureTypeEnum(event_time_type).value + ), target_feature_name_in_base=target_feature_name_in_base, table_type=TableType.FEATURE_GROUP, feature_name_in_target=feature_name_in_target, @@ -256,6 +258,47 @@ class DatasetBuilder: _event_time_ending_timestamp: datetime.datetime = field(default=None, init=False) _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = field(default_factory=list, init=False) + @classmethod + def create( + cls, + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + session: Session, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: List[str] = None, + kms_key_id: str = None, + ) -> "DatasetBuilder": + """Create a DatasetBuilder for generating a Dataset. + + Args: + base: A FeatureGroup or DataFrame to use as the base. + output_path: S3 URI for output. + session: SageMaker session. + record_identifier_feature_name: Required if base is DataFrame. + event_time_identifier_feature_name: Required if base is DataFrame. + included_feature_names: Features to include in output. + kms_key_id: KMS key for encryption. + + Returns: + DatasetBuilder instance. + """ + if isinstance(base, pd.DataFrame): + if not record_identifier_feature_name or not event_time_identifier_feature_name: + raise ValueError( + "record_identifier_feature_name and event_time_identifier_feature_name " + "are required when base is a DataFrame." + ) + return cls( + _sagemaker_session=session, + _base=base, + _output_path=output_path, + _record_identifier_feature_name=record_identifier_feature_name, + _event_time_identifier_feature_name=event_time_identifier_feature_name, + _included_feature_names=included_feature_names, + _kms_key_id=kms_key_id, + ) + def with_feature_group( self, feature_group: FeatureGroup, diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py index f7f6523b8d..0b7c747515 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -13,7 +13,6 @@ from sagemaker.mlops.feature_store import FeatureGroup as CoreFeatureGroup, FeatureGroup from sagemaker.core.helper.session_helper import Session from sagemaker.core.s3.client import S3Uploader, S3Downloader -from sagemaker.mlops.feature_store.dataset_builder import DatasetBuilder from sagemaker.mlops.feature_store.feature_definition import ( FeatureDefinition, FractionalFeatureDefinition, @@ -23,7 +22,7 @@ ) from sagemaker.mlops.feature_store.ingestion_manager_pandas import IngestionManagerPandas -from sagemaker import utils +from sagemaker.core.utils import unique_name_from_base logger = logging.getLogger(__name__) @@ -207,7 +206,7 @@ def upload_dataframe_to_s3( Tuple of (s3_folder, temp_table_name). """ - temp_id = utils.unique_name_from_base("dataframe-base") + temp_id = unique_name_from_base("dataframe-base") local_file = f"{temp_id}.csv" s3_folder = os.path.join(output_path, temp_id) @@ -460,29 +459,3 @@ def ingest_dataframe( manager.run(data_frame=data_frame, wait=wait, timeout=timeout) return manager -def create_dataset( - base: Union[FeatureGroup, pd.DataFrame], - output_path: str, - session: Session, - record_identifier_feature_name: str = None, - event_time_identifier_feature_name: str = None, - included_feature_names: Sequence[str] = None, - kms_key_id: str = None, -) -> DatasetBuilder: - """Create a DatasetBuilder for generating a Dataset.""" - if isinstance(base, pd.DataFrame): - if not record_identifier_feature_name or not event_time_identifier_feature_name: - raise ValueError( - "record_identifier_feature_name and event_time_identifier_feature_name " - "are required when base is a DataFrame." - ) - return DatasetBuilder( - _sagemaker_session=session, - _base=base, - _output_path=output_path, - _record_identifier_feature_name=record_identifier_feature_name, - _event_time_identifier_feature_name=event_time_identifier_feature_name, - _included_feature_names=included_feature_names, - _kms_key_id=kms_key_id, - ) - diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py index 60d022dab1..4d7b4e5375 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py @@ -6,13 +6,12 @@ import signal from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field -from multiprocessing.pool import AsyncResult +from multiprocessing import Pool from typing import Any, Dict, Iterable, List, Sequence, Union import pandas as pd from pandas import DataFrame from pandas.api.types import is_list_like -from pathos.multiprocessing import ProcessingPool from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup from sagemaker.core.shapes import FeatureValue @@ -54,8 +53,8 @@ class IngestionManagerPandas: feature_definitions: Dict[str, Dict[Any, Any]] max_workers: int = 1 max_processes: int = 1 - _async_result: AsyncResult = field(default=None, init=False) - _processing_pool: ProcessingPool = field(default=None, init=False) + _async_result: Any = field(default=None, init=False) + _processing_pool: Pool = field(default=None, init=False) _failed_indices: List[int] = field(default_factory=list, init=False) @property @@ -100,12 +99,11 @@ def wait(self, timeout: Union[int, float] = None): results = self._async_result.get(timeout=timeout) except KeyboardInterrupt as e: self._processing_pool.terminate() - self._processing_pool.close() - self._processing_pool.clear() + self._processing_pool.join() raise e else: self._processing_pool.close() - self._processing_pool.clear() + self._processing_pool.join() self._failed_indices = [idx for failed in results for idx in failed] @@ -170,11 +168,10 @@ def _run_multi_process( def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN) - self._processing_pool = ProcessingPool(self.max_processes, init_worker) - self._processing_pool.restart(force=True) + self._processing_pool = Pool(self.max_processes, init_worker) - self._async_result = self._processing_pool.amap( - lambda x: IngestionManagerPandas._run_multi_threaded(*x), + self._async_result = self._processing_pool.starmap_async( + IngestionManagerPandas._run_multi_threaded, args, ) diff --git a/sagemaker-mlops/tests/__init__.py b/sagemaker-mlops/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/__init__.py b/sagemaker-mlops/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py new file mode 100644 index 0000000000..f34bf7d447 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py new file mode 100644 index 0000000000..9b2ec55895 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Conftest for feature_store tests.""" +import pytest +from unittest.mock import Mock, MagicMock +import pandas as pd +import numpy as np + + +@pytest.fixture +def mock_session(): + """Create a mock Session.""" + session = Mock() + session.boto_session = Mock() + session.boto_region_name = "us-west-2" + session.sagemaker_client = Mock() + session.sagemaker_runtime_client = Mock() + session.sagemaker_featurestore_runtime_client = Mock() + return session + + +@pytest.fixture +def sample_dataframe(): + """Create a sample DataFrame for testing.""" + return pd.DataFrame({ + "id": pd.Series([1, 2, 3, 4, 5], dtype="int64"), + "value": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5], dtype="float64"), + "name": pd.Series(["a", "b", "c", "d", "e"], dtype="string"), + "event_time": pd.Series( + ["2024-01-01T00:00:00Z"] * 5, + dtype="string" + ), + }) + + +@pytest.fixture +def dataframe_with_collections(): + """Create a DataFrame with collection type columns.""" + return pd.DataFrame({ + "id": pd.Series([1, 2, 3], dtype="int64"), + "tags": pd.Series([["a", "b"], ["c"], ["d", "e", "f"]], dtype="object"), + "scores": pd.Series([[1.0, 2.0], [3.0], [4.0, 5.0]], dtype="object"), + "event_time": pd.Series(["2024-01-01"] * 3, dtype="string"), + }) + + +@pytest.fixture +def feature_definitions_dict(): + """Create a feature definitions dictionary.""" + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + "event_time": {"FeatureName": "event_time", "FeatureType": "String"}, + } + + +@pytest.fixture +def mock_feature_group(): + """Create a mock FeatureGroup from core.""" + fg = MagicMock() + fg.feature_group_name = "test-feature-group" + fg.record_identifier_feature_name = "id" + fg.event_time_feature_name = "event_time" + fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="name", feature_type="String"), + MagicMock(feature_name="event_time", feature_type="String"), + ] + fg.offline_store_config = MagicMock() + fg.offline_store_config.s3_storage_config.s3_uri = "s3://bucket/prefix" + fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix/resolved" + fg.offline_store_config.data_catalog_config.catalog = "AwsDataCatalog" + fg.offline_store_config.data_catalog_config.database = "sagemaker_featurestore" + fg.offline_store_config.data_catalog_config.table_name = "test_feature_group" + fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + fg.online_store_config = MagicMock() + fg.online_store_config.enable_online_store = True + return fg diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py new file mode 100644 index 0000000000..2fed784208 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py @@ -0,0 +1,113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for athena_query.py""" +import os +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd + +from sagemaker.mlops.feature_store.athena_query import AthenaQuery + + +class TestAthenaQuery: + @pytest.fixture + def mock_session(self): + session = Mock() + session.boto_session.client.return_value = Mock() + session.boto_region_name = "us-west-2" + return session + + @pytest.fixture + def athena_query(self, mock_session): + return AthenaQuery( + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + table_name="my_feature_group", + sagemaker_session=mock_session, + ) + + def test_initialization(self, athena_query): + assert athena_query.catalog == "AwsDataCatalog" + assert athena_query.database == "sagemaker_featurestore" + assert athena_query.table_name == "my_feature_group" + assert athena_query._current_query_execution_id is None + + @patch("sagemaker.mlops.feature_store.athena_query.start_query_execution") + def test_run_starts_query(self, mock_start, athena_query): + mock_start.return_value = {"QueryExecutionId": "query-123"} + + result = athena_query.run( + query_string="SELECT * FROM table", + output_location="s3://bucket/output", + ) + + assert result == "query-123" + assert athena_query._current_query_execution_id == "query-123" + assert athena_query._result_bucket == "bucket" + assert athena_query._result_file_prefix == "output" + + @patch("sagemaker.mlops.feature_store.athena_query.start_query_execution") + def test_run_with_kms_key(self, mock_start, athena_query): + mock_start.return_value = {"QueryExecutionId": "query-123"} + + athena_query.run( + query_string="SELECT * FROM table", + output_location="s3://bucket/output", + kms_key="arn:aws:kms:us-west-2:123:key/abc", + ) + + mock_start.assert_called_once() + call_kwargs = mock_start.call_args[1] + assert call_kwargs["kms_key"] == "arn:aws:kms:us-west-2:123:key/abc" + + @patch("sagemaker.mlops.feature_store.athena_query.wait_for_athena_query") + def test_wait_calls_helper(self, mock_wait, athena_query): + athena_query._current_query_execution_id = "query-123" + + athena_query.wait() + + mock_wait.assert_called_once_with(athena_query.sagemaker_session, "query-123") + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_get_query_execution(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} + + result = athena_query.get_query_execution() + + assert result["QueryExecution"]["Status"]["State"] == "SUCCEEDED" + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + @patch("sagemaker.mlops.feature_store.athena_query.download_athena_query_result") + @patch("pandas.read_csv") + @patch("os.path.join") + def test_as_dataframe_success(self, mock_join, mock_read_csv, mock_download, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + athena_query._result_bucket = "bucket" + athena_query._result_file_prefix = "prefix" + + mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} + mock_join.return_value = "/tmp/query-123.csv" + mock_read_csv.return_value = pd.DataFrame({"col": [1, 2, 3]}) + + with patch("tempfile.gettempdir", return_value="/tmp"): + with patch("os.remove"): + df = athena_query.as_dataframe() + + assert len(df) == 3 + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_as_dataframe_raises_when_running(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "RUNNING"}}} + + with pytest.raises(RuntimeError, match="still executing"): + athena_query.as_dataframe() + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_as_dataframe_raises_when_failed(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "FAILED"}}} + + with pytest.raises(RuntimeError, match="failed"): + athena_query.as_dataframe() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py new file mode 100644 index 0000000000..254fb0e196 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py @@ -0,0 +1,345 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for dataset_builder.py""" +import datetime +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd + +from sagemaker.mlops.feature_store import FeatureGroup +from sagemaker.mlops.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + TableType, + JoinTypeEnum, + JoinComparatorEnum, + construct_feature_group_to_be_merged, +) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, +) + + +class TestTableType: + def test_feature_group_value(self): + assert TableType.FEATURE_GROUP.value == "FeatureGroup" + + def test_data_frame_value(self): + assert TableType.DATA_FRAME.value == "DataFrame" + + +class TestJoinTypeEnum: + def test_inner_join(self): + assert JoinTypeEnum.INNER_JOIN.value == "JOIN" + + def test_left_join(self): + assert JoinTypeEnum.LEFT_JOIN.value == "LEFT JOIN" + + def test_right_join(self): + assert JoinTypeEnum.RIGHT_JOIN.value == "RIGHT JOIN" + + def test_full_join(self): + assert JoinTypeEnum.FULL_JOIN.value == "FULL JOIN" + + def test_cross_join(self): + assert JoinTypeEnum.CROSS_JOIN.value == "CROSS JOIN" + + +class TestJoinComparatorEnum: + def test_equals(self): + assert JoinComparatorEnum.EQUALS.value == "=" + + def test_greater_than(self): + assert JoinComparatorEnum.GREATER_THAN.value == ">" + + def test_less_than(self): + assert JoinComparatorEnum.LESS_THAN.value == "<" + + +class TestFeatureGroupToBeMerged: + def test_initialization(self): + fg = FeatureGroupToBeMerged( + features=["id", "value"], + included_feature_names=["id", "value"], + projected_feature_names=["id", "value"], + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + table_name="my_table", + record_identifier_feature_name="id", + event_time_identifier_feature=FeatureDefinition( + feature_name="event_time", + feature_type="String", + ), + ) + + assert fg.features == ["id", "value"] + assert fg.catalog == "AwsDataCatalog" + assert fg.table_name == "my_table" + assert fg.join_type == JoinTypeEnum.INNER_JOIN + assert fg.join_comparator == JoinComparatorEnum.EQUALS + + def test_custom_join_settings(self): + fg = FeatureGroupToBeMerged( + features=["id"], + included_feature_names=["id"], + projected_feature_names=["id"], + catalog="AwsDataCatalog", + database="db", + table_name="table", + record_identifier_feature_name="id", + event_time_identifier_feature=FeatureDefinition( + feature_name="ts", + feature_type="String", + ), + join_type=JoinTypeEnum.LEFT_JOIN, + join_comparator=JoinComparatorEnum.GREATER_THAN, + ) + + assert fg.join_type == JoinTypeEnum.LEFT_JOIN + assert fg.join_comparator == JoinComparatorEnum.GREATER_THAN + + +class TestConstructFeatureGroupToBeMerged: + @patch("sagemaker.mlops.feature_store.dataset_builder.FeatureGroup") + def test_constructs_from_feature_group(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_group_name = "test-fg" + mock_fg.record_identifier_feature_name = "id" + mock_fg.event_time_feature_name = "event_time" + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="event_time", feature_type="String"), + ] + mock_fg.offline_store_config.data_catalog_config.catalog = "MyCatalog" + mock_fg.offline_store_config.data_catalog_config.database = "MyDatabase" + mock_fg.offline_store_config.data_catalog_config.table_name = "MyTable" + mock_fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + mock_fg_class.get.return_value = mock_fg + + target_fg = MagicMock() + target_fg.feature_group_name = "test-fg" + + result = construct_feature_group_to_be_merged( + target_feature_group=target_fg, + included_feature_names=["id", "value"], + ) + + assert result.table_name == "MyTable" + assert result.database == "MyDatabase" + assert result.record_identifier_feature_name == "id" + assert result.table_type == TableType.FEATURE_GROUP + + @patch("sagemaker.mlops.feature_store.dataset_builder.FeatureGroup") + def test_raises_when_no_metastore(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_group_name = "test-fg" + mock_fg.offline_store_config = None + mock_fg_class.get.return_value = mock_fg + + target_fg = MagicMock() + target_fg.feature_group_name = "test-fg" + + with pytest.raises(RuntimeError, match="No metastore"): + construct_feature_group_to_be_merged(target_fg, None) + + +class TestDatasetBuilder: + @pytest.fixture + def mock_session(self): + return Mock() + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + "event_time": ["2024-01-01", "2024-01-02", "2024-01-03"], + }) + + def test_initialization_with_dataframe(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + assert builder._output_path == "s3://bucket/output" + assert builder._record_identifier_feature_name == "id" + + def test_fluent_api_point_in_time(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.point_in_time_accurate_join() + + assert result is builder + assert builder._point_in_time_accurate_join is True + + def test_fluent_api_include_duplicated(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.include_duplicated_records() + + assert result is builder + assert builder._include_duplicated_records is True + + def test_fluent_api_include_deleted(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.include_deleted_records() + + assert result is builder + assert builder._include_deleted_records is True + + def test_fluent_api_number_of_recent_records(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.with_number_of_recent_records_by_record_identifier(5) + + assert result is builder + assert builder._number_of_recent_records == 5 + + def test_fluent_api_number_of_records(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.with_number_of_records_from_query_results(100) + + assert result is builder + assert builder._number_of_records == 100 + + def test_fluent_api_as_of(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + timestamp = datetime.datetime(2024, 1, 15, 12, 0, 0) + result = builder.as_of(timestamp) + + assert result is builder + assert builder._write_time_ending_timestamp == timestamp + + def test_fluent_api_event_time_range(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 1, 31) + result = builder.with_event_time_range(start, end) + + assert result is builder + assert builder._event_time_starting_timestamp == start + assert builder._event_time_ending_timestamp == end + + @patch.object(DatasetBuilder, "_run_query") + @patch("sagemaker.mlops.feature_store.dataset_builder.construct_feature_group_to_be_merged") + def test_with_feature_group(self, mock_construct, mock_run, mock_session, sample_dataframe): + mock_fg_to_merge = MagicMock() + mock_construct.return_value = mock_fg_to_merge + + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + mock_fg = MagicMock() + result = builder.with_feature_group(mock_fg, target_feature_name_in_base="id") + + assert result is builder + assert len(builder._feature_groups_to_be_merged) == 1 + + +class TestDatasetBuilderCreate: + @pytest.fixture + def mock_session(self): + return Mock() + + def test_create_with_feature_group(self, mock_session): + mock_fg = MagicMock(spec=FeatureGroup) + builder = DatasetBuilder.create( + base=mock_fg, + output_path="s3://bucket/output", + session=mock_session, + ) + assert builder._base == mock_fg + assert builder._output_path == "s3://bucket/output" + + def test_create_with_dataframe(self, mock_session): + df = pd.DataFrame({"id": [1], "value": [10]}) + builder = DatasetBuilder.create( + base=df, + output_path="s3://bucket/output", + session=mock_session, + record_identifier_feature_name="id", + event_time_identifier_feature_name="event_time", + ) + assert builder._record_identifier_feature_name == "id" + + def test_create_with_dataframe_requires_identifiers(self, mock_session): + df = pd.DataFrame({"id": [1], "value": [10]}) + with pytest.raises(ValueError, match="record_identifier_feature_name"): + DatasetBuilder.create( + base=df, + output_path="s3://bucket/output", + session=mock_session, + ) + + +class TestDatasetBuilderValidation: + @pytest.fixture + def mock_session(self): + return Mock() + + def test_to_csv_raises_for_invalid_base(self, mock_session): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base="invalid", # Not DataFrame or FeatureGroup + _output_path="s3://bucket/output", + ) + + with pytest.raises(ValueError, match="must be either"): + builder.to_csv_file() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py new file mode 100644 index 0000000000..299868b5d2 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for feature_definition.py""" +import pytest + +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, + CollectionTypeEnum, + IntegralFeatureDefinition, + FractionalFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, + ListCollectionType, + SetCollectionType, +) + + +class TestFeatureTypeEnum: + def test_fractional_value(self): + assert FeatureTypeEnum.FRACTIONAL.value == "Fractional" + + def test_integral_value(self): + assert FeatureTypeEnum.INTEGRAL.value == "Integral" + + def test_string_value(self): + assert FeatureTypeEnum.STRING.value == "String" + + +class TestCollectionTypeEnum: + def test_list_value(self): + assert CollectionTypeEnum.LIST.value == "List" + + def test_set_value(self): + assert CollectionTypeEnum.SET.value == "Set" + + def test_vector_value(self): + assert CollectionTypeEnum.VECTOR.value == "Vector" + + +class TestCollectionTypes: + def test_list_collection_type(self): + collection = ListCollectionType() + assert collection.collection_type == "List" + assert collection.collection_config is None + + def test_set_collection_type(self): + collection = SetCollectionType() + assert collection.collection_type == "Set" + assert collection.collection_config is None + + def test_vector_collection_type(self): + collection = VectorCollectionType(dimension=128) + assert collection.collection_type == "Vector" + assert collection.collection_config is not None + assert collection.collection_config.vector_config.dimension == 128 + + +class TestFeatureDefinitionFactories: + def test_integral_feature_definition(self): + definition = IntegralFeatureDefinition(feature_name="my_int_feature") + assert definition.feature_name == "my_int_feature" + assert definition.feature_type == "Integral" + assert definition.collection_type is None + + def test_fractional_feature_definition(self): + definition = FractionalFeatureDefinition(feature_name="my_float_feature") + assert definition.feature_name == "my_float_feature" + assert definition.feature_type == "Fractional" + assert definition.collection_type is None + + def test_string_feature_definition(self): + definition = StringFeatureDefinition(feature_name="my_string_feature") + assert definition.feature_name == "my_string_feature" + assert definition.feature_type == "String" + assert definition.collection_type is None + + def test_integral_with_list_collection(self): + definition = IntegralFeatureDefinition( + feature_name="my_int_list", + collection_type=ListCollectionType(), + ) + assert definition.feature_name == "my_int_list" + assert definition.feature_type == "Integral" + assert definition.collection_type == "List" + + def test_string_with_set_collection(self): + definition = StringFeatureDefinition( + feature_name="my_string_set", + collection_type=SetCollectionType(), + ) + assert definition.feature_name == "my_string_set" + assert definition.feature_type == "String" + assert definition.collection_type == "Set" + + def test_fractional_with_vector_collection(self): + definition = FractionalFeatureDefinition( + feature_name="my_embedding", + collection_type=VectorCollectionType(dimension=256), + ) + assert definition.feature_name == "my_embedding" + assert definition.feature_type == "Fractional" + assert definition.collection_type == "Vector" + assert definition.collection_config.vector_config.dimension == 256 + + +class TestFeatureDefinitionSerialization: + """Test that FeatureDefinition can be serialized (Pydantic model_dump).""" + + def test_simple_definition_serialization(self): + definition = IntegralFeatureDefinition(feature_name="id") + # Pydantic model - use model_dump + data = definition.model_dump(exclude_none=True) + assert data["feature_name"] == "id" + assert data["feature_type"] == "Integral" + + def test_collection_definition_serialization(self): + definition = FractionalFeatureDefinition( + feature_name="vector", + collection_type=VectorCollectionType(dimension=10), + ) + data = definition.model_dump(exclude_none=True) + assert data["feature_name"] == "vector" + assert data["feature_type"] == "Fractional" + assert data["collection_type"] == "Vector" + assert data["collection_config"]["vector_config"]["dimension"] == 10 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py new file mode 100644 index 0000000000..a9d5408bf6 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py @@ -0,0 +1,202 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for feature_utils.py""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd +import numpy as np + +from sagemaker.mlops.feature_store.feature_utils import ( + load_feature_definitions_from_dataframe, + as_hive_ddl, + create_athena_query, + ingest_dataframe, + get_session_from_role, + _is_collection_column, + _generate_feature_definition, +) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + ListCollectionType, +) + + +class TestLoadFeatureDefinitionsFromDataframe: + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": pd.Series([1, 2, 3], dtype="int64"), + "value": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "name": pd.Series(["a", "b", "c"], dtype="string"), + }) + + def test_infers_integral_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + id_def = next(d for d in defs if d.feature_name == "id") + assert id_def.feature_type == "Integral" + + def test_infers_fractional_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + value_def = next(d for d in defs if d.feature_name == "value") + assert value_def.feature_type == "Fractional" + + def test_infers_string_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + name_def = next(d for d in defs if d.feature_name == "name") + assert name_def.feature_type == "String" + + def test_returns_correct_count(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + assert len(defs) == 3 + + def test_collection_type_with_in_memory_storage(self): + df = pd.DataFrame({ + "id": pd.Series([1, 2], dtype="int64"), + "tags": pd.Series([["a", "b"], ["c"]], dtype="object"), + }) + defs = load_feature_definitions_from_dataframe(df, online_storage_type="InMemory") + tags_def = next(d for d in defs if d.feature_name == "tags") + assert tags_def.collection_type == "List" + + +class TestIsCollectionColumn: + def test_list_column_returns_true(self): + series = pd.Series([[1, 2], [3, 4], [5]]) + assert _is_collection_column(series) == True + + def test_scalar_column_returns_false(self): + series = pd.Series([1, 2, 3]) + assert _is_collection_column(series) == False + + def test_empty_series(self): + series = pd.Series([], dtype="object") + assert _is_collection_column(series) == False + + +class TestAsHiveDdl: + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_generates_ddl_string(self, mock_fg_class): + # Setup mock + mock_fg = MagicMock() + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="name", feature_type="String"), + ] + mock_fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix" + mock_fg_class.get.return_value = mock_fg + + ddl = as_hive_ddl("my-feature-group") + + assert "CREATE EXTERNAL TABLE" in ddl + assert "my-feature-group" in ddl + assert "id INT" in ddl + assert "value FLOAT" in ddl + assert "name STRING" in ddl + assert "write_time TIMESTAMP" in ddl + assert "event_time TIMESTAMP" in ddl + assert "is_deleted BOOLEAN" in ddl + assert "s3://bucket/prefix" in ddl + + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_custom_database_and_table(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_definitions = [] + mock_fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix" + mock_fg_class.get.return_value = mock_fg + + ddl = as_hive_ddl("my-fg", database="custom_db", table_name="custom_table") + + assert "custom_db.custom_table" in ddl + + +class TestCreateAthenaQuery: + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_creates_athena_query(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.offline_store_config.data_catalog_config.catalog = "MyCatalog" + mock_fg.offline_store_config.data_catalog_config.database = "MyDatabase" + mock_fg.offline_store_config.data_catalog_config.table_name = "MyTable" + mock_fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + mock_fg_class.get.return_value = mock_fg + + session = Mock() + query = create_athena_query("my-fg", session) + + assert query.catalog == "AwsDataCatalog" # disable_glue=False uses default + assert query.database == "MyDatabase" + assert query.table_name == "MyTable" + + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_raises_when_no_metastore(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.offline_store_config = None + mock_fg_class.get.return_value = mock_fg + + session = Mock() + with pytest.raises(RuntimeError, match="No metastore"): + create_athena_query("my-fg", session) + + +class TestIngestDataframe: + @patch("sagemaker.mlops.feature_store.feature_utils.IngestionManagerPandas") + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_creates_manager_and_runs(self, mock_fg_class, mock_manager_class): + mock_fg = MagicMock() + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + ] + mock_fg_class.get.return_value = mock_fg + + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + df = pd.DataFrame({"id": [1, 2, 3]}) + result = ingest_dataframe("my-fg", df, max_workers=2, max_processes=1) + + mock_manager_class.assert_called_once() + mock_manager.run.assert_called_once() + assert result == mock_manager + + def test_raises_on_invalid_max_workers(self): + df = pd.DataFrame({"id": [1, 2, 3]}) + with pytest.raises(ValueError, match="max_workers"): + ingest_dataframe("my-fg", df, max_workers=0) + + def test_raises_on_invalid_max_processes(self): + df = pd.DataFrame({"id": [1, 2, 3]}) + with pytest.raises(ValueError, match="max_processes"): + ingest_dataframe("my-fg", df, max_processes=-1) + + +class TestGetSessionFromRole: + @patch("sagemaker.mlops.feature_store.feature_utils.boto3") + @patch("sagemaker.mlops.feature_store.feature_utils.Session") + def test_creates_session_without_role(self, mock_session_class, mock_boto3): + mock_boto_session = MagicMock() + mock_boto3.Session.return_value = mock_boto_session + + get_session_from_role(region="us-west-2") + + mock_boto3.Session.assert_called_with(region_name="us-west-2") + mock_session_class.assert_called_once() + + @patch("sagemaker.mlops.feature_store.feature_utils.boto3") + @patch("sagemaker.mlops.feature_store.feature_utils.Session") + def test_assumes_role_when_provided(self, mock_session_class, mock_boto3): + mock_boto_session = MagicMock() + mock_sts = MagicMock() + mock_sts.assume_role.return_value = { + "Credentials": { + "AccessKeyId": "key", + "SecretAccessKey": "secret", + "SessionToken": "token", + } + } + mock_boto_session.client.return_value = mock_sts + mock_boto3.Session.return_value = mock_boto_session + + get_session_from_role(region="us-west-2", assume_role="arn:aws:iam::123:role/MyRole") + + mock_sts.assume_role.assert_called_once() + assert mock_boto3.Session.call_count == 2 # Initial + after assume diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py new file mode 100644 index 0000000000..2ecf495967 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py @@ -0,0 +1,256 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for ingestion_manager_pandas.py""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd +import numpy as np + +from sagemaker.mlops.feature_store.ingestion_manager_pandas import ( + IngestionManagerPandas, + IngestionError, +) + + +class TestIngestionError: + def test_stores_failed_rows(self): + error = IngestionError([1, 5, 10], "Some rows failed") + assert error.failed_rows == [1, 5, 10] + assert "Some rows failed" in str(error) + + +class TestIngestionManagerPandas: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + } + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + "name": ["a", "b", "c"], + }) + + @pytest.fixture + def manager(self, feature_definitions): + return IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=1, + max_processes=1, + ) + + def test_initialization(self, manager): + assert manager.feature_group_name == "test-fg" + assert manager.max_workers == 1 + assert manager.max_processes == 1 + assert manager.failed_rows == [] + + def test_failed_rows_property(self, manager): + manager._failed_indices = [1, 2, 3] + assert manager.failed_rows == [1, 2, 3] + + +class TestIngestionManagerHelpers: + def test_is_feature_collection_type_true(self): + feature_defs = { + "tags": {"FeatureName": "tags", "FeatureType": "String", "CollectionType": "List"}, + } + assert IngestionManagerPandas._is_feature_collection_type("tags", feature_defs) is True + + def test_is_feature_collection_type_false(self): + feature_defs = { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + } + assert IngestionManagerPandas._is_feature_collection_type("id", feature_defs) is False + + def test_is_feature_collection_type_missing(self): + feature_defs = {} + assert IngestionManagerPandas._is_feature_collection_type("unknown", feature_defs) is False + + def test_feature_value_is_not_none_scalar(self): + assert IngestionManagerPandas._feature_value_is_not_none(5) is True + assert IngestionManagerPandas._feature_value_is_not_none(None) is False + assert IngestionManagerPandas._feature_value_is_not_none(np.nan) is False + + def test_feature_value_is_not_none_list(self): + assert IngestionManagerPandas._feature_value_is_not_none([1, 2, 3]) is True + assert IngestionManagerPandas._feature_value_is_not_none([]) is True + assert IngestionManagerPandas._feature_value_is_not_none(None) is False + + def test_convert_to_string_list(self): + result = IngestionManagerPandas._convert_to_string_list([1, 2, 3]) + assert result == ["1", "2", "3"] + + def test_convert_to_string_list_with_none(self): + result = IngestionManagerPandas._convert_to_string_list([1, None, 3]) + assert result == ["1", None, "3"] + + def test_convert_to_string_list_raises_for_non_list(self): + with pytest.raises(ValueError, match="must be an Array"): + IngestionManagerPandas._convert_to_string_list("not a list") + + +class TestIngestionManagerRun: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + } + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + }) + + @patch.object(IngestionManagerPandas, "_run_single_process_single_thread") + def test_run_single_thread_mode(self, mock_single, feature_definitions, sample_dataframe): + manager = IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=1, + max_processes=1, + ) + + manager.run(sample_dataframe) + + mock_single.assert_called_once() + + @patch.object(IngestionManagerPandas, "_run_multi_process") + def test_run_multi_process_mode(self, mock_multi, feature_definitions, sample_dataframe): + manager = IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=2, + max_processes=2, + ) + + manager.run(sample_dataframe) + + mock_multi.assert_called_once() + + +class TestIngestionManagerIngestRow: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + } + + @pytest.fixture + def collection_feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "tags": {"FeatureName": "tags", "FeatureType": "String", "CollectionType": "List"}, + } + + def test_ingest_row_success(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + mock_fg.put_record.assert_called_once() + assert len(failed_rows) == 0 + + def test_ingest_row_with_collection_type(self, collection_feature_definitions): + df = pd.DataFrame({ + "id": [1], + "tags": [["tag1", "tag2"]], + }) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=collection_feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + mock_fg.put_record.assert_called_once() + call_args = mock_fg.put_record.call_args + record = call_args[1]["record"] + + # Find the tags feature value + tags_value = next(v for v in record if v.feature_name == "tags") + assert tags_value.value_as_string_list == ["tag1", "tag2"] + + def test_ingest_row_failure_appends_to_failed(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + mock_fg.put_record.side_effect = Exception("API Error") + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + assert len(failed_rows) == 1 + assert failed_rows[0] == 0 # Index of failed row + + def test_ingest_row_with_target_stores(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=["OnlineStore"], + ) + + call_args = mock_fg.put_record.call_args + assert call_args[1]["target_stores"] == ["OnlineStore"] + + def test_ingest_row_skips_none_values(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": [None]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + call_args = mock_fg.put_record.call_args + record = call_args[1]["record"] + # Only id should be in record, name is None + assert len(record) == 1 + assert record[0].feature_name == "id" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py new file mode 100644 index 0000000000..44e3ec6085 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for inputs.py (enums).""" +import pytest + +from sagemaker.mlops.feature_store.inputs import ( + TargetStoreEnum, + OnlineStoreStorageTypeEnum, + TableFormatEnum, + ResourceEnum, + SearchOperatorEnum, + SortOrderEnum, + FilterOperatorEnum, + DeletionModeEnum, + ExpirationTimeResponseEnum, + ThroughputModeEnum, +) + + +class TestTargetStoreEnum: + def test_online_store(self): + assert TargetStoreEnum.ONLINE_STORE.value == "OnlineStore" + + def test_offline_store(self): + assert TargetStoreEnum.OFFLINE_STORE.value == "OfflineStore" + + +class TestOnlineStoreStorageTypeEnum: + def test_standard(self): + assert OnlineStoreStorageTypeEnum.STANDARD.value == "Standard" + + def test_in_memory(self): + assert OnlineStoreStorageTypeEnum.IN_MEMORY.value == "InMemory" + + +class TestTableFormatEnum: + def test_glue(self): + assert TableFormatEnum.GLUE.value == "Glue" + + def test_iceberg(self): + assert TableFormatEnum.ICEBERG.value == "Iceberg" + + +class TestResourceEnum: + def test_feature_group(self): + assert ResourceEnum.FEATURE_GROUP.value == "FeatureGroup" + + def test_feature_metadata(self): + assert ResourceEnum.FEATURE_METADATA.value == "FeatureMetadata" + + +class TestSearchOperatorEnum: + def test_and(self): + assert SearchOperatorEnum.AND.value == "And" + + def test_or(self): + assert SearchOperatorEnum.OR.value == "Or" + + +class TestSortOrderEnum: + def test_ascending(self): + assert SortOrderEnum.ASCENDING.value == "Ascending" + + def test_descending(self): + assert SortOrderEnum.DESCENDING.value == "Descending" + + +class TestFilterOperatorEnum: + def test_equals(self): + assert FilterOperatorEnum.EQUALS.value == "Equals" + + def test_not_equals(self): + assert FilterOperatorEnum.NOT_EQUALS.value == "NotEquals" + + def test_greater_than(self): + assert FilterOperatorEnum.GREATER_THAN.value == "GreaterThan" + + def test_contains(self): + assert FilterOperatorEnum.CONTAINS.value == "Contains" + + def test_exists(self): + assert FilterOperatorEnum.EXISTS.value == "Exists" + + def test_in(self): + assert FilterOperatorEnum.IN.value == "In" + + +class TestDeletionModeEnum: + def test_soft_delete(self): + assert DeletionModeEnum.SOFT_DELETE.value == "SoftDelete" + + def test_hard_delete(self): + assert DeletionModeEnum.HARD_DELETE.value == "HardDelete" + + +class TestExpirationTimeResponseEnum: + def test_disabled(self): + assert ExpirationTimeResponseEnum.DISABLED.value == "Disabled" + + def test_enabled(self): + assert ExpirationTimeResponseEnum.ENABLED.value == "Enabled" + + +class TestThroughputModeEnum: + def test_on_demand(self): + assert ThroughputModeEnum.ON_DEMAND.value == "OnDemand" + + def test_provisioned(self): + assert ThroughputModeEnum.PROVISIONED.value == "Provisioned"