|
| 1 | +from pyspark.sql import SparkSession |
| 2 | +from pyspark.ml.feature import VectorAssembler |
| 3 | +import pandas as pd |
| 4 | +import numpy as np |
| 5 | +import seaborn as sns |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +from pyspark.ml.feature import VectorAssembler |
| 8 | +from pyspark.ml.stat import Correlation |
| 9 | +import sys |
| 10 | +from awsglue.utils import getResolvedOptions |
| 11 | +from awsglue.context import GlueContext |
| 12 | +from awsglue.job import Job |
| 13 | +import boto3 |
| 14 | +import io |
| 15 | + |
| 16 | + |
| 17 | +args = getResolvedOptions(sys.argv, ['JOB_NAME']) |
| 18 | + |
| 19 | +spark = ( |
| 20 | + SparkSession.builder |
| 21 | + .appName("FeatureSelection") |
| 22 | + .config("spark.sql.parquet.enableVectorizedReader", "true") |
| 23 | + .config("spark.sql.parquet.mergeSchema", "true") |
| 24 | + .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") |
| 25 | + .config("spark.sql.catalog.glue_catalog", "org.apache.iceberg.spark.SparkCatalog") |
| 26 | + .config("spark.sql.catalog.glue_catalog.warehouse", "s3://bdp-scaled-features/") |
| 27 | + .config("spark.sql.catalog.glue_catalog.catalog-impl", "org.apache.iceberg.aws.glue.GlueCatalog") |
| 28 | + .config("spark.sql.catalog.glue_catalog.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") |
| 29 | + .config("spark.sql.catalog.glue_catalog.glue.id", "982534349340") |
| 30 | + .config("spark.sql.adaptive.enabled", "true") |
| 31 | + .getOrCreate() |
| 32 | +) |
| 33 | + |
| 34 | +glueContext = GlueContext(spark) |
| 35 | +job = Job(glueContext) |
| 36 | +job.init(args['JOB_NAME'], args) |
| 37 | + |
| 38 | + |
| 39 | +features_df = glueContext.create_data_frame.from_catalog( |
| 40 | + database="bdp", |
| 41 | + table_name="scaled_features", |
| 42 | + additional_options = { |
| 43 | + "useCatalogSchema": True, |
| 44 | + "useSparkDataSource": True |
| 45 | + } |
| 46 | +) |
| 47 | + |
| 48 | +feature_cols = [col for col in features_df.columns] |
| 49 | + |
| 50 | +df_vectorized = VectorAssembler(inputCols=feature_cols, outputCol="features").transform(features_df) |
| 51 | +df_vectorized.cache() |
| 52 | +correlation_matrix = Correlation.corr(df_vectorized, "features", method="spearman").head()[0].toArray() |
| 53 | +correlation_matrix_np = np.array(correlation_matrix) |
| 54 | +correlation_matrix_df = pd.DataFrame(correlation_matrix_np, index=feature_cols, columns=feature_cols) |
| 55 | +df_vectorized.unpersist() |
| 56 | + |
| 57 | +output_bucket = "bdp-feature-selection" |
| 58 | + |
| 59 | +# Save correlation matrix to S3 as CSV |
| 60 | +s3_client = boto3.client('s3') |
| 61 | +correlation_csv_buffer = io.StringIO() |
| 62 | +correlation_matrix_df.to_csv(correlation_csv_buffer) |
| 63 | +s3_client.put_object( |
| 64 | + Bucket=output_bucket, |
| 65 | + Key="data/correlation_matrix.csv", |
| 66 | + Body=correlation_csv_buffer.getvalue() |
| 67 | +) |
| 68 | + |
| 69 | + |
| 70 | +threshold = 0.9 |
| 71 | +to_remove = set() |
| 72 | +for i in range(len(correlation_matrix_np)): |
| 73 | + for j in range(i+1, len(correlation_matrix_np)): |
| 74 | + if abs(correlation_matrix_np[i, j]) > threshold: |
| 75 | + to_remove.add(feature_cols[j]) |
| 76 | + |
| 77 | +selected_columns = [col for col in feature_cols if col not in to_remove] |
| 78 | + |
| 79 | +# Save selected columns (only names) to S3 as a plain text file |
| 80 | +selected_columns_buffer = io.StringIO() |
| 81 | +selected_columns_buffer.write("\n".join(selected_columns)) # Write column names line by line |
| 82 | +s3_client.put_object( |
| 83 | + Bucket=output_bucket, |
| 84 | + Key="data/selected_columns.txt", |
| 85 | + Body=selected_columns_buffer.getvalue(), |
| 86 | + ContentType="text/plain" |
| 87 | +) |
| 88 | + |
| 89 | +# Generate and save the heatmap plot |
| 90 | +plt.figure(figsize=(12, 8)) |
| 91 | +sns.heatmap(correlation_matrix_df, cmap="coolwarm", fmt=".2f", vmin=-1, vmax=1, annot=False) |
| 92 | +plt.title("Correlation Heatmap") |
| 93 | +heatmap_buffer = io.BytesIO() |
| 94 | +plt.savefig(heatmap_buffer, format='png', bbox_inches='tight') |
| 95 | +heatmap_buffer.seek(0) |
| 96 | +s3_client.put_object( |
| 97 | + Bucket=output_bucket, |
| 98 | + Key="data/correlation_heatmap.png", |
| 99 | + Body=heatmap_buffer, |
| 100 | + ContentType='image/png' |
| 101 | +) |
| 102 | + |
| 103 | +job.commit() |
| 104 | +spark.stop() |
0 commit comments