Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 84 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,86 @@
# Project 1
# Project 1 - ElasticNet Model

Put your README here. Answer the following questions.
Name: Gurjot Singh Kalsi CWID: A20550984

* What does the model you have implemented do and when should it be used?
* How did you test your model to determine if it is working reasonably correctly?
* What parameters have you exposed to users of your implementation in order to tune performance? (Also perhaps provide some basic usage examples.)
* Are there specific inputs that your implementation has trouble with? Given more time, could you work around these or is it fundamental?
Name: Siva Vamsi Kolli CWID: A20560901

Name: Sai Teja Reddy Janga CWID: A20554588

This project implements an ElasticNet regression model, a linear regression method that combines both L1 (Lasso) and L2 (Ridge) penalties. The model is particularly useful when dealing with datasets that have multicollinearity (high correlation among features) and when you need to perform feature selection by shrinking some coefficients to zero.

## 1. What does the model do and when should it be used?

### What:
The ElasticNet regression model estimates the relationships between a dependent variable (target) and multiple independent variables (features) while applying regularization to prevent overfitting. The model combines the strengths of Lasso and Ridge regression:
- **Lasso** (L1) for feature selection, as it can shrink coefficients of some features to zero.
- **Ridge** (L2) for minimizing the impact of multicollinearity and stabilizing the model.

### When to Use:
ElasticNet is best used in scenarios where:
- You have a large number of features, many of which might be irrelevant or redundant.
- There is high multicollinearity (i.e., correlation between input features).
- You want a model that can perform both feature selection and regularization.

It is particularly effective for datasets where neither pure Lasso nor Ridge performs optimally due to the characteristics of the data.

## 2. How did you test your model to determine if it is working reasonably correctly?

### How:
The model was tested using the following methods:
1. **Synthetic Data**: Data was generated using a separate script (`generate_regression_data.py`) to create a dataset with known coefficients and noise. The model was then trained on this synthetic data, and the results were compared to the expected coefficients and predictions.
2. **Mean Squared Error (MSE) and Mean Absolute Error (MAE)**: These metrics were computed after training the model, and thresholds were used to validate the performance of the model. For example, the test script checks that the MSE and MAE are below a certain value.
3. **Visual Validation**: The predictions of the model were plotted against actual values to visually inspect how well the model fits the data. Residuals were also plotted to check if they are evenly distributed, indicating a well-fitting model.

## 3. What parameters have you exposed to users of your implementation in order to tune performance?

### Exposed Parameters:
The following parameters can be tuned by users to control the model's performance:
- **alpha**: This controls the overall strength of regularization. A higher value applies more regularization (both L1 and L2).
- **rho**: This balances the ratio between L1 (Lasso) and L2 (Ridge) penalties. `rho=0` applies only L2 regularization, `rho=1` applies only L1, and values in between apply both.
- **max_iter**: This defines the maximum number of iterations the model will take to converge during optimization.
- **tol**: The tolerance for optimization. Smaller values can lead to more accurate solutions but may take longer to compute.

## 4. Are there specific inputs that your implementation has trouble with?

### Problematic Inputs:
The model might struggle with:

- **Nonlinear relationships**: Since the model is fundamentally linear, datasets with nonlinear patterns cannot be accurately modeled without feature engineering (e.g., polynomial features or transformations).
- **Highly imbalanced data**: The model assumes that the noise in the data is homoscedastic (having the same variance), and imbalanced data can affect its performance by distorting this assumption.

### Potential Workarounds:
- **Nonlinearity**: To handle nonlinearity, feature engineering could be applied, or a more complex model like decision trees or neural networks might be used.
- **Imbalanced Data**: Techniques like resampling (oversampling/undersampling) or adjusting model weights could help in mitigating the effects of imbalanced data.

## 5. Steps to run the code.

The project includes a script for generating synthetic regression data, which can be used to test the ElasticNet model. The script generates data based on a user-specified linear equation with noise.

