This project demonstrates the use of Artificial Neural Networks (ANNs) for both classification and regression tasks using TensorFlow/Keras. The primary focus is on customer churn prediction with an interactive Streamlit web application.
This is a comprehensive deep learning project that includes:
- Customer Churn Classification: Predict whether a customer will leave a bank
- Salary Regression: Predict estimated salary based on customer features
- Hyperparameter Tuning: Optimize ANN architecture for best performance
- Interactive Web Application: Streamlit-based UI for real-time predictions
Section-53-annclassification/annclassification/
βββ app.py # Streamlit web application for churn prediction
βββ Churn_Modelling.csv # Main dataset
βββ requirements.txt # Python dependencies
βββ model.h5 # Trained ANN model for churn classification
βββ regression_model.h5 # Trained ANN model for salary regression
βββ label_encoder_gender.pkl # Saved encoder for gender feature
βββ onehot_encoder_geo.pkl # Saved encoder for geography feature
βββ scaler.pkl # Saved StandardScaler for feature normalization
βββ experiments.ipynb # Data exploration and preprocessing notebook
βββ hyperparametertuningann.ipynb # Hyperparameter tuning and model optimization
βββ prediction.ipynb # Churn prediction examples
βββ salaryregression.ipynb # Salary prediction model development
βββ logs/ # TensorBoard logs for churn model training
β βββ fit/20260203-094807/
β βββ train/
β βββ validation/
βββ regressionlogs/ # TensorBoard logs for regression model training
βββ fit/20260204-221255/
βββ train/
βββ validation/
File: Churn_Modelling.csv
The dataset contains customer banking information with the following features:
- CreditScore: Customer's credit score
- Geography: Customer's country (France, Germany, Spain)
- Gender: Customer's gender (Male/Female)
- Age: Customer's age in years
- Tenure: Years of customer relationship with the bank
- Balance: Account balance
- NumOfProducts: Number of products the customer uses
- HasCrCard: Whether customer has a credit card (0/1)
- IsActiveMember: Whether customer is active (0/1)
- EstimatedSalary: Estimated annual salary
- Exited: Target variable - whether customer churned (0/1)
- Install required dependencies:
pip install -r requirements.txtLaunch the Streamlit app for real-time churn predictions:
streamlit run app.pyThe app will open in your browser at http://localhost:8501 with an interactive interface to predict customer churn.
Initial data exploration and preprocessing notebook.
- Loads the Churn_Modelling.csv dataset
- Data cleaning (removes non-feature columns: RowNumber, CustomerId, Surname)
- Label encoding for categorical gender variable
- One-hot encoding for geography feature
- Basic data analysis and visualization
Focuses on finding optimal ANN architecture and hyperparameters.
Key Steps:
- Data loading and preprocessing (same as experiments.ipynb)
- Train-test split (80-20 with random_state=42)
- Feature scaling using StandardScaler
- ANN model creation with configurable hidden layers
- Grid search/Random search for optimal hyperparameters
- Model training with early stopping
- Model evaluation and performance metrics
- Saves trained model as
model.h5
Architecture Guidelines:
- Start with 1-2 hidden layers
- Hidden neurons: between input and output layer size
- Uses EarlyStopping callback to prevent overfitting
- Saves encoders and scaler for deployment
Demonstrates how to use the trained churn classification model.
Process:
- Loads pre-trained model and preprocessors (model.h5, scaler, encoders)
- Creates sample input data with customer features
- Encodes categorical variables (Gender, Geography)
- Scales features using saved StandardScaler
- Makes churn predictions with probability scores
- Interprets predictions (>0.5 = likely to churn, <0.5 = unlikely)
Develops an ANN model for salary prediction (regression task).
Process:
- Loads and preprocesses Churn_Modelling.csv data
- Drops irrelevant columns and encodes categorical features
- Splits data into training and testing sets
- Scales features using StandardScaler
- Creates ANN regression model architecture
- Trains model with appropriate loss function (MSE for regression)
- Evaluates performance using regression metrics
- Saves trained regression model as
regression_model.h5
The interactive web application provides a user-friendly interface for churn prediction.
Features:
- Input Controls:
- Geography dropdown (from training data)
- Gender dropdown (Male/Female)
- Age slider (18-92 years)
- Credit Score input field
- Balance input field
- Estimated Salary input field
- Tenure slider (0-10 years)
- Number of Products slider (1-4)
- Credit Card status (Yes/No)
- Active Member status (Yes/No)
Prediction Output:
- Displays churn probability (0-1 scale)
- Provides interpretation: "likely to churn" if probability > 0.5
- Clear, intuitive user experience
| Package | Version | Purpose |
|---|---|---|
| TensorFlow | 2.15.0 | Deep learning framework |
| Keras | Latest | Neural network API (included in TensorFlow) |
| Pandas | Latest | Data manipulation and analysis |
| NumPy | Latest | Numerical computations |
| scikit-learn | Latest | Preprocessing and model evaluation |
| Streamlit | Latest | Web application framework |
| TensorBoard | Latest | Model training visualization |
| scikit-learn-keras | Latest | Scikit-learn wrapper for Keras |
| Matplotlib | Latest | Data visualization |
| File | Purpose | Size |
|---|---|---|
model.h5 |
Trained ANN classifier for churn prediction | Binary format |
regression_model.h5 |
Trained ANN regressor for salary prediction | Binary format |
label_encoder_gender.pkl |
LabelEncoder for gender (Male/Female β 0/1) | Pickle format |
onehot_encoder_geo.pkl |
OneHotEncoder for geography (3 countries β 3 features) | Pickle format |
scaler.pkl |
StandardScaler for feature normalization | Pickle format |
- Input Features: 11 (after preprocessing)
- Output: Binary (churned or not)
- Activation: ReLU for hidden layers, Sigmoid for output
- Loss Function: Binary Crossentropy
- Optimizer: Adam
- Metrics: Accuracy, Precision, Recall, AUC
- Input Features: 11 (after preprocessing)
- Output: Continuous value (salary)
- Activation: ReLU for hidden layers, Linear for output
- Loss Function: Mean Squared Error (MSE)
- Optimizer: Adam
- Metrics: MAE, RMSE, R-squared
Two sets of training logs are available for visualization:
Churn Model (logs/fit/20260203-094807/):
- Training and validation metrics from model optimization
Regression Model (regressionlogs/fit/20260204-221255/):
- Training and validation metrics for salary prediction
View logs using:
tensorboard --logdir=logs/fit
# or
tensorboard --logdir=regressionlogs/fit- Data Cleaning: Remove non-feature columns (RowNumber, CustomerId, Surname)
- Categorical Encoding:
- Gender: LabelEncoder (Male=1, Female=0)
- Geography: OneHotEncoder (creates 3 binary features for 3 countries)
- Feature Scaling: StandardScaler (mean=0, std=1) applied to all numerical features
- Train-Test Split: 80% training, 20% testing with fixed random seed
model.save('model_name.h5')model = tf.keras.models.load_model('model_name.h5')with open('encoder_name.pkl', 'wb') as file:
pickle.dump(encoder, file)with open('encoder_name.pkl', 'rb') as file:
encoder = pickle.load(file)- Customer Retention: Identify at-risk customers and implement retention strategies
- Resource Planning: Allocate customer service resources based on churn risk
- Salary Negotiation: Estimate salary ranges for compensation planning
- Risk Assessment: Evaluate customer lifetime value and engagement
- All models use consistent preprocessing (same encoders and scaler)
- Random seed (42) ensures reproducibility
- Early stopping prevents overfitting during training
- Predictions require properly encoded and scaled inputs
- The Streamlit app handles all preprocessing automatically
- Implement cross-validation for more robust evaluation
- Experiment with different architectures (deeper networks, dropout layers)
- Add feature importance analysis
- Implement SHAP values for model interpretability
- Add batch prediction capability to the web app
- Create visualization dashboards in Streamlit
This project is a part of a deep learning .
Last Updated: February 5, 2026
Project Type: Deep Learning Classification & Regression
Framework: TensorFlow/Keras
Interface: Streamlit Web App