tests/pkg_emlearn: add sample application

This commit is contained in:
Alexandre Abadie 2020-01-07 11:04:57 +01:00
parent 5325233928
commit b5dd94d223
No known key found for this signature in database
GPG Key ID: 1C919A403CAE1405
10 changed files with 187 additions and 0 deletions

2
tests/pkg_emlearn/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
model.h
tmp/

View File

@ -0,0 +1,13 @@
include ../Makefile.tests_common
USEPKG += emlearn
BLOBS += digit
BUILDDEPS += model.h
include $(RIOTBASE)/Makefile.include
model.h: $(CURDIR)/model
$(Q)$(CURDIR)/generate_model.py
$(Q)echo "/* fix for no newline at eof */\n" >> model.h

View File

@ -0,0 +1,40 @@
## Emlearn package test application
This application shows how to use a machine learning model with emlearn on RIOT
in order to predict a value from a hand written digit image.
The model is a [Scikit-Learn](https://scikit-learn.org) random forest estimator
trained on the MNIST dataset.
### Expected output
The default digit to predict is a hand-written '6', so the application output
is the following:
```
Predicted digit: 6
```
### Use the Python scripts
The application comes with 3 Python scripts:
- `generate_digit.py` is used to generate a new digit file. This file is
embedded in the firmware image and is used as input for the inference engine.
Use the `-i` option to select a different digit.
For example, the following command:
```
$ ./generate_digit.py -i 1
```
will generate a digit containing a '9'.
The digit is displayed at the end of the script so one knows which digit is
stored.
Note that each time a new digit is generated, the firmware image must be
rebuilt to include this new digit.
- `train_model.py` is used to train a new Scikit-Learn Random Forest estimator.
The trained model is stored in the `model` binary file.
```
$ ./train_model.py
```
will just train the model.
- `generate_model.py` is used to generate the `sonar.h` header file from the
`model` binary file. The script is called automatically by the build system
when the `model` binary file is updated.

BIN
tests/pkg_emlearn/digit Normal file

Binary file not shown.

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
"""Generate a binary file from a sample image of the MNIST dataset.
Pixel of the sample are stored as float32, images have size 8x8.
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn import datasets
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
def main(args):
output_path = os.path.join(SCRIPT_DIR, args.output)
digits = datasets.load_digits()
rnd = 42
_, data, _, _ = train_test_split(digits.data, digits.target,
random_state=rnd)
data = data[args.index]
np.ndarray.tofile(data.astype('float32'), output_path)
if args.no_plot is False:
plt.gray()
plt.imshow(data.reshape(8, 8))
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--index", type=int, default=0,
help="Image index in MNIST test dataset")
parser.add_argument("-o", "--output", type=str, default='digit',
help="Output filename")
parser.add_argument("--no-plot", default=False, action='store_true',
help="Disable image display in matplotlib")
main(parser.parse_args())

View File

@ -0,0 +1,8 @@
#!/usr/bin/env python3
import emlearn
import joblib
estimator = joblib.load("model")
cmodel = emlearn.convert(estimator)
cmodel.save(file='model.h')

36
tests/pkg_emlearn/main.c Normal file
View File

@ -0,0 +1,36 @@
/*
* Copyright (C) 2019 Inria
*
* This file is subject to the terms and conditions of the GNU Lesser
* General Public License v2.1. See the file LICENSE in the top level
* directory for more details.
*/
/**
* @ingroup tests
* @{
*
* @file
* @brief Emlearn test application
*
* @author Alexandre Abadie <alexandre.abadie@inria.fr>
*
* @}
*/
#include <stdio.h>
#include <inttypes.h>
#include "model.h"
/* the digit array included must be 4-byte aligned */
__attribute__((__aligned__(4)))
#include "blob/digit.h"
int main(void)
{
printf("Predicted digit: %" PRIi32 "\n",
model_predict((const float *)digit, digit_len >> 2));
return 0;
}

BIN
tests/pkg_emlearn/model Normal file

Binary file not shown.

View File

@ -0,0 +1,18 @@
#!/usr/bin/env python3
# Copyright (C) 2019 Inria
#
# This file is subject to the terms and conditions of the GNU Lesser
# General Public License v2.1. See the file LICENSE in the top level
# directory for more details.
import sys
from testrunner import run
def testfunc(child):
child.expect_exact("Predicted digit: 6")
if __name__ == "__main__":
sys.exit(run(testfunc))

View File

@ -0,0 +1,27 @@
#!/usr/bin/env python3
import joblib
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics, datasets
rnd = 42
digits = datasets.load_digits()
Xtrain, Xtest, ytrain, ytest = train_test_split(digits.data, digits.target, random_state=rnd)
print('Loading digits dataset. 8x8=64 features')
# 0.95+ with n_estimators=10, max_depth=10
trees = 10
max_depth = 10
print('Training {} trees with max_depth {}'.format(trees, max_depth))
model = RandomForestClassifier(n_estimators=trees, max_depth=max_depth, random_state=rnd)
model.fit(Xtrain, ytrain)
# Predict
ypred = model.predict(Xtest)
print('Accuracy on validation set {:.2f}%'.format(metrics.accuracy_score(ypred, ytest)*100))
# Store the model in a binary file
joblib.dump(model, "model")