diff --git a/README.md b/README.md index c1e8359..1232af4 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,181 @@ -# Project 1 -Put your README here. Answer the following questions. -* What does the model you have implemented do and when should it be used? -* How did you test your model to determine if it is working reasonably correctly? -* What parameters have you exposed to users of your implementation in order to tune performance? (Also perhaps provide some basic usage examples.) -* Are there specific inputs that your implementation has trouble with? Given more time, could you work around these or is it fundamental? +# Project 1: ElasticNet + +## Overview + +This project implements an **ElasticNet** regression model from scratch, using **NumPy** for all the numerical calculations. The ElasticNet model is a type of linear regression that uses both **L1 (Lasso)** and **L2 (Ridge)** penalties. It’s especially useful when dealing with datasets that have lots of features, some of which may be irrelevant or highly correlated with others. + +Unlike prebuilt libraries like Scikit-Learn, this implementation relies on manually coded gradient descent to optimize the model’s weights. + +The `ElasticNetModel` class includes two main methods: +- `fit(X, y)`: This trains the model using the dataset. +- `predict(X)`: This makes predictions based on the trained model. + +## Requirements + +This project meets the following requirements: + +1. **Algorithm Implementation**: The ElasticNet regression is implemented from scratch, combining both L1 and L2 regularization penalties with gradient descent. +2. **From First Principles**: The model uses **NumPy** for matrix calculations, with no prebuilt machine learning libraries (like Scikit-Learn or Statsmodels). +3. **Testing the Model**: We tested the model using a custom script that runs it on a dataset and evaluates performance using metrics like Mean Squared Error (MSE), Mean Absolute Error (MAE), and R-squared (R²). We also generated visualizations to assess the model's performance. +4. **Flexible Input**: The model can handle any numerical dataset with proper preprocessing and normalization. The test script works with a provided `test.csv` dataset generated by a separate script. + + +## What does the model you have implemented do and when should it be used? + +The **ElasticNet** model I’ve implemented is a type of linear regression that combines two types of regularization: **L1 (Lasso)** and **L2 (Ridge)**. This makes it useful when you have a dataset with many features, especially when some of those features might not be important or are highly correlated with each other. + +ElasticNet helps prevent overfitting by penalizing large coefficients and can also automatically select important features by driving some coefficients to zero. It’s a good choice when you’re dealing with high-dimensional data or when you want a balance between selecting features and generalizing well. + + +## How did you test your model to determine if it is working reasonably correctly? + +we tested the model by writing a script that loads a dataset, trains the ElasticNet model on it, and then evaluates how well it performs using standard metrics like **Mean Squared Error (MSE)**, **Mean Absolute Error (MAE)**, and **R-squared (R²)**. These metrics give a good sense of how accurate the model’s predictions are compared to the actual values. + +we also generated visualizations, including a plot of actual vs. predicted values, a residuals plot, and a bar plot showing the importance of each feature (based on the learned coefficients). This helped visually confirm that the model is working as expected. + + + +## What parameters have you exposed to users of your implementation in order to tune performance? + +we’ve exposed several parameters that allow users to tweak the model’s behavior: + +1. **l1_penalty**: Controls the strength of L1 regularization, which helps with feature selection. +2. **l2_penalty**: Controls the strength of L2 regularization, which helps prevent overfitting. +3. **learning_rate**: Adjusts the step size during the training process. A smaller value means slower but more precise updates. +4. **max_iterations**: Sets how many times the model will update its weights during training. +5. **tolerance**: Decides when the training should stop by checking if the weight updates have become very small. + +Here’s an example of how to use these parameters: + +```python +from elasticnet.models.ElasticNet import ElasticNetModel + +# Example usage of ElasticNet +model = ElasticNetModel(l1_penalty=0.5, l2_penalty=0.3, learning_rate=0.01, max_iterations=5000) +X = [[1, 2], [2, 3], [3, 4], [4, 5]] +y = [3, 4, 5, 6] + +model.fit(X, y) +predictions = model.predict(X) +print(predictions) +``` + + +## Are there specific inputs that your implementation has trouble with? Given more time, could you work around these or is it fundamental? + +The model works well with numerical data, but there are a few things it doesn’t handle well: + +1. **Categorical data**: If you have non-numeric data (like "male" or "female"), you’ll need to convert these into numbers before passing them into the model. Right now, this needs to be done manually. + + With more time, I could add a feature to automatically handle categorical data. + +2. **Missing data**: The model expects all input values to be valid numbers. If there are missing or `NaN` values, they need to be filled in before training the model. + + This could be improved by adding automatic handling for missing data (e.g., filling with mean values). + + +## Model Description + +### When Should You Use ElasticNet? + +ElasticNet is particularly useful in these scenarios: +- If you want to encourage **sparsity** in your model (meaning some of the less important features will be ignored), ElasticNet’s L1 regularization helps with that. +- If you’re working with a **large number of features** or when some features are highly correlated, ElasticNet's combination of L1 and L2 regularization helps handle this better than basic linear regression or Lasso alone. +- If you're worried about **overfitting**, the L2 penalty helps keep the model generalized by preventing large coefficients. + + +## How to Run the Model + +### 1. Install Dependencies + +First, make sure you have **NumPy** and **Matplotlib** installed. You can install them using `pip`: + +pip install numpy matplotlib + + +### 2. Set Up the Environment + +Before running any scripts, make sure your Python environment is set up to find the project's modules. You can do this by setting the `PYTHONPATH` environment variable to your current directory: + +export PYTHONPATH=$PWD + +if we are using COmmand prompt use + +set PYTHONPATH=%cd% + +This tells Python where to look for the project's files. + +### 3. Generate the Dataset + +You’ll need to generate the dataset before running the model. Use the `generate_test_CSV.py` script to create the `test.csv` file: + +python generate_test_CSV.py + +This will generate a dataset that the model can use for training and testing. + +### 4. Run the Test Script + +Once the dataset has been generated, you can run the ElasticNet model by using the test script: + +python elasticnet/tests/test_ElasticNetModel.py + +This will: +- Load the `test.csv` dataset. +- Preprocess the data (if necessary, convert categorical features to numerical form). +- Train the model using `fit()`. +- Predict and evaluate the results using `predict()`. +- Generate visualizations like "Actual vs Predicted" and "Residuals." + +![alt text]() + + +## Evaluation Metrics + +We used the following metrics to evaluate how well the model performed: + +- **Mean Squared Error (MSE)**: Measures how far the predicted values are from the actual values. +- **Mean Absolute Error (MAE)**: Similar to MSE but based on the absolute differences. +- **R-squared (R²)**: This tells you how well the model fits the data. The closer to 1, the better. + +Here’s an example of the output: + +Mean Squared Error (MSE): 1.85 +Mean Absolute Error (MAE): 1.2 +R-squared (R²): 0.75 + +![alt text]() + + +## Visualizations + +The test script will generate several plots to help visualize the model’s performance: + +1. **Actual vs Predicted**: A scatter plot that compares the actual target values to the predicted values. +2. **Residuals**: A plot showing the differences between the actual and predicted values. +3. **Distribution of Target Values**: A histogram that shows the spread of the target variable. +4. **Feature Weights**: A bar chart showing the learned importance of each feature in the model. + +![alt text]() + + +## Known Limitations + +The current implementation handles numerical datasets well, but there are a few things to keep in mind: +- **Categorical Data**: If your dataset has non-numeric columns (like 'male' and 'female'), you’ll need to convert them to numbers before using the model. +- **Missing Values**: If your dataset has any missing values (`NaN`), you should handle those first by either filling them in or dropping the rows. + + +## Future Work + +With more time, the following improvements could be made: +- Automating the preprocessing of **categorical data** so that users don’t need to manually encode it. +- Adding better handling for **missing data** by automatically filling or dropping missing values. +- Implementing **cross-validation** to automatically tune the regularization parameters (l1_penalty, l2_penalty) for better performance. + + +### Group Memebers +Krishna Manideep Malladi (A20550891) +Udaya Sree Vankdavath (A20552992) +Manvitha Byrineni(A20550783) diff --git a/Screenshot 2024-10-10 194317.png b/Screenshot 2024-10-10 194317.png new file mode 100644 index 0000000..55e83d3 Binary files /dev/null and b/Screenshot 2024-10-10 194317.png differ diff --git a/Screenshot 2024-10-10 194334.png b/Screenshot 2024-10-10 194334.png new file mode 100644 index 0000000..9318e40 Binary files /dev/null and b/Screenshot 2024-10-10 194334.png differ diff --git a/Screenshot 2024-10-10 194352.png b/Screenshot 2024-10-10 194352.png new file mode 100644 index 0000000..f38bfd6 Binary files /dev/null and b/Screenshot 2024-10-10 194352.png differ diff --git a/elasticnet/models/ElasticNet.py b/elasticnet/models/ElasticNet.py index 017e925..8085351 100644 --- a/elasticnet/models/ElasticNet.py +++ b/elasticnet/models/ElasticNet.py @@ -1,17 +1,43 @@ +import numpy as np +class ElasticNetModel: + def __init__(self, l1_penalty=1.0, l2_penalty=1.0, learning_rate=0.001, max_iterations=1000, tolerance=1e-5): + self.l1_penalty = l1_penalty + self.l2_penalty = l2_penalty + self.learning_rate = learning_rate + self.max_iterations = max_iterations + self.tolerance = tolerance + self.weights = None + self.bias = None -class ElasticNetModel(): - def __init__(self): - pass + def fit(self, X, y): + num_samples, num_features = X.shape + self.weights = np.zeros(num_features) + self.bias = 0 + for i in range(self.max_iterations): + predictions = self.predict(X) + errors = predictions - y - def fit(self, X, y): - return ElasticNetModelResults() + gradient_weights = (X.T @ errors) / num_samples + self.l1_penalty * np.sign(self.weights) + self.l2_penalty * self.weights + gradient_bias = np.mean(errors) + + self.weights -= self.learning_rate * gradient_weights + self.bias -= self.learning_rate * gradient_bias + + if np.all(np.abs(gradient_weights) < self.tolerance): + break + + return ElasticNetModelResults(self.weights, self.bias) + + def predict(self, X): + return X @ self.weights + self.bias -class ElasticNetModelResults(): - def __init__(self): - pass +class ElasticNetModelResults: + def __init__(self, weights, bias): + self.weights = weights + self.bias = bias - def predict(self, x): - return 0.5 + def predict(self, X): + return X @ self.weights + self.bias diff --git a/elasticnet/tests/test.csv b/elasticnet/tests/test.csv new file mode 100644 index 0000000..d77959a --- /dev/null +++ b/elasticnet/tests/test.csv @@ -0,0 +1,602 @@ +sex,age,ym,child,religious,education,occupation,rate,nbaffairs +male,37.0,10.0,no,3,18,7,4,0 +female,27.0,4.0,no,4,14,6,4,0 +female,32.0,15.0,yes,1,12,1,4,0 +male,57.0,15.0,yes,5,18,6,5,0 +male,22.0,0.75,no,2,17,6,3,0 +female,32.0,1.5,no,2,17,5,5,0 +female,22.0,0.75,no,2,12,1,3,0 +male,57.0,15.0,yes,2,14,4,4,0 +female,32.0,15.0,yes,4,16,1,2,0 +male,22.0,1.5,no,4,14,4,5,0 +male,37.0,15.0,yes,2,20,7,2,0 +male,27.0,4.0,yes,4,18,6,4,0 +male,47.0,15.0,yes,5,17,6,4,0 +female,22.0,1.5,no,2,17,5,4,0 +female,27.0,4.0,no,4,14,5,4,0 +female,37.0,15.0,yes,1,17,5,5,0 +female,37.0,15.0,yes,2,18,4,3,0 +female,22.0,0.75,no,3,16,5,4,0 +female,22.0,1.5,no,2,16,5,5,0 +female,27.0,10.0,yes,2,14,1,5,0 +female,22.0,1.5,no,2,16,5,5,0 +female,22.0,1.5,no,2,16,5,5,0 +female,27.0,10.0,yes,4,16,5,4,0 +female,32.0,10.0,yes,3,14,1,5,0 +male,37.0,4.0,yes,2,20,6,4,0 +female,22.0,1.5,no,2,18,5,5,0 +female,27.0,7.0,no,4,16,1,5,0 +male,42.0,15.0,yes,5,20,6,4,0 +male,27.0,4.0,yes,3,16,5,5,0 +female,27.0,4.0,yes,3,17,5,4,0 +male,42.0,15.0,yes,4,20,6,3,0 +female,22.0,1.5,no,3,16,5,5,0 +male,27.0,0.417,no,4,17,6,4,0 +female,42.0,15.0,yes,5,14,5,4,0 +male,32.0,4.0,yes,1,18,6,4,0 +female,22.0,1.5,no,4,16,5,3,0 +female,42.0,15.0,yes,3,12,1,4,0 +female,22.0,4.0,no,4,17,5,5,0 +male,22.0,1.5,yes,1,14,3,5,0 +female,22.0,0.75,no,3,16,1,5,0 +male,32.0,10.0,yes,5,20,6,5,0 +male,52.0,15.0,yes,5,18,6,3,0 +female,22.0,0.417,no,5,14,1,4,0 +female,27.0,4.0,yes,2,18,6,1,0 +female,32.0,7.0,yes,5,17,5,3,0 +male,22.0,4.0,no,3,16,5,5,0 +female,27.0,7.0,yes,4,18,6,5,0 +female,42.0,15.0,yes,2,18,5,4,0 +male,27.0,1.5,yes,4,16,3,5,0 +male,42.0,15.0,yes,2,20,6,4,0 +female,22.0,0.75,no,5,14,3,5,0 +male,32.0,7.0,yes,2,20,6,4,0 +male,27.0,4.0,yes,5,20,6,5,0 +male,27.0,10.0,yes,4,20,6,4,0 +male,22.0,4.0,no,1,18,5,5,0 +female,37.0,15.0,yes,4,14,3,1,0 +male,22.0,1.5,yes,5,16,4,4,0 +female,37.0,15.0,yes,4,17,1,5,0 +female,27.0,0.75,no,4,17,5,4,0 +male,32.0,10.0,yes,4,20,6,4,0 +female,47.0,15.0,yes,5,14,7,2,0 +male,37.0,10.0,yes,3,20,6,4,0 +female,22.0,0.75,no,2,16,5,5,0 +male,27.0,4.0,no,2,18,4,5,0 +male,32.0,7.0,no,4,20,6,4,0 +male,42.0,15.0,yes,2,17,3,5,0 +male,37.0,10.0,yes,4,20,6,4,0 +female,47.0,15.0,yes,3,17,6,5,0 +female,22.0,1.5,no,5,16,5,5,0 +female,27.0,1.5,no,2,16,6,4,0 +female,27.0,4.0,no,3,17,5,5,0 +female,32.0,10.0,yes,5,14,4,5,0 +female,22.0,0.125,no,2,12,5,5,0 +male,47.0,15.0,yes,4,14,4,3,0 +male,32.0,15.0,yes,1,14,5,5,0 +male,27.0,7.0,yes,4,16,5,5,0 +female,22.0,1.5,yes,3,16,5,5,0 +male,27.0,4.0,yes,3,17,6,5,0 +female,22.0,1.5,no,3,16,5,5,0 +male,57.0,15.0,yes,2,14,7,2,0 +male,17.5,1.5,yes,3,18,6,5,0 +male,57.0,15.0,yes,4,20,6,5,0 +female,22.0,0.75,no,2,16,3,4,0 +male,42.0,4.0,no,4,17,3,3,0 +female,22.0,1.5,yes,4,12,1,5,0 +female,22.0,0.417,no,1,17,6,4,0 +female,32.0,15.0,yes,4,17,5,5,0 +female,27.0,1.5,no,3,18,5,2,0 +female,22.0,1.5,yes,3,14,1,5,0 +female,37.0,15.0,yes,3,14,1,4,0 +female,32.0,15.0,yes,4,14,3,4,0 +male,37.0,10.0,yes,2,14,5,3,0 +male,37.0,10.0,yes,4,16,5,4,0 +male,57.0,15.0,yes,5,20,5,3,0 +male,27.0,0.417,no,1,16,3,4,0 +female,42.0,15.0,yes,5,14,1,5,0 +male,57.0,15.0,yes,3,16,6,1,0 +male,37.0,10.0,yes,1,16,6,4,0 +male,37.0,15.0,yes,3,17,5,5,0 +male,37.0,15.0,yes,4,20,6,5,0 +female,27.0,10.0,yes,5,14,1,5,0 +male,37.0,10.0,yes,2,18,6,4,0 +female,22.0,0.125,no,4,12,4,5,0 +male,57.0,15.0,yes,5,20,6,5,0 +female,37.0,15.0,yes,4,18,6,4,0 +male,22.0,4.0,yes,4,14,6,4,0 +male,27.0,7.0,yes,4,18,5,4,0 +male,57.0,15.0,yes,4,20,5,4,0 +male,32.0,15.0,yes,3,14,6,3,0 +female,22.0,1.5,no,2,14,5,4,0 +female,32.0,7.0,yes,4,17,1,5,0 +female,37.0,15.0,yes,4,17,6,5,0 +female,32.0,1.5,no,5,18,5,5,0 +male,42.0,10.0,yes,5,20,7,4,0 +female,27.0,7.0,no,3,16,5,4,0 +male,37.0,15.0,no,4,20,6,5,0 +male,37.0,15.0,yes,4,14,3,2,0 +male,32.0,10.0,no,5,18,6,4,0 +female,22.0,0.75,no,4,16,1,5,0 +female,27.0,7.0,yes,4,12,2,4,0 +female,27.0,7.0,yes,2,16,2,5,0 +female,42.0,15.0,yes,5,18,5,4,0 +male,42.0,15.0,yes,4,17,5,3,0 +female,27.0,7.0,yes,2,16,1,2,0 +female,22.0,1.5,no,3,16,5,5,0 +male,37.0,15.0,yes,5,20,6,5,0 +female,22.0,0.125,no,2,14,4,5,0 +male,27.0,1.5,no,4,16,5,5,0 +male,32.0,1.5,no,2,18,6,5,0 +male,27.0,1.5,no,2,17,6,5,0 +female,27.0,10.0,yes,4,16,1,3,0 +male,42.0,15.0,yes,4,18,6,5,0 +female,27.0,1.5,no,2,16,6,5,0 +male,27.0,4.0,no,2,18,6,3,0 +female,32.0,10.0,yes,3,14,5,3,0 +female,32.0,15.0,yes,3,18,5,4,0 +female,22.0,0.75,no,2,18,6,5,0 +female,37.0,15.0,yes,2,16,1,4,0 +male,27.0,4.0,yes,4,20,5,5,0 +male,27.0,4.0,no,1,20,5,4,0 +female,27.0,10.0,yes,2,12,1,4,0 +female,32.0,15.0,yes,5,18,6,4,0 +male,27.0,7.0,yes,5,12,5,3,0 +male,52.0,15.0,yes,2,18,5,4,0 +male,27.0,4.0,no,3,20,6,3,0 +male,37.0,4.0,yes,1,18,5,4,0 +male,27.0,4.0,yes,4,14,5,4,0 +female,52.0,15.0,yes,5,12,1,3,0 +female,57.0,15.0,yes,4,16,6,4,0 +male,27.0,7.0,yes,1,16,5,4,0 +male,37.0,7.0,yes,4,20,6,3,0 +male,22.0,0.75,no,2,14,4,3,0 +male,32.0,4.0,yes,2,18,5,3,0 +male,37.0,15.0,yes,4,20,6,3,0 +male,22.0,0.75,yes,2,14,4,3,0 +male,42.0,15.0,yes,4,20,6,3,0 +female,52.0,15.0,yes,5,17,1,1,0 +female,37.0,15.0,yes,4,14,1,2,0 +male,27.0,7.0,yes,4,14,5,3,0 +male,32.0,4.0,yes,2,16,5,5,0 +female,27.0,4.0,yes,2,18,6,5,0 +female,27.0,4.0,yes,2,18,5,5,0 +male,37.0,15.0,yes,5,18,6,5,0 +female,47.0,15.0,yes,5,12,5,4,0 +female,32.0,10.0,yes,3,17,1,4,0 +female,27.0,1.5,yes,4,17,1,2,0 +female,57.0,15.0,yes,2,18,5,2,0 +female,22.0,1.5,no,4,14,5,4,0 +male,42.0,15.0,yes,3,14,3,4,0 +male,57.0,15.0,yes,4,9,2,2,0 +male,57.0,15.0,yes,4,20,6,5,0 +female,22.0,0.125,no,4,14,4,5,0 +female,32.0,10.0,yes,4,14,1,5,0 +female,42.0,15.0,yes,3,18,5,4,0 +female,27.0,1.5,no,2,18,6,5,0 +male,32.0,0.125,yes,2,18,5,2,0 +female,27.0,4.0,no,3,16,5,4,0 +female,27.0,10.0,yes,2,16,1,4,0 +female,32.0,7.0,yes,4,16,1,3,0 +female,37.0,15.0,yes,4,14,5,4,0 +female,42.0,15.0,yes,5,17,6,2,0 +male,32.0,1.5,yes,4,14,6,5,0 +female,32.0,4.0,yes,3,17,5,3,0 +female,37.0,7.0,no,4,18,5,5,0 +female,22.0,0.417,yes,3,14,3,5,0 +female,27.0,7.0,yes,4,14,1,5,0 +male,27.0,0.75,no,3,16,5,5,0 +male,27.0,4.0,yes,2,20,5,5,0 +male,32.0,10.0,yes,4,16,4,5,0 +male,32.0,15.0,yes,1,14,5,5,0 +male,22.0,0.75,no,3,17,4,5,0 +female,27.0,7.0,yes,4,17,1,4,0 +male,27.0,0.417,yes,4,20,5,4,0 +male,37.0,15.0,yes,4,20,5,4,0 +female,37.0,15.0,yes,2,14,1,3,0 +male,22.0,4.0,yes,1,18,5,4,0 +male,37.0,15.0,yes,4,17,5,3,0 +female,22.0,1.5,no,2,14,4,5,0 +male,52.0,15.0,yes,4,14,6,2,0 +female,22.0,1.5,no,4,17,5,5,0 +male,32.0,4.0,yes,5,14,3,5,0 +male,32.0,4.0,yes,2,14,3,5,0 +female,22.0,1.5,no,3,16,6,5,0 +male,27.0,0.75,no,2,18,3,3,0 +female,22.0,7.0,yes,2,14,5,2,0 +female,27.0,0.75,no,2,17,5,3,0 +female,37.0,15.0,yes,4,12,1,2,0 +female,22.0,1.5,no,1,14,1,5,0 +female,37.0,10.0,no,2,12,4,4,0 +female,37.0,15.0,yes,4,18,5,3,0 +female,42.0,15.0,yes,3,12,3,3,0 +male,22.0,4.0,no,2,18,5,5,0 +male,52.0,7.0,yes,2,20,6,2,0 +male,27.0,0.75,no,2,17,5,5,0 +female,27.0,4.0,no,2,17,4,5,0 +male,42.0,1.5,no,5,20,6,5,0 +male,22.0,1.5,no,4,17,6,5,0 +male,22.0,4.0,no,4,17,5,3,0 +female,22.0,4.0,yes,1,14,5,4,0 +male,37.0,15.0,yes,5,20,4,5,0 +female,37.0,10.0,yes,3,16,6,3,0 +male,42.0,15.0,yes,4,17,6,5,0 +female,47.0,15.0,yes,4,17,5,5,0 +male,22.0,1.5,no,4,16,5,4,0 +female,32.0,10.0,yes,3,12,1,4,0 +female,22.0,7.0,yes,1,14,3,5,0 +female,32.0,10.0,yes,4,17,5,4,0 +male,27.0,1.5,yes,2,16,2,4,0 +male,37.0,15.0,yes,4,14,5,5,0 +male,42.0,4.0,yes,3,14,4,5,0 +female,37.0,15.0,yes,5,14,5,4,0 +female,32.0,7.0,yes,4,17,5,5,0 +female,42.0,15.0,yes,4,18,6,5,0 +male,27.0,4.0,no,4,18,6,4,0 +male,22.0,0.75,no,4,18,6,5,0 +male,27.0,4.0,yes,4,14,5,3,0 +female,22.0,0.75,no,5,18,1,5,0 +female,52.0,15.0,yes,5,9,5,5,0 +male,32.0,10.0,yes,3,14,5,5,0 +female,37.0,15.0,yes,4,16,4,4,0 +male,32.0,7.0,yes,2,20,5,4,0 +female,42.0,15.0,yes,3,18,1,4,0 +male,32.0,15.0,yes,1,16,5,5,0 +male,27.0,4.0,yes,3,18,5,5,0 +female,32.0,15.0,yes,4,12,3,4,0 +male,22.0,0.75,yes,3,14,2,4,0 +female,22.0,1.5,no,3,16,5,3,0 +female,42.0,15.0,yes,4,14,3,5,0 +female,52.0,15.0,yes,3,16,5,4,0 +male,37.0,15.0,yes,5,20,6,4,0 +female,47.0,15.0,yes,4,12,2,3,0 +male,57.0,15.0,yes,2,20,6,4,0 +male,32.0,7.0,yes,4,17,5,5,0 +female,27.0,7.0,yes,4,17,1,4,0 +male,22.0,1.5,no,1,18,6,5,0 +female,22.0,4.0,yes,3,9,1,4,0 +female,22.0,1.5,no,2,14,1,5,0 +male,42.0,15.0,yes,2,20,6,4,0 +male,57.0,15.0,yes,4,9,2,4,0 +female,27.0,7.0,yes,2,18,1,5,0 +female,22.0,4.0,yes,3,14,1,5,0 +male,37.0,15.0,yes,4,14,5,3,0 +male,32.0,7.0,yes,1,18,6,4,0 +female,22.0,1.5,no,2,14,5,5,0 +female,22.0,1.5,yes,3,12,1,3,0 +male,52.0,15.0,yes,2,14,5,5,0 +female,37.0,15.0,yes,2,14,1,1,0 +female,32.0,10.0,yes,2,14,5,5,0 +male,42.0,15.0,yes,4,20,4,5,0 +female,27.0,4.0,yes,3,18,4,5,0 +male,37.0,15.0,yes,4,20,6,5,0 +male,27.0,1.5,no,3,18,5,5,0 +female,22.0,0.125,no,2,16,6,3,0 +male,32.0,10.0,yes,2,20,6,3,0 +female,27.0,4.0,no,4,18,5,4,0 +female,27.0,7.0,yes,2,12,5,1,0 +male,32.0,4.0,yes,5,18,6,3,0 +female,37.0,15.0,yes,2,17,5,5,0 +male,47.0,15.0,no,4,20,6,4,0 +male,27.0,1.5,no,1,18,5,5,0 +male,37.0,15.0,yes,4,20,6,4,0 +female,32.0,15.0,yes,4,18,1,4,0 +female,32.0,7.0,yes,4,17,5,4,0 +female,42.0,15.0,yes,3,14,1,3,0 +female,27.0,7.0,yes,3,16,1,4,0 +male,27.0,1.5,no,3,16,4,2,0 +male,22.0,1.5,no,3,16,3,5,0 +male,27.0,4.0,yes,3,16,4,2,0 +female,27.0,7.0,yes,3,12,1,2,0 +female,37.0,15.0,yes,2,18,5,4,0 +female,37.0,7.0,yes,3,14,4,4,0 +male,22.0,1.5,no,2,16,5,5,0 +male,37.0,15.0,yes,5,20,5,4,0 +female,22.0,1.5,no,4,16,5,3,0 +female,32.0,10.0,yes,4,16,1,5,0 +male,27.0,4.0,no,2,17,5,3,0 +female,22.0,0.417,no,4,14,5,5,0 +female,27.0,4.0,no,2,18,5,5,0 +male,37.0,15.0,yes,4,18,5,3,0 +male,37.0,10.0,yes,5,20,7,4,0 +female,27.0,7.0,yes,2,14,4,2,0 +male,32.0,4.0,yes,2,16,5,5,0 +male,32.0,4.0,yes,2,16,6,4,0 +male,22.0,1.5,no,3,18,4,5,0 +female,22.0,4.0,yes,4,14,3,4,0 +female,17.5,0.75,no,2,18,5,4,0 +male,32.0,10.0,yes,4,20,4,5,0 +female,32.0,0.75,no,5,14,3,3,0 +male,37.0,15.0,yes,4,17,5,3,0 +male,32.0,4.0,no,3,14,4,5,0 +female,27.0,1.5,no,2,17,3,2,0 +female,22.0,7.0,yes,4,14,1,5,0 +male,47.0,15.0,yes,5,14,6,5,0 +male,27.0,4.0,yes,1,16,4,4,0 +female,37.0,15.0,yes,5,14,1,3,0 +male,42.0,4.0,yes,4,18,5,5,0 +female,32.0,4.0,yes,2,14,1,5,0 +male,52.0,15.0,yes,2,14,7,4,0 +female,22.0,1.5,no,2,16,1,4,0 +male,52.0,15.0,yes,4,12,2,4,0 +female,22.0,0.417,no,3,17,1,5,0 +female,22.0,1.5,no,2,16,5,5,0 +male,27.0,4.0,yes,4,20,6,4,0 +female,32.0,15.0,yes,4,14,1,5,0 +female,27.0,1.5,no,2,16,3,5,0 +male,32.0,4.0,no,1,20,6,5,0 +male,37.0,15.0,yes,3,20,6,4,0 +female,32.0,10.0,no,2,16,6,5,0 +female,32.0,10.0,yes,5,14,5,5,0 +male,37.0,1.5,yes,4,18,5,3,0 +male,32.0,1.5,no,2,18,4,4,0 +female,32.0,10.0,yes,4,14,1,4,0 +female,47.0,15.0,yes,4,18,5,4,0 +female,27.0,10.0,yes,5,12,1,5,0 +male,27.0,4.0,yes,3,16,4,5,0 +female,37.0,15.0,yes,4,12,4,2,0 +female,27.0,0.75,no,4,16,5,5,0 +female,37.0,15.0,yes,4,16,1,5,0 +female,32.0,15.0,yes,3,16,1,5,0 +female,27.0,10.0,yes,2,16,1,5,0 +male,27.0,7.0,no,2,20,6,5,0 +female,37.0,15.0,yes,2,14,1,3,0 +male,27.0,1.5,yes,2,17,4,4,0 +female,22.0,0.75,yes,2,14,1,5,0 +male,22.0,4.0,yes,4,14,2,4,0 +male,42.0,0.125,no,4,17,6,4,0 +male,27.0,1.5,yes,4,18,6,5,0 +male,27.0,7.0,yes,3,16,6,3,0 +female,52.0,15.0,yes,4,14,1,3,0 +male,27.0,1.5,no,5,20,5,2,0 +female,27.0,1.5,no,2,16,5,5,0 +female,27.0,1.5,no,3,17,5,5,0 +male,22.0,0.125,no,5,16,4,4,0 +female,27.0,4.0,yes,4,16,1,5,0 +female,27.0,4.0,yes,4,12,1,5,0 +female,47.0,15.0,yes,2,14,5,5,0 +female,32.0,15.0,yes,3,14,5,3,0 +male,42.0,7.0,yes,2,16,5,5,0 +male,22.0,0.75,no,4,16,6,4,0 +male,27.0,0.125,no,3,20,6,5,0 +male,32.0,10.0,yes,3,20,6,5,0 +female,22.0,0.417,no,5,14,4,5,0 +female,47.0,15.0,yes,5,14,1,4,0 +female,32.0,10.0,yes,3,14,1,5,0 +male,57.0,15.0,yes,4,17,5,5,0 +male,27.0,4.0,yes,3,20,6,5,0 +female,32.0,7.0,yes,4,17,1,5,0 +female,37.0,10.0,yes,4,16,1,5,0 +female,32.0,10.0,yes,1,18,1,4,0 +female,22.0,4.0,no,3,14,1,4,0 +female,27.0,7.0,yes,4,14,3,2,0 +male,57.0,15.0,yes,5,18,5,2,0 +male,32.0,7.0,yes,2,18,5,5,0 +female,27.0,1.5,no,4,17,1,3,0 +male,22.0,1.5,no,4,14,5,5,0 +female,22.0,1.5,yes,4,14,5,4,0 +female,32.0,7.0,yes,3,16,1,5,0 +female,47.0,15.0,yes,3,16,5,4,0 +female,22.0,0.75,no,3,16,1,5,0 +female,22.0,1.5,yes,2,14,5,5,0 +female,27.0,4.0,yes,1,16,5,5,0 +male,52.0,15.0,yes,4,16,5,5,0 +male,32.0,10.0,yes,4,20,6,5,0 +male,47.0,15.0,yes,4,16,6,4,0 +female,27.0,7.0,yes,2,14,1,2,0 +female,22.0,1.5,no,4,14,4,5,0 +female,32.0,10.0,yes,2,16,5,4,0 +female,22.0,0.75,no,2,16,5,4,0 +female,22.0,1.5,no,2,16,5,5,0 +female,42.0,15.0,yes,3,18,6,4,0 +female,27.0,7.0,yes,5,14,4,5,0 +male,42.0,15.0,yes,4,16,4,4,0 +female,57.0,15.0,yes,3,18,5,2,0 +male,42.0,15.0,yes,3,18,6,2,0 +female,32.0,7.0,yes,2,14,1,2,0 +male,22.0,4.0,no,5,12,4,5,0 +female,22.0,1.5,no,1,16,6,5,0 +female,22.0,0.75,no,1,14,4,5,0 +female,32.0,15.0,yes,4,12,1,5,0 +male,22.0,1.5,no,2,18,5,3,0 +male,27.0,4.0,yes,5,17,2,5,0 +female,27.0,4.0,yes,4,12,1,5,0 +male,42.0,15.0,yes,5,18,5,4,0 +male,32.0,1.5,no,2,20,7,3,0 +male,57.0,15.0,no,4,9,3,1,0 +male,37.0,7.0,no,4,18,5,5,0 +male,52.0,15.0,yes,2,17,5,4,0 +male,47.0,15.0,yes,4,17,6,5,0 +female,27.0,7.0,no,2,17,5,4,0 +female,27.0,7.0,yes,4,14,5,5,0 +female,22.0,4.0,no,2,14,3,3,0 +male,37.0,7.0,yes,2,20,6,5,0 +male,27.0,7.0,no,4,12,4,3,0 +male,42.0,10.0,yes,4,18,6,4,0 +female,22.0,1.5,no,3,14,1,5,0 +female,22.0,4.0,yes,2,14,1,3,0 +female,57.0,15.0,no,4,20,6,5,0 +male,37.0,15.0,yes,4,14,4,3,0 +female,27.0,7.0,yes,3,18,5,5,0 +female,17.5,10.0,no,4,14,4,5,0 +male,22.0,4.0,yes,4,16,5,5,0 +female,27.0,4.0,yes,2,16,1,4,0 +female,37.0,15.0,yes,2,14,5,1,0 +female,22.0,1.5,no,5,14,1,4,0 +male,27.0,7.0,yes,2,20,5,4,0 +male,27.0,4.0,yes,4,14,5,5,0 +male,22.0,0.125,no,1,16,3,5,0 +female,27.0,7.0,yes,4,14,1,4,0 +female,32.0,15.0,yes,5,16,5,3,0 +male,32.0,10.0,yes,4,18,5,4,0 +female,32.0,15.0,yes,2,14,3,4,0 +female,22.0,1.5,no,3,17,5,5,0 +male,27.0,4.0,yes,4,17,4,4,0 +female,52.0,15.0,yes,5,14,1,5,0 +female,27.0,7.0,yes,2,12,1,2,0 +female,27.0,7.0,yes,3,12,1,4,0 +female,42.0,15.0,yes,2,14,1,4,0 +female,42.0,15.0,yes,4,14,5,4,0 +male,27.0,7.0,yes,4,14,3,3,0 +male,27.0,7.0,yes,2,20,6,2,0 +female,42.0,15.0,yes,3,12,3,3,0 +male,27.0,4.0,yes,3,16,3,5,0 +female,27.0,7.0,yes,3,14,1,4,0 +female,22.0,1.5,no,2,14,4,5,0 +female,27.0,4.0,yes,4,14,1,4,0 +female,22.0,4.0,no,4,14,5,5,0 +female,22.0,1.5,no,2,16,4,5,0 +male,47.0,15.0,no,4,14,5,4,0 +male,37.0,10.0,yes,2,18,6,2,0 +male,37.0,15.0,yes,3,17,5,4,0 +female,27.0,4.0,yes,2,16,1,4,0 +male,27.0,1.5,no,3,18,4,4,3 +female,27.0,4.0,yes,3,17,1,5,3 +male,37.0,15.0,yes,5,18,6,2,7 +female,32.0,10.0,yes,3,17,5,2,12 +male,22.0,0.125,no,4,16,5,5,1 +female,22.0,1.5,yes,2,14,1,5,1 +male,37.0,15.0,yes,4,14,5,2,12 +female,22.0,1.5,no,2,14,3,4,7 +male,37.0,15.0,yes,2,18,6,4,2 +female,32.0,15.0,yes,4,12,3,2,3 +female,37.0,15.0,yes,4,14,4,2,1 +female,42.0,15.0,yes,3,17,1,4,7 +female,42.0,15.0,yes,5,9,4,1,12 +male,37.0,10.0,yes,2,20,6,2,12 +female,32.0,15.0,yes,3,14,1,2,12 +male,27.0,4.0,no,1,18,6,5,3 +male,37.0,10.0,yes,2,18,7,3,7 +female,27.0,4.0,no,3,17,5,5,7 +male,42.0,15.0,yes,4,16,5,5,1 +female,47.0,15.0,yes,5,14,4,5,1 +female,27.0,4.0,yes,3,18,5,4,7 +female,27.0,7.0,yes,5,14,1,4,1 +male,27.0,1.5,yes,3,17,5,4,12 +female,27.0,7.0,yes,4,14,6,2,12 +female,42.0,15.0,yes,4,16,5,4,3 +female,27.0,10.0,yes,4,12,7,3,7 +male,27.0,1.5,no,2,18,5,2,1 +male,32.0,4.0,no,4,20,6,4,1 +female,27.0,7.0,yes,3,14,1,3,1 +female,32.0,10.0,yes,4,14,1,4,3 +male,27.0,4.0,yes,2,18,7,2,3 +female,17.5,0.75,no,5,14,4,5,1 +female,32.0,10.0,yes,4,18,1,5,1 +female,32.0,7.0,yes,2,17,6,4,7 +male,37.0,15.0,yes,2,20,6,4,7 +female,37.0,10.0,no,1,20,5,3,7 +female,32.0,10.0,yes,2,16,5,5,12 +male,52.0,15.0,yes,2,20,6,4,7 +female,42.0,15.0,yes,1,12,1,3,7 +male,52.0,15.0,yes,2,20,6,3,1 +male,37.0,15.0,yes,3,18,6,5,2 +female,22.0,4.0,no,3,12,3,4,12 +male,27.0,7.0,yes,1,18,6,2,12 +male,27.0,4.0,yes,3,18,5,5,1 +male,47.0,15.0,yes,4,17,6,5,12 +female,42.0,15.0,yes,4,12,1,1,12 +male,27.0,4.0,no,3,14,3,4,7 +female,32.0,7.0,yes,4,18,4,5,7 +male,32.0,0.417,yes,3,12,3,4,1 +male,47.0,15.0,yes,5,16,5,4,3 +male,37.0,15.0,yes,2,20,5,4,12 +male,22.0,4.0,yes,2,17,6,4,7 +male,27.0,4.0,no,2,14,4,5,1 +female,52.0,15.0,yes,5,16,1,3,7 +male,27.0,4.0,no,3,14,3,3,1 +female,27.0,10.0,yes,4,16,1,4,1 +male,32.0,7.0,yes,3,14,7,4,1 +male,32.0,7.0,yes,2,18,4,1,7 +male,22.0,1.5,no,1,14,3,2,3 +male,22.0,4.0,yes,3,18,6,4,7 +male,42.0,15.0,yes,4,20,6,4,7 +female,57.0,15.0,yes,1,18,5,4,2 +female,32.0,4.0,yes,3,18,5,2,7 +male,27.0,4.0,yes,1,16,4,4,1 +male,32.0,7.0,yes,4,16,1,4,7 +male,57.0,15.0,yes,1,17,4,4,2 +female,42.0,15.0,yes,4,14,5,2,7 +male,37.0,10.0,yes,1,18,5,3,7 +male,42.0,15.0,yes,3,17,6,1,3 +female,52.0,15.0,yes,3,14,4,4,1 +female,27.0,7.0,yes,3,17,5,3,2 +male,32.0,7.0,yes,2,12,4,2,12 +male,22.0,4.0,no,4,14,2,5,1 +male,27.0,7.0,yes,3,18,6,4,3 +female,37.0,15.0,yes,1,18,5,5,12 +female,32.0,15.0,yes,3,17,1,3,7 +female,27.0,7.0,no,2,17,5,5,7 +female,32.0,7.0,yes,3,17,5,3,1 +male,32.0,1.5,yes,2,14,2,4,1 +female,42.0,15.0,yes,4,14,1,2,12 +male,32.0,10.0,yes,3,14,5,4,7 +male,37.0,4.0,yes,1,20,6,3,7 +female,27.0,4.0,yes,2,16,5,3,1 +female,42.0,15.0,yes,3,14,4,3,12 +male,27.0,10.0,yes,5,20,6,5,1 +male,37.0,10.0,yes,2,20,6,2,12 +female,27.0,7.0,yes,1,14,3,3,12 +female,27.0,7.0,yes,4,12,1,2,3 +male,32.0,10.0,yes,2,14,4,4,3 +female,17.5,0.75,yes,2,12,1,3,12 +female,32.0,15.0,yes,3,18,5,4,12 +female,22.0,7.0,no,4,14,4,3,2 +male,32.0,7.0,yes,4,20,6,5,1 +male,27.0,4.0,yes,2,18,6,2,7 +female,22.0,1.5,yes,5,14,5,3,1 +female,32.0,15.0,no,3,17,5,1,12 +female,42.0,15.0,yes,2,12,1,2,12 +male,42.0,15.0,yes,3,20,5,4,7 +male,32.0,10.0,no,2,18,4,2,12 +female,32.0,15.0,yes,3,9,1,1,12 +male,57.0,15.0,yes,5,20,4,5,7 +male,47.0,15.0,yes,4,20,6,4,12 +female,42.0,15.0,yes,2,17,6,3,2 +male,37.0,15.0,yes,3,17,6,3,12 +male,37.0,15.0,yes,5,17,5,2,12 +male,27.0,10.0,yes,2,20,6,4,7 +male,37.0,15.0,yes,2,16,5,4,2 +female,32.0,15.0,yes,1,14,5,2,12 +male,32.0,10.0,yes,3,17,6,3,7 +male,37.0,15.0,yes,4,18,5,1,2 +female,27.0,1.5,no,2,17,5,5,7 +female,47.0,15.0,yes,2,17,5,2,3 +male,37.0,15.0,yes,2,17,5,4,12 +female,27.0,4.0,no,2,14,5,5,12 +female,27.0,10.0,yes,4,14,1,5,2 +female,22.0,4.0,yes,3,16,1,3,1 +male,52.0,7.0,no,4,16,5,5,12 +female,27.0,4.0,yes,1,16,3,5,2 +female,37.0,15.0,yes,2,17,6,4,7 +female,27.0,4.0,no,1,17,3,1,2 +female,17.5,0.75,yes,2,12,3,5,12 +female,32.0,15.0,yes,5,18,5,4,7 +female,22.0,4.0,no,1,16,3,5,7 +male,32.0,4.0,yes,4,18,6,4,2 +female,22.0,1.5,yes,3,18,5,2,1 +female,42.0,15.0,yes,2,17,5,4,3 +male,32.0,7.0,yes,4,16,4,4,1 +male,37.0,15.0,no,3,14,6,2,12 +male,42.0,15.0,yes,3,16,6,3,1 +male,27.0,4.0,yes,1,18,5,4,1 +male,37.0,15.0,yes,4,20,7,3,2 +male,37.0,15.0,yes,3,20,6,4,7 +male,22.0,1.5,no,2,12,3,3,3 +male,32.0,4.0,yes,3,20,6,2,3 +male,32.0,15.0,yes,5,20,6,5,2 +female,52.0,15.0,yes,1,18,5,5,12 +male,47.0,15.0,no,1,18,6,5,12 +female,32.0,15.0,yes,4,16,4,4,3 +female,32.0,15.0,yes,3,14,3,2,7 +female,27.0,7.0,yes,4,16,1,2,7 +male,42.0,15.0,yes,3,18,6,2,12 +female,42.0,15.0,yes,2,14,3,2,7 +male,27.0,7.0,yes,2,17,5,4,12 +male,32.0,10.0,yes,4,14,4,3,3 +male,47.0,15.0,yes,3,16,4,2,7 +male,22.0,1.5,yes,1,12,2,5,1 +female,32.0,10.0,yes,2,18,5,4,7 +male,32.0,10.0,yes,2,17,6,5,2 +male,22.0,7.0,yes,3,18,6,2,2 +female,32.0,15.0,yes,3,14,1,5,1 diff --git a/elasticnet/tests/test_ElasticNetModel.py b/elasticnet/tests/test_ElasticNetModel.py index 5022c3c..85051d3 100644 --- a/elasticnet/tests/test_ElasticNetModel.py +++ b/elasticnet/tests/test_ElasticNetModel.py @@ -1,19 +1,89 @@ import csv +import numpy as np +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +from elasticnet.models.ElasticNet import ElasticNetModel +import matplotlib.pyplot as plt -import numpy +def preprocess_data(data): + # Convert 'sex' and 'child' columns to numeric + for row in data: + row['sex'] = 0 if row['sex'] == 'male' else 1 + row['child'] = 0 if row['child'] == 'no' else 1 -from elasticnet.models.ElasticNet import ElasticNetModel + # Prepare X (features) and y (target 'rate') + X = np.array([[float(row[k]) for k in row.keys() if k != 'rate'] for row in data]) + y = np.array([float(row['rate']) for row in data]) + + return X, y def test_predict(): model = ElasticNetModel() data = [] - with open("small_test.csv", "r") as file: + + # Load data from test.csv + with open("elasticnet/tests/test.csv", "r") as file: reader = csv.DictReader(file) for row in reader: data.append(row) - X = numpy.array([[v for k,v in datum.items() if k.startswith('x')] for datum in data]) - y = numpy.array([[v for k,v in datum.items() if k=='y'] for datum in data]) - results = model.fit(X,y) + # Preprocess data to handle categorical values + X, y = preprocess_data(data) + + # Print the fields (column names) being used for training + print("Fields used for training:", list(data[0].keys())[:-1]) # Excluding the target ('rate') + + # Fit the model and make predictions + results = model.fit(X, y) preds = results.predict(X) - assert preds == 0.5 + + # Print all training data and predictions + print("\n---- Training Data (all rows) ----") + print(X) + print("\n---- Predicted Values (all rows) ----") + print(preds) + + # Calculate and print evaluation metrics + mse = mean_squared_error(y, preds) + mae = mean_absolute_error(y, preds) + r2 = r2_score(y, preds) + + print("\n---- Evaluation Metrics ----") + print(f"Mean Squared Error (MSE): {mse}") + print(f"Mean Absolute Error (MAE): {mae}") + print(f"R-squared (R²): {r2}") + + # Visualizations + fig, axs = plt.subplots(2, 2, figsize=(16, 12)) + + # Actual vs Predicted + axs[0, 0].scatter(y, preds, color='blue', alpha=0.6) + axs[0, 0].plot([min(y), max(y)], [min(y), max(y)], color='red', linewidth=2) + axs[0, 0].set_title("Actual vs Predicted") + axs[0, 0].set_xlabel('Actual') + axs[0, 0].set_ylabel('Predicted') + + # Residuals + residuals = y - preds + axs[0, 1].scatter(preds, residuals, color='purple', alpha=0.6) + axs[0, 1].axhline(0, color='red', linestyle='--', linewidth=2) + axs[0, 1].set_title("Residuals") + axs[0, 1].set_xlabel('Predicted') + axs[0, 1].set_ylabel('Residuals') + + # Distribution of Target Values + axs[1, 0].hist(y, bins=50, color='green', alpha=0.7) + axs[1, 0].set_title("Distribution of 'rate' (Target)") + axs[1, 0].set_xlabel('rate') + axs[1, 0].set_ylabel('Frequency') + + # Feature Weights + axs[1, 1].bar(range(len(model.weights)), model.weights, color='orange') + axs[1, 1].set_title("Feature Weights") + axs[1, 1].set_xlabel('Feature Index') + axs[1, 1].set_ylabel('Weight') + + plt.tight_layout(pad=5.0) + plt.show() + +if __name__ == "__main__": + test_predict() diff --git a/generate_test_CSV.py b/generate_test_CSV.py new file mode 100644 index 0000000..40a14c1 --- /dev/null +++ b/generate_test_CSV.py @@ -0,0 +1,12 @@ +import pandas as pd +from pydataset import data + +# Load the 'Fair' dataset +fair_data = data('Fair') + +# Save it as 'test.csv' inside the 'elasticnet/test/' directory +fair_data.to_csv('elasticnet/tests/test.csv', index=False) + +print("test.csv has been generated and saved in elasticnet/test/") + +