diff --git a/Data_Generator.bat b/Data_Generator.bat new file mode 100644 index 0000000..bb95164 --- /dev/null +++ b/Data_Generator.bat @@ -0,0 +1,2 @@ +python generate_regression_data.py -N 100 -m 2.5 1.5 -b 0.5 -scale 1.0 -rnge -10 10 -seed 42 -output_file train_data.csv +python generate_regression_data.py -N 30 -m 2.5 1.5 -b 0.5 -scale 1.0 -rnge -10 10 -seed 42 -output_file test_data.csv diff --git a/Data_Generator.sh b/Data_Generator.sh new file mode 100644 index 0000000..bb95164 --- /dev/null +++ b/Data_Generator.sh @@ -0,0 +1,2 @@ +python generate_regression_data.py -N 100 -m 2.5 1.5 -b 0.5 -scale 1.0 -rnge -10 10 -seed 42 -output_file train_data.csv +python generate_regression_data.py -N 30 -m 2.5 1.5 -b 0.5 -scale 1.0 -rnge -10 10 -seed 42 -output_file test_data.csv diff --git a/README.md b/README.md index c1e8359..a7b08a1 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,41 @@ -# Project 1 +## Group Member(s) +ZIRUI OU A20516756 -Put your README here. Answer the following questions. +### What does the model do and when should it be used? -* What does the model you have implemented do and when should it be used? -* How did you test your model to determine if it is working reasonably correctly? -* What parameters have you exposed to users of your implementation in order to tune performance? (Also perhaps provide some basic usage examples.) -* Are there specific inputs that your implementation has trouble with? Given more time, could you work around these or is it fundamental? +The ElasticNetModel is a type of linear regression that combines L1 and L2 regularization. It's ideal for predicting continuous outcomes, especially useful in datasets with irrelevant features or highly correlated features. + +This model should be used when you suspect that your data contains irrelevant features or when the features are highly correlated. It's particularly useful in scenarios where you need a model that's robust against issues like multicollinearity (where independent variables are correlated) and when you want to prevent overfitting in your predictive model. + +### How did you test your model to determine if it is working reasonably correctly? +The script tests the ElasticNet model by training it on a set of training data, making predictions on a separate test dataset, and then calculating the Mean Squared Error (MSE) between the predicted and actual values to assess accuracy. It ensures the MSE is below a threshold of 1 to verify the model's performance. + + +### What parameters are exposed to users to tune performance? + +- **`lr` (Learning Rate):** Controls the update magnitude of model coefficients. +- **`n_iter` (Number of Iterations):** Determines how many times the model will process the entire dataset. +- **`l1_ratio` (L1 Ratio):** Balances between L1 and L2 regularization. +- **`alpha` (Regularization Strength):** Adjusts the overall strength of the regularization. + +#### Basic Usage Example +```python + +model = ElasticNetModel(lr=0.01, n_iter=1000, l1_ratio=0.5, alpha=1.0) +model.fit(X_train, y_train) +predictions = model.predict(X_test) +print(predictions) + +``` + +### Are there specific inputs that your implementation has trouble with? +Yes, the model struggles with non-numeric data, missing values, due to its basic implementation. + +### Given more time, could these issues be worked around? +Yes, with more time, enhancements like automatic handling of non-numeric data and missing values, could be implemented to make the model more robust and efficient. + + +### Before you RUN: +1. please using `pip install numba numpy` to install numba and numpy before run it. +1. And make sure `test_data.csv` and `train_data.csv` are in the correct location, if not there, use one of the `Data_Generator` scripts to generate it according to the platform you are using.. +2. Now you should ready to run the test program using `python elasticnet\tests\test_ElasticNetModel.py`. diff --git a/elasticnet/models/ElasticNet.py b/elasticnet/models/ElasticNet.py index 017e925..5184b57 100644 --- a/elasticnet/models/ElasticNet.py +++ b/elasticnet/models/ElasticNet.py @@ -1,17 +1,67 @@ +import numpy as np +from numba import jit +from typing import Tuple +class ElasticNetModel: + def __init__( + self, + learning_rate: float = 0.01, + iterations: int = 1000, + l1_ratio: float = 0.5, + alpha: float = 1.0) -> None: + + self.learning_rate = learning_rate + self.iterations = iterations + self.l1_ratio = l1_ratio + self.alpha = alpha + self.weights = np.empty(0) + self.bias = 0.0 -class ElasticNetModel(): - def __init__(self): - pass + def fit( + self, + features: np.ndarray, + target: np.ndarray) -> None: + + num_samples, num_features = features.shape + self.weights = np.zeros(num_features) + self.bias = 0.0 + self.weights, self.bias = self._optimize( + features, + target, + self.weights, + self.bias, + self.learning_rate, + self.iterations, + self.alpha, + self.l1_ratio, + num_samples) - def fit(self, X, y): - return ElasticNetModelResults() + @staticmethod + @jit(nopython=True, nogil=True) + def _optimize( + features: np.ndarray, + target: np.ndarray, + weights: np.ndarray, + bias: float, + learning_rate: float, + iterations: int, + alpha: float, + l1_ratio: float, + num_samples: int) -> Tuple[np.ndarray, float]: + for _ in range(iterations): + predictions = np.dot(features, weights) + bias + errors = predictions - target -class ElasticNetModelResults(): - def __init__(self): - pass + l2_gradient = 2 * weights + l1_gradient = np.sign(weights) - def predict(self, x): - return 0.5 + weights -= learning_rate * ((1 / num_samples) * np.dot(features.T, errors) + alpha * ((1 - l1_ratio) * l2_gradient + l1_ratio * l1_gradient)) + bias -= learning_rate * (1 / num_samples) * np.sum(errors) + + return weights, bias + + def predict(self, features: np.ndarray) -> np.ndarray: + predictions = np.dot(features, self.weights) + self.bias + return predictions diff --git a/elasticnet/tests/small_test.csv b/elasticnet/tests/small_test.csv deleted file mode 100644 index bf8442e..0000000 --- a/elasticnet/tests/small_test.csv +++ /dev/null @@ -1,51 +0,0 @@ -x_0,x_1,x_2,y --2.421348566501347,6.290215260063935,2.516304163087373,10.240119830146476 -8.13465811997068,-6.975968662410185,-3.2810945459842866,-6.8962940548446845 --0.4531238994261493,0.05889462611191654,-3.592293253611172,14.10428803155231 -3.979832584128687,-8.129001764124755,9.202914789330517,-43.788867687445624 --4.354231825431758,2.4724749171156333,8.45972163584499,-12.067617018047834 -8.726620980175113,-9.607722575405269,-5.092837184080405,-8.265643240683891 --0.29136484802189955,8.224663789274086,-3.8193339707565555,32.98185595386334 -1.4118708853910462,6.003042800612462,3.9968255952773095,0.7267789346532836 -0.21525181834957507,-3.321041549359367,-5.352746248495515,11.93444109619503 -4.80226153299567,9.818246112545182,4.936296097738831,3.5995719453822046 -9.71733974143089,0.1440918710436101,8.74993701189404,-34.917122745540794 -4.098687611436789,-9.75205878861841,7.980744101999381,-43.32805584620358 --2.398060521804659,2.8278192128541733,-1.626174948927721,16.91539285950553 -5.398272903061114,7.583046908728093,2.758295974535457,4.437457748228852 -3.371527871466675,-5.430064318728407,2.1915998058530857,-16.03565826569788 -2.0863644528269365,0.10824916542728857,8.144465640869694,-25.094326089867696 -2.8255940202840684,-2.286321234798363,4.771241059098381,-18.000440202657604 --8.150227640024978,-4.259315052105519,1.8923353680502952,-1.3930242667026356 --6.067265316809651,3.6776254617776942,8.4817269440159,-10.278522746897893 -8.64017362219969,9.717801217085075,4.980672567111553,-0.9266647796977245 --4.636910653452324,0.9373715699813872,4.978170771263397,-3.8217233698137143 --7.940395120999431,2.953441321061362,-0.9370552302607145,21.291726783530805 -7.692709298116139,-5.485844206553388,-6.019643260327971,2.1873435652525455 --6.485086441297707,7.06589989184231,-8.842925435171665,50.35981404591074 -5.036321300769028,2.0420739888497152,-4.368234397412891,15.435100617505809 --2.203566631709222,-6.141030616852454,-1.822186931753599,-0.5890454529472771 -3.2620868350599768,7.851306022896178,-4.479265977335616,27.896949611024628 -6.402611257683294,-4.018677430646336,0.48600102750762986,-12.289355696825485 -5.378501224056757,4.355667003325474,-7.565417868242747,31.017195148404717 -2.0486633392332614,8.253411759540757,-3.966950647644751,29.555547834722987 -2.626017326894857,3.314924154867276,9.810418858378235,-22.85112181951592 --0.04750452520510429,5.935777040113393,-0.3470621837504506,16.516617979443822 --6.775500897482147,-0.8747563332852692,-2.758815934335188,16.55155644731519 --5.130765599150095,8.959898235120185,1.1701541118251235,22.753375944830324 -9.607901921761815,-9.108821424255002,5.524296399378377,-41.93781490943017 --2.9201254899877434,5.134928295361929,-9.896226148902585,43.58829658171542 -6.956501039100711,0.8359369151964895,-6.1636372998431295,16.225403196517274 -7.725179239543149,-4.913104095867496,-1.110476120153832,-9.936035489824537 --6.142683379729563,1.4244393989902058,1.8529074318076262,5.554396424524908 --2.0474061706133977,-1.2170618863263076,8.899325908803291,-23.596187786238964 -9.359523403637155,3.4124788823300065,-1.4222946765509725,2.4507844709064064 --8.642800876507275,-9.508822574677566,2.9901775243378577,-16.775543378589024 --2.470992582133973,5.1672327675732195,-8.753045094764744,40.855147394263106 --7.756097982925145,5.227601844332813,-3.179199348468109,30.739018818654756 -5.393783291304004,-1.5186710515725927,-7.469139234639499,17.503383657767756 --7.644671911438172,1.8115363641056241,-6.167155079348694,33.57677356652164 -6.557442460132911,-4.44188855380612,-6.368621306151785,7.435670420087931 -0.21009363927752744,-2.719754693698011,1.0885820356480096,-6.289562485886653 --8.571672299069252,8.890348599509473,5.468260371802332,15.412904086362603 -7.872454219630789,-3.9905860234116357,0.9068940749874717,-16.017543419998542 diff --git a/elasticnet/tests/test_ElasticNetModel.py b/elasticnet/tests/test_ElasticNetModel.py index 5022c3c..de6cd69 100644 --- a/elasticnet/tests/test_ElasticNetModel.py +++ b/elasticnet/tests/test_ElasticNetModel.py @@ -1,19 +1,45 @@ import csv +import numpy as np +import sys +import os -import numpy +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(os.path.dirname(current_dir)) +sys.path.append(project_root) from elasticnet.models.ElasticNet import ElasticNetModel -def test_predict(): - model = ElasticNetModel() +def load_data(filepath): data = [] - with open("small_test.csv", "r") as file: + with open(filepath, "r") as file: reader = csv.DictReader(file) for row in reader: data.append(row) + X = np.array([[float(v) for k, v in datum.items() if k.startswith('x')] for datum in data]) + y = np.array([float(datum['y']) for datum in data]) + return X, y + +def test_predict(): + model = ElasticNetModel() + + train_X, train_y = load_data(os.path.join(project_root, 'train_data.csv')) + + test_X, test_y = load_data(os.path.join(project_root, 'test_data.csv')) + + model.fit(train_X, train_y) + + preds = model.predict(test_X) + #print(f"prediction:\n {preds}") + + mse = np.mean((preds - test_y) ** 2) + #print(mse) + assert mse < 1 + + print("Actual\tPredicted\tAbs Error") + + [print(f"{a}\t{p}\t{abs(a-p)}") for a, p in zip(test_y, preds)] + + print(f"MSE : {mse}") - X = numpy.array([[v for k,v in datum.items() if k.startswith('x')] for datum in data]) - y = numpy.array([[v for k,v in datum.items() if k=='y'] for datum in data]) - results = model.fit(X,y) - preds = results.predict(X) - assert preds == 0.5 +if __name__ == "__main__": + test_predict() diff --git a/regularized_discriminant_analysis/models/RegularizedDiscriminantAnalysis.py b/regularized_discriminant_analysis/models/RegularizedDiscriminantAnalysis.py deleted file mode 100644 index 089f9ad..0000000 --- a/regularized_discriminant_analysis/models/RegularizedDiscriminantAnalysis.py +++ /dev/null @@ -1,17 +0,0 @@ - - -class RDAModel(): - def __init__(self): - pass - - - def fit(self, X, y): - return RDAModelResults() - - -class RDAModelResults(): - def __init__(self): - pass - - def predict(self, x): - return 0.5 diff --git a/regularized_discriminant_analysis/test_rdamodel.py b/regularized_discriminant_analysis/test_rdamodel.py deleted file mode 100644 index 095725b..0000000 --- a/regularized_discriminant_analysis/test_rdamodel.py +++ /dev/null @@ -1,19 +0,0 @@ -import csv - -import numpy - -from regularized_discriminant_analysis.models.RegularizedDiscriminantAnalysis import RDAModel - -def test_predict(): - model = ElasticNetModel() - data = [] - with open("small_sample.csv", "r") as file: - reader = csv.DictReader(file) - for row in reader: - data.append(row) - - X = numpy.array([[v for k,v in datum.items() if k.startswith('x')] for datum in data]) - y = numpy.array([[v for k,v in datum.items() if k=='y'] for datum in data]) - results = model.fit(X,y) - preds = results.predict(X) - assert preds == 0.5 diff --git a/requirements.txt b/requirements.txt index 18af45d..fdd2127 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ numpy pytest ipython +numba +scipy \ No newline at end of file diff --git a/test_data.csv b/test_data.csv new file mode 100644 index 0000000..e27a49e --- /dev/null +++ b/test_data.csv @@ -0,0 +1,31 @@ +x_0,x_1,y +5.479120971119267,-1.2224312049589532,11.498324504666494 +7.171958398227648,3.9473605811872776,25.319215221941516 +-8.11645304224701,9.512447032735118,-7.205331828130651 +5.222794039807059,5.721286105539075,21.804029227840488 +-7.4377273464890825,-0.9922812420886569,-19.41998716425069 +-2.5840395153483753,8.535299776972035,7.429073208446392 +2.8773024016132904,6.455232265416598,18.087330981950977 +-1.131716023453377,-5.455225564304462,-9.718781169890208 +1.0916957403166965,-8.723654877916493,-10.204968038331437 +6.55262343985164,2.6332879824412974,20.36913878062648 +5.1617548017074775,-2.9094806374026323,9.8981419294219 +9.413960487898066,7.862422426443953,35.63723053452948 +5.567669941475238,-6.1072258429606485,3.9826497659091995 +-0.6655799254593155,-9.123924684255424,-15.983124054034906 +-6.914210158649043,3.6609790648490925,-12.21350908535058 +4.895243118156342,9.350194648684202,27.260560512470924 +-3.4834928372369607,-2.5908058793026223,-11.952515175975769 +-0.6088837744838411,-6.2105728183142865,-9.647583309613264 +-7.401569893290567,-0.48590147548132556,-19.16002959278494 +-5.461813018982317,3.396279893650206,-7.901573015903769 +-1.2569616225533853,6.653563921156749,7.9635322193189975 +4.005302040044983,-3.7526671723591782,4.574907801853451 +6.645196027904021,6.095287149936038,26.71269603222152 +-2.2504324193965104,-4.233437921395118,-12.138163871650603 +3.6499100794995094,-7.204950327813804,-1.5457041395370035 +-6.001835950497833,-9.85275460497989,-29.665459677712747 +5.738487550042768,3.2970171318406436,18.595904927278845 +4.1033075725267025,5.614580620439359,19.667112342761378 +-0.8216844892332009,1.3748239190578744,0.0386223153010854 +-7.204060037446851,-7.709398529280531,-29.061753768810235 diff --git a/train_data.csv b/train_data.csv new file mode 100644 index 0000000..f9990e4 --- /dev/null +++ b/train_data.csv @@ -0,0 +1,101 @@ +x_0,x_1,y +5.479120971119267,-1.2224312049589532,13.00748241504878 +7.171958398227648,3.9473605811872776,23.956331744490445 +-8.11645304224701,9.512447032735118,-5.527583923234919 +5.222794039807059,5.721286105539075,21.97547135930175 +-7.4377273464890825,-0.9922812420886569,-19.24516568055676 +-2.5840395153483753,8.535299776972035,8.25033273840083 +2.8773024016132904,6.455232265416598,17.46668930905987 +-1.131716023453377,-5.455225564304462,-9.868189611813277 +1.0916957403166965,-8.723654877916493,-11.906415067114022 +6.55262343985164,2.6332879824412974,20.78277217136093 +5.1617548017074775,-2.9094806374026323,8.196935777871875 +9.413960487898066,7.862422426443953,34.609721798987465 +5.567669941475238,-6.1072258429606485,4.380183722318371 +-0.6655799254593155,-9.123924684255424,-15.183960280732238 +-6.914210158649043,3.6609790648490925,-10.378154256992955 +4.895243118156342,9.350194648684202,25.437007050677593 +-3.4834928372369607,-2.5908058793026223,-12.064309419451916 +-0.6088837744838411,-6.2105728183142865,-10.82223809701461 +-7.401569893290567,-0.48590147548132556,-19.06045004081037 +-5.461813018982317,3.396279893650206,-7.057354881675879 +-1.2569616225533853,6.653563921156749,7.8760572623555865 +4.005302040044983,-3.7526671723591782,6.2216524490164336 +6.645196027904021,6.095287149936038,26.101415115414206 +-2.2504324193965104,-4.233437921395118,-12.172180542254655 +3.6499100794995094,-7.204950327813804,-1.4065091098524314 +-6.001835950497833,-9.85275460497989,-29.041224992443194 +5.738487550042768,3.2970171318406436,19.968317931321597 +4.1033075725267025,5.614580620439359,18.09575178974243 +-0.8216844892332009,1.3748239190578744,0.5985144371316835 +-7.204060037446851,-7.709398529280531,-28.846019557399018 +3.368059235809433,-0.5780758771373495,10.57050831135148 +1.3047221296237765,5.299977148320512,13.58861565782188 +2.694366400011816,1.071588013159916,7.990054669210593 +1.1841432149082713,-3.920998038747756,-2.708522382400132 +-9.383643308641211,-1.2656522153527519,-26.32102859646916 +-5.708306543609416,-1.8294271255072765,-17.10561406124794 +7.068061465363321,-5.321210282693185,10.503943242958867 +-8.83394516621868,-4.372322159560069,-26.93749253399857 +-4.128124844666328,3.238330294537901,-5.691900507602576 +1.1406430468255664,5.677964182128271,11.214407450188526 +3.2862708065477513,-1.8722627711985886,3.7599938298328404 +6.280407693320694,-6.660541601845922,6.047540909987943 +-9.54575853732279,-8.199042784487165,-36.725374931897285 +4.447187011929007,-0.7624553949722532,9.944845009998064 +-6.774564419327964,0.020895502067270755,-17.281928573386594 +-6.953757945736632,3.9264075015547206,-11.089046166262056 +-1.0768744885193868,-2.3795754780703504,-7.5192778297606235 +-3.96975821704247,2.605651862377769,-6.982962994430512 +-2.763747788932191,-8.24700161367798,-16.65062478081915 +-7.639881957589694,9.23795329099029,-6.030197538762829 +8.171613814152142,3.9941426762149916,25.8234629712482 +-4.682600770809609,9.383527546954479,4.70570292172901 +5.5750180793158925,4.337803783179911,23.849318042300006 +-1.0127699571242275,-4.55516876309682,-10.03624466628114 +-8.072180756930013,8.052047930876832,-7.970628952797816 +-0.8844742033277786,-5.952732704095394,-10.298729013522028 +-3.8808675169869495,1.5843913788379194,-5.096884079804902 +-6.464544341215365,7.1322856818475096,-5.949789408695387 +5.170390596704202,4.389259119018735,19.764587324346497 +-1.3581392044979257,2.546176814048863,1.7012547858902243 +1.681959378254712,2.9969320310963994,9.635062566747997 +-8.311113577202217,-1.683851956587807,-23.179717949117354 +-9.16771652276215,-0.12018361510962094,-22.73338969408585 +-3.4027757533442937,-7.109516222679062,-20.04610952574931 +-7.931940645548967,1.7528914435542404,-16.938688192515723 +-6.588140629262278,8.502402367535943,-3.4831355118607354 +1.621222794007899,-3.0626039093032587,0.19132101069111573 +1.818309829628335,-9.54392257940605,-9.825436513857255 +9.171184264828906,-0.35393126114199447,23.368602292904413 +5.654704545005725,-8.345400001551228,3.1313771780072974 +-0.2668333832367935,-0.18586011290958204,-0.29044429978789077 +8.756529099499659,1.4345610475215071,24.894920728430613 +-0.5302119788609243,-4.660486738162128,-7.763104706817636 +-3.3686200531489563,0.4134480494307553,-7.3012936656348435 +-1.2217707938990667,-9.567758402393391,-17.6276226218806 +6.525838483887156,7.923215436795335,29.015913626584204 +-7.195018220027785,1.0807228707809884,-15.96374784231147 +-7.828485177291129,3.4448018607962343,-11.810841843102876 +-4.375324323219834,3.1884526938380358,-4.082276864817289 +4.539892285737652,5.37294983835314,20.295002024430342 +-7.845181080882069,8.320236902752157,-7.395654557771705 +-5.395720182102384,-9.251748876476405,-27.979335241953986 +1.0970493878296672,-2.5815543227512254,0.561434938536195 +6.595794862648262,6.165029441286038,26.49978054369685 +-3.6572221435456935,9.057987901394899,5.424069896619725 +-4.181643237197628,0.30114258463429167,-11.246980203016824 +-4.880698188647945,8.720871400979266,2.307000111407218 +-6.707843648359637,-9.10178761215342,-29.46787020091486 +-1.2980587999392412,9.847511281116741,10.915689237412531 +7.833545325098278,4.972160389138985,27.070579089004223 +7.8158498175704985,7.8689327939572635,32.10674093921866 +0.3771672077289807,-3.6814189633841394,-4.026743627397514 +5.4402486422197605,3.233225263355221,18.65828831477868 +-2.526845422525799,-8.110666638769695,-18.086601782535 +4.935792226980521,-4.750789681542706,5.4613186669303575 +8.736263010675586,-5.180588499886305,14.722337288961976 +-7.5448413517702795,6.622253442498124,-6.957231242685356 +-6.93431366751012,-6.414633836845218,-29.024393364974426 +1.9876558304168697,7.49124081674929,16.469150536656425 +-6.071306685708535,-3.793526541998105,-20.19204410589605