This project uses an Artificial Neural Network (ANN) to predict customer churn based on various features. The model predicts whether a customer is likely to leave the bank based on their demographic and account information.
The project is organized into the following directories:
ANN/
├── .github/
│ └── workflows/
│ └── ci.yml # GitHub Actions CI pipeline
├── data/ # Data files
│ └── Churn_Modelling.csv # Training dataset
├── images/ # Images and diagrams
│ └── pipeline_diagram_detailed.png
├── models/ # Trained models and preprocessors
│ ├── model.h5 # Trained ANN model
│ ├── scaler.pkl # StandardScaler for feature scaling
│ ├── label_encoder_gender.pkl # Label encoder for Gender
│ └── onehot_encoder_geo.pkl # One-hot encoder for Geography
├── notebooks/ # Jupyter notebooks
│ ├── ANN.ipynb # Main notebook for building and training the ANN
│ └── prediction.ipynb # Notebook for making predictions
├── src/ # Source code
│ └── app.py # Streamlit web application
├── .gitignore # Git ignore rules
├── requirements.txt # Python dependencies
└── README.md # This file
- Python 3.8 or higher
- pip (Python package manager)
Install all required packages using pip:
pip install -r requirements.txtNote: If you encounter a ModuleNotFoundError: No module named 'sklearn', make sure scikit-learn is installed:
pip install scikit-learnEnsure all model files exist in the models/ directory:
model.h5scaler.pkllabel_encoder_gender.pklonehot_encoder_geo.pkl
If these files are missing, you'll need to train the model first (see Model Training section).
-
Navigate to the project root directory:
cd ANN -
Run the Streamlit app:
streamlit run src/app.py
-
Access the application:
- The app will automatically open in your default web browser
- If it doesn't, navigate to
http://localhost:8501in your browser
-
Use the application:
- Fill in the customer information using the form controls
- Click the "Predict Churn" button to see the prediction
- The prediction shows both the probability and a binary classification
-
Start Jupyter Notebook:
jupyter notebook
-
Open the desired notebook:
notebooks/ANN.ipynb- For training the modelnotebooks/prediction.ipynb- For making predictions programmatically
To retrain the model from scratch:
- Open
notebooks/ANN.ipynbin Jupyter Notebook - Run all cells sequentially
- Ensure the generated files are saved to the
models/directory:model.h5- The trained modelscaler.pkl- The fitted StandardScalerlabel_encoder_gender.pkl- The fitted LabelEncoder for Genderonehot_encoder_geo.pkl- The fitted OneHotEncoder for Geography
Important: After training, move all generated .pkl and .h5 files to the models/ directory.
For real-time predictions, your input data should follow this CSV format:
CreditScore(integer): Credit score of the customer (typically 350-850)Geography(string): Country of residence - must be one of:France,Germany, orSpainGender(string): Gender - must be eitherMaleorFemaleAge(integer): Age of the customer (typically 18-92)Tenure(integer): Number of years the customer has been with the bank (0-10)Balance(float): Account balanceNumOfProducts(integer): Number of bank products the customer uses (1-4)HasCrCard(integer): Whether the customer has a credit card (0 or 1)IsActiveMember(integer): Whether the customer is an active member (0 or 1)EstimatedSalary(float): Estimated salary of the customer
CreditScore,Geography,Gender,Age,Tenure,Balance,NumOfProducts,HasCrCard,IsActiveMember,EstimatedSalary
650,France,Male,35,5,125000.50,2,1,1,75000.00
720,Germany,Female,42,3,85000.25,1,1,0,95000.00
580,Spain,Male,28,7,50000.00,3,0,1,55000.00You can modify notebooks/prediction.ipynb to read from a CSV file and make batch predictions:
import pandas as pd
# Load your realtime data
realtime_data = pd.read_csv('your_data.csv')
# Process and predict for each row- Launch the app using
streamlit run src/app.py - Fill in the customer information:
- Geography: Select from dropdown (France, Germany, Spain)
- Gender: Select from dropdown (Male, Female)
- Age: Use slider (18-92)
- Balance: Enter account balance
- Credit Score: Enter credit score
- Estimated Salary: Enter estimated salary
- Tenure: Use slider (0-10 years)
- Number of Products: Use slider (1-4)
- Has Credit Card: Select 0 (No) or 1 (Yes)
- Is Active Member: Select 0 (No) or 1 (Yes)
- View the prediction:
- Churn Probability: A value between 0 and 1
- Prediction: Binary classification (likely to churn or not)
See notebooks/prediction.ipynb for examples of making predictions programmatically.
Solution: Install scikit-learn:
pip install scikit-learnSolution: Ensure all model files exist in the models/ directory. If missing, train the model using notebooks/ANN.ipynb.
Solution: Make sure all dependencies are installed:
pip install -r requirements.txt- tensorflow (>=2.20.0): Deep learning framework
- pandas (>=2.0.0): Data manipulation
- numpy (>=1.24.0): Numerical computing
- scikit-learn (>=1.3.0): Machine learning utilities
- streamlit (>=1.40.0): Web application framework
- matplotlib (>=3.7.0): Plotting library
- tensorboard (>=2.20.0): TensorFlow visualization
This project includes a simple GitHub Actions CI pipeline that automatically:
- Verifies Python syntax - Checks that the code is syntactically correct
- Verifies dependencies - Ensures all required packages can be imported
- Verifies app import - Tests that the application can be imported
- Checks model files - Verifies all required model files exist
The CI pipeline runs automatically on:
- Push to
mainbranch - Pull requests to
mainbranch
View the workflow file at .github/workflows/ci.yml
The pipeline is lightweight and fast, perfect for a public repository.
This project is for educational purposes.
