1+ from network_security .entity .artifact_entity import DataIngestionArtifact , DataValidationArtifact
2+ from network_security .entity .config_entity import DataValidationConfig
3+ from network_security .exceptions .exception import NetworkSecurityException
4+ from network_security .logging .logger import logging
5+ from network_security .utils .main_utils .utils import read_yaml_file , write_yaml_file
6+ from network_security .constants .training_pipeline import SCHEMA_FILE_PATH
7+
8+ from scipy .stats import ks_2samp # helps with detecting drifting in data
9+ import pandas as pd
10+ import numpy as np
11+ import os , sys
12+
13+
14+ class DataValidation :
15+ """
16+ Static methods
17+ """
18+ @staticmethod
19+ def read_data (file_path : str ) -> pd .DataFrame :
20+ try :
21+ return pd .read_csv (file_path )
22+ except Exception as e :
23+ raise NetworkSecurityException (e , sys )
24+
25+ """
26+ Class methods start here
27+ """
28+ def __init__ (self ,
29+ data_ingestion_artifact : DataIngestionArtifact ,
30+ data_validation_config : DataValidationConfig ):
31+ try :
32+ self .data_ingestion_artifact = data_ingestion_artifact
33+ self .data_validation_config = data_validation_config
34+ self ._schema_config = read_yaml_file (SCHEMA_FILE_PATH )
35+ except Exception as e :
36+ raise NetworkSecurityException (e , sys )
37+
38+ def initiate_data_validation (self ) -> DataValidationArtifact :
39+ try :
40+ train_file_path = self .data_ingestion_artifact .train_file_path
41+ test_file_path = self .data_ingestion_artifact .test_file_path
42+ logging .info ("Reading train and test data for validation" )
43+
44+ # reading train and test data
45+ train_df = DataValidation .read_data (train_file_path )
46+ test_df = DataValidation .read_data (test_file_path )
47+
48+ # validating number of columns in train dataframe
49+ status = self .validate_number_of_columns (train_df )
50+ if not status :
51+ logging .info ("Number of columns in train dataframe are not as per schema" )
52+
53+ # validating number of columns in test dataframe
54+ status = self .validate_number_of_columns (test_df )
55+ if not status :
56+ logging .info ("Number of columns in test dataframe are not as per schema" )
57+
58+ # checking for data drift
59+ status = self .detect_data_drift (base_df = train_df , current_df = test_df )
60+ dir_path = os .path .dirname (self .data_validation_config .valid_train_file_path )
61+ os .makedirs (dir_path , exist_ok = True )
62+
63+ # saving the validated train and test data in their respective paths
64+ train_df .to_csv (self .data_validation_config .valid_train_file_path , index = False , header = True )
65+ test_df .to_csv (self .data_validation_config .valid_test_file_path , index = False , header = True )
66+
67+ data_validation_artifact = DataValidationArtifact (
68+ validation_status = status ,
69+ valid_train_file_path = self .data_ingestion_artifact .train_file_path ,
70+ valid_test_file_path = self .data_ingestion_artifact .test_file_path ,
71+ invalid_train_file_path = None ,
72+ invalid_test_file_path = None ,
73+ drift_report_file_path = self .data_validation_config .drift_report_file_path
74+ )
75+
76+ return data_validation_artifact
77+ except Exception as e :
78+ raise NetworkSecurityException (e , sys )
79+
80+ def validate_number_of_columns (self , dataframe : pd .DataFrame ) -> bool :
81+ try :
82+ num_of_cols = len (self ._schema_config ['columns' ])
83+ logging .info (f"Required number of columns: { num_of_cols } " )
84+ logging .info (f"Dataframe has columns: { dataframe .shape [1 ]} " )
85+ return True if dataframe .shape [1 ] == num_of_cols else False
86+ except Exception as e :
87+ raise NetworkSecurityException (e , sys )
88+
89+ def validate_number_of_numeric_columns (self , dataframe : pd .DataFrame ) -> bool :
90+ try :
91+ # Get expected numerical columns from schema
92+ numerical_columns = self ._schema_config ['numerical_columns' ]
93+ dataframe_columns = dataframe .columns .tolist ()
94+
95+ # Check which numerical columns are present in the dataframe
96+ present_numerical_cols = [col for col in numerical_columns if col in dataframe_columns ]
97+ missing_numerical_cols = [col for col in numerical_columns if col not in dataframe_columns ]
98+
99+ logging .info (f"Required number of numerical columns: { len (numerical_columns )} " )
100+ logging .info (f"Dataframe has numerical columns: { len (present_numerical_cols )} " )
101+
102+ if missing_numerical_cols :
103+ logging .warning (f"Missing numerical columns: { missing_numerical_cols } " )
104+ return False
105+
106+ return True
107+
108+ except Exception as e :
109+ raise NetworkSecurityException (e , sys )
110+
111+ def detect_data_drift (self ,
112+ base_df : pd .DataFrame ,
113+ current_df : pd .DataFrame ,
114+ threshold : float = 0.05 ) -> bool :
115+ try :
116+ status = True
117+ report = {}
118+ for col in base_df .columns :
119+ d1 = base_df [col ]
120+ d2 = current_df [col ]
121+
122+ is_sample_distribution = ks_2samp (d1 , d2 )
123+ if threshold <= is_sample_distribution .pvalue :
124+ is_found = False
125+ else :
126+ is_found = True
127+ status = False
128+
129+ report .update ({
130+ col : {
131+ "p_value" : float (is_sample_distribution .pvalue ),
132+ "drift_status" : is_found
133+ }
134+ })
135+
136+ # creating directory for drift report file path
137+ drift_report_file_path = self .data_validation_config .drift_report_file_path
138+ dir_path = os .path .dirname (drift_report_file_path )
139+ os .makedirs (dir_path , exist_ok = True )
140+
141+ # writing to the yaml file
142+ write_yaml_file (
143+ file_path = drift_report_file_path ,
144+ content = report
145+ )
146+ except Exception as e :
147+ raise NetworkSecurityException (e , sys )
0 commit comments