-
Notifications
You must be signed in to change notification settings - Fork 1.4k
"Add S3 data loader support to DBTableManager and data formulator" #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2fb5016
9363049
6433640
0d3e6c0
48f4ee2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| from data_formulator.data_loader.external_data_loader import ExternalDataLoader | ||
| from data_formulator.data_loader.mysql_data_loader import MySQLDataLoader | ||
| from data_formulator.data_loader.kusto_data_loader import KustoDataLoader | ||
| from data_formulator.data_loader.s3_data_loader import S3DataLoader | ||
|
|
||
| DATA_LOADERS = { | ||
| "mysql": MySQLDataLoader, | ||
| "kusto": KustoDataLoader | ||
| "kusto": KustoDataLoader, | ||
| "s3": S3DataLoader, | ||
| } | ||
|
|
||
| __all__ = ["ExternalDataLoader", "MySQLDataLoader", "KustoDataLoader", "DATA_LOADERS"] | ||
| __all__ = ["ExternalDataLoader", "MySQLDataLoader", "KustoDataLoader", "S3DataLoader", "DATA_LOADERS"] |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,188 @@ | ||||||
| import json | ||||||
| import pandas as pd | ||||||
| import duckdb | ||||||
| import os | ||||||
|
|
||||||
| from data_formulator.data_loader.external_data_loader import ExternalDataLoader, sanitize_table_name | ||||||
| from typing import Dict, Any, List | ||||||
|
|
||||||
| class S3DataLoader(ExternalDataLoader): | ||||||
|
|
||||||
| @staticmethod | ||||||
| def list_params() -> List[Dict[str, Any]]: | ||||||
| params_list = [ | ||||||
| {"name": "aws_access_key_id", "type": "string", "required": True, "default": "", "description": "AWS access key ID"}, | ||||||
| {"name": "aws_secret_access_key", "type": "string", "required": True, "default": "", "description": "AWS secret access key"}, | ||||||
| {"name": "aws_session_token", "type": "string", "required": False, "default": "", "description": "AWS session token (required for temporary credentials)"}, | ||||||
| {"name": "region_name", "type": "string", "required": True, "default": "us-east-1", "description": "AWS region name"}, | ||||||
| {"name": "bucket", "type": "string", "required": True, "default": "", "description": "S3 bucket name"} | ||||||
| ] | ||||||
| return params_list | ||||||
|
|
||||||
| def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnection): | ||||||
| self.params = params | ||||||
| self.duck_db_conn = duck_db_conn | ||||||
|
|
||||||
| # Extract parameters | ||||||
| self.aws_access_key_id = params.get("aws_access_key_id", "") | ||||||
| self.aws_secret_access_key = params.get("aws_secret_access_key", "") | ||||||
| self.aws_session_token = params.get("aws_session_token", "") | ||||||
| self.region_name = params.get("region_name", "us-east-1") | ||||||
| self.bucket = params.get("bucket", "") | ||||||
|
|
||||||
| # Install and load the httpfs extension for S3 access | ||||||
| self.duck_db_conn.install_extension("httpfs") | ||||||
| self.duck_db_conn.load_extension("httpfs") | ||||||
|
|
||||||
| # Set AWS credentials for DuckDB | ||||||
| self.duck_db_conn.execute(f"SET s3_region='{self.region_name}'") | ||||||
| self.duck_db_conn.execute(f"SET s3_access_key_id='{self.aws_access_key_id}'") | ||||||
| self.duck_db_conn.execute(f"SET s3_secret_access_key='{self.aws_secret_access_key}'") | ||||||
| if self.aws_session_token: # Add this block | ||||||
| self.duck_db_conn.execute(f"SET s3_session_token='{self.aws_session_token}'") | ||||||
|
|
||||||
| def list_tables(self) -> List[Dict[str, Any]]: | ||||||
| # Use boto3 to list objects in the bucket | ||||||
| import boto3 | ||||||
|
|
||||||
| s3_client = boto3.client( | ||||||
| 's3', | ||||||
| aws_access_key_id=self.aws_access_key_id, | ||||||
| aws_secret_access_key=self.aws_secret_access_key, | ||||||
| aws_session_token=self.aws_session_token if self.aws_session_token else None, | ||||||
| region_name=self.region_name | ||||||
| ) | ||||||
|
|
||||||
| # List objects in the bucket | ||||||
| response = s3_client.list_objects_v2(Bucket=self.bucket) | ||||||
|
|
||||||
| results = [] | ||||||
|
|
||||||
| if 'Contents' in response: | ||||||
| for obj in response['Contents']: | ||||||
| key = obj['Key'] | ||||||
|
|
||||||
| # Skip directories and non-data files | ||||||
| if key.endswith('/') or not self._is_supported_file(key): | ||||||
| continue | ||||||
|
|
||||||
| # Create S3 URL | ||||||
| s3_url = f"s3://{self.bucket}/{key}" | ||||||
|
|
||||||
| try: | ||||||
| # Choose the appropriate read function based on file extension | ||||||
| if s3_url.lower().endswith('.parquet'): | ||||||
| sample_df = self.duck_db_conn.execute(f"SELECT * FROM read_parquet('{s3_url}') LIMIT 10").df() | ||||||
| elif s3_url.lower().endswith('.json') or s3_url.lower().endswith('.jsonl'): | ||||||
| sample_df = self.duck_db_conn.execute(f"SELECT * FROM read_json_auto('{s3_url}') LIMIT 10").df() | ||||||
| elif s3_url.lower().endswith('.csv'): # Default to CSV for other formats | ||||||
| sample_df = self.duck_db_conn.execute(f"SELECT * FROM read_csv_auto('{s3_url}') LIMIT 10").df() | ||||||
|
|
||||||
| # Get column information | ||||||
| columns = [{ | ||||||
| 'name': col, | ||||||
| 'type': str(sample_df[col].dtype) | ||||||
| } for col in sample_df.columns] | ||||||
|
|
||||||
| # Get sample data | ||||||
| sample_rows = json.loads(sample_df.to_json(orient="records")) | ||||||
|
|
||||||
| # Estimate row count (this is approximate for CSV files) | ||||||
| row_count = self._estimate_row_count(s3_url) | ||||||
|
|
||||||
| table_metadata = { | ||||||
| "row_count": row_count, | ||||||
| "columns": columns, | ||||||
| "sample_rows": sample_rows | ||||||
| } | ||||||
|
|
||||||
| results.append({ | ||||||
| "name": s3_url, | ||||||
| "metadata": table_metadata | ||||||
| }) | ||||||
| except Exception as e: | ||||||
| # Skip files that can't be read | ||||||
| print(f"Error reading {s3_url}: {e}") | ||||||
|
||||||
| print(f"Error reading {s3_url}: {e}") | |
| logging.error(f"Error reading {s3_url}: {e}") |
Copilot
AI
May 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a full COUNT(*) scan on large Parquet files may be slow; consider reading row-count metadata from the file footer instead of scanning all rows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The boto3 client is created in multiple methods; consider extracting common initialization into a private helper to reduce duplication and simplify updates.