diff --git a/.classpath b/.classpath
index 4e833749a..879b51ba4 100644
--- a/.classpath
+++ b/.classpath
@@ -1,12 +1,16 @@
+
+
+
+
+
-
@@ -39,19 +43,57 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/.gitignore b/.gitignore
index 940d04e06..9035a922a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,5 @@
/.settings/
/junitvmwatcher*.properties
.vscode/
+/model/
+Query.jar
diff --git a/src/antlr/Boa.g b/src/antlr/Boa.g
index 7ba829f44..40befbe02 100644
--- a/src/antlr/Boa.g
+++ b/src/antlr/Boa.g
@@ -112,6 +112,7 @@ type returns [AbstractType ast]
| q=queueType { $ast = $q.ast; }
| set=setType { $ast = $set.ast; }
| e=enumType { $ast = $e.ast; }
+ | model=modelType { $ast = $model.ast; }
| id=identifier { $ast = $id.ast; }
;
@@ -202,7 +203,12 @@ outputType returns [OutputType ast]
@after { $ast.setPositions($l, $c, getEndLine(), getEndColumn()); }
: OUTPUT (tk=SET { $ast = new OutputType((Identifier)new Identifier($tk.text).setPositions(getStartLine($tk), getStartColumn($tk), getEndLine($tk), getEndColumn($tk))); } | id=identifier { $ast = new OutputType($id.ast); }) (LPAREN el=expressionList RPAREN { $ast.setArgs($el.list); })? (LBRACKET m=component RBRACKET { $ast.addIndice($m.ast); })* OF m=component { $ast.setType($m.ast); } (WEIGHT m=component { $ast.setWeight($m.ast); })? (FORMAT LPAREN el=expressionList RPAREN)?
;
-
+modelType returns [ModelType ast]
+ locals [int l, int c]
+ @init { $l = getStartLine(); $c = getStartColumn(); }
+ @after { $ast.setPositions($l, $c, getEndLine(), getEndColumn()); }
+ : id=identifier { $ast = new ModelType($id.ast); } OF m=component { $ast.setType($m.ast); }
+ ;
functionType returns [FunctionType ast]
locals [int l, int c]
@init {
diff --git a/src/java/boa/BoaTup.java b/src/java/boa/BoaTup.java
new file mode 100644
index 000000000..43c630c32
--- /dev/null
+++ b/src/java/boa/BoaTup.java
@@ -0,0 +1,11 @@
+package boa;
+
+import java.io.IOException;
+import java.lang.ClassNotFoundException;
+
+public interface BoaTup {
+ public String[] getValues();
+ public String[] getFieldNames();
+ public byte[] serialize(Object o) throws IOException;
+ public Object getValue(String f);
+}
diff --git a/src/java/boa/aggregators/Aggregator.java b/src/java/boa/aggregators/Aggregator.java
index 619aa1070..e531af17c 100644
--- a/src/java/boa/aggregators/Aggregator.java
+++ b/src/java/boa/aggregators/Aggregator.java
@@ -22,6 +22,7 @@
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer.Context;
+import boa.BoaTup;
import boa.functions.BoaCasts;
import boa.io.EmitKey;
import boa.io.EmitValue;
@@ -34,10 +35,12 @@
*/
public abstract class Aggregator {
private long arg;
+ private String mlarg; //for ML
@SuppressWarnings("rawtypes")
private Context context;
private EmitKey key;
private boolean combining;
+ private int vectorSize; //for ML
/**
* Construct an Aggregator.
@@ -60,6 +63,21 @@ public Aggregator(final long arg) {
this.arg = arg;
}
+ /**
+ * Construct an Aggregator.
+ *
+ * @param arg
+ * A String containing the argument to the table
+ *
+ */
+ //for ML
+ public Aggregator(final String arg) {
+ this();
+
+ this.mlarg = arg;
+ }
+
+
/**
* Reset this aggregator for a new key.
*
@@ -92,6 +110,8 @@ public void aggregate(final double data, final String metadata) throws IOExcepti
public void aggregate(final double data) throws IOException, InterruptedException, FinishedException {
this.aggregate(BoaCasts.doubleToString(data), null);
}
+
+
@SuppressWarnings("unchecked")
protected void collect(final String data, final String metadata) throws IOException, InterruptedException {
@@ -122,6 +142,13 @@ protected void collect(final double data, final String metadata) throws IOExcept
protected void collect(final double data) throws IOException, InterruptedException {
this.collect(BoaCasts.doubleToString(data), null);
}
+
+ public void aggregate(final BoaTup data, final String metadata) throws IOException, InterruptedException, FinishedException, IllegalAccessException {
+ }
+
+ public void aggregate(final BoaTup data) throws IOException, InterruptedException, FinishedException, IllegalAccessException {
+ this.aggregate(data, null);
+ }
public void finish() throws IOException, InterruptedException {
// do nothing by default
@@ -155,4 +182,13 @@ public void setKey(final EmitKey key) {
public EmitKey getKey() {
return this.key;
}
+
+ //for ML
+ public int getVectorSize() {
+ return this.vectorSize;
+ }
+ //for ML
+ public void setVectorSize(int vectorSize) {
+ this.vectorSize = vectorSize;
+ }
}
diff --git a/src/java/boa/aggregators/ml/LinearRegressionAggregator.java b/src/java/boa/aggregators/ml/LinearRegressionAggregator.java
new file mode 100644
index 000000000..7893a9f84
--- /dev/null
+++ b/src/java/boa/aggregators/ml/LinearRegressionAggregator.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright 2014, Hridesh Rajan, Robert Dyer,
+ * and Iowa State University of Science and Technology
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package boa.aggregators.ml;
+
+import boa.BoaTup;
+import boa.aggregators.AggregatorSpec;
+import boa.aggregators.FinishedException;
+import weka.classifiers.functions.LinearRegression;
+
+import java.io.IOException;
+
+/**
+ * A Boa aggregator for training the model using LinearRegression.
+ *
+ * @author ankuraga
+ */
+@AggregatorSpec(name = "linearregression", formalParameters = {"string"})
+public class LinearRegressionAggregator extends MLAggregator {
+ private LinearRegression model;
+
+ public LinearRegressionAggregator() {
+ this.model = new LinearRegression();
+ }
+
+ public LinearRegressionAggregator(final String s) {
+ super(s);
+ }
+
+ @Override
+ public void aggregate(String data, String metadata) throws NumberFormatException, IOException, InterruptedException {
+ aggregate(data, metadata, "LinearRegression");
+ }
+
+ public void aggregate(final BoaTup data, final String metadata) throws IOException, InterruptedException, FinishedException, IllegalAccessException {
+ aggregate(data, metadata, "LinearRegression");
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void finish() throws IOException, InterruptedException {
+ try {
+ System.out.println("Linearregression working now with: " + this.trainingSet.numInstances());
+ System.out.println(this.trainingSet);
+ this.model.buildClassifier(this.trainingSet);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ System.out.println("modeling done");
+ this.saveModel(this.model);
+// this.saveTrainingSet(this.trainingSet);
+ this.collect(this.model.toString());
+ }
+}
diff --git a/src/java/boa/aggregators/ml/MLAggregator.java b/src/java/boa/aggregators/ml/MLAggregator.java
new file mode 100644
index 000000000..d2e225267
--- /dev/null
+++ b/src/java/boa/aggregators/ml/MLAggregator.java
@@ -0,0 +1,314 @@
+/*
+ * Copyright 2014, Hridesh Rajan, Robert Dyer,
+ * and Iowa State University of Science and Technology
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package boa.aggregators.ml;
+
+import boa.BoaTup;
+import boa.aggregators.Aggregator;
+import boa.datagen.DefaultProperties;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapred.FileOutputFormat;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapreduce.JobContext;
+import weka.classifiers.Classifier;
+import weka.classifiers.Evaluation;
+import weka.core.*;
+import weka.filters.Filter;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.lang.reflect.Array;
+import java.util.ArrayList;
+
+/**
+ * A Boa ML aggregator to train models.
+ *
+ * @author ankuraga
+ */
+public abstract class MLAggregator extends Aggregator {
+ protected final ArrayList fvAttributes;
+ protected Instances unFilteredInstances;
+ protected ArrayList vector;
+ protected Instances trainingSet;
+ protected int NumOfAttributes;
+ protected String[] options;
+ protected boolean flag;
+ protected int count;
+
+ public MLAggregator() {
+ this.fvAttributes = new ArrayList();
+ this.vector = new ArrayList();
+ }
+
+ public MLAggregator(final String s) {
+ super(s);
+ this.fvAttributes = new ArrayList();
+ this.vector = new ArrayList();
+ try {
+ options = Utils.splitOptions(s);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ public void evaluate(Classifier model, Instances trainingSet) {
+ try {
+ Evaluation evaluation = new Evaluation(trainingSet);
+ evaluation.evaluateModel(model, trainingSet);
+ this.collect(" Training set evaluation \n " + evaluation.toSummaryString());
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ public void saveTrainingSet(Object trainingSet) {
+ FSDataOutputStream out = null;
+ FileSystem fileSystem = null;
+ Path filePath = null;
+ ObjectOutputStream objectOut = null;
+ try {
+ JobContext context = (JobContext) getContext();
+ Configuration configuration = context.getConfiguration();
+ int boaJobId = configuration.getInt("boa.hadoop.jobid", 0);
+ JobConf job = new JobConf(configuration);
+ Path outputPath = FileOutputFormat.getOutputPath(job);
+ fileSystem = outputPath.getFileSystem(context.getConfiguration());
+
+ if (DefaultProperties.localOutput != null) {
+ fileSystem.mkdirs(new Path(DefaultProperties.localOutput, new Path("" + boaJobId)));
+ filePath = new Path(DefaultProperties.localOutput, new Path("" + boaJobId, new Path(("" + getKey()).split("\\[")[0] + System.currentTimeMillis() + "data")));
+ } else {
+ fileSystem.mkdirs(new Path(DefaultProperties.HADOOP_OUT_LOCATION, new Path("" + boaJobId)));
+ filePath = new Path(DefaultProperties.HADOOP_OUT_LOCATION, new Path("" + boaJobId, new Path(("" + getKey()).split("\\[")[0] + System.currentTimeMillis() + "data")));
+ }
+
+ if (fileSystem.exists(filePath))
+ return;
+
+ out = fileSystem.create(filePath);
+
+ ByteArrayOutputStream byteOutStream = new ByteArrayOutputStream();
+ objectOut = new ObjectOutputStream(byteOutStream);
+ objectOut.writeObject(trainingSet);
+ byte[] serializedObject = byteOutStream.toByteArray();
+
+ out.write(serializedObject);
+
+ this.collect(filePath.toString());
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ } finally {
+ try {
+ if (out != null) out.close();
+ if (objectOut != null) objectOut.close();
+ } catch (final Exception e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ public void saveModel(Object model) {
+ FSDataOutputStream out = null;
+ FileSystem fileSystem = null;
+ Path filePath = null;
+ ObjectOutputStream objectOut = null;
+ try {
+ JobContext context = (JobContext) getContext();
+ Configuration configuration = context.getConfiguration();
+ int boaJobId = configuration.getInt("boa.hadoop.jobid", 0);
+ JobConf job = new JobConf(configuration);
+ Path outputPath = FileOutputFormat.getOutputPath(job);
+ fileSystem = outputPath.getFileSystem(context.getConfiguration());
+
+ if (DefaultProperties.localOutput != null) {
+ fileSystem.mkdirs(new Path(DefaultProperties.localOutput, new Path("" + boaJobId)));
+ filePath = new Path(DefaultProperties.localOutput, new Path("" + boaJobId, new Path(("" + getKey()).split("\\[")[0] + System.currentTimeMillis() + "ML.model")));
+ } else {
+ fileSystem.mkdirs(new Path(DefaultProperties.HADOOP_OUT_LOCATION, new Path("" + boaJobId)));
+ filePath = new Path(DefaultProperties.HADOOP_OUT_LOCATION, new Path("" + boaJobId, new Path(("" + getKey()).split("\\[")[0] + System.currentTimeMillis() + "ML.model")));
+ }
+
+
+ if (fileSystem.exists(filePath))
+ return;
+
+ out = fileSystem.create(filePath);
+
+ ByteArrayOutputStream byteOutStream = new ByteArrayOutputStream();
+ objectOut = new ObjectOutputStream(byteOutStream);
+ objectOut.writeObject(model);
+ byte[] serializedObject = byteOutStream.toByteArray();
+
+ out.write(serializedObject);
+
+ this.collect(filePath.toString());
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ } finally {
+ try {
+ if (out != null) out.close();
+ if (objectOut != null) objectOut.close();
+ } catch (final Exception e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ protected void applyFilterToUnfilteredInstances(Filter filter) throws Exception {
+ unFilteredInstances = Filter.useFilter(unFilteredInstances, filter);
+ }
+
+ protected void applyFilterToUnfilteredInstances(Filter filter, Instances filteredInstances) throws Exception {
+ unFilteredInstances = Filter.useFilter(unFilteredInstances, filter);
+ moveFromUnFilteredToFiltered(filteredInstances);
+ }
+
+ protected void moveFromUnFilteredToFiltered(Instances filteredInstances) {
+ int totalUnfilteredInstances = unFilteredInstances.numInstances();
+ filteredInstances.addAll(unFilteredInstances.subList(0, totalUnfilteredInstances));
+ while (totalUnfilteredInstances-- > 0) {
+ unFilteredInstances.remove(0);
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public abstract void aggregate(final String data, final String metadata) throws NumberFormatException, IOException, InterruptedException;
+
+ protected void attributeCreation(BoaTup data, final String name) {
+ this.fvAttributes.clear();
+ try {
+ String[] fieldNames = data.getFieldNames();
+ int count = 0;
+ for (int i = 0; i < fieldNames.length; i++) {
+ if (data.getValue(fieldNames[i]).getClass().isEnum()) {
+ ArrayList fvNominalVal = new ArrayList();
+ for (Object obj : data.getValue(fieldNames[i]).getClass().getEnumConstants())
+ fvNominalVal.add(obj.toString());
+ this.fvAttributes.add(new Attribute("Nominal" + count, fvNominalVal));
+ count++;
+ } else if (data.getValue(fieldNames[i]).getClass().isArray()) {
+ int l = Array.getLength(data.getValue(fieldNames[i])) - 1;
+ for (int j = 0; j <= l; j++) {
+ this.fvAttributes.add(new Attribute("Attribute" + count));
+ count++;
+ }
+ } else {
+ this.fvAttributes.add(new Attribute("Attribute" + count));
+ count++;
+ }
+ }
+ this.NumOfAttributes = count;
+ this.flag = true;
+ this.trainingSet = new Instances(name, this.fvAttributes, 1);
+ this.trainingSet.setClassIndex(this.NumOfAttributes - 1);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ protected void instanceCreation(ArrayList data) {
+ try {
+ Instance instance = new DenseInstance(this.NumOfAttributes);
+ for (int i = 0; i < this.NumOfAttributes; i++)
+ instance.setValue((Attribute) this.fvAttributes.get(i), Double.parseDouble(data.get(i)));
+
+ trainingSet.add(instance);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ protected void instanceCreation(BoaTup data) {
+ try {
+ int count = 0;
+ Instance instance = new DenseInstance(this.NumOfAttributes);
+ String[] fieldNames = data.getFieldNames();
+ for (int i = 0; i < fieldNames.length; i++) {
+ if (data.getValue(fieldNames[i]).getClass().isEnum()) {
+ instance.setValue((Attribute) this.fvAttributes.get(count), String.valueOf(data.getValue(fieldNames[i])));
+ count++;
+ } else if (data.getValue(fieldNames[i]).getClass().isArray()) {
+ int x = Array.getLength(data.getValue(fieldNames[i])) - 1;
+ Object o = data.getValue(fieldNames[i]);
+ for (int j = 0; j <= x; j++) {
+ instance.setValue((Attribute) this.fvAttributes.get(count), Double.parseDouble(String.valueOf(Array.get(o, j))));
+ count++;
+ }
+ } else {
+ if (NumberUtils.isNumber(String.valueOf(data.getValue(fieldNames[i]))))
+ instance.setValue((Attribute) this.fvAttributes.get(count), Double.parseDouble(String.valueOf(data.getValue(fieldNames[i]))));
+ else
+ instance.setValue((Attribute) this.fvAttributes.get(count), String.valueOf(data.getValue(fieldNames[i])));
+ count++;
+ }
+ }
+ trainingSet.add(instance);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+
+ protected void attributeCreation(String name) {
+ System.out.println(this.getVectorSize());
+ fvAttributes.clear();
+ NumOfAttributes = this.getVectorSize();
+ try {
+ for (int i = 0; i < NumOfAttributes; i++) {
+ fvAttributes.add(new Attribute("Attribute" + i));
+ }
+
+ this.flag = true;
+ trainingSet = new Instances(name, fvAttributes, 1);
+ trainingSet.setClassIndex(NumOfAttributes - 1);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ protected void aggregate(final String data, final String metadata, String
+ name) throws IOException, InterruptedException {
+ if (this.count != this.getVectorSize()) {
+ this.vector.add(data);
+ this.count++;
+ }
+
+ if (this.count == this.getVectorSize()) {
+ if (this.flag != true)
+ attributeCreation(name);
+ instanceCreation(this.vector);
+ this.vector = new ArrayList();
+ this.count = 0;
+ }
+ }
+
+ protected void aggregate(final BoaTup data, final String metadata, String
+ name) throws IOException, InterruptedException {
+ if (this.flag != true)
+ attributeCreation(data, name);
+ instanceCreation(data);
+ }
+}
diff --git a/src/java/boa/compiler/SymbolTable.java b/src/java/boa/compiler/SymbolTable.java
index ea182ba05..039dde962 100644
--- a/src/java/boa/compiler/SymbolTable.java
+++ b/src/java/boa/compiler/SymbolTable.java
@@ -28,9 +28,10 @@
import boa.aggregators.AggregatorSpec;
import boa.functions.FunctionSpec;
import boa.types.*;
+import boa.types.ml.BoaLinearRegression;
+import boa.types.ml.BoaModel;
import boa.types.proto.*;
import boa.types.proto.enums.*;
-
import boa.compiler.ast.Operand;
import boa.compiler.ast.statements.VisitStatement;
@@ -379,12 +380,16 @@ public static void resetTypeMap() {
types.put("float", new BoaFloat());
types.put("time", new BoaTime());
types.put("string", new BoaString());
-
+
for (final BoaType t : dslTupleTypes)
types.put(t.toString(), t);
for (final BoaType t : dslMapTypes)
types.put(t.toString(), t);
+
+ types.put("LinearRegression", new BoaLinearRegression());
+ types.put("Model", new BoaModel());
+ types.put("tuple", new BoaTuple());
}
public SymbolTable cloneNonLocals() throws IOException {
@@ -596,6 +601,7 @@ private static void importLibs(final List urls) throws IOException {
boa.aggregators.UniqueAggregator.class,
boa.aggregators.VarianceAggregator.class,
boa.aggregators.PreconditionAggregator.class,
+ boa.aggregators.ml.LinearRegressionAggregator.class,
};
for (final Class> c : builtinAggs)
importAggregator(c);
@@ -755,4 +761,13 @@ public String toString() {
return r.toString();
}
+
+ public static BoaType getMLAggregatorType(String aggregtorName) {
+ for (Entry e : types.entrySet()) {
+ if (e.getKey().equalsIgnoreCase(aggregtorName)) {
+ return e.getValue();
+ }
+ }
+ return null;
+ }
}
diff --git a/src/java/boa/compiler/ast/types/ModelType.java b/src/java/boa/compiler/ast/types/ModelType.java
new file mode 100644
index 000000000..4b5e8566e
--- /dev/null
+++ b/src/java/boa/compiler/ast/types/ModelType.java
@@ -0,0 +1,91 @@
+/*
+ * Copyright 2014, Hridesh Rajan, Robert Dyer,
+ * and Iowa State University of Science and Technology
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package boa.compiler.ast.types;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import boa.compiler.ast.Component;
+import boa.compiler.ast.Identifier;
+import boa.compiler.visitors.AbstractVisitor;
+import boa.compiler.visitors.AbstractVisitorNoArgNoRet;
+import boa.compiler.visitors.AbstractVisitorNoReturn;
+
+/**
+ *
+ * @author ankuraga
+ */
+public class ModelType extends AbstractType {
+ protected Identifier id;
+ protected Component t;
+
+ public Identifier getId() {
+ return id;
+ }
+
+ public void setId(final Identifier id) {
+ id.setParent(this);
+ this.id = id;
+ }
+
+ public Component getType() {
+ return t;
+ }
+
+ public void setType(final Component t) {
+ t.setParent(this);
+ this.t = t;
+ }
+
+ public ModelType (final Identifier id) {
+ this(id, null);
+ }
+
+ public ModelType (final Identifier id, final Component t) {
+ if (id != null)
+ id.setParent(this);
+ if (t != null)
+ t.setParent(this);
+ this.id = id;
+ this.t = t;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public T accept(final AbstractVisitor v, A arg) {
+ return v.visit(this, arg);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void accept(final AbstractVisitorNoReturn v, A arg) {
+ v.visit(this, arg);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void accept(final AbstractVisitorNoArgNoRet v) {
+ v.visit(this);
+ }
+
+ public ModelType clone() {
+ final ModelType m;
+ m = new ModelType(id.clone(), t.clone());
+ copyFieldsTo(m);
+ return m;
+ }
+}
diff --git a/src/java/boa/compiler/visitors/AbstractVisitor.java b/src/java/boa/compiler/visitors/AbstractVisitor.java
index 85c992b38..8c768e765 100644
--- a/src/java/boa/compiler/visitors/AbstractVisitor.java
+++ b/src/java/boa/compiler/visitors/AbstractVisitor.java
@@ -468,4 +468,10 @@ public ReturnTypeT visit(final TraversalType n, final ArgTypeT arg) {
public ReturnTypeT visit(final FixPType n, final ArgTypeT arg) {
return null;
}
+
+ public ReturnTypeT visit(final ModelType n, final ArgTypeT arg) {
+ n.getId().accept(this, arg);
+ n.getType().accept(this, arg);
+ return null;
+ }
}
diff --git a/src/java/boa/compiler/visitors/AbstractVisitorNoArgNoRet.java b/src/java/boa/compiler/visitors/AbstractVisitorNoArgNoRet.java
index a022a0f3a..e538fe0a7 100644
--- a/src/java/boa/compiler/visitors/AbstractVisitorNoArgNoRet.java
+++ b/src/java/boa/compiler/visitors/AbstractVisitorNoArgNoRet.java
@@ -414,4 +414,9 @@ public void visit(final TraversalType n) {
public void visit(final FixPType n) {
}
+
+ public void visit(final ModelType n) {
+ n.getId().accept(this);
+ n.getType().accept(this);
+ }
}
diff --git a/src/java/boa/compiler/visitors/AbstractVisitorNoReturn.java b/src/java/boa/compiler/visitors/AbstractVisitorNoReturn.java
index bed42124e..26c3345be 100644
--- a/src/java/boa/compiler/visitors/AbstractVisitorNoReturn.java
+++ b/src/java/boa/compiler/visitors/AbstractVisitorNoReturn.java
@@ -406,4 +406,9 @@ public void visit(final TraversalType n, final ArgTypeT arg) {
public void visit(final FixPType n, final ArgTypeT arg) {
}
+
+ public void visit(final ModelType n, final ArgTypeT arg) {
+ n.getId().accept(this, arg);
+ n.getType().accept(this, arg);
+ }
}
diff --git a/src/java/boa/compiler/visitors/CodeGeneratingVisitor.java b/src/java/boa/compiler/visitors/CodeGeneratingVisitor.java
index 16bb1b469..50cc05d93 100644
--- a/src/java/boa/compiler/visitors/CodeGeneratingVisitor.java
+++ b/src/java/boa/compiler/visitors/CodeGeneratingVisitor.java
@@ -37,6 +37,7 @@
import boa.compiler.ast.types.*;
import boa.compiler.visitors.analysis.*;
import boa.types.*;
+import boa.types.ml.BoaModel;
/**
*
@@ -969,10 +970,12 @@ public void visit(final AssignmentStatement n) {
final ST st = stg.getInstanceOf("Assignment");
n.getLhs().accept(this);
+
final String lhs = code.removeLast();
n.getRhs().accept(this);
String rhs = code.removeLast();
+ System.out.println(n.getRhs());
if (n.getLhs().type instanceof BoaTuple && n.getRhs().type instanceof BoaArray) {
final Operand op = n.getRhs().getLhs().getLhs().getLhs().getLhs().getLhs().getOperand();
@@ -1003,10 +1006,17 @@ else if (lhs.charAt(idx) == ')')
code.add(lhs.substring(0, idx - ".get(".length()) + ".put(" + lhs.substring(idx, lhs.lastIndexOf(')')) + ", " + rhs + lhs.substring(lhs.lastIndexOf(')')) + ";");
return;
}
+ String typecast = "";
+ if (rhs.contains(".load(")) {
+
+ rhs = rhs.substring(0,rhs.length()-1) + ", new " +
+ ((BoaModel)n.getLhs().type).getType().toJavaType() + "())";
+ typecast = "(" + (n.getLhs().type + "").split("\\/")[0] + ")";
+ }
st.add("lhs", lhs);
st.add("operator", n.getOp());
- st.add("rhs", rhs);
+ st.add("rhs", typecast + rhs);
code.add(st.render());
}
diff --git a/src/java/boa/compiler/visitors/TypeCheckingVisitor.java b/src/java/boa/compiler/visitors/TypeCheckingVisitor.java
index 7d9f38100..aff37b5e2 100644
--- a/src/java/boa/compiler/visitors/TypeCheckingVisitor.java
+++ b/src/java/boa/compiler/visitors/TypeCheckingVisitor.java
@@ -30,6 +30,9 @@
import boa.compiler.ast.types.*;
import boa.compiler.transforms.VisitorDesugar;
import boa.types.*;
+
+import boa.types.ml.BoaLinearRegression;
+import boa.types.ml.BoaModel;
import boa.types.proto.CodeRepositoryProtoTuple;
/**
@@ -937,6 +940,36 @@ public void visit(final VarDeclStatement n, final SymbolTable env) {
if (rhs != null && !lhs.assigns(rhs) && !env.hasCast(rhs, lhs))
throw new TypeCheckException(n.getInitializer(), "incorrect type '" + rhs + "' for assignment to '" + id + ": " + lhs + "'");
+
+ if (n.getType() instanceof OutputType) {
+ BoaModel model = (BoaModel) SymbolTable.getMLAggregatorType(((OutputType) n.getType()).getId().getToken());
+ if (model != null) {
+ BoaType t = ((OutputType) n.getType()).type;
+ List types = new ArrayList();
+
+ if (t instanceof BoaTable) {
+ t = ((BoaTable) t).getType();
+ }
+ if (t instanceof BoaTuple)
+ types = ((BoaTuple) t).getTypes();
+ else if (t instanceof BoaArray)
+ types.add(((BoaArray) t).getType());
+ else
+ types.add(t);
+
+ if (model instanceof BoaLinearRegression) {
+ if (!(types.get(types.size() - 1) instanceof BoaInt
+ || types.get(types.size() - 1) instanceof BoaFloat || types.get(types.size() - 1) instanceof BoaTime))
+ throw new TypeCheckException(n, "LinearRegression required class to be numeric or date");
+ for (int i = 0; i < types.size() - 1; i++) {
+ if (!(types.get(i) instanceof BoaEnum || types.get(i) instanceof BoaFloat ||
+ types.get(i) instanceof BoaInt || types.get(i) instanceof BoaTime || types.get(i) instanceof BoaArray))
+ throw new TypeCheckException(n, "LinearRegression required attributes to be numeric, nominal or date");
+ }
+ }
+ }
+ }
+
} else {
if (rhs == null)
throw new TypeCheckException(n, "variable declaration requires an explicit type or an initializer");
@@ -1400,6 +1433,29 @@ public void visit(final OutputType n, final SymbolTable env) {
n.getId().accept(this, env);
}
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void visit(final ModelType n, final SymbolTable env) {
+ n.env = env;
+ n.getType().accept(this, env);
+
+ if (env.hasType(n.getId().getToken()))
+ n.type = SymbolTable.getType(n.getId().getToken());
+ else
+ try {
+ n.type = env.get(n.getId().getToken());
+ } catch (final RuntimeException e) {
+ throw new TypeCheckException(n, "invalid identifier '" + n.getId().getToken() + "'", e);
+ }
+
+ if (n.type instanceof BoaLinearRegression)
+ n.type = new BoaLinearRegression(n.getType().type);
+
+ }
+
/** {@inheritDoc} */
@Override
public void visit(final StackType n, final SymbolTable env) {
diff --git a/src/java/boa/datagen/DefaultProperties.java b/src/java/boa/datagen/DefaultProperties.java
index ae9c65eb6..a854e5c58 100644
--- a/src/java/boa/datagen/DefaultProperties.java
+++ b/src/java/boa/datagen/DefaultProperties.java
@@ -74,6 +74,13 @@ public class DefaultProperties {
public static String localDataPath = null;
+ // for ML
+ public static String HADOOP_SEQ_FILE_LOCATION = "";
+ //public static String HADOOP_OUT_LOCATION = "";
+ public static String HADOOP_OUT_LOCATION = "./model";
+ public static String localOutput = null;
+
+
@SuppressWarnings("unused")
private static String getRoot() {
File dir = new File(System.getProperty("user.dir"));
diff --git a/src/java/boa/functions/BoaIntrinsics.java b/src/java/boa/functions/BoaIntrinsics.java
index 2413a68a6..f6fa3b19c 100644
--- a/src/java/boa/functions/BoaIntrinsics.java
+++ b/src/java/boa/functions/BoaIntrinsics.java
@@ -16,6 +16,11 @@
*/
package boa.functions;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.ObjectInputStream;
+import java.lang.reflect.Array;
+import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -28,11 +33,24 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+
import boa.types.Code.CodeRepository;
import boa.types.Code.Revision;
import boa.types.Diff.ChangedFile;
import boa.types.Shared.ChangeKind;
import boa.types.Toplevel.Project;
+import boa.types.ml.BoaLinearRegression;
+import boa.types.ml.BoaModel;
+import weka.classifiers.Classifier;
+import weka.core.Attribute;
+import weka.core.DenseInstance;
+import weka.core.Instance;
+import weka.core.Instances;
/**
* Boa domain-specific functions.
@@ -80,7 +98,171 @@ private static int getRevisionIndex(final CodeRepository cr, final String id) {
}
return -1;
}
+
+ /**
+ * Given the model URL, deserialize the model and return Model type
+ *
+ * @param Take URL for the model
+ * @return Model type after deserializing
+ */
+ @FunctionSpec(name = "load", returnType = "Model", formalParameters = {"string"})
+ public static BoaModel load(final String URL, final Object o) throws Exception {
+ Object unserializedObject = null;
+ FSDataInputStream in = null;
+ ObjectInputStream dataIn = null;
+ ByteArrayOutputStream bo = null;
+ try {
+ final Configuration conf = new Configuration();
+ final FileSystem fileSystem = FileSystem.get(conf);
+ final Path path = new Path("hdfs://master" + URL);
+ in = fileSystem.open(path);
+
+ final byte[] b = new byte[(int)fileSystem.getLength(path) + 1];
+
+ int c = 0;
+ bo = new ByteArrayOutputStream();
+ while((c = in.read(b)) != -1){
+ bo.write(b, 0, c);
+ }
+
+ ByteArrayInputStream bin = new ByteArrayInputStream(bo.toByteArray());
+ dataIn = new ObjectInputStream(bin);
+ unserializedObject = dataIn.readObject();
+ }
+ catch(Exception e){
+ e.printStackTrace();
+ }
+ finally {
+ try {
+ if (in != null) in.close();
+ if (dataIn != null) dataIn.close();
+ if (bo != null) bo.close();
+ } catch (final Exception e) { e.printStackTrace(); }
+ }
+
+ Classifier clr = (Classifier)unserializedObject;
+ BoaModel m = null;
+
+
+ if(clr.toString().contains("Linear Regression")){
+ m = new BoaLinearRegression(clr, o);
+ }
+
+ return m;
+ }
+
+ @FunctionSpec(name = "classify", returnType = "string", formalParameters = { "Model","tuple"})
+ public static String classify(final BoaModel model, final boa.BoaTup vector) throws Exception {
+ int NumOfAttributes = 0;
+ ArrayList fvAttributes = new ArrayList();
+ try {
+ String[] fieldNames = vector.getFieldNames();
+ int count = 0;
+ for(int i = 0; i < fieldNames.length; i++) {
+ if(vector.getValue(fieldNames[i]).getClass().isEnum()) {
+ ArrayList fvNominalVal = new ArrayList();
+ for(Object obj: vector.getValue(fieldNames[i]).getClass().getEnumConstants())
+ fvNominalVal.add(obj.toString());
+ fvAttributes.add(new Attribute("Nominal" + count, fvNominalVal));
+ count++;
+ }
+ else if(vector.getValue(fieldNames[i]).getClass().isArray()) {
+ int l = Array.getLength(vector.getValue(fieldNames[i])) - 1;
+ for(int j = 0; j <= l; j++) {
+ fvAttributes.add(new Attribute("Attribute" + count));
+ count++;
+ }
+ }
+ else {
+ fvAttributes.add(new Attribute("Attribute" + count));
+ count++;
+ }
+ }
+
+ String[] fields = ((boa.BoaTup)model.getObject()).getFieldNames();
+ Field lastfield = model.getObject().getClass().getField(fields[fields.length - 1]);
+ if(lastfield.getType().isEnum()) {
+ ArrayList fvNominalVal = new ArrayList();
+ for(Object obj: lastfield.getType().getEnumConstants())
+ fvNominalVal.add(obj.toString());
+ fvAttributes.add(new Attribute("Nominal" + count, fvNominalVal));
+ count++;
+ }
+ else {
+ fvAttributes.add(new Attribute("Attribute" + count));
+ count++;
+ }
+
+ NumOfAttributes = count;
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+
+ Instances testingSet = new Instances("Classifier", fvAttributes, 1);
+ testingSet.setClassIndex(NumOfAttributes-1);
+
+ Instance instance = new DenseInstance(NumOfAttributes);
+
+ for(int i=0; i fvAttributes = new ArrayList();
+
+ for(int i=0; i < NumOfAttributes - 1; i++) {
+ fvAttributes.add(new Attribute("Attribute" + i));
+ }
+
+ try {
+ String[] fields = ((boa.BoaTup)model.getObject()).getFieldNames();
+ Field lastfield = model.getObject().getClass().getField(fields[fields.length - 1]);
+ if(lastfield.getType().isEnum()) {
+ ArrayList fvNominalVal = new ArrayList();
+ for(Object obj: lastfield.getType().getEnumConstants())
+ fvNominalVal.add(obj.toString());
+ fvAttributes.add(new Attribute("Nominal" + (NumOfAttributes - 1), fvNominalVal));
+ } else
+ fvAttributes.add(new Attribute("Attribute" + (NumOfAttributes - 1)));
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+
+ Instances testingSet = new Instances("Classifier", fvAttributes, 1);
+ testingSet.setClassIndex(NumOfAttributes - 1);
+ Instance instance = new DenseInstance(NumOfAttributes);
+ for(int i=0; i 0) {
+ byte[] bytes = new byte[length];
+ in.readFully(bytes, 0, length);
+ ByteArrayInputStream bin = new ByteArrayInputStream(bytes);
+ ObjectInputStream dataIn = new ObjectInputStream(bin);
+ Object o = null;
+ try {
+ o = dataIn.readObject();
+ } catch(Exception e) {
+ e.printStackTrace();
+ }
+ this.tdata = (BoaTup)o;
+ }
}
/** {@inheritDoc} */
@@ -308,6 +407,14 @@ public void write(final DataOutput out) throws IOException {
Text.writeString(out, "");
else
Text.writeString(out, this.metadata);
+
+ if (this.tdata == null)
+ out.writeInt(0);
+ else {
+ byte[] serializedObject = this.tdata.serialize(this.tdata);
+ out.writeInt(serializedObject.length);
+ out.write(serializedObject);
+ }
}
/**
@@ -339,6 +446,11 @@ public String getMetadata() {
public void setMetadata(final String metadata) {
this.metadata = metadata;
}
+
+ public BoaTup getTuple() {
+ return this.tdata;
+ }
+
@Override
public int hashCode() {
@@ -373,4 +485,4 @@ public boolean equals(final Object obj) {
public String toString() {
return Arrays.toString(this.data) + ":" + this.metadata;
}
-}
+}
\ No newline at end of file
diff --git a/src/java/boa/runtime/BoaReducer.java b/src/java/boa/runtime/BoaReducer.java
index a1f75e034..d5e64d18b 100644
--- a/src/java/boa/runtime/BoaReducer.java
+++ b/src/java/boa/runtime/BoaReducer.java
@@ -25,6 +25,7 @@
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.Reducer.Context;
import org.apache.log4j.Logger;
import boa.aggregators.Aggregator;
@@ -78,15 +79,25 @@ public void setConf(final Configuration conf) {
protected void reduce(final EmitKey key, final Iterable values, final Context context) throws IOException, InterruptedException {
// get the aggregator named by the emit key
final Aggregator a = this.aggregators.get(key.getName());
-
+ boolean setVector = true;
+
a.setCombining(false);
a.start(key);
a.setContext(context);
+ int counter = 1;
for (final EmitValue value : values)
try {
+ if (value.getTuple() != null) {
+ a.aggregate(value.getTuple(), value.getMetadata());
+ } else {
+ if (setVector && value.getData().length > 1) {
+ a.setVectorSize(value.getData().length);
+ setVector = false;
+ }
for (final String s : value.getData())
a.aggregate(s, value.getMetadata());
+ }
} catch (final FinishedException e) {
// we are done
return;
diff --git a/src/java/boa/types/ml/BoaLinearRegression.java b/src/java/boa/types/ml/BoaLinearRegression.java
new file mode 100644
index 000000000..cf1fa4058
--- /dev/null
+++ b/src/java/boa/types/ml/BoaLinearRegression.java
@@ -0,0 +1,123 @@
+/*
+ * Copyright 2014, Hridesh Rajan, Robert Dyer,
+ * and Iowa State University of Science and Technology
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package boa.types.ml;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import boa.types.BoaType;
+import boa.types.BoaName;
+
+import weka.classifiers.Classifier;
+/**
+ * A {@link BoaType} representing ML model of LinearRegression with attached types.
+ *
+ * @author ankuraga
+ */
+public class BoaLinearRegression extends BoaModel{
+ private Classifier clr;
+ private BoaType t;
+ private Object o;
+
+ /**
+ * Default BoaLinearRegression Constructor.
+ *
+ */
+ public BoaLinearRegression(){
+ }
+
+ /**
+ * Construct a BoaLinearRegression.
+ *
+ * @param t
+ * A {@link BoaType} containing the types attached with this model
+ *
+ */
+ public BoaLinearRegression(BoaType t){
+ this.t = t;
+ }
+
+ /**
+ * Construct a BoaLinearRegression.
+ *
+ * @param clr
+ * A {@link Classifier} containing ML model
+ *
+ * @param o
+ * A {@link Object} containing type object
+ *
+ */
+ public BoaLinearRegression(Classifier clr, Object o){
+ this.clr = clr;
+ this.o = o;
+ }
+
+ /**
+ * Get the classifier of this model.
+ *
+ * @return A {@link Classifier} representing ML model
+ *
+ */
+ public Classifier getClassifier() {
+ return this.clr;
+ }
+
+ /**
+ * Get the type attached with this model.
+ *
+ * @return A {@link BoaType} representing type attached with ML model
+ *
+ */
+ public BoaType getType() {
+ return this.t;
+ }
+
+ /**
+ * Get the type object of this model.
+ *
+ * @return A {@link Object} representing type object
+ *
+ */
+ public Object getObject() {
+ return this.o;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean assigns(final BoaType that) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean accepts(final BoaType that) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toJavaType() {
+ return "boa.types.ml.BoaLinearRegression";
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return "boa.types.ml.BoaLinearRegression" + "/" + this.t;
+ }
+}
\ No newline at end of file
diff --git a/src/java/boa/types/ml/BoaModel.java b/src/java/boa/types/ml/BoaModel.java
new file mode 100644
index 000000000..420a146da
--- /dev/null
+++ b/src/java/boa/types/ml/BoaModel.java
@@ -0,0 +1,123 @@
+/*
+ * Copyright 2014, Hridesh Rajan, Robert Dyer,
+ * and Iowa State University of Science and Technology
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package boa.types.ml;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import boa.types.BoaName;
+import boa.types.BoaType;
+import weka.classifiers.Classifier;
+
+/**
+ * A {@link BoaType} representing model of any ML type.
+ *
+ * @author ankuraga
+ */
+public class BoaModel extends BoaType {
+ private Classifier clr;
+ private BoaType t;
+ private Object o;
+
+ /**
+ * Default BoaModel Constructor.
+ *
+ */
+ public BoaModel(){
+ }
+
+ /**
+ * Construct a BoaModel.
+ *
+ * @param t
+ * A {@link BoaType} containing the types attached with this model
+ *
+ */
+ public BoaModel(BoaType t){
+ this.t = t;
+ }
+
+ /**
+ * Construct a BoaModel.
+ *
+ * @param clr
+ * A {@link Classifier} containing ML model
+ *
+ * @param o
+ * A {@link Object} containing type object
+ *
+ */
+ public BoaModel(Classifier clr, Object o){
+ this.clr = clr;
+ this.o = o;
+ }
+
+ /**
+ * Get the classifier of this model.
+ *
+ * @return A {@link Classifier} representing ML model
+ *
+ */
+ public Classifier getClassifier() {
+ return this.clr;
+ }
+
+ /**
+ * Get the type attached with this model.
+ *
+ * @return A {@link BoaType} representing type attached with ML model
+ *
+ */
+ public BoaType getType() {
+ return this.t;
+ }
+
+ /**
+ * Get the type object of this model.
+ *
+ * @return A {@link Object} representing type object
+ *
+ */
+ public Object getObject() {
+ return this.o;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean assigns(final BoaType that) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean accepts(final BoaType that) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toJavaType() {
+ return "boa.types.ml.BoaModel";
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return "boa.types.ml.BoaModel" + "/" + this.t;
+ }
+}
\ No newline at end of file