diff --git a/egads b/egads new file mode 100755 index 0000000..ab8b304 --- /dev/null +++ b/egads @@ -0,0 +1,29 @@ +#!/usr/bin/perl + +sub usage { + print "Usage: egads args\n"; + print " Commands:\n"; + print " StreamForecast\n"; + print " TrainForecastingModel\n"; + print " MakeEmptyModel\n"; + exit (0); +} + +sub launch { + my $cmd = "java -cp lib/OpenForecast-0.5.0.jar:target/egads-jar-with-dependencies.jar -Dlog4j.configurationFile=log4j.xml com.yahoo.egads." . join (' ', @ARGV); + print STDERR $cmd, "\n"; + system $cmd; +} + +my $command = $ARGV[0]; + +if ($command eq "MakeEmptyModel") { + launch(); +} elsif ($command eq "StreamForecast") { + launch(); +} elsif ($command eq "TrainForecastingModel") { + launch(); +} else { + usage(); +} + diff --git a/pom.xml b/pom.xml index fd2f3ba..5073e90 100644 --- a/pom.xml +++ b/pom.xml @@ -20,6 +20,11 @@ + + gnu.getopt + java-getopt + 1.0.13 + org.apache.logging.log4j log4j-api diff --git a/src/main/java/com/yahoo/egads/Egads.java b/src/main/java/com/yahoo/egads/Egads.java index 9ee8be7..07c4eeb 100644 --- a/src/main/java/com/yahoo/egads/Egads.java +++ b/src/main/java/com/yahoo/egads/Egads.java @@ -2,10 +2,13 @@ package com.yahoo.egads; -import java.util.Properties; import java.io.FileInputStream; import java.io.InputStream; -import com.yahoo.egads.utilities.*; +import java.util.Properties; + +import com.yahoo.egads.utilities.FileInputProcessor; +import com.yahoo.egads.utilities.InputProcessor; +import com.yahoo.egads.utilities.StdinProcessor; /* * Call stack. diff --git a/src/main/java/com/yahoo/egads/StreamForecast.java b/src/main/java/com/yahoo/egads/StreamForecast.java new file mode 100644 index 0000000..fa43aa1 --- /dev/null +++ b/src/main/java/com/yahoo/egads/StreamForecast.java @@ -0,0 +1,92 @@ +package com.yahoo.egads; + +import gnu.getopt.Getopt; + +import java.util.HashMap; +import java.util.Scanner; + +import com.yahoo.egads.data.FileModelStore; +import com.yahoo.egads.data.Model; +import com.yahoo.egads.data.ModelStore; +import com.yahoo.egads.data.TimeSeries; +import com.yahoo.egads.models.tsmm.StreamingOlympicModel; +import com.yahoo.egads.models.tsmm.TimeSeriesStreamingModel; + +public class StreamForecast { + public static void main(String[] args) { + HashMap options = processOptions(args); + Scanner sc = new Scanner(System.in); + ModelStore ms = new FileModelStore ("models"); + while (sc.hasNextLine()) { + String line = sc.nextLine(); + String[] fields = line.split(","); + String series; + int timestamp; + double measured; + double forecast = Double.NaN; + try { + series = fields[0]; + timestamp = Integer.parseInt(fields[1]); + measured = Double.parseDouble(fields[2]); + TimeSeriesStreamingModel model = (TimeSeriesStreamingModel) ms.getModel(series, Model.ModelType.FORECAST); + if (model == null) { + model = new StreamingOlympicModel(); + ms.storeModel(series, model); + System.err.println ("No such model " + series); + } + TimeSeries.Entry e = new TimeSeries.Entry(timestamp, (float)measured); + forecast = model.predict(e); + model.update(e); + if (!options.containsKey(new Integer('q'))) { + System.out.println(String.join(",", series, String.format("%d", timestamp), String.format("%f", measured), String.format("%f", forecast))); + } + } catch (Exception e) { + System.err.println("Invalid input line " + line); + continue; + } + } + ms.writeCachedModels(); + if (options.containsKey(new Integer('t'))) { + for (Model m : ms.getCachedModels()) { + System.out.println(m.errorSummaryString()); + } + } + sc.close(); + } + + public static void usage() { + System.out.println ("Usage: StreamForecast [-m ] [-p ] [-t] [-q] [-h]"); + System.out.println (" Modeltypes:"); + System.out.println (" sos: Streaming Olympic Scoring"); + System.out.println (" Default preperties file is config.ini"); + System.out.println (" -t: Run in model test mode. This outputs error stats at the end of the model run"); + System.exit(0); + } + + public static HashMap processOptions (String[] args) { + HashMap result = new HashMap(); + // defaults + result.put(new Integer('m'), "sos"); + result.put(new Integer('p'), "config.ini"); + + Getopt g = new Getopt("TrainForecastingModel", args, "m:p:htq"); + int c; + while ((c = g.getopt()) != -1) { + switch (c) { + case 'm': + case 'n': + case 'p': + result.put(c, g.getOptarg()); + break; + case 't': + case 'q': + result.put(c, "True"); + break; + case 'h': + usage(); + } + } + return result; + } + +} diff --git a/src/main/java/com/yahoo/egads/TrainForecastingModel.java b/src/main/java/com/yahoo/egads/TrainForecastingModel.java new file mode 100644 index 0000000..1d53d7d --- /dev/null +++ b/src/main/java/com/yahoo/egads/TrainForecastingModel.java @@ -0,0 +1,87 @@ +package com.yahoo.egads; + +import gnu.getopt.Getopt; + +import java.io.FileInputStream; +import java.util.HashMap; +import java.util.Properties; +import java.util.Scanner; + +import com.yahoo.egads.data.FileModelStore; +import com.yahoo.egads.data.ModelFactory; +import com.yahoo.egads.data.ModelStore; +import com.yahoo.egads.data.TimeSeries; +import com.yahoo.egads.models.tsmm.TimeSeriesModel; + +public class TrainForecastingModel { + public static void main(String[] args) throws Exception { + HashMap options = processOptions(args); + HashMap inputs = new HashMap(); + Scanner sc = new Scanner(System.in); + ModelStore ms = new FileModelStore ("models"); + Properties osProps = new Properties(); + osProps.load (new FileInputStream(options.get(new Integer('p')))); + while (sc.hasNextLine()) { + String line = sc.nextLine(); + String[] fields = line.split(","); + String series; + int timestamp; + float measured; + try { + series = fields[0]; + timestamp = Integer.parseInt(fields[1]); + measured = Float.parseFloat(fields[2]); + } catch (Exception e) { + System.err.println("Invalid input line " + line); + continue; + } + if (!inputs.containsKey(series)) { + inputs.put(series, new TimeSeries.DataSequence()); + } + TimeSeries.DataSequence seq = inputs.get(series); + seq.add(new TimeSeries.Entry(timestamp, measured)); + } + sc.close(); + ModelFactory mf = new ModelFactory(osProps); + for (String series : inputs.keySet()) { + TimeSeries.DataSequence seq = inputs.get(series); + TimeSeriesModel m = mf.getTSModel(options.get(new Integer('m'))); + if (m != null) { + m.train(seq); + ms.storeModel(series, m); + } + } + } + + public static void usage() { + System.out.println ("Usage: TrainForecastingModel [-m ] [-p ] [-h]"); + System.out.println (" Modeltypes:"); + System.out.println (" sos: Streaming Olympic Scoring"); + System.out.println (" Default preperties file is config.ini"); + System.exit(0); + } + + public static HashMap processOptions (String[] args) { + HashMap result = new HashMap(); + // defaults + result.put(new Integer('m'), "sos"); + result.put(new Integer('p'), "config.ini"); + + Getopt g = new Getopt("TrainForecastingModel", args, "m:p:h"); + int c; + while ((c = g.getopt()) != -1) { + switch (c) { + case 'm': + result.put(c, g.getOptarg()); + break; + case 'p': + result.put(c, g.getOptarg()); + break; + case 'h': + usage(); + } + } + return result; + } + +} diff --git a/src/main/java/com/yahoo/egads/batch/TrainForecastingModel.java b/src/main/java/com/yahoo/egads/batch/TrainForecastingModel.java deleted file mode 100644 index 72bfc7f..0000000 --- a/src/main/java/com/yahoo/egads/batch/TrainForecastingModel.java +++ /dev/null @@ -1,55 +0,0 @@ -package com.yahoo.egads.batch; - -import java.util.Properties; -import java.util.Scanner; -import java.util.HashMap; - -import com.yahoo.egads.data.FileModelStore; -import com.yahoo.egads.data.ModelStore; -import com.yahoo.egads.data.TimeSeries; -import com.yahoo.egads.models.tsmm.OlympicModel; - -public class TrainForecastingModel { - public static void main(String[] args) { - HashMap inputs = new HashMap(); - Scanner sc = new Scanner(System.in); - ModelStore m = new FileModelStore ("models"); - Properties osProps = new Properties(); - osProps.setProperty("TIME_SHIFTS", "0,1"); - osProps.setProperty("BASE_WINDOWS", "24,168"); - osProps.setProperty("NUM_WEEKS", "6"); - osProps.setProperty("NUM_TO_DROP", "1"); - while (sc.hasNextLine()) { - String line = sc.nextLine(); - String[] fields = line.split(","); - String series; - int timestamp; - float measured; - try { - series = fields[0]; - timestamp = Integer.parseInt(fields[1]); - measured = Float.parseFloat(fields[2]); - } catch (Exception e) { - System.err.println("Invalid input line " + line); - continue; - } - if (!inputs.containsKey(series)) { - inputs.put(series, new TimeSeries.DataSequence()); - } - TimeSeries.DataSequence seq = inputs.get(series); - seq.add(new TimeSeries.Entry(timestamp, measured)); - } - sc.close(); - for (String series : inputs.keySet()) { - TimeSeries.DataSequence seq = inputs.get(series); - OlympicModel o = new OlympicModel(osProps); - o.train(seq); - System.out.println (series + ":"); - for (TimeSeries.Entry e : seq) { - System.out.println(e.time + ": " + e.value); - } - m.storeModel(series, o); - } - } - -} diff --git a/src/main/java/com/yahoo/egads/data/Entry.java b/src/main/java/com/yahoo/egads/data/Entry.java deleted file mode 100644 index a7d1b54..0000000 --- a/src/main/java/com/yahoo/egads/data/Entry.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2015, Yahoo Inc. - * Copyrights licensed under the GPL License. - * See the accompanying LICENSE file for terms. - */ - -// A simple egads entry class. - -package com.yahoo.egads.data; - -public class Entry { - public T ts; - public F val; - - public Entry(T ts, F val) { - this.ts = ts; - this.val = val; - } - - public String toString() { - return this.ts + "," + this.val; - } -} diff --git a/src/main/java/com/yahoo/egads/data/FileModelStore.java b/src/main/java/com/yahoo/egads/data/FileModelStore.java index c5a72f9..65f3eff 100644 --- a/src/main/java/com/yahoo/egads/data/FileModelStore.java +++ b/src/main/java/com/yahoo/egads/data/FileModelStore.java @@ -5,20 +5,36 @@ import java.io.FileOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.util.Collection; +import java.util.HashMap; public class FileModelStore implements ModelStore { + HashMap cache; String path; + protected static org.apache.logging.log4j.Logger logger = org.apache.logging.log4j.LogManager.getLogger(FileModelStore.class.getName()); + public FileModelStore (String path) { - File dir = new File (path); - dir.mkdirs(); this.path = path; + cache = new HashMap(); + new File (path).mkdirs(); + } + + private String getFilename (String tag, Model.ModelType type) { + String filename = tag.replaceAll("[^\\w_-]", "_"); + if (type == Model.ModelType.ANOMALY) { + filename = "anomaly." + filename; + } else if (type == Model.ModelType.FORECAST) { + filename = "forecast." + filename; + } + return filename; } @Override public void storeModel(String tag, Model m) { - String filename = tag.replaceAll("[^\\w_-]", "_"); + String filename = getFilename(tag, m.getModelType()); String fqn = path + "/" + filename; try { + m.clearModified(); ObjectOutputStream o = new ObjectOutputStream(new FileOutputStream (fqn)); o.writeObject(m); o.close(); @@ -28,18 +44,36 @@ public void storeModel(String tag, Model m) { } @Override - public Model retrieveModel(String tag) { - String filename = tag.replaceAll("[^\\w_-]", "_"); + public Model getModel(String tag, Model.ModelType type) { + String filename = getFilename(tag, type); + if (cache.containsKey(filename)) { + return cache.get(filename); + } String fqn = path + "/" + filename; Model m = null; try { ObjectInputStream o = new ObjectInputStream(new FileInputStream(fqn)); m = (Model) o.readObject(); o.close(); + cache.put(filename, m); + return m; } catch (Exception e) { - e.printStackTrace(); + logger.debug("Model not found: " + tag); } - return m; + return null; + } + public void writeCachedModels() { + for (String key : cache.keySet()) { + Model model = cache.get(key); + if (model.isModified()) { +// The key always has the model type prepended - remove it before storing + key = key.replaceFirst("[a-zA-Z]*\\.", ""); + storeModel(key, model); + } + } + } + public Collection getCachedModels() { + return cache.values(); } } diff --git a/src/main/java/com/yahoo/egads/data/Model.java b/src/main/java/com/yahoo/egads/data/Model.java index f6b6f22..f681970 100644 --- a/src/main/java/com/yahoo/egads/data/Model.java +++ b/src/main/java/com/yahoo/egads/data/Model.java @@ -13,11 +13,15 @@ import java.io.Serializable; public interface Model extends JsonAble, Serializable { + enum ModelType {FORECAST, ANOMALY}; // resets the model. public void reset(); // Gets the model name and type public String getModelName(); - public String getModelType(); + public ModelType getModelType(); + public boolean isModified (); + public void clearModified(); + public String errorSummaryString(); } diff --git a/src/main/java/com/yahoo/egads/data/ModelStore.java b/src/main/java/com/yahoo/egads/data/ModelStore.java index a8316e2..ab83b2b 100644 --- a/src/main/java/com/yahoo/egads/data/ModelStore.java +++ b/src/main/java/com/yahoo/egads/data/ModelStore.java @@ -1,6 +1,12 @@ package com.yahoo.egads.data; +import java.util.Collection; + +import com.yahoo.egads.data.Model.ModelType; + public interface ModelStore { public void storeModel(String tag, Model m); - public Model retrieveModel (String tag); + Model getModel(String tag, ModelType type); + public void writeCachedModels(); + public Collection getCachedModels(); } diff --git a/src/main/java/com/yahoo/egads/models/adm/AnomalyDetectionAbstractModel.java b/src/main/java/com/yahoo/egads/models/adm/AnomalyDetectionAbstractModel.java index f000384..b6bd87d 100644 --- a/src/main/java/com/yahoo/egads/models/adm/AnomalyDetectionAbstractModel.java +++ b/src/main/java/com/yahoo/egads/models/adm/AnomalyDetectionAbstractModel.java @@ -12,21 +12,26 @@ import org.json.JSONStringer; import com.yahoo.egads.data.JsonEncoder; +import com.yahoo.egads.data.Model; +import com.yahoo.egads.data.Model.ModelType; +import com.yahoo.egads.models.tsmm.TimeSeriesModel; public abstract class AnomalyDetectionAbstractModel implements AnomalyDetectionModel { - protected org.apache.logging.log4j.Logger logger; protected float sDAutoSensitivity = 3; protected float amntAutoSensitivity = (float) 0.05; protected String outputDest = ""; protected String modelName; + protected boolean modified; + + protected static org.apache.logging.log4j.Logger logger = org.apache.logging.log4j.LogManager.getLogger(AnomalyDetectionModel.class.getName()); public String getModelName() { return modelName; } - public String getModelType() { - return "Anomaly"; + public ModelType getModelType() { + return Model.ModelType.ANOMALY; } @Override @@ -65,4 +70,19 @@ public AnomalyDetectionAbstractModel(Properties config) { } this.outputDest = config.getProperty("OUTPUT"); } + + public boolean isModified () { + return modified; + } + public void clearModified() { + modified = false; + } + + public String errorSummaryString() { + return ""; + } + + public void clearErrorStats() { + } + } diff --git a/src/main/java/com/yahoo/egads/models/tsmm/OlympicModel.java b/src/main/java/com/yahoo/egads/models/tsmm/OlympicModel.java index f6097fd..078a941 100644 --- a/src/main/java/com/yahoo/egads/models/tsmm/OlympicModel.java +++ b/src/main/java/com/yahoo/egads/models/tsmm/OlympicModel.java @@ -23,26 +23,6 @@ public class OlympicModel extends TimeSeriesAbstractModel { private static final long serialVersionUID = 1L; - public int getNumWeeks() { - return numWeeks; - } - - public int getNumToDrop() { - return numToDrop; - } - - public int[] getTimeShifts() { - return timeShifts; - } - - public int[] getBaseWindows() { - return baseWindows; - } - - public ArrayList getModel() { - return model; - } - // Number of weeks to look back when computing the // estimate. protected int numWeeks; diff --git a/src/main/java/com/yahoo/egads/models/tsmm/StreamingOlympicModel.java b/src/main/java/com/yahoo/egads/models/tsmm/StreamingOlympicModel.java new file mode 100644 index 0000000..3c00b37 --- /dev/null +++ b/src/main/java/com/yahoo/egads/models/tsmm/StreamingOlympicModel.java @@ -0,0 +1,141 @@ +/* + * Copyright 2015, Yahoo Inc. + * Copyrights licensed under the GPL License. + * See the accompanying LICENSE file for terms. + */ + +// Olympic scoring model considers the average of the last k weeks +// (dropping the b highest and lowest values) as the current prediction. + +package com.yahoo.egads.models.tsmm; + +import com.yahoo.egads.data.*; +import com.yahoo.egads.data.TimeSeries.Entry; + +import java.util.HashMap; +import java.util.Properties; +import java.util.ArrayList; +import java.util.Collections; + +import com.yahoo.egads.utilities.FileUtils; + +public class StreamingOlympicModel extends TimeSeriesStreamingModel { + // methods //////////////////////////////////////////////// + + private static final long serialVersionUID = 1L; + + private HashMap model; + protected int period; + protected double smoothingFactor; + + public StreamingOlympicModel() { + super(); + smoothingFactor = 0.4; + period = 86400 * 7; + model = new HashMap(); + } + public StreamingOlympicModel(double smoothingFactor, int period) { + super(); + this.smoothingFactor = smoothingFactor; + this.period = period; + this.model = new HashMap(); + } + + public void reset() { + model = new HashMap(); + } + + private long timeToModelTime (long time) { + if (period == 86400 * 7) { + return weeklyOffset(time); + } + if (period == 86400) { + return dailyOffset(time); + } + return time % period; + } + + public void update (TimeSeries.Entry entry) { + long modelTime = timeToModelTime(entry.time); + if (model.containsKey(modelTime)) { + model.put(modelTime, model.get(modelTime) * (1 - smoothingFactor) + entry.value * smoothingFactor); + } else { + model.put(modelTime, (double)entry.value); + } + modified = true; + } + + public double predict (TimeSeries.Entry entry) { + long modelTime = timeToModelTime(entry.time); + double prediction; + if (model.containsKey(modelTime)) { + prediction = model.get(modelTime); + } else { + prediction = entry.value; + } + double error = entry.value - prediction; + sumErr += error; + sumAbsErr += Math.abs(error); + sumAbsPercentErr += 100 * Math.abs(error / entry.value); + sumErrSquared += error * error; + processedPoints++; + return prediction; + } + + private void runSeries (TimeSeries.DataSequence data) { + clearErrorStats(); + for (TimeSeries.Entry entry : data) { + predict(entry); + update(entry); + } + } + + public void train(TimeSeries.DataSequence data) { + StreamingOlympicModel winner = null; + double sf = 0.0; + for (sf = 0.0; sf <= 1; sf += 0.1) { + StreamingOlympicModel m = new StreamingOlympicModel(sf, this.period); + m.runSeries(data); + logger.debug ("Testing Smoothing Factor " + String.format("%.2f", m.smoothingFactor) + " -> "+ m.errorSummaryString()); + if (betterThan(m, winner)) { + winner = m; + } + } + double min = winner.smoothingFactor - 0.09; + if (min < 0) min = 0; + double max = winner.smoothingFactor + 0.09; + if (max >= 1) max = .99; + for (sf = min; sf <= max; sf += 0.01) { + StreamingOlympicModel m = new StreamingOlympicModel(sf, this.period); + m.runSeries(data); + logger.debug ("Testing Smoothing Factor " + String.format("%.2f", m.smoothingFactor) + " -> "+ m.errorSummaryString()); + if (betterThan(m, winner)) { + winner = m; + } + } + this.smoothingFactor = winner.smoothingFactor; + reset(); + runSeries(data); + logger.debug ("Winner: Smoothing Factor = " + String.format("%.2f", this.smoothingFactor)); + } + + public double getSmoothingFactor() { + return smoothingFactor; + } + + public void setSmoothingFactor(double smoothingFactor) { + this.smoothingFactor = smoothingFactor; + } + + public void update(TimeSeries.DataSequence data) { + + } + + public String getModelName() { + return "OlympicModel"; + } + + public void predict(TimeSeries.DataSequence sequence) throws Exception { + return; + } +} diff --git a/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesAbstractModel.java b/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesAbstractModel.java index 35c23b5..84bec2e 100644 --- a/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesAbstractModel.java +++ b/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesAbstractModel.java @@ -6,6 +6,8 @@ package com.yahoo.egads.models.tsmm; +import java.util.Calendar; +import java.util.Date; import java.util.Properties; import org.json.JSONObject; @@ -21,25 +23,48 @@ public abstract class TimeSeriesAbstractModel implements TimeSeriesModel { - // Accuracy stats for this model. + private static final long serialVersionUID = 1L; + // Accuracy stats for this model. protected double bias; protected double mad; protected double mape; protected double mse; protected double sae; protected String modelName; + protected Properties config; + protected boolean modified; - static org.apache.logging.log4j.Logger logger = org.apache.logging.log4j.LogManager.getLogger(TimeSeriesModel.class.getName()); + protected static org.apache.logging.log4j.Logger logger = org.apache.logging.log4j.LogManager.getLogger(TimeSeriesModel.class.getName()); protected boolean errorsInit = false; protected int dynamicParameters = 0; + public TimeSeriesAbstractModel(Properties config) { + this.config = config; + modified = false; + if (config.getProperty("DYNAMIC_PARAMETERS") != null) { + this.dynamicParameters = new Integer(config.getProperty("DYNAMIC_PARAMETERS")); + } + } + + protected long weeklyOffset (long time) { + Date d = new Date(time * 1000); + long ret = d.getDay() * 86400 + d.getHours() * 3600 + d.getMinutes() * 60 + d.getSeconds(); + return ret; + } + + protected long dailyOffset (long time) { + Date d = new Date(time * 1000); + long ret = d.getHours() * 3600 + d.getMinutes() * 60 + d.getSeconds(); + return ret; + } + public String getModelName() { return modelName; } - public String getModelType() { - return "Forecast"; + public ModelType getModelType() { + return Model.ModelType.FORECAST; } @Override @@ -51,15 +76,18 @@ public void toJson(JSONStringer json_out) throws Exception { public void fromJson(JSONObject json_obj) throws Exception { JsonEncoder.fromJson(this, json_obj); } + - // Acts as a factory method. - public TimeSeriesAbstractModel(Properties config) { - if (config.getProperty("DYNAMIC_PARAMETERS") != null) { - this.dynamicParameters = new Integer(config.getProperty("DYNAMIC_PARAMETERS")); - } - + public void clearErrorStats() { + bias = 0.0; + mad = 0.0; + mape = 0.0; + mse = 0.0; + sae = 0.0; + errorsInit = false; } + // Acts as a factory method. protected static boolean betterThan(TimeSeriesAbstractModel model1, TimeSeriesAbstractModel model2) { // Special case. Any model is better than no model! if (model2 == null) { @@ -97,6 +125,8 @@ protected static boolean betterThan(TimeSeriesAbstractModel model1, TimeSeriesAb } else if (model1.getSAE() - model2.getSAE() >= tolerance) { score--; } + + logger.debug ("Comparison score: " + score); if (score == 0) { // At this point, we're still unsure which one is best @@ -105,11 +135,20 @@ protected static boolean betterThan(TimeSeriesAbstractModel model1, TimeSeriesAb model1.getBias() - model2.getBias() + model1.getMAD() - model2.getMAD() + model1.getMAPE() - model2.getMAPE() + model1.getMSE() - model2.getMSE() + model1.getSAE() - model2.getSAE(); + logger.debug ("Diff: " + diff); return (diff < 0); } return (score > 0); } + + public String errorSummaryString () { + return ("B:" + String.format("%.2f", getBias()) + + "\tMAD:" + String.format("%.2f", getMAD()) + + "\tMAPE:" + String.format("%.2f", getMAPE()) + + "\tMSE:" + String.format("%.2f", getMSE()) + + "\tSAE:" + String.format("%.2f", getSAE())); + } /* * Forecasting model already has the errors defined. @@ -139,15 +178,15 @@ protected void initForecastErrors(ForecastingModel forecaster, TimeSeries.DataSe * Initializes all errors given the model. */ protected void initForecastErrors(ArrayList model, TimeSeries.DataSequence data) { - // Reset various helper summations + clearErrorStats(); + + int n = data.size(); double sumErr = 0.0; double sumAbsErr = 0.0; double sumAbsPercentErr = 0.0; double sumErrSquared = 0.0; int processedPoints = 0; - int n = data.size(); - for (int i = 0; i < n; i++) { // Calculate error in forecast, and update sums appropriately double error = model.get(i) - data.get(i).value; @@ -157,13 +196,14 @@ protected void initForecastErrors(ArrayList model, TimeSeries.DataSequenc sumErrSquared += error * error; processedPoints++; } - this.bias = sumErr / processedPoints; - this.mad = sumAbsErr / processedPoints; - this.mape = sumAbsPercentErr / processedPoints; - this.mse = sumErrSquared / processedPoints; - this.sae = sumAbsErr; + bias = sumErr / processedPoints; + mad = sumAbsErr / processedPoints; + mape = sumAbsPercentErr / processedPoints; + mse = sumErrSquared / processedPoints; + sae = sumAbsErr; errorsInit = true; } + /** * Returns the bias - the arithmetic mean of the errors - obtained from applying the current forecasting model to @@ -234,4 +274,12 @@ public double getSAE() { } return sae; } + public boolean isModified () { + return modified; + } + + public void clearModified() { + modified = false; + } + } diff --git a/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesModel.java b/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesModel.java index 562b485..4ef7449 100644 --- a/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesModel.java +++ b/src/main/java/com/yahoo/egads/models/tsmm/TimeSeriesModel.java @@ -14,10 +14,10 @@ public interface TimeSeriesModel extends Model { // methods //////////////////////////////////////////////// - public abstract void train(TimeSeries.DataSequence data) throws Exception; + public void train(TimeSeries.DataSequence data) throws Exception; - public abstract void update(TimeSeries.DataSequence data) throws Exception; + public void update(TimeSeries.DataSequence data) throws Exception; // predicts the values of the time series specified by the 'time' fields of the sequence and sets the 'value' fields of the sequence - public abstract void predict(TimeSeries.DataSequence sequence) throws Exception; + public void predict(TimeSeries.DataSequence sequence) throws Exception; } diff --git a/src/main/java/com/yahoo/egads/streaming/ProcessStream.java b/src/main/java/com/yahoo/egads/streaming/ProcessStream.java deleted file mode 100644 index fa6bc51..0000000 --- a/src/main/java/com/yahoo/egads/streaming/ProcessStream.java +++ /dev/null @@ -1,31 +0,0 @@ -package com.yahoo.egads.streaming; - -import java.util.Scanner; - -public class ProcessStream { - public static void main(String[] args) { - Scanner sc = new Scanner(System.in); - while (sc.hasNextLine()) { - String line = sc.nextLine(); - String[] fields = line.split(","); - String series; - int timestamp; - double measured; - double forecast = Double.NaN; - try { - series = fields[0]; - timestamp = Integer.parseInt(fields[1]); - measured = Double.parseDouble(fields[2]); - if (fields.length > 3) { - forecast = Double.parseDouble(fields[3]); - } - } catch (Exception e) { - System.err.println("Invalid input line " + line); - continue; - } - System.out.println(series + "/" + timestamp + ": " + measured + "/" + forecast); - } - sc.close(); - } - -}