Преглед изворни кода

Call wrapper function in depthwise_conv unit tests

Patrik Laurell пре 4 година
родитељ
комит
ea60f7b098

+ 2 - 0
CMSIS/NN/Tests/UnitTest/README.md

@@ -9,6 +9,8 @@ The [Unity test framework](http://www.throwtheswitch.org/unity) is used for runn
 Python3 is required.
 It has been tested with Python 3.6 and it has been tested on Ubuntu 16 and 18.
 
+Make sure to use a `pip` version > 19.0 (or >20.3 for macOS), otherwise tensorflow 2 packages are not available.
+
 There is a requirement file that can be used to install the dependencies.
 
 ```

+ 69 - 1
CMSIS/NN/Tests/UnitTest/TestCases/test_arm_depthwise_conv_3x3_s8/test_arm_depthwise_conv_3x3_s8.c

@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
+ * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -83,6 +83,25 @@ void depthwise_kernel_3x3_arm_depthwise_conv_3x3_s8(void)
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
     TEST_ASSERT_TRUE(validate(output, depthwise_kernel_3x3_output_ref, DEPTHWISE_KERNEL_3X3_DST_SIZE));
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    result = arm_depthwise_conv_wrapper_s8(&ctx,
+                                           &dw_conv_params,
+                                           &quant_params,
+                                           &input_dims,
+                                           input_data,
+                                           &filter_dims,
+                                           kernel_data,
+                                           &bias_dims,
+                                           bias_data,
+                                           &output_dims,
+                                           output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected, result);
+    TEST_ASSERT_TRUE(validate(output, depthwise_kernel_3x3_output_ref, DEPTHWISE_KERNEL_3X3_DST_SIZE));
 }
 
 void depthwise_kernel_3x3_arm_depthwise_conv_3x3_1_s8(void)
@@ -142,6 +161,32 @@ void depthwise_kernel_3x3_arm_depthwise_conv_3x3_1_s8(void)
 
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
+
+    // The wrapper calls different functions for Cortex-M55 and Cortex-M7. Hence
+    // the different expected status.
+#if defined(ARM_MATH_MVEI)
+    const arm_status expected_wrapper = ARM_MATH_SUCCESS;
+#else
+    const arm_status expected_wrapper = ARM_MATH_ARGUMENT_ERROR;
+#endif
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    result = arm_depthwise_conv_wrapper_s8(&ctx,
+                                           &dw_conv_params,
+                                           &quant_params,
+                                           &input_dims,
+                                           input_data,
+                                           &filter_dims,
+                                           kernel_data,
+                                           &bias_dims,
+                                           bias_data,
+                                           &output_dims,
+                                           output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected_wrapper, result);
 }
 
 void depthwise_kernel_3x3_arm_depthwise_conv_3x3_2_s8(void)
@@ -201,4 +246,27 @@ void depthwise_kernel_3x3_arm_depthwise_conv_3x3_2_s8(void)
 
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    // When calling the wrapper the arm_depthwise_conv_3x3_s8 will
+    // not be called. arm_depthwise_conv_s8_opt will be called and
+    // exit with success.
+    const arm_status expected_wrapper = ARM_MATH_SUCCESS;
+
+    result = arm_depthwise_conv_wrapper_s8(&ctx,
+                                           &dw_conv_params,
+                                           &quant_params,
+                                           &input_dims,
+                                           input_data,
+                                           &filter_dims,
+                                           kernel_data,
+                                           &bias_dims,
+                                           bias_data,
+                                           &output_dims,
+                                           output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected_wrapper, result);
 }

+ 96 - 1
CMSIS/NN/Tests/UnitTest/TestCases/test_arm_depthwise_conv_s8/test_arm_depthwise_conv_s8.c

@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
+ * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -83,6 +83,25 @@ void basic_arm_depthwise_conv_s8(void)
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
     TEST_ASSERT_TRUE(validate(output, basic_output_ref, BASIC_DST_SIZE));
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    result = arm_depthwise_conv_wrapper_s8(&ctx,
+                                           &dw_conv_params,
+                                           &quant_params,
+                                           &input_dims,
+                                           input_data,
+                                           &filter_dims,
+                                           basic_weights,
+                                           &bias_dims,
+                                           bias_data,
+                                           &output_dims,
+                                           output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected, result);
+    TEST_ASSERT_TRUE(validate(output, basic_output_ref, BASIC_DST_SIZE));
 }
 
 void stride2pad1_arm_depthwise_conv_s8(void)
@@ -143,6 +162,25 @@ void stride2pad1_arm_depthwise_conv_s8(void)
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
     TEST_ASSERT_TRUE(validate(output, stride2pad1_output_ref, STRIDE2PAD1_DST_SIZE));
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    result = arm_depthwise_conv_wrapper_s8(&ctx,
+                                           &dw_conv_params,
+                                           &quant_params,
+                                           &input_dims,
+                                           input_data,
+                                           &filter_dims,
+                                           kernel_data,
+                                           &bias_dims,
+                                           bias_data,
+                                           &output_dims,
+                                           output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected, result);
+    TEST_ASSERT_TRUE(validate(output, stride2pad1_output_ref, STRIDE2PAD1_DST_SIZE));
 }
 
 void depthwise_2_arm_depthwise_conv_s8(void)
@@ -203,6 +241,25 @@ void depthwise_2_arm_depthwise_conv_s8(void)
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
     TEST_ASSERT_TRUE(validate(output, depthwise_2_output_ref, DEPTHWISE_2_DST_SIZE));
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    result = arm_depthwise_conv_s8(&ctx,
+                                   &dw_conv_params,
+                                   &quant_params,
+                                   &input_dims,
+                                   input_data,
+                                   &filter_dims,
+                                   kernel_data,
+                                   &bias_dims,
+                                   bias_data,
+                                   &output_dims,
+                                   output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected, result);
+    TEST_ASSERT_TRUE(validate(output, depthwise_2_output_ref, DEPTHWISE_2_DST_SIZE));
 }
 
 void depthwise_out_activation_arm_depthwise_conv_s8(void)
@@ -263,6 +320,25 @@ void depthwise_out_activation_arm_depthwise_conv_s8(void)
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
     TEST_ASSERT_TRUE(validate(output, depthwise_out_activation_output_ref, DEPTHWISE_OUT_ACTIVATION_DST_SIZE));
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    result = arm_depthwise_conv_s8(&ctx,
+                                   &dw_conv_params,
+                                   &quant_params,
+                                   &input_dims,
+                                   input_data,
+                                   &filter_dims,
+                                   kernel_data,
+                                   &bias_dims,
+                                   bias_data,
+                                   &output_dims,
+                                   output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected, result);
+    TEST_ASSERT_TRUE(validate(output, depthwise_out_activation_output_ref, DEPTHWISE_OUT_ACTIVATION_DST_SIZE));
 }
 
 void depthwise_mult_batches_arm_depthwise_conv_s8(void)
@@ -323,4 +399,23 @@ void depthwise_mult_batches_arm_depthwise_conv_s8(void)
     free(ctx.buf);
     TEST_ASSERT_EQUAL(expected, result);
     TEST_ASSERT_TRUE(validate(output, depthwise_mult_batches_output_ref, DEPTHWISE_MULT_BATCHES_DST_SIZE));
+
+    ctx.buf = NULL;
+    ctx.size = 0;
+
+    result = arm_depthwise_conv_s8(&ctx,
+                                   &dw_conv_params,
+                                   &quant_params,
+                                   &input_dims,
+                                   input_data,
+                                   &filter_dims,
+                                   kernel_data,
+                                   &bias_dims,
+                                   bias_data,
+                                   &output_dims,
+                                   output);
+
+    free(ctx.buf);
+    TEST_ASSERT_EQUAL(expected, result);
+    TEST_ASSERT_TRUE(validate(output, depthwise_mult_batches_output_ref, DEPTHWISE_MULT_BATCHES_DST_SIZE));
 }

+ 75 - 49
CMSIS/NN/Tests/UnitTest/requirements.txt

@@ -1,39 +1,49 @@
-absl-py==0.12.0
 appdirs==1.4.4
+apsw==3.16.2.post1
 asn1ate==0.6.0
-asn1crypto==1.4.0
-astunparse==1.6.3
-beautifulsoup4==4.6.3
-cachetools==4.2.1
+asn1crypto==0.24.0
+backports.functools-lru-cache==1.6.3
+backports.shutil-get-terminal-size==1.0.0
+beautifulsoup4==4.9.3
+cbor==1.0.0
 certifi==2020.12.5
 cffi==1.14.5
-chardet==3.0.4
+chardet==4.0.0
+CherryPy==8.9.1
 Click==7.0
 cmsis-pack-manager==0.2.10
-colorama==0.3.9
-cryptography==3.3.2
+colorama==0.4.4
+cryptography==2.9.2
+cssselect==1.0.3
+cssutils==1.0.2
+decorator==4.1.2
+dnspython==1.15.0
 ecdsa==0.16.1
+enum34==1.1.10
 fasteners==0.16
+feedparser==5.2.1
 flatbuffers==1.12
-future==0.16.0
-gast==0.3.3
-google-auth==1.27.1
-google-auth-oauthlib==0.4.3
-google-pasta==0.2.0
-grpcio==1.32.0
-h5py==2.10.0
+functools32==3.2.3.post2
+future==0.18.2
+gyp==0.1
+html5-parser==0.4.4
+html5lib==0.999999999
 icetea==1.2.4
-idna==2.7
-importlib-metadata==3.7.3
-intelhex==2.2.1
-Jinja2==2.11.3
+idna==2.10
+intelhex==2.3.0
+ipaddress==1.0.23
+ipython==5.5.0
+ipython-genutils==0.2.0
+Jinja2==2.10.3
 jsonmerge==1.8.0
 jsonschema==2.6.0
-junit-xml==1.8
-Keras-Preprocessing==1.1.2
+junit-xml==1.9
+keyring==10.6.0
+keyrings.alt==3.0
 lockfile==0.12.2
+lxml==4.2.1
 manifest-tool==1.5.2
-Markdown==3.3.4
+Markdown==2.6.9
 MarkupSafe==1.1.1
 mbed-cli==1.10.5
 mbed-cloud-sdk==2.0.8
@@ -41,41 +51,57 @@ mbed-flasher==0.10.1
 mbed-greentea==1.7.4
 mbed-host-tests==1.5.10
 mbed-ls==1.7.12
-mbed-os-tools==0.0.15
+mbed-os-tools==1.8.4
+mechanize==0.2.5
+mercurial==4.5.3
 milksnake==0.1.5
-numpy==1.19.5
-oauthlib==3.1.0
-opt-einsum==3.3.0
-packaging==20.9
-prettytable==0.7.2
+monotonic==1.5
+msgpack==0.5.6
+netifaces==0.10.4
+numpy==1.16.6
+olefile==0.45.1
+pathlib2==2.3.0
+pexpect==4.2.1
+pickleshare==0.7.4
+Pillow==5.1.0
+prettytable==1.0.1
+prompt-toolkit==1.0.15
 protobuf==3.5.2.post1
-psutil==5.6.6
+psutil==5.6.2
 pyasn1==0.2.3
-pyasn1-modules==0.2.8
+pycairo==1.16.2
 pycparser==2.20
-pycryptodome==3.7.3
+pycrypto==2.6.1
+pycryptodome==3.10.1
 pyelftools==0.25
+Pygments==2.2.0
+pygobject==3.26.1
+pykerberos==1.1.14
+pyOpenSSL==20.0.1
 pyparsing==2.4.7
-pyserial==3.4
+pyserial==3.5
 python-dateutil==2.8.1
-python-dotenv==0.15.0
+python-dotenv==0.17.0
 pyusb==1.1.1
-PyYAML==5.4
-requests==2.20.1
-requests-oauthlib==1.3.0
-rsa==4.7.2
+pyxdg==0.25
+PyYAML==4.2b1
+regex==2017.12.12
+repoze.lru==0.7
+requests==2.25.1
+Routes==2.4.1
+scandir==1.7
+SecretStorage==2.3.1
 semver==2.13.0
-six==1.12.0
-soupsieve==2.2
-tensorboard==2.4.1
-tensorboard-plugin-wit==1.8.0
-tensorflow>=2.4.2
-tensorflow-estimator==2.4.0
-termcolor==1.1.0
-typing-extensions==3.7.4.3
-urllib3==1.26.5
+simplegeneric==0.8.1
+simplejson==3.13.2
+six==1.15.0
+soupsieve==1.9.6
+traitlets==4.3.2
+typing==3.7.4.3
+urllib3==1.26.4
+uTidylib==0.3
+vboxapi==1.0
 wcwidth==0.2.5
-Werkzeug==1.0.1
-wrapt==1.12.1
+webencodings==0.5
+WebOb==1.7.3
 yattag==1.14.0
-zipp==3.4.1