IMU_Classifier.ino 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #include <LSM6DS3.h>
  2. #include <Wire.h>
  3. #include <TensorFlowLite.h>
  4. #include <tensorflow/lite/micro/all_ops_resolver.h>
  5. #include <tensorflow/lite/micro/micro_error_reporter.h>
  6. #include <tensorflow/lite/micro/micro_interpreter.h>
  7. #include <tensorflow/lite/schema/schema_generated.h>
  8. //#include <tensorflow/lite/version.h>
  9. #include "model.h"
  10. const float accelerationThreshold = 2.5; // threshold of significant in G's
  11. const int numSamples = 119;
  12. int samplesRead = numSamples;
  13. LSM6DS3 myIMU(I2C_MODE, 0x6A);
  14. // global variables used for TensorFlow Lite (Micro)
  15. tflite::MicroErrorReporter tflErrorReporter;
  16. // pull in all the TFLM ops, you can remove this line and
  17. // only pull in the TFLM ops you need, if would like to reduce
  18. // the compiled size of the sketch.
  19. tflite::AllOpsResolver tflOpsResolver;
  20. const tflite::Model* tflModel = nullptr;
  21. tflite::MicroInterpreter* tflInterpreter = nullptr;
  22. TfLiteTensor* tflInputTensor = nullptr;
  23. TfLiteTensor* tflOutputTensor = nullptr;
  24. // Create a static memory buffer for TFLM, the size may need to
  25. // be adjusted based on the model you are using
  26. constexpr int tensorArenaSize = 8 * 1024;
  27. byte tensorArena[tensorArenaSize] __attribute__((aligned(16)));
  28. // array to map gesture index to a name
  29. const char* GESTURES[] = {
  30. "punch",
  31. "flex"
  32. };
  33. #define NUM_GESTURES (sizeof(GESTURES) / sizeof(GESTURES[0]))
  34. void setup() {
  35. Serial.begin(9600);
  36. while (!Serial);
  37. if (myIMU.begin() != 0) {
  38. Serial.println("Device error");
  39. } else {
  40. Serial.println("Device OK!");
  41. }
  42. Serial.println();
  43. // get the TFL representation of the model byte array
  44. tflModel = tflite::GetModel(model);
  45. if (tflModel->version() != TFLITE_SCHEMA_VERSION) {
  46. Serial.println("Model schema mismatch!");
  47. while (1);
  48. }
  49. // Create an interpreter to run the model
  50. tflInterpreter = new tflite::MicroInterpreter(tflModel, tflOpsResolver, tensorArena, tensorArenaSize, &tflErrorReporter);
  51. // Allocate memory for the model's input and output tensors
  52. tflInterpreter->AllocateTensors();
  53. // Get pointers for the model's input and output tensors
  54. tflInputTensor = tflInterpreter->input(0);
  55. tflOutputTensor = tflInterpreter->output(0);
  56. }
  57. void loop() {
  58. float aX, aY, aZ, gX, gY, gZ;
  59. // wait for significant motion
  60. while (samplesRead == numSamples) {
  61. // read the acceleration data
  62. aX = myIMU.readFloatAccelX();
  63. aY = myIMU.readFloatAccelY();
  64. aZ = myIMU.readFloatAccelZ();
  65. // sum up the absolutes
  66. float aSum = fabs(aX) + fabs(aY) + fabs(aZ);
  67. // check if it's above the threshold
  68. if (aSum >= accelerationThreshold) {
  69. // reset the sample read count
  70. samplesRead = 0;
  71. break;
  72. }
  73. }
  74. // check if the all the required samples have been read since
  75. // the last time the significant motion was detected
  76. while (samplesRead < numSamples) {
  77. // check if new acceleration AND gyroscope data is available
  78. // read the acceleration and gyroscope data
  79. aX = myIMU.readFloatAccelX();
  80. aY = myIMU.readFloatAccelY();
  81. aZ = myIMU.readFloatAccelZ();
  82. gX = myIMU.readFloatGyroX();
  83. gY = myIMU.readFloatGyroY();
  84. gZ = myIMU.readFloatGyroZ();
  85. // normalize the IMU data between 0 to 1 and store in the model's
  86. // input tensor
  87. tflInputTensor->data.f[samplesRead * 6 + 0] = (aX + 4.0) / 8.0;
  88. tflInputTensor->data.f[samplesRead * 6 + 1] = (aY + 4.0) / 8.0;
  89. tflInputTensor->data.f[samplesRead * 6 + 2] = (aZ + 4.0) / 8.0;
  90. tflInputTensor->data.f[samplesRead * 6 + 3] = (gX + 2000.0) / 4000.0;
  91. tflInputTensor->data.f[samplesRead * 6 + 4] = (gY + 2000.0) / 4000.0;
  92. tflInputTensor->data.f[samplesRead * 6 + 5] = (gZ + 2000.0) / 4000.0;
  93. samplesRead++;
  94. if (samplesRead == numSamples) {
  95. // Run inferencing
  96. TfLiteStatus invokeStatus = tflInterpreter->Invoke();
  97. if (invokeStatus != kTfLiteOk) {
  98. Serial.println("Invoke failed!");
  99. while (1);
  100. return;
  101. }
  102. // Loop through the output tensor values from the model
  103. for (int i = 0; i < NUM_GESTURES; i++) {
  104. Serial.print(GESTURES[i]);
  105. Serial.print(": ");
  106. Serial.println(tflOutputTensor->data.f[i], 6);
  107. }
  108. Serial.println();
  109. }
  110. }
  111. }