### Command to Generate Data:
```bash

python3 Project1/generate_regression_data.py -N 100 -m 3 -2 -b 5 -scale 0.1 -rnge -10 10 -seed 42 -output_file Project1/elasticnet/models/small_test.csv

### Explanation of Arguments:
- `-N 100`: Specifies the number of samples to generate.
- `-m 3 -2`: Defines the slope coefficients for the linear relationship (in this case, two features with slopes of 3 and -2).
- `-b 5`: Sets the intercept (offset) of the linear equation.
- `-scale 0.1`: Adds Gaussian noise with a standard deviation of 0.1.
- `-rnge -10 10`: Specifies the range from which feature values (X) are uniformly sampled.
- `-seed 42`: Sets the random seed for reproducibility.
- `-output_file`: Specifies the path to save the generated dataset as a CSV file.

### Example Output:
The command will generate a CSV file (`small_test.csv`) with columns for each feature (`x_0`, `x_1`) and a target value (`y`). This file can then be used to train and test the ElasticNet model.

Run the ElasticNet.py file to train the model.

Then run the test_ElasticNetModel.py to test the model.

Results shows MSE and MAE values and regression plots.

### 6. Project Structure
- **ElasticNet.py**: Contains the implementation of the ElasticNet regression model.
- **metrics.py**: Provides functions for evaluating model performance (MSE, MAE).
- **generate_regression_data.py**: A script to generate synthetic linear data with noise for testing the model.
- **test_ElasticNetModel.py**: A test script to evaluate the model.
1 change: 1 addition & 0 deletions elasticnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from models import ElasticNetModel
122 changes: 112 additions & 10 deletions elasticnet/models/ElasticNet.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,119 @@
import numpy as np

class ElasticNetModel:
def __init__(self, alpha=1.0, rho=0.5, max_iter=1000, tol=1e-4):
"""
Initialize the ElasticNet model.

class ElasticNetModel():
def __init__(self):
pass

Parameters:
- alpha: float, regularization strength.
- rho: float, mixing parameter between L1 and L2 (0 <= rho <= 1).
- max_iter: int, maximum number of iterations.
- tol: float, tolerance for convergence.
"""
self.alpha = alpha
self.rho = rho
self.max_iter = max_iter
self.tol = tol
self.coef_ = None
self.intercept_ = None
self.mean_ = None
self.scale_ = None

def fit(self, X, y):
return ElasticNetModelResults()
"""
Fit the ElasticNet model to the data.

Parameters:
- X: ndarray of shape (n_samples, n_features)
- y: ndarray of shape (n_samples,)

Returns:
- ElasticNetModelResults: fitted model results containing intercept and coefficients.
"""
# Standardize features
self.mean_ = np.mean(X, axis=0)
self.scale_ = np.std(X, axis=0)
# To avoid division by zero
self.scale_[self.scale_ == 0] = 1
X_std = (X - self.mean_) / self.scale_

n_samples, n_features = X_std.shape
X_aug = np.hstack((np.ones((n_samples, 1)), X_std)) # Add intercept
n_features += 1 # Account for intercept

self.coef_ = np.zeros(n_features)
X_squared_sum = np.sum(X_aug ** 2, axis=0)

for iteration in range(self.max_iter):
coef_old = self.coef_.copy()
for j in range(n_features):
# Compute residual excluding feature j
residual = y - X_aug @ self.coef_ + self.coef_[j] * X_aug[:, j]

if j == 0:
# Update intercept (no regularization)
self.coef_[j] = np.mean(residual)
else:
# Compute rho_alpha and denominator with L2 term
rho_alpha = self.alpha * self.rho
denominator = X_squared_sum[j] + self.alpha * (1 - self.rho)

# Compute raw update
ro = np.dot(X_aug[:, j], residual)

# Apply soft-thresholding
if ro < -rho_alpha:
self.coef_[j] = (ro + rho_alpha) / denominator
elif ro > rho_alpha:
self.coef_[j] = (ro - rho_alpha) / denominator
else:
self.coef_[j] = 0.0

# Check for convergence
if np.sum(np.abs(self.coef_ - coef_old)) < self.tol:
print(f"Converged in {iteration + 1} iterations.")
break
else:
print(f"Did not converge within {self.max_iter} iterations.")

# Separate intercept and coefficients
self.intercept_ = self.coef_[0] - np.sum((self.coef_[1:] * self.mean_) / self.scale_)
self.coef_ = self.coef_[1:] / self.scale_
return ElasticNetModelResults(self.intercept_, self.coef_)

def predict(self, X):
"""
Predict using the ElasticNet model.

Parameters:
- X: ndarray of shape (n_samples, n_features)

Returns:
- y_pred: ndarray of shape (n_samples,)
"""
return X @ self.coef_ + self.intercept_

class ElasticNetModelResults:
def __init__(self, intercept, coef):
"""
Store the intercept and coefficients.

Parameters:
- intercept: float
- coef: ndarray of shape (n_features,)
"""
self.intercept_ = intercept
self.coef_ = coef

def predict(self, X):
"""
Predict using the ElasticNet model.

class ElasticNetModelResults():
def __init__(self):
pass
Parameters:
- X: ndarray of shape (n_samples, n_features)

def predict(self, x):
return 0.5
Returns:
- y_pred: ndarray of shape (n_samples,)
"""
return X @ self.coef_ + self.intercept_
1 change: 1 addition & 0 deletions elasticnet/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ElasticNet import ElasticNetModel
7 changes: 7 additions & 0 deletions elasticnet/models/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np

def mean_squared_error(y_true, y_pred):
return np.mean((y_true - y_pred) ** 2)

def mean_absolute_error(y_true, y_pred):
return np.mean(np.abs(y_true - y_pred))
Binary file added elasticnet/models/predictions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added elasticnet/models/residuals.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
101 changes: 101 additions & 0 deletions elasticnet/models/small_test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
x_0,x_1,y
5.479120971119267,-1.2224312049589536,23.946558002744613
7.171958398227649,3.947360581187278,18.581693520022434
-8.11645304224701,9.512447032735118,-38.37476537888327
5.222794039807059,5.721286105539076,9.209465618490574
-7.4377273464890825,-0.9922812420886573,-15.29486210041004
-2.5840395153483753,8.535299776972035,-19.681969913857824
2.877302401613291,6.4552322654165994,0.7305011646968487
-1.1317160234533774,-5.455225564304462,12.579696937576477
1.091695740316696,-8.723654877916493,25.517379766679973
6.552623439851642,2.633287982441297,19.386422514479317
5.161754801707477,-2.9094806374026327,26.219902652898405
9.413960487898066,7.8624224264439535,17.395155304763932
5.567669941475238,-6.1072258429606485,33.829646273654134
-0.6655799254593164,-9.123924684255424,21.217697248062816
-6.914210158649043,3.6609790648490925,-22.972998351409714
4.895243118156342,9.3501946486842,0.8527007853266708
-3.483492837236961,-2.5908058793026223,-0.2658036038461975
-0.6088837744838416,-6.2105728183142865,15.54607736984369
-7.401569893290567,-0.4859014754813251,-16.265674038345246
-5.461813018982317,3.396279893650207,-18.077723061716902
-1.2569616225533853,6.653563921156749,-12.024201166273263
4.005302040044983,-3.7526671723591787,24.65498027559758
6.645196027904021,6.095287149936038,12.729563215914997
-2.250432419396511,-4.233437921395118,6.645984323433632
3.64991007949951,-7.204950327813804,30.33724501243809
-6.001835950497833,-9.85275460497989,6.724251037593403
5.738487550042768,3.2970171318406427,15.639085722292387
4.1033075725267025,5.614580620439358,5.972322669478055
-0.8216844892332009,1.3748239190578748,-0.20565232765256497
-7.204060037446851,-7.709398529280531,-1.1705602207656014
3.368059235809433,-0.5780758771373495,16.51207686545639
1.3047221296237765,5.299977148320512,-1.4981034466415266
2.694366400011816,1.0715880131599165,10.854598838659735
1.184143214908271,-3.920998038747756,16.36568738606541
-9.383643308641211,-1.2656522153527527,-20.765969695401832
-5.708306543609416,-1.829427125507277,-8.525136081210041
7.068061465363322,-5.321210282693185,36.87816546183537
-8.83394516621868,-4.372322159560069,-12.636605817447078
-4.128124844666328,3.238330294537901,-13.933943506849147
1.1406430468255668,5.67796418212827,-2.9994138677866187
3.286270806547751,-1.8722627711985895,18.388609059066567
6.280407693320694,-6.660541601845922,37.146039691599434
-9.54575853732279,-8.199042784487165,-7.345431484179995
4.447187011929006,-0.7624553949722523,19.813527882994915
-6.774564419327964,0.020895502067270755,-15.453170339935191
-6.953757945736632,3.9264075015547206,-23.72351509574459
-1.076874488519386,-2.3795754780703504,6.35275465144688
-3.9697582170424695,2.605651862377769,-12.267282900422046
-2.7637477889321915,-8.24700161367798,13.415684571762215
-7.639881957589694,9.23795329099029,-36.524294712877065
8.171613814152142,3.994142676214991,21.41687753218098
-4.682600770809609,9.383527546954477,-27.63116605350565
5.5750180793158925,4.337803783179912,13.339953388511894
-1.0127699571242266,-4.55516876309682,10.954870991938426
-8.072180756930013,8.052047930876833,-35.35746302822251
-0.8844742033277786,-5.952732704095394,14.286198353301502
-3.88086751698695,1.5843913788379194,-9.63851554419613
-6.464544341215365,7.1322856818475096,-28.756890095183937
5.170390596704202,4.389259119018735,11.708125767480924
-1.3581392044979257,2.546176814048864,-4.08903748398533
1.6819593782547115,2.9969320310964,4.095490680017996
-8.311113577202217,-1.6838519565878074,-16.603252425554047
-9.16771652276215,-0.12018361510962139,-22.276164634518814
-3.4027757533442937,-7.109516222679062,8.873215604488244
-7.931940645548967,1.7528914435542409,-22.32542219815285
-6.588140629262278,8.502402367535943,-31.79586537185962
1.6212227940078994,-3.0626039093032587,16.012093189592843
1.818309829628335,-9.54392257940605,29.487241925815205
9.171184264828906,-0.35393126114199536,33.26856916902523
5.654704545005725,-8.345400001551228,38.75618521990161
-0.26683338323679306,-0.1858601129095816,4.58676300887563
8.756529099499657,1.4345610475215076,28.435640844295875
-0.5302119788609239,-4.660486738162128,12.735653074499272
-3.3686200531489563,0.41344804943075575,-5.932747818999239
-1.2217707938990663,-9.567758402393391,20.3980486197353
6.525838483887156,7.923215436795335,8.76273400423813
-7.195018220027785,1.080722870780988,-18.756229061486682
-7.828485177291129,3.444801860796234,-25.165742422572798
-4.375324323219834,3.1884526938380353,-14.345542867088048
4.539892285737652,5.37294983835314,7.912361835762327
-7.845181080882069,8.320236902752159,-35.252322769120006
-5.395720182102384,-9.251748876476405,7.2050960594473175
1.0970493878296672,-2.581554322751225,13.573371104300339
6.595794862648264,6.165029441286036,12.483600627887435
-3.657222143545693,9.057987901394899,-24.039627893087715
-4.181643237197628,0.3011425846342908,-8.321673479558886
-4.880698188647945,8.720871400979266,-26.99109351974655
-6.707843648359637,-9.10178761215342,3.1254863130493633
-1.2980587999392412,9.847511281116741,-18.700242030492653
7.833545325098279,4.972160389138985,18.509162716271874
7.815849817570497,7.8689327939572635,12.736055585232616
0.37716720772898116,-3.68141896338414,13.499586229790847
5.4402486422197605,3.2332252633552216,14.825078281368484
-2.5268454225257986,-8.110666638769695,13.6304481831554
4.93579222698052,-4.750789681542706,29.28375830620628
8.736263010675584,-5.180588499886305,41.58522228300961
-7.5448413517702795,6.622253442498122,-30.731881743007765
-6.93431366751012,-6.414633836845218,-3.2303391729330535
1.98765583041687,7.49124081674929,-4.043199168698937
-6.071306685708535,-3.7935265419981046,-5.609215730992151
Loading