diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..146a55a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+tensorflow_model/logs/
+tensorflow_model/MNIST_data/
+tensorflow_model/out/
+
diff --git a/MnistAndroid/.gitignore b/MnistAndroid/.gitignore
index 39fb081..f237317 100644
--- a/MnistAndroid/.gitignore
+++ b/MnistAndroid/.gitignore
@@ -7,3 +7,4 @@
/build
/captures
.externalNativeBuild
+.idea/
diff --git a/MnistAndroid/.idea/compiler.xml b/MnistAndroid/.idea/compiler.xml
deleted file mode 100644
index 96cc43e..0000000
--- a/MnistAndroid/.idea/compiler.xml
+++ /dev/null
@@ -1,22 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/copyright/profiles_settings.xml b/MnistAndroid/.idea/copyright/profiles_settings.xml
deleted file mode 100644
index e7bedf3..0000000
--- a/MnistAndroid/.idea/copyright/profiles_settings.xml
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/gradle.xml b/MnistAndroid/.idea/gradle.xml
deleted file mode 100644
index 7ac24c7..0000000
--- a/MnistAndroid/.idea/gradle.xml
+++ /dev/null
@@ -1,18 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/misc.xml b/MnistAndroid/.idea/misc.xml
deleted file mode 100644
index 5d19981..0000000
--- a/MnistAndroid/.idea/misc.xml
+++ /dev/null
@@ -1,46 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/modules.xml b/MnistAndroid/.idea/modules.xml
deleted file mode 100644
index 02e6fb4..0000000
--- a/MnistAndroid/.idea/modules.xml
+++ /dev/null
@@ -1,9 +0,0 @@
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/runConfigurations.xml b/MnistAndroid/.idea/runConfigurations.xml
deleted file mode 100644
index 7f68460..0000000
--- a/MnistAndroid/.idea/runConfigurations.xml
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/MnistAndroid/app/CMakeLists.txt b/MnistAndroid/app/CMakeLists.txt
deleted file mode 100644
index f8e6e8b..0000000
--- a/MnistAndroid/app/CMakeLists.txt
+++ /dev/null
@@ -1,44 +0,0 @@
-# For more information about using CMake with Android Studio, read the
-# documentation: https://d.android.com/studio/projects/add-native-code.html
-
-# Sets the minimum version of CMake required to build the native library.
-
-cmake_minimum_required(VERSION 3.4.1)
-
-# Creates and names a library, sets it as either STATIC
-# or SHARED, and provides the relative paths to its source code.
-# You can define multiple libraries, and CMake builds them for you.
-# Gradle automatically packages shared libraries with your APK.
-
-add_library( # Sets the name of the library.
- native-lib
-
- # Sets the library as a shared library.
- SHARED
-
- # Provides a relative path to your source file(s).
- src/main/cpp/native-lib.cpp )
-
-# Searches for a specified prebuilt library and stores the path as a
-# variable. Because CMake includes system libraries in the search path by
-# default, you only need to specify the name of the public NDK library
-# you want to add. CMake verifies that the library exists before
-# completing its build.
-
-find_library( # Sets the name of the path variable.
- log-lib
-
- # Specifies the name of the NDK library that
- # you want CMake to locate.
- log )
-
-# Specifies libraries CMake should link to your target library. You
-# can link multiple libraries, such as libraries you define in this
-# build script, prebuilt third-party libraries, or system libraries.
-
-target_link_libraries( # Specifies the target library.
- native-lib
-
- # Links the target library to the log library
- # included in the NDK.
- ${log-lib} )
\ No newline at end of file
diff --git a/MnistAndroid/app/build.gradle b/MnistAndroid/app/build.gradle
index a33663b..3992eef 100644
--- a/MnistAndroid/app/build.gradle
+++ b/MnistAndroid/app/build.gradle
@@ -9,12 +9,6 @@ android {
targetSdkVersion 25
versionCode 1
versionName "1.0"
- testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
- externalNativeBuild {
- cmake {
- cppFlags ""
- }
- }
}
buildTypes {
release {
@@ -22,20 +16,9 @@ android {
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
- externalNativeBuild {
- cmake {
- path "CMakeLists.txt"
- }
- }
}
dependencies {
- compile fileTree(include: ['*.jar'], dir: 'libs')
- androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'com.android.support:appcompat-v7:25.3.1'
- compile 'com.android.support.constraint:constraint-layout:1.0.2'
- testCompile 'junit:junit:4.12'
- compile files('libs/libandroid_tensorflow_inference_java.jar')
+ compile 'org.tensorflow:tensorflow-android:1.2.0-rc0'
}
diff --git a/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar b/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar
deleted file mode 100644
index 3b8d93b..0000000
Binary files a/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar and /dev/null differ
diff --git a/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java b/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java
deleted file mode 100644
index fb64682..0000000
--- a/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java
+++ /dev/null
@@ -1,26 +0,0 @@
-package mariannelinhares.mnistandroid;
-
-import android.content.Context;
-import android.support.test.InstrumentationRegistry;
-import android.support.test.runner.AndroidJUnit4;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import static org.junit.Assert.*;
-
-/**
- * Instrumentation test, which will execute on an Android device.
- *
- * @see Testing documentation
- */
-@RunWith(AndroidJUnit4.class)
-public class ExampleInstrumentedTest {
- @Test
- public void useAppContext() throws Exception {
- // Context of the app under test.
- Context appContext = InstrumentationRegistry.getTargetContext();
-
- assertEquals("mariannelinhares.mnistandroid", appContext.getPackageName());
- }
-}
diff --git a/MnistAndroid/app/src/main/assets/expert-graph.pb b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb
similarity index 52%
rename from MnistAndroid/app/src/main/assets/expert-graph.pb
rename to MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb
index 3533274..80c7ecd 100644
Binary files a/MnistAndroid/app/src/main/assets/expert-graph.pb and b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb differ
diff --git a/MnistAndroid/app/src/main/assets/graph.pb b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb
similarity index 52%
rename from MnistAndroid/app/src/main/assets/graph.pb
rename to MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb
index dcd1321..279adec 100644
Binary files a/MnistAndroid/app/src/main/assets/graph.pb and b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb differ
diff --git a/MnistAndroid/app/src/main/cpp/native-lib.cpp b/MnistAndroid/app/src/main/cpp/native-lib.cpp
deleted file mode 100644
index cbb9c07..0000000
--- a/MnistAndroid/app/src/main/cpp/native-lib.cpp
+++ /dev/null
@@ -1,11 +0,0 @@
-#include
-#include
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_mariannelinhares_mnistandroid_MainActivity_stringFromJNI(
- JNIEnv* env,
- jobject /* this */) {
- std::string hello = "Hello from C++";
- return env->NewStringUTF(hello.c_str());
-}
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java
index eb086be..dd39deb 100644
--- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java
@@ -15,7 +15,9 @@
See the License for the specific language governing permissions and
limitations under the License.
- From: https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist/DrawModel.java
+ From: https://raw.githubusercontent
+ .com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist
+ /DrawModel.java
*/
import android.app.Activity;
@@ -25,41 +27,32 @@
import android.view.View;
import android.widget.Button;
import android.widget.TextView;
-
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
-
+import java.util.ArrayList;
+import java.util.List;
+import mariannelinhares.mnistandroid.models.Classification;
+import mariannelinhares.mnistandroid.models.Classifier;
+import mariannelinhares.mnistandroid.models.TensorFlowClassifier;
import mariannelinhares.mnistandroid.views.DrawModel;
import mariannelinhares.mnistandroid.views.DrawView;
/**
* Changed by marianne-linhares on 21/04/17.
- * https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist/DrawModel.java
+ * https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp
+ * /narr/tensorflowmnist/DrawModel.java
*/
public class MainActivity extends Activity implements View.OnClickListener, View.OnTouchListener {
+ private static final int PIXEL_WIDTH = 28;
+
// ui related
private Button clearBtn, classBtn;
private TextView resText;
-
- // tensorflow input and output
- private static final int INPUT_SIZE = 28;
- private static final String INPUT_NAME = "input";
- private static final String OUTPUT_NAME = "output";
-
- private static final String MODEL_FILE = "file:///android_asset/expert-graph.pb";
- private static final String LABEL_FILE = "file:///android_asset/labels.txt";
-
- private Classifier classifier;
-
- private Executor executor = Executors.newSingleThreadExecutor();
+ private List mClassifiers = new ArrayList<>();
// views related
private DrawModel drawModel;
private DrawView drawView;
- private static final int PIXEL_WIDTH = 28;
-
private PointF mTmpPiont = new PointF();
private float mLastX;
@@ -71,97 +64,84 @@ protected void onCreate(Bundle savedInstanceState) {
setContentView(R.layout.activity_main);
//get drawing view
- drawView = (DrawView)findViewById(R.id.draw);
+ drawView = (DrawView) findViewById(R.id.draw);
drawModel = new DrawModel(PIXEL_WIDTH, PIXEL_WIDTH);
drawView.setModel(drawModel);
drawView.setOnTouchListener(this);
//clear button
- clearBtn = (Button)findViewById(R.id.btn_clear);
+ clearBtn = (Button) findViewById(R.id.btn_clear);
clearBtn.setOnClickListener(this);
//class button
- classBtn = (Button)findViewById(R.id.btn_class);
+ classBtn = (Button) findViewById(R.id.btn_class);
classBtn.setOnClickListener(this);
// res text
- resText = (TextView)findViewById(R.id.tfRes);
+ resText = (TextView) findViewById(R.id.tfRes);
// tensorflow
loadModel();
}
+ @Override
+ protected void onResume() {
+ drawView.onResume();
+ super.onResume();
+ }
+
+ @Override
+ protected void onPause() {
+ drawView.onPause();
+ super.onPause();
+ }
+
private void loadModel() {
- executor.execute(new Runnable() {
+ new Thread(new Runnable() {
@Override
public void run() {
try {
- classifier = Classifier.create(getApplicationContext().getAssets(),
- MODEL_FILE,
- LABEL_FILE,
- INPUT_SIZE,
- INPUT_NAME,
- OUTPUT_NAME);
+ mClassifiers.add(
+ TensorFlowClassifier.create(getAssets(), "TensorFlow",
+ "opt_mnist_convnet-tf.pb", "labels.txt", PIXEL_WIDTH,
+ "input", "output", true));
+ mClassifiers.add(
+ TensorFlowClassifier.create(getAssets(), "Keras",
+ "opt_mnist_convnet-keras.pb", "labels.txt", PIXEL_WIDTH,
+ "conv2d_1_input", "dense_2/Softmax", false));
} catch (final Exception e) {
- throw new RuntimeException("Error initializing TensorFlow!", e);
+ throw new RuntimeException("Error initializing classifiers!", e);
}
}
- });
- }
-
- /**
- * A native method that is implemented by the 'native-lib' native library,
- * which is packaged with this application.
- */
- public native String stringFromJNI();
-
- // Used to load the 'native-lib' library on application startup.
- static {
- System.loadLibrary("native-lib");
+ }).start();
}
-
@Override
- public void onClick(View view){
-
- if(view.getId() == R.id.btn_clear) {
+ public void onClick(View view) {
+ if (view.getId() == R.id.btn_clear) {
drawModel.clear();
drawView.reset();
drawView.invalidate();
- resText.setText("Result: ");
- }
- else if(view.getId() == R.id.btn_class){
-
+ resText.setText("");
+ } else if (view.getId() == R.id.btn_class) {
float pixels[] = drawView.getPixelData();
- final Classification res = classifier.recognize(pixels);
- String result = "Result: ";
- if (res.getLabel() == null) {
- resText.setText(result + "?");
- }
- else {
- result += res.getLabel();
- result += "\nwith probability: " + res.getConf();
- resText.setText(result);
+ String text = "";
+ for (Classifier classifier : mClassifiers) {
+ final Classification res = classifier.recognize(pixels);
+ if (res.getLabel() == null) {
+ text += classifier.name() + ": ?\n";
+ } else {
+ text += String.format("%s: %s, %f\n", classifier.name(), res.getLabel(),
+ res.getConf());
+ }
}
+ resText.setText(text);
}
}
- @Override
- protected void onResume() {
- drawView.onResume();
- super.onResume();
- }
-
- @Override
- protected void onPause() {
- drawView.onPause();
- super.onPause();
- }
-
-
@Override
public boolean onTouch(View v, MotionEvent event) {
int action = event.getAction() & MotionEvent.ACTION_MASK;
@@ -169,11 +149,9 @@ public boolean onTouch(View v, MotionEvent event) {
if (action == MotionEvent.ACTION_DOWN) {
processTouchDown(event);
return true;
-
} else if (action == MotionEvent.ACTION_MOVE) {
processTouchMove(event);
return true;
-
} else if (action == MotionEvent.ACTION_UP) {
processTouchUp();
return true;
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java
similarity index 58%
rename from MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java
rename to MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java
index 95bfb1f..5ba6c0b 100644
--- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java
@@ -1,4 +1,4 @@
-package mariannelinhares.mnistandroid;
+package mariannelinhares.mnistandroid.models;
/**
* Created by marianne-linhares on 20/04/17.
@@ -9,16 +9,12 @@ public class Classification {
private float conf;
private String label;
- public Classification(float conf, String label) {
- update(conf, label);
- }
-
- public Classification() {
- this.conf = (float)-1.0;
+ Classification() {
+ this.conf = -1.0F;
this.label = null;
}
- public void update(float conf, String label) {
+ void update(float conf, String label) {
this.conf = conf;
this.label = label;
}
@@ -30,5 +26,4 @@ public String getLabel() {
public float getConf() {
return conf;
}
-
}
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java
new file mode 100644
index 0000000..fd551d1
--- /dev/null
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java
@@ -0,0 +1,11 @@
+package mariannelinhares.mnistandroid.models;
+
+/**
+ * Created by Piasy{github.com/Piasy} on 29/05/2017.
+ */
+
+public interface Classifier {
+ String name();
+
+ Classification recognize(final float[] pixels);
+}
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java
similarity index 55%
rename from MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java
rename to MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java
index f9b34ed..a908f67 100644
--- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java
@@ -1,38 +1,38 @@
-package mariannelinhares.mnistandroid;
+package mariannelinhares.mnistandroid.models;
import android.content.res.AssetManager;
-
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
-
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
/**
- * Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java
+ * Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master
+ * /app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java
* Created by marianne-linhares on 20/04/17.
*/
-public class Classifier {
+public class TensorFlowClassifier implements Classifier {
// Only returns if at least this confidence
private static final float THRESHOLD = 0.1f;
private TensorFlowInferenceInterface tfHelper;
+ private String name;
private String inputName;
private String outputName;
private int inputSize;
+ private boolean feedKeepProb;
private List labels;
private float[] output;
private String[] outputNames;
- static private List readLabels(Classifier c, AssetManager am, String fileName) throws IOException {
- BufferedReader br = null;
- br = new BufferedReader(new InputStreamReader(am.open(fileName)));
+ private static List readLabels(AssetManager am, String fileName) throws IOException {
+ BufferedReader br = new BufferedReader(new InputStreamReader(am.open(fileName)));
String line;
List labels = new ArrayList<>();
@@ -44,44 +44,49 @@ static private List readLabels(Classifier c, AssetManager am, String fil
return labels;
}
+ public static TensorFlowClassifier create(AssetManager assetManager, String name,
+ String modelPath, String labelFile, int inputSize, String inputName, String outputName,
+ boolean feedKeepProb) throws IOException {
+ TensorFlowClassifier c = new TensorFlowClassifier();
- static public Classifier create(AssetManager assetManager, String modelPath, String labelPath,
- int inputSize, String inputName, String outputName)
- throws IOException {
-
- Classifier c = new Classifier();
+ c.name = name;
c.inputName = inputName;
c.outputName = outputName;
- // Read labels
- String labelFile = labelPath.split("file:///android_asset/")[1];
- c.labels = readLabels(c, assetManager, labelFile);
-
- c.tfHelper = new TensorFlowInferenceInterface();
- if (c.tfHelper.initializeTensorFlow(assetManager, modelPath) != 0) {
- throw new RuntimeException("TF initialization failed");
- }
+ c.labels = readLabels(assetManager, labelFile);
+ c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
int numClasses = 10;
c.inputSize = inputSize;
// Pre-allocate buffer.
- c.outputNames = new String[]{ outputName };
+ c.outputNames = new String[] { outputName };
c.outputName = outputName;
c.output = new float[numClasses];
+ c.feedKeepProb = feedKeepProb;
+
return c;
}
+ @Override
+ public String name() {
+ return name;
+ }
+
+ @Override
public Classification recognize(final float[] pixels) {
- tfHelper.fillNodeFloat(inputName, new int[]{inputSize * inputSize}, pixels);
- tfHelper.runInference(outputNames);
+ tfHelper.feed(inputName, pixels, 1, inputSize, inputSize, 1);
+ if (feedKeepProb) {
+ tfHelper.feed("keep_prob", new float[] { 1 });
+ }
+ tfHelper.run(outputNames);
- tfHelper.readNodeFloat(outputName, output);
+ tfHelper.fetch(outputName, output);
// Find the best classification
Classification ans = new Classification();
diff --git a/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so b/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so
deleted file mode 100644
index 9390465..0000000
Binary files a/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so and /dev/null differ
diff --git a/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so b/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so
deleted file mode 100644
index d4572f3..0000000
Binary files a/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so and /dev/null differ
diff --git a/MnistAndroid/app/src/main/res/layout/activity_main.xml b/MnistAndroid/app/src/main/res/layout/activity_main.xml
index e07fe80..b3f727e 100644
--- a/MnistAndroid/app/src/main/res/layout/activity_main.xml
+++ b/MnistAndroid/app/src/main/res/layout/activity_main.xml
@@ -1,47 +1,50 @@
+ xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:tools="http://schemas.android.com/tools"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:orientation="vertical"
+ android:paddingBottom="@dimen/activity_vertical_margin"
+ android:paddingLeft="@dimen/activity_horizontal_margin"
+ android:paddingRight="@dimen/activity_horizontal_margin"
+ android:paddingTop="@dimen/activity_vertical_margin"
+ tools:context="mariannelinhares.mnistandroid.MainActivity"
+ >
+ android:id="@+id/draw"
+ android:layout_width="match_parent"
+ android:layout_height="0dp"
+ android:layout_weight="1"
+ />
+ android:layout_width="match_parent"
+ android:layout_height="wrap_content"
+ android:orientation="horizontal"
+ >
+ android:id="@+id/btn_clear"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:text="Clear"
+ />
+ android:id="@+id/btn_class"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:text="Detect"
+ />
+
-
-
-
+ android:textAppearance="?android:attr/textAppearanceMedium"
+ />
\ No newline at end of file
diff --git a/MnistAndroid/app/src/test/java/mariannelinhares/mnistandroid/ExampleUnitTest.java b/MnistAndroid/app/src/test/java/mariannelinhares/mnistandroid/ExampleUnitTest.java
deleted file mode 100644
index 2a0350d..0000000
--- a/MnistAndroid/app/src/test/java/mariannelinhares/mnistandroid/ExampleUnitTest.java
+++ /dev/null
@@ -1,17 +0,0 @@
-package mariannelinhares.mnistandroid;
-
-import org.junit.Test;
-
-import static org.junit.Assert.*;
-
-/**
- * Example local unit test, which will execute on the development machine (host).
- *
- * @see Testing documentation
- */
-public class ExampleUnitTest {
- @Test
- public void addition_isCorrect() throws Exception {
- assertEquals(4, 2 + 2);
- }
-}
\ No newline at end of file
diff --git a/MnistAndroid/build.gradle b/MnistAndroid/build.gradle
index b78a0b8..7033e6a 100644
--- a/MnistAndroid/build.gradle
+++ b/MnistAndroid/build.gradle
@@ -5,7 +5,7 @@ buildscript {
jcenter()
}
dependencies {
- classpath 'com.android.tools.build:gradle:2.3.1'
+ classpath 'com.android.tools.build:gradle:2.3.2'
// NOTE: Do not place your application dependencies here; they belong
// in the individual module build.gradle files
@@ -15,6 +15,9 @@ buildscript {
allprojects {
repositories {
jcenter()
+ flatDir {
+ dirs "$rootProject.projectDir/aars"
+ }
}
}
diff --git a/MnistAndroid/gradle/wrapper/gradle-wrapper.properties b/MnistAndroid/gradle/wrapper/gradle-wrapper.properties
index 45120ad..0caac1b 100644
--- a/MnistAndroid/gradle/wrapper/gradle-wrapper.properties
+++ b/MnistAndroid/gradle/wrapper/gradle-wrapper.properties
@@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-3.5-all.zip
diff --git a/README.md b/README.md
index 192b8d3..bd4ac5c 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,6 @@ how to save your model and export it for Android or other devices check the
very simple tutorial bellow.
The UI and expert-graph.pb model were taken from: https://github.com/miyosuda/TensorFlowAndroidMNIST, so thank you miyousuda.
-The TensorFlow jar and so armeabi-v7a were taken from: https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample,
-so thank you MindorksOpenSource.
-The Tensorflow so of x86 was taken from: https://github.com/cesardelgadof/TensorFlowAndroidMNIST, so thank you cesardelgadof.
If you have no ideia what I just said above, have a look on the instructions bellow.
@@ -32,7 +29,7 @@ A full example can be seen [here](https://github.com/mari-linhares/mnist-android
Example: `_w = sess.eval(w)`, where w was learned from training.
3. Rewrite your model changing the variables for constants with value = in memory copy of learned variables.
Example: `w_save = tf.constant(_w)`
-
+
Also make sure to put names in the input and output of the model, this will be needed for the model later.
Example:
`x = tf.placeholder(tf.float32, [None, 1000], name='input')`
@@ -42,23 +39,10 @@ A full example can be seen [here](https://github.com/mari-linhares/mnist-android
## How to run my model with Android?
-You need two things:
-
-1. [The TensorFlow jar](https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master/app/libs/libandroid_tensorflow_inference_java.jar)
- Move it to the libs folder, right click and add as library.
-
-2. The TensorFlow so file for the desired architecture:
-[x86](https://github.com/cesardelgadof/TensorFlowAndroidMNIST/blob/master/app/src/main/jniLibs/x86/libtensorflow_mnist.so)
-[armeabi-v7a](https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/tree/master/app/src/main/jniLibs/armeabi-v7a)
-
-Creat the jniLibs/x86 folder or the jniLibs/armeabi-v7a folder at the main folder.
-Move it to app/src/main/jniLibs/x86/libtensorflow_inference.so or app/src/jniLibs/armeabi-v7a/libtensorflow_inference.so
-
-If you want to generate these files yourself, [here](https://blog.mindorks.com/android-tensorflow-machine-learning-example-ff0e9b2654cc) is a nice tutorial of how to do it.
+You need `tensorflow.aar`, which can be downloaded from [the nightly build artifact of TensorFlow CI](http://ci.tensorflow.org/view/Nightly/job/nightly-android/), here we use [the #124 build](http://ci.tensorflow.org/view/Nightly/job/nightly-android/124/artifact/).
## Interacting with TensorFlow
To interact with TensorFlow you will need an instance of TensorFlowInferenceInterface, you can see more details about it [here](https://github.com/mari-linhares/mnist-android-tensorflow/blob/master/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java)
Thank you, have fun!
-
diff --git a/tensorflow_model/convnet.py b/tensorflow_model/convnet.py
deleted file mode 100644
index 0ebdab1..0000000
--- a/tensorflow_model/convnet.py
+++ /dev/null
@@ -1,205 +0,0 @@
-# needed libraries
-import tensorflow as tf
-
-from tensorflow.examples.tutorials.mnist import input_data
-
-logs_path = '/tmp/tensorflow_logs/convnet'
-
-# mnist.train = 55,000 input data
-# mnist.test = 10,000 input data
-# mnist.validate = 5,000 input data
-mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
-
-# Implementing Convnet with TF
-def weight_variable(shape, name=None):
- # break simmetry
- if name:
- w = tf.truncated_normal(shape, stddev=0.1, name=name)
- else:
- w = tf.truncated_normal(shape, stddev=0.1)
-
- return tf.Variable(w)
-
-
-def bias_variable(shape, name=None):
- # avoid dead neurons
- if name:
- b = tf.constant(0.1, shape=shape, name=name)
- else:
- b = tf.constant(0.1, shape=shape)
- return tf.Variable(b)
-
-
-# pool
-def max_pool_2x2(x):
- return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
- strides=[1, 2, 2, 1], padding='SAME')
-
-def new_conv_layer(x, w):
- return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
-
-# our network!!!
-
-g = tf.Graph()
-
-with g.as_default():
-
- # input data
- x = tf.placeholder(tf.float32, shape=[None, 28*28], name='input_data')
- x_image = tf.reshape(x, [-1, 28, 28, 1])
- # correct labels
- y_ = tf.placeholder(tf.float32, shape=[None, 10], name='correct_labels')
-
- # fist conv layer
- with tf.name_scope('convLayer1'):
- w1 = weight_variable([5, 5, 1, 32])
- b1 = bias_variable([32])
- convlayer1 = tf.nn.relu(new_conv_layer(x_image, w1) + b1)
- max_pool1 = max_pool_2x2(convlayer1)
-
- # second conv layer
- with tf.name_scope('convLayer2'):
- w2 = weight_variable([5, 5, 32, 64])
- b2 = bias_variable([64])
- convlayer2 = tf.nn.relu(new_conv_layer(max_pool1, w2) + b2)
- max_pool2 = max_pool_2x2(convlayer2)
-
- # flat layer
- with tf.name_scope('flattenLayer'):
- flat_layer = tf.reshape(max_pool2, [-1, 7 * 7 * 64])
-
- # fully connected layer
- with tf.name_scope('FullyConnectedLayer'):
- wfc1 = weight_variable([7 * 7 * 64, 1024])
- bfc1 = bias_variable([1024])
- fc1 = tf.nn.relu(tf.matmul(flat_layer, wfc1) + bfc1)
-
- # DROPOUT
- with tf.name_scope('Dropout'):
- keep_prob = tf.placeholder(tf.float32)
- drop_layer = tf.nn.dropout(fc1, keep_prob)
-
- # final layer
- with tf.name_scope('FinalLayer'):
- w_f = weight_variable([1024, 10])
- b_f = bias_variable([10])
- y_f = tf.matmul(drop_layer, w_f) + b_f
- y_f_softmax = tf.nn.softmax(y_f)
-
- # loss
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,
- logits=y_f))
-
- # train step
- train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
-
- # accuracy
- correct_prediction = tf.equal(tf.argmax(y_f_softmax, 1), tf.argmax(y_, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
-
- # Create a summary to monitor loss tensor
- tf.summary.scalar("loss", loss)
- # Create a summary to monitor accuracy tensor
- tf.summary.scalar("accuracy", accuracy)
- # Merge all summaries into a single op
- merged_summary_op = tf.summary.merge_all()
-
- # init
- init = tf.global_variables_initializer()
-
- # Running the graph
-
- num_steps = 3000
- batch_size = 16
- test_size = 10000
- test_accuracy = 0.0
-
- sess = tf.Session()
-
- sess.run(init)
- # op to write logs to Tensorboard
- summary_writer = tf.summary.FileWriter(logs_path,
- graph=tf.get_default_graph())
-
- for step in range(num_steps):
- batch = mnist.train.next_batch(batch_size)
-
- ts, error, acc, summary = sess.run([train_step, loss, accuracy,
- merged_summary_op],
- feed_dict={x: batch[0],
- y_: batch[1],
- keep_prob: 0.5})
- if step % 100 == 0:
- train_accuracy = accuracy.eval({
- x: batch[0], y_: batch[1], keep_prob: 1.0}, sess)
- print('step %d, training accuracy %f' % (step, train_accuracy))
- '''
- print 'Done!'
- print 'Evaluating...'
- for i in xrange(test_size/50):
- batch = mnist.test.next_batch(50)
- acc = accuracy.eval({x: batch[0], y_: batch[1],
- keep_prob: 1.0}, sess)
- if i % 10 == 0:
- print('%d: test accuracy %f' % (i, acc))
- test_accuracy += acc
- print 'avg test accuracy:', test_accuracy/(test_size/50.0)
- '''
-
-# copying variables as constants to export graph
-_w1 = w1.eval(sess)
-_b1 = b1.eval(sess)
-_w2 = w2.eval(sess)
-_b2 = b2.eval(sess)
-_wfc1 = wfc1.eval(sess)
-_bfc1 = bfc1.eval(sess)
-_w_f = w_f.eval(sess)
-_b_f = b_f.eval(sess)
-
-sess.close()
-
-g2 = tf.Graph()
-with g2.as_default():
-
- # input data
- x2 = tf.placeholder(tf.float32, shape=[None, 28*28], name='input')
- x2_image = tf.reshape(x2, [-1, 28, 28, 1])
- # correct labels
- y2_ = tf.placeholder(tf.float32, shape=[None, 10])
-
- w1_2 = tf.constant(_w1)
- b1_2 = tf.constant(_b1)
- convlayer1_2 = tf.nn.relu(new_conv_layer(x2_image, w1_2) + b1_2)
- max_pool1_2 = max_pool_2x2(convlayer1_2)
-
- w2_2 = tf.constant(_w2)
- b2_2 = tf.constant(_b2)
- convlayer2_2 = tf.nn.relu(new_conv_layer(max_pool1_2, w2_2) + b2_2)
- max_pool2_2 = max_pool_2x2(convlayer2_2)
-
- # flat layer
- flat_layer_2 = tf.reshape(max_pool2_2, [-1, 7 * 7 * 64])
-
- # fully connected layer
- wfc1_2 = tf.constant(_wfc1)
- bfc1_2 = tf.constant(_bfc1)
- fc1_2 = tf.nn.relu(tf.matmul(flat_layer_2, wfc1_2) + bfc1_2)
-
- # no dropout layer
-
- # final layer
- w_f_2 = tf.constant(_w_f)
- b_f_2 = tf.constant(_b_f)
- y_f_2 = tf.matmul(fc1_2, w_f_2) + b_f_2
- y_f_softmax_2 = tf.nn.softmax(y_f_2, name='output')
-
- # init
- init_2 = tf.global_variables_initializer()
-
- sess_2 = tf.Session()
- init_2 = tf.initialize_all_variables()
- sess_2.run(init_2)
-
- graph_def = g2.as_graph_def()
- tf.train.write_graph(graph_def, '', 'graph.pb', as_text=False)
-
diff --git a/tensorflow_model/graph.pb b/tensorflow_model/graph.pb
deleted file mode 100644
index dcd1321..0000000
Binary files a/tensorflow_model/graph.pb and /dev/null differ
diff --git a/tensorflow_model/mnist_convnet.py b/tensorflow_model/mnist_convnet.py
new file mode 100644
index 0000000..e38f874
--- /dev/null
+++ b/tensorflow_model/mnist_convnet.py
@@ -0,0 +1,140 @@
+# Python 3.6.0
+# tensorflow 1.1.0
+
+import os
+import os.path as path
+
+import tensorflow as tf
+from tensorflow.python.tools import freeze_graph
+from tensorflow.python.tools import optimize_for_inference_lib
+
+from tensorflow.examples.tutorials.mnist import input_data
+
+MODEL_NAME = 'mnist_convnet'
+NUM_STEPS = 3000
+BATCH_SIZE = 16
+
+def model_input(input_node_name, keep_prob_node_name):
+ x = tf.placeholder(tf.float32, shape=[None, 28*28], name=input_node_name)
+ keep_prob = tf.placeholder(tf.float32, name=keep_prob_node_name)
+ y_ = tf.placeholder(tf.float32, shape=[None, 10])
+ return x, keep_prob, y_
+
+def build_model(x, keep_prob, y_, output_node_name):
+ x_image = tf.reshape(x, [-1, 28, 28, 1])
+ # 28*28*1
+
+ conv1 = tf.layers.conv2d(x_image, 64, 3, 1, 'same', activation=tf.nn.relu)
+ # 28*28*64
+ pool1 = tf.layers.max_pooling2d(conv1, 2, 2, 'same')
+ # 14*14*64
+
+ conv2 = tf.layers.conv2d(pool1, 128, 3, 1, 'same', activation=tf.nn.relu)
+ # 14*14*128
+ pool2 = tf.layers.max_pooling2d(conv2, 2, 2, 'same')
+ # 7*7*128
+
+ conv3 = tf.layers.conv2d(pool2, 256, 3, 1, 'same', activation=tf.nn.relu)
+ # 7*7*256
+ pool3 = tf.layers.max_pooling2d(conv3, 2, 2, 'same')
+ # 4*4*256
+
+ flatten = tf.reshape(pool3, [-1, 4*4*256])
+ fc = tf.layers.dense(flatten, 1024, activation=tf.nn.relu)
+ dropout = tf.nn.dropout(fc, keep_prob)
+ logits = tf.layers.dense(dropout, 10)
+ outputs = tf.nn.softmax(logits, name=output_node_name)
+
+ # loss
+ loss = tf.reduce_mean(
+ tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))
+
+ # train step
+ train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
+
+ # accuracy
+ correct_prediction = tf.equal(tf.argmax(outputs, 1), tf.argmax(y_, 1))
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+
+ tf.summary.scalar("loss", loss)
+ tf.summary.scalar("accuracy", accuracy)
+ merged_summary_op = tf.summary.merge_all()
+
+ return train_step, loss, accuracy, merged_summary_op
+
+def train(x, keep_prob, y_, train_step, loss, accuracy,
+ merged_summary_op, saver):
+ print("training start...")
+
+ mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
+
+ init_op = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init_op)
+
+ tf.train.write_graph(sess.graph_def, 'out',
+ MODEL_NAME + '.pbtxt', True)
+
+ # op to write logs to Tensorboard
+ summary_writer = tf.summary.FileWriter('logs/',
+ graph=tf.get_default_graph())
+
+ for step in range(NUM_STEPS):
+ batch = mnist.train.next_batch(BATCH_SIZE)
+ if step % 100 == 0:
+ train_accuracy = accuracy.eval(feed_dict={
+ x: batch[0], y_: batch[1], keep_prob: 1.0})
+ print('step %d, training accuracy %f' % (step, train_accuracy))
+ _, summary = sess.run([train_step, merged_summary_op],
+ feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
+ summary_writer.add_summary(summary, step)
+
+ saver.save(sess, 'out/' + MODEL_NAME + '.chkp')
+
+ test_accuracy = accuracy.eval(feed_dict={x: mnist.test.images,
+ y_: mnist.test.labels,
+ keep_prob: 1.0})
+ print('test accuracy %g' % test_accuracy)
+
+ print("training finished!")
+
+def export_model(input_node_names, output_node_name):
+ freeze_graph.freeze_graph('out/' + MODEL_NAME + '.pbtxt', None, False,
+ 'out/' + MODEL_NAME + '.chkp', output_node_name, "save/restore_all",
+ "save/Const:0", 'out/frozen_' + MODEL_NAME + '.pb', True, "")
+
+ input_graph_def = tf.GraphDef()
+ with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
+ input_graph_def.ParseFromString(f.read())
+
+ output_graph_def = optimize_for_inference_lib.optimize_for_inference(
+ input_graph_def, input_node_names, [output_node_name],
+ tf.float32.as_datatype_enum)
+
+ with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
+ f.write(output_graph_def.SerializeToString())
+
+ print("graph saved!")
+
+def main():
+ if not path.exists('out'):
+ os.mkdir('out')
+
+ input_node_name = 'input'
+ keep_prob_node_name = 'keep_prob'
+ output_node_name = 'output'
+
+ x, keep_prob, y_ = model_input(input_node_name, keep_prob_node_name)
+
+ train_step, loss, accuracy, merged_summary_op = build_model(x, keep_prob,
+ y_, output_node_name)
+ saver = tf.train.Saver()
+
+ train(x, keep_prob, y_, train_step, loss, accuracy,
+ merged_summary_op, saver)
+
+ export_model([input_node_name, keep_prob_node_name], output_node_name)
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow_model/mnist_convnet_keras.py b/tensorflow_model/mnist_convnet_keras.py
new file mode 100644
index 0000000..8d1fdfb
--- /dev/null
+++ b/tensorflow_model/mnist_convnet_keras.py
@@ -0,0 +1,116 @@
+# Python 3.6.0
+# tensorflow 1.1.0
+# Keras 2.0.4
+
+import os
+import os.path as path
+
+import keras
+from keras.datasets import mnist
+from keras.models import Sequential
+from keras.layers import Input, Dense, Dropout, Flatten
+from keras.layers import Conv2D, MaxPooling2D
+from keras import backend as K
+
+import tensorflow as tf
+from tensorflow.python.tools import freeze_graph
+from tensorflow.python.tools import optimize_for_inference_lib
+
+MODEL_NAME = 'mnist_convnet'
+EPOCHS = 1
+BATCH_SIZE = 128
+
+
+def load_data():
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
+ x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
+ x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
+ x_train = x_train.astype('float32')
+ x_test = x_test.astype('float32')
+ x_train /= 255
+ x_test /= 255
+ y_train = keras.utils.to_categorical(y_train, 10)
+ y_test = keras.utils.to_categorical(y_test, 10)
+ return x_train, y_train, x_test, y_test
+
+
+def build_model():
+ model = Sequential()
+ model.add(Conv2D(filters=64, kernel_size=3, strides=1, \
+ padding='same', activation='relu', \
+ input_shape=[28, 28, 1]))
+ # 28*28*64
+ model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
+ # 14*14*64
+
+ model.add(Conv2D(filters=128, kernel_size=3, strides=1, \
+ padding='same', activation='relu'))
+ # 14*14*128
+ model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
+ # 7*7*128
+
+ model.add(Conv2D(filters=256, kernel_size=3, strides=1, \
+ padding='same', activation='relu'))
+ # 7*7*256
+ model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
+ # 4*4*256
+
+ model.add(Flatten())
+ model.add(Dense(1024, activation='relu'))
+ #model.add(Dropout(0.5))
+ model.add(Dense(10, activation='softmax'))
+ return model
+
+
+def train(model, x_train, y_train, x_test, y_test):
+ model.compile(loss=keras.losses.categorical_crossentropy, \
+ optimizer=keras.optimizers.Adadelta(), \
+ metrics=['accuracy'])
+
+ model.fit(x_train, y_train, \
+ batch_size=BATCH_SIZE, \
+ epochs=EPOCHS, \
+ verbose=1, \
+ validation_data=(x_test, y_test))
+
+
+def export_model(saver, model, input_node_names, output_node_name):
+ tf.train.write_graph(K.get_session().graph_def, 'out', \
+ MODEL_NAME + '_graph.pbtxt')
+
+ saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')
+
+ freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \
+ False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
+ "save/restore_all", "save/Const:0", \
+ 'out/frozen_' + MODEL_NAME + '.pb', True, "")
+
+ input_graph_def = tf.GraphDef()
+ with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
+ input_graph_def.ParseFromString(f.read())
+
+ output_graph_def = optimize_for_inference_lib.optimize_for_inference(
+ input_graph_def, input_node_names, [output_node_name],
+ tf.float32.as_datatype_enum)
+
+ with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
+ f.write(output_graph_def.SerializeToString())
+
+ print("graph saved!")
+
+
+def main():
+ if not path.exists('out'):
+ os.mkdir('out')
+
+ x_train, y_train, x_test, y_test = load_data()
+
+ model = build_model()
+
+ train(model, x_train, y_train, x_test, y_test)
+
+ export_model(tf.train.Saver(), model, ["conv2d_1_input"], "dense_2/Softmax")
+
+
+if __name__ == '__main__':
+ main()