-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathClassify.java
More file actions
414 lines (348 loc) · 14.3 KB
/
Classify.java
File metadata and controls
414 lines (348 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
package cs475;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.*;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
public class Classify {
static public LinkedList<Option> options = new LinkedList<Option>();
static int gd_iterations = 20;
static double gd_eta = .01;
static int num_features = 0;
static double online_learning_rate = 1.0;
static int online_training_iterations = 1;
static double polynomial_kernel_exponent =2;
static double cluster_lambda = 0.0;
static int clustering_training_iterations = 10;
static int num_clusters = 3;
public static void main(String[] args) throws IOException {
// Parse the command line.
long startTime=System.currentTimeMillis();
String[] manditory_args = { "mode"};
createCommandLineOptions();
CommandLineUtilities.initCommandLineParameters(args, Classify.options, manditory_args);
// System.out.println("CLASSPATH: "+System.getProperty("java.class.path"));//prints out null
// System.out.println("PATH: "+System.getenv("PATH"));//prints out null
if (CommandLineUtilities.hasArg("gd_iterations"))
gd_iterations = CommandLineUtilities.getOptionValueAsInt("gd_iterations");
if (CommandLineUtilities.hasArg("online_training_iterations"))
online_training_iterations = CommandLineUtilities.getOptionValueAsInt("online_training_iterations");
if (CommandLineUtilities.hasArg("gd_eta"))
gd_eta = CommandLineUtilities.getOptionValueAsFloat("gd_eta");
if (CommandLineUtilities.hasArg("polynomial_kernel_exponent"))
polynomial_kernel_exponent = CommandLineUtilities.getOptionValueAsFloat("polynomial_kernel_exponent");
if (CommandLineUtilities.hasArg("num_features_to_select"))
num_features = CommandLineUtilities.getOptionValueAsInt("num_features_to_select");
if (CommandLineUtilities.hasArg("online_learning_rate"))
online_learning_rate = CommandLineUtilities.getOptionValueAsFloat("online_learning_rate");
if (CommandLineUtilities.hasArg("cluster_lambda"))
cluster_lambda = CommandLineUtilities.getOptionValueAsFloat("cluster_lambda");
if (CommandLineUtilities.hasArg("clustering_training_iterations"))
clustering_training_iterations = CommandLineUtilities.getOptionValueAsInt("clustering_training_iterations");
if (CommandLineUtilities.hasArg("num_clusters"))
num_clusters = CommandLineUtilities.getOptionValueAsInt("num_clusters");
String mode = CommandLineUtilities.getOptionValue("mode");
String data = CommandLineUtilities.getOptionValue("data");
String predictions_file = CommandLineUtilities.getOptionValue("predictions_file");
String algorithm = CommandLineUtilities.getOptionValue("algorithm");
String model_file = CommandLineUtilities.getOptionValue("model_file");
if (mode.equalsIgnoreCase("train")) {
if (data == null || algorithm == null || model_file == null) {
System.out.println("Train requires the following arguments: data, algorithm, model_file");
System.exit(0);
}
// Load the training data.
DataReader data_reader = new DataReader(data, true);
List<Instance> instances = data_reader.readData();
data_reader.close();
// Train the model.
Predictor predictor = train(instances, algorithm);
saveObject(predictor, model_file);
} else if (mode.equalsIgnoreCase("test")) {
if (data == null || predictions_file == null || model_file == null) {
System.out.println("Train requires the following arguments: data, predictions_file, model_file");
System.exit(0);
}
// Load the test data.
DataReader data_reader = new DataReader(data, true);
List<Instance> instances = data_reader.readData();
data_reader.close();
// Load the model.
Predictor predictor = (Predictor)loadObject(model_file);
// System.out.println("after load model");
evaluateAndSavePredictions(predictor, instances, predictions_file);
} else {
System.out.println("Requires mode argument.");
}
long endTime=System.currentTimeMillis();
System.out.println("Time: "+(endTime-startTime)+"ms");
}
private static Predictor train(List<Instance> instances, String algorithm) {
// TODO Train the model using "algorithm" on "data"
if (algorithm.equalsIgnoreCase("majority")){
MajorityClassification mc = new MajorityClassification();
mc.train(instances);
// TODO Evaluate the model
MajorityEvaluator mae = new MajorityEvaluator();
double eva = mae.evaluate(instances, mc);
System.out.println(eva);
return mc;
}
if(algorithm.equalsIgnoreCase("even_odd")){
EvenOddClassification eoc = new EvenOddClassification();
eoc.train(instances);
// TODO Evaluate the model
EvenOddEvaluator eoe = new EvenOddEvaluator();
double eva = eoe.evaluate(instances, eoc);
System.out.println(eva);
return eoc;
}
if(algorithm.equalsIgnoreCase("logistic_regression")){
//TODO Train the model
LogisticRegression lr = new LogisticRegression(gd_iterations,gd_eta,num_features);
lr.train(instances);
lr.printWeight();
//Evaluate the model
LogisticRegressionEvaluator lre =
new LogisticRegressionEvaluator();
double eva = lre.evaluate(instances, lr);
return lr;
}
if(algorithm.equalsIgnoreCase("margin_perceptron")){
//TODO Train the model
MarginPerceptron mp = new MarginPerceptron(online_training_iterations,online_learning_rate);
mp.train(instances);
//Evaluate the model
MarginPerceptronEvaluator mpe =
new MarginPerceptronEvaluator();
double eva = mpe.evaluate(instances, mp);
return mp;
}
if(algorithm.equalsIgnoreCase("mira")){
//TODO Train the model
MarginInfusedRelaxation mira = new MarginInfusedRelaxation(online_training_iterations);
mira.train(instances);
//Evaluate the model
MiraEvaluator mirae =
new MiraEvaluator();
double eva = mirae.evaluate(instances, mira);
return mira;
}
if(algorithm.equalsIgnoreCase("perceptron_linear_kernel")){
//TODO Train the model
DualPerceptron dp = new DualPerceptron(online_training_iterations,
online_learning_rate, true, false,0,instances);
dp.train(instances);
//Evaluate the model
DualPerceptronEvaluator dpe =
new DualPerceptronEvaluator();
double eva = dpe.evaluate(instances, dp);
return dp;
}
if(algorithm.equalsIgnoreCase("perceptron_polynomial_kernel")){
//TODO Train the model
DualPerceptron dp = new DualPerceptron(online_training_iterations,
online_learning_rate, false, true, polynomial_kernel_exponent, instances);
System.out.println("d = " + polynomial_kernel_exponent);
dp.train(instances);
//Evaluate the model
DualPerceptronEvaluator dpe =
new DualPerceptronEvaluator();
double eva = dpe.evaluate(instances, dp);
return dp;
}
if(algorithm.equalsIgnoreCase("ska")){
//TODO Train the model
SKA ska = new SKA(num_clusters,clustering_training_iterations);
// System.out.println("d = " + polynomial_kernel_exponent);
ska.train(instances);
//Evaluate the model
return ska;
}
if(algorithm.equalsIgnoreCase("lambda_means")){
//TODO Train the model
LambdaMeans lm = new LambdaMeans(cluster_lambda, clustering_training_iterations);
// System.out.println("d = " + polynomial_kernel_exponent);
lm.train(instances);
// LambdaMeansEvaluator lme = new LambdaMeansEvaluator();
// double eva = lme.evaluate(instances, lm);
System.out.println("cluster: " +lm.numCluster);
// lm.printResult();
//Evaluate the model
// LambdaMeansEvaluator lme =
// new LambdaMeansEvaluator();
// double eva = mirae.evaluate(instances, mira);
return lm;
}
return null;
}
private static void evaluateAndSavePredictions(Predictor predictor,
List<Instance> instances, String predictions_file) throws IOException {
PredictionsWriter writer = new PredictionsWriter(predictions_file);
// TODO Evaluate the model if labels are available.
// System.out.println(predictor);
if(predictor.getpreName().equalsIgnoreCase("majority")){
System.out.println("majority test");
MajorityEvaluator mae = new MajorityEvaluator();
double eva = mae.evaluate(instances, predictor);
System.out.println(eva);
for (Instance instance : instances) {
Label label = predictor.predict(instance);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("even_odd")){
System.out.println("even_odd test");
EvenOddEvaluator eoe = new EvenOddEvaluator();
double eva = eoe.evaluate(instances, predictor);
System.out.println(eva);
for (Instance instance : instances) {
Label label = predictor.predict(instance);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("logistic_regression")){
System.out.println("LR test");
LogisticRegressionEvaluator lre = new LogisticRegressionEvaluator();
double eva = lre.evaluate(instances, predictor);
System.out.println(eva);
for (Instance instance : instances) {
Label label = predictor.predict(instance);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("margin_perceptron")){
System.out.println("margin_perceptron test");
MarginPerceptronEvaluator mpe = new MarginPerceptronEvaluator();
double eva = mpe.evaluate(instances, predictor);
System.out.println(eva);
for (Instance instance : instances) {
Label label = predictor.predict(instance);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("mira")){
System.out.println("margin_perceptron test");
MiraEvaluator mirae = new MiraEvaluator();
double eva = mirae.evaluate(instances, predictor);
System.out.println(eva);
for (Instance instance : instances) {
Label label = predictor.predict(instance);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("perceptron_linear_kernel")){
System.out.println("perceptron linear test");
DualPerceptronEvaluator dpe =
new DualPerceptronEvaluator();
double eva = dpe.evaluate(instances, predictor);
System.out.println(eva);
DualPerceptron dp = (DualPerceptron)predictor;
double d = dp.dvalue();
boolean isLinear = dp.isLinear();
boolean isPoly = !isLinear;
for (Instance instance : instances) {
Label label = dp.predict(instance,isLinear,isPoly,d);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("perceptron_polynomial_kernel")){
System.out.println("perceptron polynomial test");
DualPerceptronEvaluator dpe =
new DualPerceptronEvaluator();
double eva = dpe.evaluate(instances, predictor);
System.out.println(eva);
DualPerceptron dp = (DualPerceptron)predictor;
double d = dp.dvalue();
boolean isLinear = dp.isLinear();
boolean isPoly = !isLinear;
for (Instance instance : instances) {
Label label = dp.predict(instance,isLinear,isPoly,d);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("lambda_means")){
System.out.println("lambda test");
LambdaMeansEvaluator lme = new LambdaMeansEvaluator();
double eva = lme.evaluate(instances, predictor);
// System.out.println(eva);
for (Instance instance : instances) {
Label label = predictor.predict(instance);
writer.writePrediction(label);
}
}
if(predictor.getpreName().equalsIgnoreCase("ska")){
System.out.println("ska test");
SKAEvaluator skae = new SKAEvaluator();
double eva = skae.evaluate(instances, predictor);
// System.out.println(eva);
for (Instance instance : instances) {
Label label = predictor.predict(instance);
writer.writePrediction(label);
}
}
writer.close();
}
public static void saveObject(Object object, String file_name) {
try {
ObjectOutputStream oos =
new ObjectOutputStream(new BufferedOutputStream(
new FileOutputStream(new File(file_name))));
oos.writeObject(object);
oos.close();
}
catch (IOException e) {
System.err.println("Exception writing file " + file_name + ": " + e);
}
}
/**
* Load a single object from a filename.
* @param file_name
* @return
*/
public static Object loadObject(String file_name) {
ObjectInputStream ois;
try {
ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(new File(file_name))));
Object object = ois.readObject();
ois.close();
return object;
} catch (IOException e) {
e.printStackTrace();
System.err.println("IO Error loading: " + file_name);
} catch (ClassNotFoundException e) {
System.err.println("Not found class in when loading: " + file_name);
}
return null;
}
public static void registerOption(String option_name, String arg_name, boolean has_arg, String description) {
OptionBuilder.withArgName(arg_name);
OptionBuilder.hasArg(has_arg);
OptionBuilder.withDescription(description);
Option option = OptionBuilder.create(option_name);
Classify.options.add(option);
}
private static void createCommandLineOptions() {
registerOption("data", "String", true, "The data to use.");
registerOption("mode", "String", true, "Operating mode: train or test.");
registerOption("predictions_file", "String", true, "The predictions file to create.");
registerOption("algorithm", "String", true, "The name of the algorithm for training.");
registerOption("model_file", "String", true, "The name of the model file to create/load.");
// Other options will be added here.
registerOption("gd_eta", "int", true, "The step size parameter for GD.");
registerOption("gd_iterations", "int", true, "The number of GD iterations.");
registerOption("num_features_to_select", "int", true, "The number of features to select.");
registerOption("online_learning_rate", "double",true,"The learning rate for pereceptron.");
registerOption("online_training_iterations", "int", true, "The number of training iterations for online methods.");
registerOption("polynomial_kernel_exponent", "double", true, "The exponent of the polynomial kernel.");
registerOption("cluster_lambda", "double", true, "The value of lambda of lambda-means.");
registerOption("clustering_training_iterations", "int", true, "The number of clustering iterations");
registerOption("num_clusters", "int", true, "The number of cluster");
}
}