diff --git a/README.md b/README.md index c1e8359..bdf3d70 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,86 @@ -# Project 1 +# Project 1 - ElasticNet Model -Put your README here. Answer the following questions. +Name: Gurjot Singh Kalsi CWID: A20550984 -* What does the model you have implemented do and when should it be used? -* How did you test your model to determine if it is working reasonably correctly? -* What parameters have you exposed to users of your implementation in order to tune performance? (Also perhaps provide some basic usage examples.) -* Are there specific inputs that your implementation has trouble with? Given more time, could you work around these or is it fundamental? +Name: Siva Vamsi Kolli CWID: A20560901 + +Name: Sai Teja Reddy Janga CWID: A20554588 + +This project implements an ElasticNet regression model, a linear regression method that combines both L1 (Lasso) and L2 (Ridge) penalties. The model is particularly useful when dealing with datasets that have multicollinearity (high correlation among features) and when you need to perform feature selection by shrinking some coefficients to zero. + +## 1. What does the model do and when should it be used? + +### What: +The ElasticNet regression model estimates the relationships between a dependent variable (target) and multiple independent variables (features) while applying regularization to prevent overfitting. The model combines the strengths of Lasso and Ridge regression: +- **Lasso** (L1) for feature selection, as it can shrink coefficients of some features to zero. +- **Ridge** (L2) for minimizing the impact of multicollinearity and stabilizing the model. + +### When to Use: +ElasticNet is best used in scenarios where: +- You have a large number of features, many of which might be irrelevant or redundant. +- There is high multicollinearity (i.e., correlation between input features). +- You want a model that can perform both feature selection and regularization. + +It is particularly effective for datasets where neither pure Lasso nor Ridge performs optimally due to the characteristics of the data. + +## 2. How did you test your model to determine if it is working reasonably correctly? + +### How: +The model was tested using the following methods: +1. **Synthetic Data**: Data was generated using a separate script (`generate_regression_data.py`) to create a dataset with known coefficients and noise. The model was then trained on this synthetic data, and the results were compared to the expected coefficients and predictions. +2. **Mean Squared Error (MSE) and Mean Absolute Error (MAE)**: These metrics were computed after training the model, and thresholds were used to validate the performance of the model. For example, the test script checks that the MSE and MAE are below a certain value. +3. **Visual Validation**: The predictions of the model were plotted against actual values to visually inspect how well the model fits the data. Residuals were also plotted to check if they are evenly distributed, indicating a well-fitting model. + +## 3. What parameters have you exposed to users of your implementation in order to tune performance? + +### Exposed Parameters: +The following parameters can be tuned by users to control the model's performance: +- **alpha**: This controls the overall strength of regularization. A higher value applies more regularization (both L1 and L2). +- **rho**: This balances the ratio between L1 (Lasso) and L2 (Ridge) penalties. `rho=0` applies only L2 regularization, `rho=1` applies only L1, and values in between apply both. +- **max_iter**: This defines the maximum number of iterations the model will take to converge during optimization. +- **tol**: The tolerance for optimization. Smaller values can lead to more accurate solutions but may take longer to compute. + +## 4. Are there specific inputs that your implementation has trouble with? + +### Problematic Inputs: +The model might struggle with: + +- **Nonlinear relationships**: Since the model is fundamentally linear, datasets with nonlinear patterns cannot be accurately modeled without feature engineering (e.g., polynomial features or transformations). +- **Highly imbalanced data**: The model assumes that the noise in the data is homoscedastic (having the same variance), and imbalanced data can affect its performance by distorting this assumption. + +### Potential Workarounds: +- **Nonlinearity**: To handle nonlinearity, feature engineering could be applied, or a more complex model like decision trees or neural networks might be used. +- **Imbalanced Data**: Techniques like resampling (oversampling/undersampling) or adjusting model weights could help in mitigating the effects of imbalanced data. + +## 5. Steps to run the code. + +The project includes a script for generating synthetic regression data, which can be used to test the ElasticNet model. The script generates data based on a user-specified linear equation with noise. + +### Command to Generate Data: +```bash + +python3 Project1/generate_regression_data.py -N 100 -m 3 -2 -b 5 -scale 0.1 -rnge -10 10 -seed 42 -output_file Project1/elasticnet/models/small_test.csv + +### Explanation of Arguments: +- `-N 100`: Specifies the number of samples to generate. +- `-m 3 -2`: Defines the slope coefficients for the linear relationship (in this case, two features with slopes of 3 and -2). +- `-b 5`: Sets the intercept (offset) of the linear equation. +- `-scale 0.1`: Adds Gaussian noise with a standard deviation of 0.1. +- `-rnge -10 10`: Specifies the range from which feature values (X) are uniformly sampled. +- `-seed 42`: Sets the random seed for reproducibility. +- `-output_file`: Specifies the path to save the generated dataset as a CSV file. + +### Example Output: +The command will generate a CSV file (`small_test.csv`) with columns for each feature (`x_0`, `x_1`) and a target value (`y`). This file can then be used to train and test the ElasticNet model. + +Run the ElasticNet.py file to train the model. + +Then run the test_ElasticNetModel.py to test the model. + +Results shows MSE and MAE values and regression plots. + +### 6. Project Structure +- **ElasticNet.py**: Contains the implementation of the ElasticNet regression model. +- **metrics.py**: Provides functions for evaluating model performance (MSE, MAE). +- **generate_regression_data.py**: A script to generate synthetic linear data with noise for testing the model. +- **test_ElasticNetModel.py**: A test script to evaluate the model. \ No newline at end of file diff --git a/elasticnet/__init__.py b/elasticnet/__init__.py index e69de29..87c314a 100644 --- a/elasticnet/__init__.py +++ b/elasticnet/__init__.py @@ -0,0 +1 @@ +from models import ElasticNetModel \ No newline at end of file diff --git a/elasticnet/models/ElasticNet.py b/elasticnet/models/ElasticNet.py index 017e925..05d5087 100644 --- a/elasticnet/models/ElasticNet.py +++ b/elasticnet/models/ElasticNet.py @@ -1,17 +1,119 @@ +import numpy as np +class ElasticNetModel: + def __init__(self, alpha=1.0, rho=0.5, max_iter=1000, tol=1e-4): + """ + Initialize the ElasticNet model. -class ElasticNetModel(): - def __init__(self): - pass - + Parameters: + - alpha: float, regularization strength. + - rho: float, mixing parameter between L1 and L2 (0 <= rho <= 1). + - max_iter: int, maximum number of iterations. + - tol: float, tolerance for convergence. + """ + self.alpha = alpha + self.rho = rho + self.max_iter = max_iter + self.tol = tol + self.coef_ = None + self.intercept_ = None + self.mean_ = None + self.scale_ = None def fit(self, X, y): - return ElasticNetModelResults() + """ + Fit the ElasticNet model to the data. + + Parameters: + - X: ndarray of shape (n_samples, n_features) + - y: ndarray of shape (n_samples,) + + Returns: + - ElasticNetModelResults: fitted model results containing intercept and coefficients. + """ + # Standardize features + self.mean_ = np.mean(X, axis=0) + self.scale_ = np.std(X, axis=0) + # To avoid division by zero + self.scale_[self.scale_ == 0] = 1 + X_std = (X - self.mean_) / self.scale_ + + n_samples, n_features = X_std.shape + X_aug = np.hstack((np.ones((n_samples, 1)), X_std)) # Add intercept + n_features += 1 # Account for intercept + + self.coef_ = np.zeros(n_features) + X_squared_sum = np.sum(X_aug ** 2, axis=0) + + for iteration in range(self.max_iter): + coef_old = self.coef_.copy() + for j in range(n_features): + # Compute residual excluding feature j + residual = y - X_aug @ self.coef_ + self.coef_[j] * X_aug[:, j] + + if j == 0: + # Update intercept (no regularization) + self.coef_[j] = np.mean(residual) + else: + # Compute rho_alpha and denominator with L2 term + rho_alpha = self.alpha * self.rho + denominator = X_squared_sum[j] + self.alpha * (1 - self.rho) + + # Compute raw update + ro = np.dot(X_aug[:, j], residual) + + # Apply soft-thresholding + if ro < -rho_alpha: + self.coef_[j] = (ro + rho_alpha) / denominator + elif ro > rho_alpha: + self.coef_[j] = (ro - rho_alpha) / denominator + else: + self.coef_[j] = 0.0 + + # Check for convergence + if np.sum(np.abs(self.coef_ - coef_old)) < self.tol: + print(f"Converged in {iteration + 1} iterations.") + break + else: + print(f"Did not converge within {self.max_iter} iterations.") + + # Separate intercept and coefficients + self.intercept_ = self.coef_[0] - np.sum((self.coef_[1:] * self.mean_) / self.scale_) + self.coef_ = self.coef_[1:] / self.scale_ + return ElasticNetModelResults(self.intercept_, self.coef_) + + def predict(self, X): + """ + Predict using the ElasticNet model. + + Parameters: + - X: ndarray of shape (n_samples, n_features) + + Returns: + - y_pred: ndarray of shape (n_samples,) + """ + return X @ self.coef_ + self.intercept_ + +class ElasticNetModelResults: + def __init__(self, intercept, coef): + """ + Store the intercept and coefficients. + + Parameters: + - intercept: float + - coef: ndarray of shape (n_features,) + """ + self.intercept_ = intercept + self.coef_ = coef + def predict(self, X): + """ + Predict using the ElasticNet model. -class ElasticNetModelResults(): - def __init__(self): - pass + Parameters: + - X: ndarray of shape (n_samples, n_features) - def predict(self, x): - return 0.5 + Returns: + - y_pred: ndarray of shape (n_samples,) + """ + return X @ self.coef_ + self.intercept_ \ No newline at end of file diff --git a/elasticnet/models/__init__.py b/elasticnet/models/__init__.py index e69de29..0650f05 100644 --- a/elasticnet/models/__init__.py +++ b/elasticnet/models/__init__.py @@ -0,0 +1 @@ +from ElasticNet import ElasticNetModel \ No newline at end of file diff --git a/elasticnet/models/metrics.py b/elasticnet/models/metrics.py new file mode 100644 index 0000000..a77ce9c --- /dev/null +++ b/elasticnet/models/metrics.py @@ -0,0 +1,7 @@ +import numpy as np + +def mean_squared_error(y_true, y_pred): + return np.mean((y_true - y_pred) ** 2) + +def mean_absolute_error(y_true, y_pred): + return np.mean(np.abs(y_true - y_pred)) \ No newline at end of file diff --git a/elasticnet/models/predictions.png b/elasticnet/models/predictions.png new file mode 100644 index 0000000..64c1320 Binary files /dev/null and b/elasticnet/models/predictions.png differ diff --git a/elasticnet/models/residuals.png b/elasticnet/models/residuals.png new file mode 100644 index 0000000..94ee3e0 Binary files /dev/null and b/elasticnet/models/residuals.png differ diff --git a/elasticnet/models/small_test.csv b/elasticnet/models/small_test.csv new file mode 100644 index 0000000..14fad1d --- /dev/null +++ b/elasticnet/models/small_test.csv @@ -0,0 +1,101 @@ +x_0,x_1,y +5.479120971119267,-1.2224312049589536,23.946558002744613 +7.171958398227649,3.947360581187278,18.581693520022434 +-8.11645304224701,9.512447032735118,-38.37476537888327 +5.222794039807059,5.721286105539076,9.209465618490574 +-7.4377273464890825,-0.9922812420886573,-15.29486210041004 +-2.5840395153483753,8.535299776972035,-19.681969913857824 +2.877302401613291,6.4552322654165994,0.7305011646968487 +-1.1317160234533774,-5.455225564304462,12.579696937576477 +1.091695740316696,-8.723654877916493,25.517379766679973 +6.552623439851642,2.633287982441297,19.386422514479317 +5.161754801707477,-2.9094806374026327,26.219902652898405 +9.413960487898066,7.8624224264439535,17.395155304763932 +5.567669941475238,-6.1072258429606485,33.829646273654134 +-0.6655799254593164,-9.123924684255424,21.217697248062816 +-6.914210158649043,3.6609790648490925,-22.972998351409714 +4.895243118156342,9.3501946486842,0.8527007853266708 +-3.483492837236961,-2.5908058793026223,-0.2658036038461975 +-0.6088837744838416,-6.2105728183142865,15.54607736984369 +-7.401569893290567,-0.4859014754813251,-16.265674038345246 +-5.461813018982317,3.396279893650207,-18.077723061716902 +-1.2569616225533853,6.653563921156749,-12.024201166273263 +4.005302040044983,-3.7526671723591787,24.65498027559758 +6.645196027904021,6.095287149936038,12.729563215914997 +-2.250432419396511,-4.233437921395118,6.645984323433632 +3.64991007949951,-7.204950327813804,30.33724501243809 +-6.001835950497833,-9.85275460497989,6.724251037593403 +5.738487550042768,3.2970171318406427,15.639085722292387 +4.1033075725267025,5.614580620439358,5.972322669478055 +-0.8216844892332009,1.3748239190578748,-0.20565232765256497 +-7.204060037446851,-7.709398529280531,-1.1705602207656014 +3.368059235809433,-0.5780758771373495,16.51207686545639 +1.3047221296237765,5.299977148320512,-1.4981034466415266 +2.694366400011816,1.0715880131599165,10.854598838659735 +1.184143214908271,-3.920998038747756,16.36568738606541 +-9.383643308641211,-1.2656522153527527,-20.765969695401832 +-5.708306543609416,-1.829427125507277,-8.525136081210041 +7.068061465363322,-5.321210282693185,36.87816546183537 +-8.83394516621868,-4.372322159560069,-12.636605817447078 +-4.128124844666328,3.238330294537901,-13.933943506849147 +1.1406430468255668,5.67796418212827,-2.9994138677866187 +3.286270806547751,-1.8722627711985895,18.388609059066567 +6.280407693320694,-6.660541601845922,37.146039691599434 +-9.54575853732279,-8.199042784487165,-7.345431484179995 +4.447187011929006,-0.7624553949722523,19.813527882994915 +-6.774564419327964,0.020895502067270755,-15.453170339935191 +-6.953757945736632,3.9264075015547206,-23.72351509574459 +-1.076874488519386,-2.3795754780703504,6.35275465144688 +-3.9697582170424695,2.605651862377769,-12.267282900422046 +-2.7637477889321915,-8.24700161367798,13.415684571762215 +-7.639881957589694,9.23795329099029,-36.524294712877065 +8.171613814152142,3.994142676214991,21.41687753218098 +-4.682600770809609,9.383527546954477,-27.63116605350565 +5.5750180793158925,4.337803783179912,13.339953388511894 +-1.0127699571242266,-4.55516876309682,10.954870991938426 +-8.072180756930013,8.052047930876833,-35.35746302822251 +-0.8844742033277786,-5.952732704095394,14.286198353301502 +-3.88086751698695,1.5843913788379194,-9.63851554419613 +-6.464544341215365,7.1322856818475096,-28.756890095183937 +5.170390596704202,4.389259119018735,11.708125767480924 +-1.3581392044979257,2.546176814048864,-4.08903748398533 +1.6819593782547115,2.9969320310964,4.095490680017996 +-8.311113577202217,-1.6838519565878074,-16.603252425554047 +-9.16771652276215,-0.12018361510962139,-22.276164634518814 +-3.4027757533442937,-7.109516222679062,8.873215604488244 +-7.931940645548967,1.7528914435542409,-22.32542219815285 +-6.588140629262278,8.502402367535943,-31.79586537185962 +1.6212227940078994,-3.0626039093032587,16.012093189592843 +1.818309829628335,-9.54392257940605,29.487241925815205 +9.171184264828906,-0.35393126114199536,33.26856916902523 +5.654704545005725,-8.345400001551228,38.75618521990161 +-0.26683338323679306,-0.1858601129095816,4.58676300887563 +8.756529099499657,1.4345610475215076,28.435640844295875 +-0.5302119788609239,-4.660486738162128,12.735653074499272 +-3.3686200531489563,0.41344804943075575,-5.932747818999239 +-1.2217707938990663,-9.567758402393391,20.3980486197353 +6.525838483887156,7.923215436795335,8.76273400423813 +-7.195018220027785,1.080722870780988,-18.756229061486682 +-7.828485177291129,3.444801860796234,-25.165742422572798 +-4.375324323219834,3.1884526938380353,-14.345542867088048 +4.539892285737652,5.37294983835314,7.912361835762327 +-7.845181080882069,8.320236902752159,-35.252322769120006 +-5.395720182102384,-9.251748876476405,7.2050960594473175 +1.0970493878296672,-2.581554322751225,13.573371104300339 +6.595794862648264,6.165029441286036,12.483600627887435 +-3.657222143545693,9.057987901394899,-24.039627893087715 +-4.181643237197628,0.3011425846342908,-8.321673479558886 +-4.880698188647945,8.720871400979266,-26.99109351974655 +-6.707843648359637,-9.10178761215342,3.1254863130493633 +-1.2980587999392412,9.847511281116741,-18.700242030492653 +7.833545325098279,4.972160389138985,18.509162716271874 +7.815849817570497,7.8689327939572635,12.736055585232616 +0.37716720772898116,-3.68141896338414,13.499586229790847 +5.4402486422197605,3.2332252633552216,14.825078281368484 +-2.5268454225257986,-8.110666638769695,13.6304481831554 +4.93579222698052,-4.750789681542706,29.28375830620628 +8.736263010675584,-5.180588499886305,41.58522228300961 +-7.5448413517702795,6.622253442498122,-30.731881743007765 +-6.93431366751012,-6.414633836845218,-3.2303391729330535 +1.98765583041687,7.49124081674929,-4.043199168698937 +-6.071306685708535,-3.7935265419981046,-5.609215730992151 diff --git a/elasticnet/models/test_ElasticNetModel.py b/elasticnet/models/test_ElasticNetModel.py new file mode 100644 index 0000000..2b492dd --- /dev/null +++ b/elasticnet/models/test_ElasticNetModel.py @@ -0,0 +1,74 @@ +import csv +import numpy as np +import matplotlib.pyplot as plt +from ElasticNet import ElasticNetModel +from metrics import mean_squared_error, mean_absolute_error + +# Load the data from a CSV file +def load_data(filename): + data = [] + with open(filename, "r") as file: + reader = csv.DictReader(file) + for row in reader: + data.append(row) + X = np.array([[float(v) for k, v in datum.items() if k.startswith('x')] for datum in data]) + y = np.array([float(datum['y']) for datum in data]) + return X, y + +# Plot predictions vs actual values for regression +def plot_predictions(y_true, y_pred, filename='predictions.png'): + plt.figure(figsize=(8, 6)) + colors = np.abs(y_true - y_pred) + plt.scatter(y_true, y_pred, c=colors, cmap='viridis', alpha=0.7) + plt.colorbar(label="Difference (|y_true - y_pred|)") + plt.xlabel('Actual Values') + plt.ylabel('Predicted Values') + plt.title('Actual vs Predicted Values') + plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--') + plt.savefig(filename) + plt.close() + +# Plot residuals for regression +def plot_residuals(y_true, y_pred, filename='residuals.png'): + residuals = y_true - y_pred + plt.figure(figsize=(8, 6)) + colors = residuals + plt.scatter(y_pred, residuals, c=colors, cmap='coolwarm', alpha=0.7) + plt.colorbar(label="Residuals") + plt.xlabel('Predicted Values') + plt.ylabel('Residuals') + plt.title('Residuals vs Predicted Values') + plt.axhline(0, color='r', linestyle='--') + plt.savefig(filename) + plt.close() + +# Test the ElasticNet model's prediction +def test_predict(task='regression'): + # Initialize the model with values + model = ElasticNetModel(alpha=0.1, rho=0.5, max_iter=1000, tol=1e-4) + + # Load data + X, y = load_data("/Users/gurjotsinghkalsi/Desktop/Fall2024/MachineLearning/Project1/elasticnet/models/small_test.csv") + + # Fit the model and get predictions + results = model.fit(X, y) + preds = results.predict(X) + + if task == 'regression': + # Calculate metrics for regression + mse = mean_squared_error(y, preds) + mae = mean_absolute_error(y, preds) + print(f"MSE: {mse}") + print(f"MAE: {mae}") + + # Assert that MSE and MAE are below certain thresholds + assert mse < 1.0, f"MSE too high: {mse}" + assert mae < 1.0, f"MAE too high: {mae}" + + # Plot regression results + plot_predictions(y, preds, filename='/Users/gurjotsinghkalsi/Desktop/Fall2024/MachineLearning/Project1/elasticnet/models/predictions.png') + plot_residuals(y, preds, filename='/Users/gurjotsinghkalsi/Desktop/Fall2024/MachineLearning/Project1/elasticnet/models/residuals.png') + print("Regression plots saved: predictions.png and residuals.png") + +if __name__ == "__main__": + test_predict(task='regression') diff --git a/elasticnet/tests/small_test.csv b/elasticnet/tests/small_test.csv deleted file mode 100644 index bf8442e..0000000 --- a/elasticnet/tests/small_test.csv +++ /dev/null @@ -1,51 +0,0 @@ -x_0,x_1,x_2,y --2.421348566501347,6.290215260063935,2.516304163087373,10.240119830146476 -8.13465811997068,-6.975968662410185,-3.2810945459842866,-6.8962940548446845 --0.4531238994261493,0.05889462611191654,-3.592293253611172,14.10428803155231 -3.979832584128687,-8.129001764124755,9.202914789330517,-43.788867687445624 --4.354231825431758,2.4724749171156333,8.45972163584499,-12.067617018047834 -8.726620980175113,-9.607722575405269,-5.092837184080405,-8.265643240683891 --0.29136484802189955,8.224663789274086,-3.8193339707565555,32.98185595386334 -1.4118708853910462,6.003042800612462,3.9968255952773095,0.7267789346532836 -0.21525181834957507,-3.321041549359367,-5.352746248495515,11.93444109619503 -4.80226153299567,9.818246112545182,4.936296097738831,3.5995719453822046 -9.71733974143089,0.1440918710436101,8.74993701189404,-34.917122745540794 -4.098687611436789,-9.75205878861841,7.980744101999381,-43.32805584620358 --2.398060521804659,2.8278192128541733,-1.626174948927721,16.91539285950553 -5.398272903061114,7.583046908728093,2.758295974535457,4.437457748228852 -3.371527871466675,-5.430064318728407,2.1915998058530857,-16.03565826569788 -2.0863644528269365,0.10824916542728857,8.144465640869694,-25.094326089867696 -2.8255940202840684,-2.286321234798363,4.771241059098381,-18.000440202657604 --8.150227640024978,-4.259315052105519,1.8923353680502952,-1.3930242667026356 --6.067265316809651,3.6776254617776942,8.4817269440159,-10.278522746897893 -8.64017362219969,9.717801217085075,4.980672567111553,-0.9266647796977245 --4.636910653452324,0.9373715699813872,4.978170771263397,-3.8217233698137143 --7.940395120999431,2.953441321061362,-0.9370552302607145,21.291726783530805 -7.692709298116139,-5.485844206553388,-6.019643260327971,2.1873435652525455 --6.485086441297707,7.06589989184231,-8.842925435171665,50.35981404591074 -5.036321300769028,2.0420739888497152,-4.368234397412891,15.435100617505809 --2.203566631709222,-6.141030616852454,-1.822186931753599,-0.5890454529472771 -3.2620868350599768,7.851306022896178,-4.479265977335616,27.896949611024628 -6.402611257683294,-4.018677430646336,0.48600102750762986,-12.289355696825485 -5.378501224056757,4.355667003325474,-7.565417868242747,31.017195148404717 -2.0486633392332614,8.253411759540757,-3.966950647644751,29.555547834722987 -2.626017326894857,3.314924154867276,9.810418858378235,-22.85112181951592 --0.04750452520510429,5.935777040113393,-0.3470621837504506,16.516617979443822 --6.775500897482147,-0.8747563332852692,-2.758815934335188,16.55155644731519 --5.130765599150095,8.959898235120185,1.1701541118251235,22.753375944830324 -9.607901921761815,-9.108821424255002,5.524296399378377,-41.93781490943017 --2.9201254899877434,5.134928295361929,-9.896226148902585,43.58829658171542 -6.956501039100711,0.8359369151964895,-6.1636372998431295,16.225403196517274 -7.725179239543149,-4.913104095867496,-1.110476120153832,-9.936035489824537 --6.142683379729563,1.4244393989902058,1.8529074318076262,5.554396424524908 --2.0474061706133977,-1.2170618863263076,8.899325908803291,-23.596187786238964 -9.359523403637155,3.4124788823300065,-1.4222946765509725,2.4507844709064064 --8.642800876507275,-9.508822574677566,2.9901775243378577,-16.775543378589024 --2.470992582133973,5.1672327675732195,-8.753045094764744,40.855147394263106 --7.756097982925145,5.227601844332813,-3.179199348468109,30.739018818654756 -5.393783291304004,-1.5186710515725927,-7.469139234639499,17.503383657767756 --7.644671911438172,1.8115363641056241,-6.167155079348694,33.57677356652164 -6.557442460132911,-4.44188855380612,-6.368621306151785,7.435670420087931 -0.21009363927752744,-2.719754693698011,1.0885820356480096,-6.289562485886653 --8.571672299069252,8.890348599509473,5.468260371802332,15.412904086362603 -7.872454219630789,-3.9905860234116357,0.9068940749874717,-16.017543419998542 diff --git a/elasticnet/tests/test_ElasticNetModel.py b/elasticnet/tests/test_ElasticNetModel.py deleted file mode 100644 index 5022c3c..0000000 --- a/elasticnet/tests/test_ElasticNetModel.py +++ /dev/null @@ -1,19 +0,0 @@ -import csv - -import numpy - -from elasticnet.models.ElasticNet import ElasticNetModel - -def test_predict(): - model = ElasticNetModel() - data = [] - with open("small_test.csv", "r") as file: - reader = csv.DictReader(file) - for row in reader: - data.append(row) - - X = numpy.array([[v for k,v in datum.items() if k.startswith('x')] for datum in data]) - y = numpy.array([[v for k,v in datum.items() if k=='y'] for datum in data]) - results = model.fit(X,y) - preds = results.predict(X) - assert preds == 0.5