tests/pkg_cmsis-nn: add cifar10 example application

This commit is contained in:
Alexandre Abadie 2020-01-08 09:42:06 +01:00
parent fdd2b97387
commit db7d8a58e8
No known key found for this signature in database
GPG Key ID: 1C919A403CAE1405
8 changed files with 2083 additions and 0 deletions

View File

@ -0,0 +1,22 @@
BOARD ?= nucleo-l476rg
include ../Makefile.tests_common
USEPKG += cmsis-nn
BLOBS += input
# Boards that were tested and are known to work
# This package only works with Cortex M3, M4 and M7 CPUs but there's no easy
# way provided by the build system to filter them at that level (arch_cortexm is
# the only feature available) for the moment.
BOARD_WHITELIST := \
b-l475e-iot01a \
iotlab-m3 \
nrf52832-mdk \
nrf52dk \
nucleo-l476rg \
same54-xpro \
stm32f723e-disco
#
include $(RIOTBASE)/Makefile.include

View File

@ -0,0 +1,31 @@
## ARM CMSIS-NN package
This application shows how to use the neural network API provided by the ARM CMSIS
package in order to determine the type of "object" present in an RGB image.
The image are part of the [SIFAR10 dataset](http://www.cs.toronto.edu/~kriz/cifar.html)
which contains 10 classes of objects: plane, car, cat, bird, deer, dog, frog,
horse, ship and truck.
Expected output
---------------
```
Predicted class: cat
```
Change the input image
----------------------
Use the `generate_image.py` script and the `-i` option to generate a new
input image.
For example, the following command
```
./generate_image.py -i 1
```
will generate an input containing an image with a boat.
The generated image is displayed at the end of the script execution, for visual
validation of the prediction made by the neural network running on the device.
Note that each time a new image is generated, the firmware must be rebuilt so
that it embeds the new image.

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
"""Generate a binary file from a sample image of the CIFAR-10 dataset.
Pixel of the sample are stored as uint8, images have size 32x32x3.
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
def main(args):
_, (cifar10_test, _) = cifar10.load_data()
data = cifar10_test[args.index]
data = data.astype('uint8')
output_path = os.path.join(SCRIPT_DIR, args.output)
np.ndarray.tofile(data, output_path)
if args.no_plot is False:
plt.imshow(data)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--index", type=int, default=0,
help="Image index in CIFAR test dataset")
parser.add_argument("-o", "--output", type=str, default='input',
help="Output filename")
parser.add_argument("--no-plot", default=False, action='store_true',
help="Disable image display in matplotlib")
main(parser.parse_args())

1
tests/pkg_cmsis-nn/input Normal file
View File

@ -0,0 +1 @@
žp1Ÿo/¥t3¦v5 p.œm)¢s/Ÿq-žo,Ÿq)¡t) o4¡o1¦u)©u-ªw,§u(¢q& o' p+œm,•k-k-”j+•k,<2C>e'Œb+<2B>a)<29>a&‰_$~[$tU!˜p3—n(Ÿr-¦t8¢p1 q+¤u/¢r-£t.œn&o)Ÿn6£q4ªw)«u(«s!©s o!šp)—s2n5h7Œf4<66>d0•i2“f.f-Ža&<26>b"ˆ_}[ wX"—n/—m!žo$§o0 j*£s,¥u-¥u-£s+¢s+žr0<72>m9¡o3¦s&§r%©q#ªt'Ÿr/o6y`1nZ4bN2eM/rU2xV0†`7<>g3Œc'Žc#b"_"xY!k(n  m®p,§n+§u.©x0©w0¥s,¥u-§{9¿K<E2809A>o)¢s/¤r6žp:•oChP/gWAbZL\ZTPKBJ?2VF4S>'qU-„b.Œf+Œe'ˆc'^$k)œr0¡s1ªr/©r+£q(©x/¦t,¤q)¤t*­€;ö֤Ük—r8o<ŽlGoP2N5UE8qgbpnojrvafi]^]JHCTNFUI/iS-€`0Še.…^$<24>]$”m6…h@d9“p5¡s,¥q'§t)§s)£o%¥t'£v*´ŠU<C5A0>zN€f:aK+B2E:+B8-YSLvqnzyxwzzrtt^``cda[[V::/C:%lT1Œi:Šb,†_(d9m_P/%XJ™u0ªv+¨s(ªv+©u*¦t%¤x'“k4<6B>b;lKdWFDC9NSHHK@STJ„yŽ„|vlicZkf^sogUSM?GE./'O=$„b:<3A>c0†]'ƒsZc`\*+&F@)<29>o8§u*¥r$¨t'«w1¡q3Œm3x^1nM<6E>ƒktj]XWO[_XUXRMME|vk£™Œˆ|pf]QjbXd]TUQJ6<:1519/ kS2Šg3ˆa'ª¡<C2AA>gii6:;|yq™|R¡q+£u)¦z2¥yB®‡_qY;}iN<69><4E><79>€yoeVPJRQMTURPNIQG=Š}p‡{qg]WOFSMEVRLGIC895(#J;#…j;‰g-´°£†<E280A0>^dišš•®•pžt3œt/™v<Ï´íÖÆÏ´¦œƒw®™”ƒ}}nk]UOVTOJJG;95LD:‰}p<>…z…|rjbYVQJWUNTUNKLG21+(_K,„g9··¯ltzŽ—ž¥©¨±œzp2Ÿv3zY/ÕųíàâÜ¿¼¤‡ƒ·Ÿ›œ‰„}lhxohNLEPPM-,([UM¯¥š<C2A5>“‰“Škd\WSMgf`XXONOI;;;)$!;.hQ.¼¿½dlt‡<74>™ª¯²»§ˆ¦x;­{7†],u_P¶¼Ç«¤ªŽ…¹¡—½«Ÿ†wjuk_fbYTTO&&"}yqÒÉÀ ˜Ž]YRSPK^]Xhh^UWQIKN757>70L8½ÂÂZ`i<7F>¯´¹®œ…¦{D²{5Ÿm/aD,¨š˜¨<CB9C>~‰r^º¦”ØÊ· <C2A0>{qbxrisrm22/–“ŒÂ»²›•Œ{vo[XSTSOTTP__UVWQTWYIIIOJ@I7½ÀÁ]_g˜š£¹¼ÀwnbˆjB­|:§t2gH'“„x}g§•½®âØÈ´¬<C2B4><C2AC>ƒu~ukurmGGDš˜“ºµ®•<C2AE>ˆrnhWUPPPLHIFPPHcd^decZXQaYE^I"ÂÄÄlkp¨§¬ºº¼immcYCœw>§z7dJ"sjXŠ{gƹ©¾´©¬¥ŸŒŒš<C592>ˆ}gd_GGF˜˜•³¯ª‰…€znmiUVS[][_`Zmnhstod`PaU5u_/ÅÅÅ„<C385>ˆ¬§®¸²µ‰ŽNSMŒxX}Ms^4x]<5D>ƒtæÝÓòìæŠ‰‡ƒypypelh__XK<58>†v¨Ÿ˜“ŠpleWUPGHDWXWihcpmcxn]gV6y`0ˆh0ËËÌ ¨¤²¿¶¼¨ª¬NVZ~}~Š~qŠyR`P%š<>­£¢˜<C2A2>Œ„uqjXqjZee\ieWpZ:«<>hœŠm”<6D>~‡vmiaNLHOOM^]^e[RkS7}X-—l7<6C>h.Ö××£¦´¤§¸·¸Â°¶º^fi``fœ•”‰oj]=<3D>tivi_rfYtiYf[IsnbV[Xeg_<67>€fv`@D8 €xi…~sKE=<83:85GFAf]Nt^@<40>pDt@Œn6ÔÓͲ¸À§¯½­µÁ°¸¼|ƒ…VX`<60><EFBFBD>™”<E284A2>‡€ohZPM@7†yl|o`<60>ud“<64>…U\]\`]x„u]um\kcVKD:@;4,)'A><VE(…i;w> x6šs-—o.ÇÀ´»½»«°µ®³¹±¶¸<C2B6>˜VZcwy„z|ˆ<CB86>†~F;3<>vllaV†{¸°¨tvvIKIƒwg‰|i†<69>vYVN31,432/14ZZ]y[<£vD«y@¤q4žo2•k.¥œÃÁ»³²¯±­¬µµ´˜<C2B4> cgoƒ‡«¯¹gio]ZWPMI]ZVzvt²­­¿¶±””ddeYNBWM?<=9.46&.3!).9E<GSldK<64>}R<>{L€m=qExi?ux|ÃÈȱ²°²©¨µ³³Š<C2B3>“SW[™Ÿõ÷úÛÞá…Œ<E280A6><EFBFBD>“•œ¤°¶À¾ÄÐÂÀŨ¬µ}…<>nmm=>>#1:"6D1FW:Qf=Un:ToEczHewNhxE`p;\p7ZsOi…¯ÅÕ®·À°¬±±±¶ŒmpqÓÓÑýü÷üýüÐàè|<7C><>r„•|<7C>¢t…œz…˜h|”D]wDWh<Re4To2Tn3Us8]}8^ƒ3[+`‡3h<33>;lŽ0a„+a‰*_„)Y‡`‰¨<E280B0>¨¼¨®¼²¶À¥ª®¥¦¤öõíýûñãçänˆ™<Xo5Pi1Li1Kk0He-Os*Qx.Qq*Rt&V}.Z}.Y~+W€*Y„.].^‰2`‰7`‡5^†3_-Z…[<5B>W;f†ƒ™°¦³¿„ˆ‰Â½µþúòñõõ<C3B5>Ÿ¯=^2Tv2Tw3Uy1Sx2Tt/Vu*Tu'Rs"Oq#Sx'V}&U}*Y-\†8g>gŽ;eŽ8f2c<32>.^Œ3g•0o¢^Œ"U|Ijˆ€”§€ˆ<E282AC>×ÕÑÿýù»ÆÍB]v6[€2X}4Z4Z.Sy-Rs+Rq)Qp$Pq'Su(V{(Yƒ+\†._Š;l>n˜@m“;l•6lš2i˜F{§S‰¶4r¥#c“V)SzB_~€‘¤àåêðõ÷|<7C>™:\r1W{8^ƒ6\<5C>,Rw,Rw/Sw.Tw+Sw+V{,X,Zƒ-a<>6j:nš6i.a<>+_Œ$[Š3lžI²UжL}©2n¢#b•YŠ#V…,S~NjŠÊÛéÓäêa~ŒAh~6^<5E>0W|:a…0W{(Pt-Rw/Tz0W~/Y.Y„3aŒ'\Š']0f”/e“']U…(e™C<E284A2>¶C~°.bŽ3`2l¡#a“ \<5C>!X<>)XŠ.T}h…ŸªÅÓ@dw6ay4^€5_=g:d‡6`ƒ-Sx*Ov)Px.X1\‡.\ˆ*_(]Š'\ˆ%Z‡(]Š,f—?}²/n¤<g3]ˆD|±*d”X‰&[%W+Y„*OqGk…1YrMiGi&Ru1]€8d‡:f‰5\€8^ƒ<c‰9c5aŠ2_‰-^ˆ'Xƒ!S}*[…>pšO„³Iƒµ8t¨&a @l(U=t¨1f”#U„+[<5B>'Z*\†,X}(Qp*UsHhCfJmGjIl$Ps/Vx8_€>e‡Bm<42>KwœEq˜1_†+X+X<i<>UªmœÅ]‘¾<s¤RR~@k6k 8i•-Y„+V†(Y†(\„(W{&Qs$OrEiBeIlEhIl?b:Y Fd/Wv=h‰Jw˜Bo5`ƒ4_-W{CmYƒ§i¶Y‡¯0cM|"T<>Cn

133
tests/pkg_cmsis-nn/main.c Normal file
View File

@ -0,0 +1,133 @@
/*
* Copyright (C) 2019 Inria
*
* This file is subject to the terms and conditions of the GNU Lesser
* General Public License v2.1. See the file LICENSE in the top level
* directory for more details.
*/
/**
* @ingroup tests
* @{
*
* @file
* @brief Sample application for ARM CMSIS-NN package
*
* This example is adapted from ARM CMSIS CIFAR10 example to RIOT by Alexandre Abadie
* https://github.com/ARM-software/CMSIS_5/tree/develop/CMSIS/NN/Examples/ARM/arm_nn_examples/cifar10
*
* @author Alexandre Abadie <alexandre.abadie@inria.fr>
*/
#include <stdint.h>
#include <stdio.h>
#include "arm_math.h"
#include "parameter.h"
#include "weights.h"
#include "arm_nnfunctions.h"
#include "blob/input.h"
/* There are 10 different classes of objects in the CIFAR10 dataset */
#define CLASSES_NUMOF 10
/* include the input and weights */
static const q7_t conv1_wt[CONV1_IM_CH * CONV1_KER_DIM * CONV1_KER_DIM * CONV1_OUT_CH] = CONV1_WT;
static const q7_t conv1_bias[CONV1_OUT_CH] = CONV1_BIAS;
static const q7_t conv2_wt[CONV2_IM_CH * CONV2_KER_DIM * CONV2_KER_DIM * CONV2_OUT_CH] = CONV2_WT;
static const q7_t conv2_bias[CONV2_OUT_CH] = CONV2_BIAS;
static const q7_t conv3_wt[CONV3_IM_CH * CONV3_KER_DIM * CONV3_KER_DIM * CONV3_OUT_CH] = CONV3_WT;
static const q7_t conv3_bias[CONV3_OUT_CH] = CONV3_BIAS;
static const q7_t ip1_wt[IP1_DIM * IP1_OUT] = IP1_WT;
static const q7_t ip1_bias[IP1_OUT] = IP1_BIAS;
/* Here the image_data should be the raw uint8 type RGB image in [RGB, RGB, RGB ... RGB] format */
// static const uint8_t image_data[CONV1_IM_CH * CONV1_IM_DIM * CONV1_IM_DIM] = IMG_DATA;
static q7_t output_data[IP1_OUT];
/* vector buffer: max(im2col buffer,average pool buffer, fully connected buffer) */
static q7_t col_buffer[2 * 5 * 5 * 32 * 2];
static q7_t img_buffer1[32 * 32 * 10 * 4];
static q7_t *img_buffer2 = (q7_t *)(img_buffer1 + (32 * 32 * 32));
static const char classes[CLASSES_NUMOF][6] = {
"plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" };
int main(void)
{
printf("start execution\n");
uint8_t *image_data = (uint8_t *)input;
/* input pre-processing */
int mean_data[3] = INPUT_MEAN_SHIFT;
unsigned int scale_data[3] = INPUT_RIGHT_SHIFT;
for (unsigned i = 0; i < input_len; i += 3) {
img_buffer2[i] = (q7_t)__SSAT( ((((int)image_data[i] - mean_data[0]) << 7) + (0x1 << (scale_data[0] - 1)))
>> scale_data[0], 8);
img_buffer2[i + 1] = (q7_t)__SSAT( ((((int)image_data[i + 1] - mean_data[1]) << 7) + (0x1 << (scale_data[1] - 1)))
>> scale_data[1], 8);
img_buffer2[i + 2] = (q7_t)__SSAT( ((((int)image_data[i + 2] - mean_data[2]) << 7) + (0x1 << (scale_data[2] - 1)))
>> scale_data[2], 8);
}
/* conv1 img_buffer2 -> img_buffer1 */
arm_convolve_HWC_q7_RGB(img_buffer2, CONV1_IM_DIM, CONV1_IM_CH, conv1_wt, CONV1_OUT_CH, CONV1_KER_DIM, CONV1_PADDING,
CONV1_STRIDE, conv1_bias, CONV1_BIAS_LSHIFT, CONV1_OUT_RSHIFT, img_buffer1, CONV1_OUT_DIM,
(q15_t *)col_buffer, NULL);
arm_relu_q7(img_buffer1, CONV1_OUT_DIM * CONV1_OUT_DIM * CONV1_OUT_CH);
/* pool1 img_buffer1 -> img_buffer2 */
arm_maxpool_q7_HWC(img_buffer1, CONV1_OUT_DIM, CONV1_OUT_CH, POOL1_KER_DIM,
POOL1_PADDING, POOL1_STRIDE, POOL1_OUT_DIM, NULL, img_buffer2);
/* conv2 img_buffer2 -> img_buffer1 */
arm_convolve_HWC_q7_fast(img_buffer2, CONV2_IM_DIM, CONV2_IM_CH, conv2_wt, CONV2_OUT_CH, CONV2_KER_DIM,
CONV2_PADDING, CONV2_STRIDE, conv2_bias, CONV2_BIAS_LSHIFT, CONV2_OUT_RSHIFT, img_buffer1,
CONV2_OUT_DIM, (q15_t *)col_buffer, NULL);
arm_relu_q7(img_buffer1, CONV2_OUT_DIM * CONV2_OUT_DIM * CONV2_OUT_CH);
/* pool2 img_buffer1 -> img_buffer2 */
arm_maxpool_q7_HWC(img_buffer1, CONV2_OUT_DIM, CONV2_OUT_CH, POOL2_KER_DIM,
POOL2_PADDING, POOL2_STRIDE, POOL2_OUT_DIM, col_buffer, img_buffer2);
/* conv3 img_buffer2 -> img_buffer1 */
arm_convolve_HWC_q7_fast(img_buffer2, CONV3_IM_DIM, CONV3_IM_CH, conv3_wt, CONV3_OUT_CH, CONV3_KER_DIM,
CONV3_PADDING, CONV3_STRIDE, conv3_bias, CONV3_BIAS_LSHIFT, CONV3_OUT_RSHIFT, img_buffer1,
CONV3_OUT_DIM, (q15_t *)col_buffer, NULL);
arm_relu_q7(img_buffer1, CONV3_OUT_DIM * CONV3_OUT_DIM * CONV3_OUT_CH);
/* pool3 img_buffer-> img_buffer2 */
arm_maxpool_q7_HWC(img_buffer1, CONV3_OUT_DIM, CONV3_OUT_CH, POOL3_KER_DIM,
POOL3_PADDING, POOL3_STRIDE, POOL3_OUT_DIM, col_buffer, img_buffer2);
arm_fully_connected_q7_opt(img_buffer2, ip1_wt, IP1_DIM, IP1_OUT, IP1_BIAS_LSHIFT, IP1_OUT_RSHIFT, ip1_bias,
output_data, (q15_t *)img_buffer1);
arm_softmax_q7(output_data, CLASSES_NUMOF, output_data);
int val = -1;
uint8_t class_idx = 0;
for (unsigned i = 0; i < CLASSES_NUMOF; i++) {
if (output_data[i] > val) {
val = output_data[i];
class_idx = i;
}
}
if (val > 0) {
printf("Predicted class: %s\n", classes[class_idx]);
}
else {
puts("No match found");
}
return 0;
}

View File

@ -0,0 +1,73 @@
/*
* Copyright (C) 2019 Inria
*
* This file is subject to the terms and conditions of the GNU Lesser
* General Public License v2.1. See the file LICENSE in the top level
* directory for more details.
*/
/**
* @ingroup tests
* @{
*
* @file
* @brief CNN parameters
*/
#ifndef PARAMETER_H
#define PARAMETER_H
#ifdef __cplusplus
extern "C" {
#endif
#define CONV1_IM_DIM 32
#define CONV1_IM_CH 3
#define CONV1_KER_DIM 5
#define CONV1_PADDING 2
#define CONV1_STRIDE 1
#define CONV1_OUT_CH 32
#define CONV1_OUT_DIM 32
#define POOL1_KER_DIM 3
#define POOL1_STRIDE 2
#define POOL1_PADDING 0
#define POOL1_OUT_DIM 16
#define CONV2_IM_DIM 16
#define CONV2_IM_CH 32
#define CONV2_KER_DIM 5
#define CONV2_PADDING 2
#define CONV2_STRIDE 1
#define CONV2_OUT_CH 16
#define CONV2_OUT_DIM 16
#define POOL2_KER_DIM 3
#define POOL2_STRIDE 2
#define POOL2_PADDING 0
#define POOL2_OUT_DIM 8
#define CONV3_IM_DIM 8
#define CONV3_IM_CH 16
#define CONV3_KER_DIM 5
#define CONV3_PADDING 2
#define CONV3_STRIDE 1
#define CONV3_OUT_CH 32
#define CONV3_OUT_DIM 8
#define POOL3_KER_DIM 3
#define POOL3_STRIDE 2
#define POOL3_PADDING 0
#define POOL3_OUT_DIM 4
#define IP1_DIM 4*4*32
#define IP1_IM_DIM 4
#define IP1_IM_CH 32
#define IP1_OUT 10
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif /* PARAMETER_H */
/** @} */

View File

@ -0,0 +1,12 @@
#!/usr/bin/env python3
import sys
from testrunner import run
def testfunc(child):
child.expect_exact("Predicted class: cat")
if __name__ == "__main__":
sys.exit(run(testfunc))

1772
tests/pkg_cmsis-nn/weights.h Normal file

File diff suppressed because it is too large Load Diff