tests/pkg_tensorflow-lite: add mnist_mlp complete example
This commit is contained in:
parent
b376c7c5de
commit
26f3307d87
@ -3,9 +3,21 @@ DEVELHELP ?= 0
|
|||||||
|
|
||||||
include ../Makefile.tests_common
|
include ../Makefile.tests_common
|
||||||
|
|
||||||
EXAMPLE ?= hello_world
|
# Other available example: hello_world
|
||||||
|
EXAMPLE ?= mnist
|
||||||
|
|
||||||
USEPKG += tensorflow-lite
|
USEPKG += tensorflow-lite
|
||||||
|
|
||||||
|
# internal mnist example is available as an external module
|
||||||
|
ifeq (mnist,$(EXAMPLE))
|
||||||
|
# TensorFlow-Lite crashes on M4/M7 CPUs when FPU is enabled, so disable it by
|
||||||
|
# default for now
|
||||||
|
DISABLE_MODULE += cortexm_fpu
|
||||||
|
USEMODULE += $(EXAMPLE)
|
||||||
|
EXTERNAL_MODULE_DIRS += $(CURDIR)/$(EXAMPLE)
|
||||||
|
else
|
||||||
|
# Use upstream example
|
||||||
USEMODULE += tensorflow-lite-$(EXAMPLE)
|
USEMODULE += tensorflow-lite-$(EXAMPLE)
|
||||||
|
endif
|
||||||
|
|
||||||
include $(RIOTBASE)/Makefile.include
|
include $(RIOTBASE)/Makefile.include
|
||||||
|
|||||||
1
tests/pkg_tensorflow-lite/mnist/.gitignore
vendored
Normal file
1
tests/pkg_tensorflow-lite/mnist/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
model_basic.tflite
|
||||||
16
tests/pkg_tensorflow-lite/mnist/Makefile
Normal file
16
tests/pkg_tensorflow-lite/mnist/Makefile
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
MODULE = mnist
|
||||||
|
|
||||||
|
SRCXXEXT = cc
|
||||||
|
|
||||||
|
CXXEXFLAGS += -Wno-unused-parameter
|
||||||
|
CXXEXFLAGS += -Wno-type-limits
|
||||||
|
|
||||||
|
CFLAGS += -Wno-pedantic
|
||||||
|
|
||||||
|
# Add the tensorflow lite quantized model as a blob
|
||||||
|
BLOBS += model.tflite
|
||||||
|
|
||||||
|
# Add the input digit image as blob
|
||||||
|
BLOBS += digit
|
||||||
|
|
||||||
|
include $(RIOTBASE)/Makefile.base
|
||||||
BIN
tests/pkg_tensorflow-lite/mnist/digit
Normal file
BIN
tests/pkg_tensorflow-lite/mnist/digit
Normal file
Binary file not shown.
40
tests/pkg_tensorflow-lite/mnist/generate_digit.py
Executable file
40
tests/pkg_tensorflow-lite/mnist/generate_digit.py
Executable file
@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""Generate a binary file from a sample image of the MNIST dataset.
|
||||||
|
Pixel of the sample are stored as float32, images have size 28x28.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tensorflow.keras.datasets import mnist
|
||||||
|
|
||||||
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
_, (mnist_test, _) = mnist.load_data()
|
||||||
|
data = mnist_test[args.index]
|
||||||
|
data = data.astype('uint8')
|
||||||
|
|
||||||
|
output_path = os.path.join(SCRIPT_DIR, args.output)
|
||||||
|
np.ndarray.tofile(data, output_path)
|
||||||
|
|
||||||
|
if args.no_plot is False:
|
||||||
|
plt.gray()
|
||||||
|
plt.imshow(data.reshape(28, 28))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-i", "--index", type=int, default=0,
|
||||||
|
help="Image index in MNIST test dataset")
|
||||||
|
parser.add_argument("-o", "--output", type=str, default='digit',
|
||||||
|
help="Output filename")
|
||||||
|
parser.add_argument("--no-plot", default=False, action='store_true',
|
||||||
|
help="Disable image display in matplotlib")
|
||||||
|
main(parser.parse_args())
|
||||||
127
tests/pkg_tensorflow-lite/mnist/main_functions.cc
Normal file
127
tests/pkg_tensorflow-lite/mnist/main_functions.cc
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2019 Inria
|
||||||
|
*
|
||||||
|
* This file is subject to the terms and conditions of the GNU Lesser
|
||||||
|
* General Public License v2.1. See the file LICENSE in the top level
|
||||||
|
* directory for more details.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup tests
|
||||||
|
*
|
||||||
|
* @file
|
||||||
|
* @brief TensorFlow Lite MNIST MLP inference functions
|
||||||
|
*
|
||||||
|
* @author Alexandre Abadie <alexandre.abadie@inria.fr>
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_interpreter.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/version.h"
|
||||||
|
|
||||||
|
#include "blob/digit.h"
|
||||||
|
#include "blob/model.tflite.h"
|
||||||
|
|
||||||
|
#define THRESHOLD (0.5)
|
||||||
|
|
||||||
|
// Globals, used for compatibility with Arduino-style sketches.
|
||||||
|
namespace {
|
||||||
|
tflite::ErrorReporter* error_reporter = nullptr;
|
||||||
|
const tflite::Model* model = nullptr;
|
||||||
|
tflite::MicroInterpreter* interpreter = nullptr;
|
||||||
|
TfLiteTensor* input = nullptr;
|
||||||
|
TfLiteTensor* output = nullptr;
|
||||||
|
|
||||||
|
// Create an area of memory to use for input, output, and intermediate arrays.
|
||||||
|
// Finding the minimum value for your model may require some trial and error.
|
||||||
|
constexpr int kTensorArenaSize = 6 * 1024;
|
||||||
|
uint8_t tensor_arena[kTensorArenaSize];
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// The name of this function is important for Arduino compatibility.
|
||||||
|
void setup()
|
||||||
|
{
|
||||||
|
// Set up logging. Google style is to avoid globals or statics because of
|
||||||
|
// lifetime uncertainty, but since this has a trivial destructor it's okay.
|
||||||
|
static tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
// Map the model into a usable data structure. This doesn't involve any
|
||||||
|
// copying or parsing, it's a very lightweight operation.
|
||||||
|
model = tflite::GetModel(model_tflite);
|
||||||
|
if (model->version() != TFLITE_SCHEMA_VERSION) {
|
||||||
|
printf("Model provided is schema version %d not equal "
|
||||||
|
"to supported version %d.",
|
||||||
|
static_cast<uint8_t>(model->version()), TFLITE_SCHEMA_VERSION);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicitly load required operators
|
||||||
|
static tflite::MicroMutableOpResolver micro_mutable_op_resolver;
|
||||||
|
micro_mutable_op_resolver.AddBuiltin(
|
||||||
|
tflite::BuiltinOperator_FULLY_CONNECTED,
|
||||||
|
tflite::ops::micro::Register_FULLY_CONNECTED(), 1, 4);
|
||||||
|
micro_mutable_op_resolver.AddBuiltin(
|
||||||
|
tflite::BuiltinOperator_SOFTMAX,
|
||||||
|
tflite::ops::micro::Register_SOFTMAX(), 1, 2);
|
||||||
|
micro_mutable_op_resolver.AddBuiltin(
|
||||||
|
tflite::BuiltinOperator_QUANTIZE,
|
||||||
|
tflite::ops::micro::Register_QUANTIZE());
|
||||||
|
micro_mutable_op_resolver.AddBuiltin(
|
||||||
|
tflite::BuiltinOperator_DEQUANTIZE,
|
||||||
|
tflite::ops::micro::Register_DEQUANTIZE(), 1, 2);
|
||||||
|
|
||||||
|
// Build an interpreter to run the model with.
|
||||||
|
static tflite::MicroInterpreter static_interpreter(
|
||||||
|
model, micro_mutable_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
|
||||||
|
interpreter = &static_interpreter;
|
||||||
|
|
||||||
|
// Allocate memory from the tensor_arena for the model's tensors.
|
||||||
|
TfLiteStatus allocate_status = interpreter->AllocateTensors();
|
||||||
|
if (allocate_status != kTfLiteOk) {
|
||||||
|
puts("AllocateTensors() failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain pointers to the model's input and output tensors.
|
||||||
|
input = interpreter->input(0);
|
||||||
|
output = interpreter->output(0);
|
||||||
|
|
||||||
|
// Copy digit array in input tensor
|
||||||
|
for (unsigned i = 0; i < digit_len; ++i) {
|
||||||
|
input->data.f[i] = static_cast<float>(digit[i]) / 255.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run inference, and report any error
|
||||||
|
TfLiteStatus invoke_status = interpreter->Invoke();
|
||||||
|
if (invoke_status != kTfLiteOk) {
|
||||||
|
puts("Invoke failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the best match from the output tensor
|
||||||
|
float val = 0;
|
||||||
|
uint8_t digit = 0;
|
||||||
|
for (unsigned i = 0; i < 10; ++i) {
|
||||||
|
float current = output->data.f[i];
|
||||||
|
if (current > THRESHOLD && current > val) {
|
||||||
|
val = current;
|
||||||
|
digit = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output the prediction, if there's one
|
||||||
|
if (val > 0) {
|
||||||
|
printf("Digit prediction: %d\n", digit);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
puts("No match found");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The name of this function is important for Arduino compatibility.
|
||||||
|
void loop() {}
|
||||||
117
tests/pkg_tensorflow-lite/mnist/mnist_mlp.py
Executable file
117
tests/pkg_tensorflow-lite/mnist/mnist_mlp.py
Executable file
@ -0,0 +1,117 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# imports for array-handling
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
# keras imports for the dataset and building our neural network
|
||||||
|
from tensorflow.keras.datasets import mnist
|
||||||
|
from tensorflow.keras.models import Sequential
|
||||||
|
from tensorflow.keras.layers import Dense, Dropout
|
||||||
|
|
||||||
|
|
||||||
|
# let's keep our keras backend tensorflow quiet
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
|
# load mnist dataset
|
||||||
|
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||||
|
|
||||||
|
# building the input vector from the 28x28 pixels
|
||||||
|
X_train = X_train.reshape(60000, 784)
|
||||||
|
X_test = X_test.reshape(10000, 784)
|
||||||
|
X_train = X_train.astype('float32')
|
||||||
|
X_test = X_test.astype('float32')
|
||||||
|
|
||||||
|
# Split the train set in a train + validation set
|
||||||
|
X_valid = X_train[50000:]
|
||||||
|
y_valid = y_train[50000:]
|
||||||
|
X_train = X_train[:50000]
|
||||||
|
y_train = y_train[:50000]
|
||||||
|
|
||||||
|
# Normalize the data
|
||||||
|
X_train = X_train / 255.0
|
||||||
|
X_test = X_test / 255.0
|
||||||
|
X_valid = X_valid / 255.0
|
||||||
|
|
||||||
|
# building a very simple linear stack of layers using a sequential model
|
||||||
|
model = Sequential([
|
||||||
|
Dense(64, activation='relu', input_shape=(784,)),
|
||||||
|
Dropout(0.2),
|
||||||
|
Dense(10, activation='softmax')
|
||||||
|
])
|
||||||
|
|
||||||
|
# compiling the sequential model
|
||||||
|
model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'],
|
||||||
|
optimizer='adam')
|
||||||
|
|
||||||
|
batch_size = 32
|
||||||
|
epochs = 30
|
||||||
|
|
||||||
|
# training the model and saving metrics in history
|
||||||
|
history = model.fit(X_train, y_train,
|
||||||
|
batch_size=batch_size, epochs=epochs,
|
||||||
|
verbose=2,
|
||||||
|
validation_data=(X_valid, y_valid))
|
||||||
|
|
||||||
|
# saving the model
|
||||||
|
# Convert the model to the TensorFlow Lite format without quantization
|
||||||
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
# Save the basic model to disk
|
||||||
|
open("model_basic.tflite", "wb").write(tflite_model)
|
||||||
|
|
||||||
|
# Convert the model to the TensorFlow Lite format with quantization
|
||||||
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||||
|
|
||||||
|
(mnist_train, _), (_, _) = mnist.load_data()
|
||||||
|
mnist_train = mnist_train.reshape(60000, 784)
|
||||||
|
mnist_train = mnist_train.astype('float32')
|
||||||
|
mnist_train = mnist_train / 255.0
|
||||||
|
mnist_ds = tf.data.Dataset.from_tensor_slices((mnist_train)).batch(1)
|
||||||
|
|
||||||
|
|
||||||
|
def representative_data_gen():
|
||||||
|
for input_value in mnist_ds.take(100):
|
||||||
|
yield [input_value]
|
||||||
|
|
||||||
|
|
||||||
|
converter.representative_dataset = representative_data_gen
|
||||||
|
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
# # Save the quantized model to disk
|
||||||
|
open("model.tflite", "wb").write(tflite_model)
|
||||||
|
|
||||||
|
basic_model_size = os.path.getsize("model_basic.tflite")
|
||||||
|
print("Basic model is %d bytes" % basic_model_size)
|
||||||
|
quantized_model_size = os.path.getsize("model.tflite")
|
||||||
|
print("Quantized model is %d bytes" % quantized_model_size)
|
||||||
|
difference = basic_model_size - quantized_model_size
|
||||||
|
print("Difference is %d bytes" % difference)
|
||||||
|
|
||||||
|
# Now let's verify the model on a few input digits
|
||||||
|
# Instantiate an interpreter for the model
|
||||||
|
model_quantized_reloaded = tf.lite.Interpreter('model.tflite')
|
||||||
|
|
||||||
|
# Allocate memory for each model
|
||||||
|
model_quantized_reloaded.allocate_tensors()
|
||||||
|
|
||||||
|
# Get the input and output tensors so we can feed in values and get the results
|
||||||
|
model_quantized_input = model_quantized_reloaded.get_input_details()[0]["index"]
|
||||||
|
model_quantized_output = model_quantized_reloaded.get_output_details()[0]["index"]
|
||||||
|
|
||||||
|
# Create arrays to store the results
|
||||||
|
model_quantized_predictions = np.empty(X_test.size)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
# Invoke the interpreter
|
||||||
|
model_quantized_reloaded.set_tensor(model_quantized_input, X_test[i:i+1, :])
|
||||||
|
model_quantized_reloaded.invoke()
|
||||||
|
model_quantized_prediction = model_quantized_reloaded.get_tensor(model_quantized_output)
|
||||||
|
|
||||||
|
print("Digit: {} - Prediction:\n{}".format(y_test[i], model_quantized_prediction))
|
||||||
|
print("")
|
||||||
BIN
tests/pkg_tensorflow-lite/mnist/model.tflite
Normal file
BIN
tests/pkg_tensorflow-lite/mnist/model.tflite
Normal file
Binary file not shown.
@ -5,7 +5,9 @@ from testrunner import run
|
|||||||
|
|
||||||
|
|
||||||
def testfunc(child):
|
def testfunc(child):
|
||||||
pass
|
# The default image of the test application contains a 7 (e.g. it's the
|
||||||
|
# first image in the MNIST test dataset)
|
||||||
|
child.expect_exact("Digit prediction: 7")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user