comparisons.cc 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
  13. #include "tensorflow/lite/c/common.h"
  14. #include "tensorflow/lite/kernels/internal/quantization_util.h"
  15. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
  16. #include "tensorflow/lite/kernels/kernel_util.h"
  17. #include "tensorflow/lite/micro/kernels/kernel_util.h"
  18. namespace tflite {
  19. namespace ops {
  20. namespace micro {
  21. namespace comparisons {
  22. namespace {
  23. struct OpData {
  24. ComparisonParams params;
  25. };
  26. constexpr int kInputTensor1 = 0;
  27. constexpr int kInputTensor2 = 1;
  28. constexpr int kOutputTensor = 0;
  29. TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
  30. TFLITE_DCHECK(node->user_data != nullptr);
  31. const OpData* data = static_cast<const OpData*>(node->user_data);
  32. const TfLiteEvalTensor* input1 =
  33. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  34. const TfLiteEvalTensor* input2 =
  35. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  36. TfLiteEvalTensor* output =
  37. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  38. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  39. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  40. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  41. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  42. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  43. switch (input1->type) {
  44. case kTfLiteBool:
  45. requires_broadcast
  46. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  47. data->params, input1_shape,
  48. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  49. tflite::micro::GetTensorData<bool>(input2), output_shape,
  50. output_data)
  51. : reference_ops::EqualNoScaling(
  52. data->params, input1_shape,
  53. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  54. tflite::micro::GetTensorData<bool>(input2), output_shape,
  55. output_data);
  56. break;
  57. case kTfLiteFloat32:
  58. requires_broadcast
  59. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  60. data->params, input1_shape,
  61. tflite::micro::GetTensorData<float>(input1), input2_shape,
  62. tflite::micro::GetTensorData<float>(input2), output_shape,
  63. output_data)
  64. : reference_ops::EqualNoScaling(
  65. data->params, input1_shape,
  66. tflite::micro::GetTensorData<float>(input1), input2_shape,
  67. tflite::micro::GetTensorData<float>(input2), output_shape,
  68. output_data);
  69. break;
  70. case kTfLiteInt32:
  71. requires_broadcast
  72. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  73. data->params, input1_shape,
  74. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  75. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  76. output_data)
  77. : reference_ops::EqualNoScaling(
  78. data->params, input1_shape,
  79. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  80. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  81. output_data);
  82. break;
  83. case kTfLiteInt64:
  84. requires_broadcast
  85. ? reference_ops::Broadcast4DSlowEqualNoScaling(
  86. data->params, input1_shape,
  87. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  88. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  89. output_data)
  90. : reference_ops::EqualNoScaling(
  91. data->params, input1_shape,
  92. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  93. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  94. output_data);
  95. break;
  96. case kTfLiteUInt8:
  97. requires_broadcast
  98. ? reference_ops::Broadcast4DSlowEqualWithScaling(
  99. data->params, input1_shape,
  100. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  101. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  102. output_data)
  103. : reference_ops::EqualWithScaling(
  104. data->params, input1_shape,
  105. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  106. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  107. output_data);
  108. break;
  109. case kTfLiteInt8:
  110. requires_broadcast
  111. ? reference_ops::Broadcast4DSlowEqualWithScaling(
  112. data->params, input1_shape,
  113. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  114. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  115. output_data)
  116. : reference_ops::EqualWithScaling(
  117. data->params, input1_shape,
  118. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  119. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  120. output_data);
  121. break;
  122. default:
  123. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  124. TfLiteTypeGetName(input1->type), input1->type);
  125. return kTfLiteError;
  126. }
  127. return kTfLiteOk;
  128. }
  129. // TODO(renjieliu): Refactor the logic to avoid duplications.
  130. TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
  131. TFLITE_DCHECK(node->user_data != nullptr);
  132. const OpData* data = static_cast<const OpData*>(node->user_data);
  133. const TfLiteEvalTensor* input1 =
  134. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  135. const TfLiteEvalTensor* input2 =
  136. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  137. TfLiteEvalTensor* output =
  138. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  139. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  140. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  141. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  142. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  143. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  144. switch (input1->type) {
  145. case kTfLiteBool:
  146. requires_broadcast
  147. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  148. data->params, input1_shape,
  149. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  150. tflite::micro::GetTensorData<bool>(input2), output_shape,
  151. output_data)
  152. : reference_ops::NotEqualNoScaling(
  153. data->params, input1_shape,
  154. tflite::micro::GetTensorData<bool>(input1), input2_shape,
  155. tflite::micro::GetTensorData<bool>(input2), output_shape,
  156. output_data);
  157. break;
  158. case kTfLiteFloat32:
  159. requires_broadcast
  160. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  161. data->params, input1_shape,
  162. tflite::micro::GetTensorData<float>(input1), input2_shape,
  163. tflite::micro::GetTensorData<float>(input2), output_shape,
  164. output_data)
  165. : reference_ops::NotEqualNoScaling(
  166. data->params, input1_shape,
  167. tflite::micro::GetTensorData<float>(input1), input2_shape,
  168. tflite::micro::GetTensorData<float>(input2), output_shape,
  169. output_data);
  170. break;
  171. case kTfLiteInt32:
  172. requires_broadcast
  173. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  174. data->params, input1_shape,
  175. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  176. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  177. output_data)
  178. : reference_ops::NotEqualNoScaling(
  179. data->params, input1_shape,
  180. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  181. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  182. output_data);
  183. break;
  184. case kTfLiteInt64:
  185. requires_broadcast
  186. ? reference_ops::Broadcast4DSlowNotEqualNoScaling(
  187. data->params, input1_shape,
  188. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  189. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  190. output_data)
  191. : reference_ops::NotEqualNoScaling(
  192. data->params, input1_shape,
  193. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  194. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  195. output_data);
  196. break;
  197. case kTfLiteUInt8:
  198. requires_broadcast
  199. ? reference_ops::Broadcast4DSlowNotEqualWithScaling(
  200. data->params, input1_shape,
  201. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  202. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  203. output_data)
  204. : reference_ops::NotEqualWithScaling(
  205. data->params, input1_shape,
  206. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  207. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  208. output_data);
  209. break;
  210. case kTfLiteInt8:
  211. requires_broadcast
  212. ? reference_ops::Broadcast4DSlowNotEqualWithScaling(
  213. data->params, input1_shape,
  214. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  215. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  216. output_data)
  217. : reference_ops::NotEqualWithScaling(
  218. data->params, input1_shape,
  219. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  220. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  221. output_data);
  222. break;
  223. default:
  224. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  225. TfLiteTypeGetName(input1->type), input1->type);
  226. return kTfLiteError;
  227. }
  228. return kTfLiteOk;
  229. }
  230. TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
  231. TFLITE_DCHECK(node->user_data != nullptr);
  232. const OpData* data = static_cast<const OpData*>(node->user_data);
  233. const TfLiteEvalTensor* input1 =
  234. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  235. const TfLiteEvalTensor* input2 =
  236. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  237. TfLiteEvalTensor* output =
  238. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  239. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  240. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  241. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  242. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  243. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  244. switch (input1->type) {
  245. case kTfLiteFloat32:
  246. requires_broadcast
  247. ? reference_ops::Broadcast4DSlowGreaterNoScaling(
  248. data->params, input1_shape,
  249. tflite::micro::GetTensorData<float>(input1), input2_shape,
  250. tflite::micro::GetTensorData<float>(input2), output_shape,
  251. output_data)
  252. : reference_ops::GreaterNoScaling(
  253. data->params, input1_shape,
  254. tflite::micro::GetTensorData<float>(input1), input2_shape,
  255. tflite::micro::GetTensorData<float>(input2), output_shape,
  256. output_data);
  257. break;
  258. case kTfLiteInt32:
  259. requires_broadcast
  260. ? reference_ops::Broadcast4DSlowGreaterNoScaling(
  261. data->params, input1_shape,
  262. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  263. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  264. output_data)
  265. : reference_ops::GreaterNoScaling(
  266. data->params, input1_shape,
  267. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  268. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  269. output_data);
  270. break;
  271. case kTfLiteInt64:
  272. requires_broadcast
  273. ? reference_ops::Broadcast4DSlowGreaterNoScaling(
  274. data->params, input1_shape,
  275. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  276. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  277. output_data)
  278. : reference_ops::GreaterNoScaling(
  279. data->params, input1_shape,
  280. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  281. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  282. output_data);
  283. break;
  284. case kTfLiteUInt8:
  285. requires_broadcast
  286. ? reference_ops::Broadcast4DSlowGreaterWithScaling(
  287. data->params, input1_shape,
  288. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  289. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  290. output_data)
  291. : reference_ops::GreaterWithScaling(
  292. data->params, input1_shape,
  293. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  294. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  295. output_data);
  296. break;
  297. case kTfLiteInt8:
  298. requires_broadcast
  299. ? reference_ops::Broadcast4DSlowGreaterWithScaling(
  300. data->params, input1_shape,
  301. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  302. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  303. output_data)
  304. : reference_ops::GreaterWithScaling(
  305. data->params, input1_shape,
  306. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  307. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  308. output_data);
  309. break;
  310. default:
  311. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  312. TfLiteTypeGetName(input1->type), input1->type);
  313. return kTfLiteError;
  314. }
  315. return kTfLiteOk;
  316. }
  317. TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
  318. TFLITE_DCHECK(node->user_data != nullptr);
  319. const OpData* data = static_cast<const OpData*>(node->user_data);
  320. const TfLiteEvalTensor* input1 =
  321. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  322. const TfLiteEvalTensor* input2 =
  323. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  324. TfLiteEvalTensor* output =
  325. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  326. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  327. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  328. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  329. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  330. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  331. switch (input1->type) {
  332. case kTfLiteFloat32:
  333. requires_broadcast
  334. ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
  335. data->params, input1_shape,
  336. tflite::micro::GetTensorData<float>(input1), input2_shape,
  337. tflite::micro::GetTensorData<float>(input2), output_shape,
  338. output_data)
  339. : reference_ops::GreaterEqualNoScaling(
  340. data->params, input1_shape,
  341. tflite::micro::GetTensorData<float>(input1), input2_shape,
  342. tflite::micro::GetTensorData<float>(input2), output_shape,
  343. output_data);
  344. break;
  345. case kTfLiteInt32:
  346. requires_broadcast
  347. ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
  348. data->params, input1_shape,
  349. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  350. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  351. output_data)
  352. : reference_ops::GreaterEqualNoScaling(
  353. data->params, input1_shape,
  354. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  355. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  356. output_data);
  357. break;
  358. case kTfLiteInt64:
  359. requires_broadcast
  360. ? reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
  361. data->params, input1_shape,
  362. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  363. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  364. output_data)
  365. : reference_ops::GreaterEqualNoScaling(
  366. data->params, input1_shape,
  367. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  368. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  369. output_data);
  370. break;
  371. case kTfLiteUInt8:
  372. requires_broadcast
  373. ? reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
  374. data->params, input1_shape,
  375. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  376. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  377. output_data)
  378. : reference_ops::GreaterEqualWithScaling(
  379. data->params, input1_shape,
  380. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  381. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  382. output_data);
  383. break;
  384. case kTfLiteInt8:
  385. requires_broadcast
  386. ? reference_ops::Broadcast4DSlowGreaterEqualWithScaling(
  387. data->params, input1_shape,
  388. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  389. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  390. output_data)
  391. : reference_ops::GreaterEqualWithScaling(
  392. data->params, input1_shape,
  393. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  394. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  395. output_data);
  396. break;
  397. default:
  398. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  399. TfLiteTypeGetName(input1->type), input1->type);
  400. return kTfLiteError;
  401. }
  402. return kTfLiteOk;
  403. }
  404. TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
  405. TFLITE_DCHECK(node->user_data != nullptr);
  406. const OpData* data = static_cast<const OpData*>(node->user_data);
  407. const TfLiteEvalTensor* input1 =
  408. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  409. const TfLiteEvalTensor* input2 =
  410. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  411. TfLiteEvalTensor* output =
  412. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  413. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  414. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  415. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  416. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  417. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  418. switch (input1->type) {
  419. case kTfLiteFloat32:
  420. requires_broadcast
  421. ? reference_ops::Broadcast4DSlowLessNoScaling(
  422. data->params, input1_shape,
  423. tflite::micro::GetTensorData<float>(input1), input2_shape,
  424. tflite::micro::GetTensorData<float>(input2), output_shape,
  425. output_data)
  426. : reference_ops::LessNoScaling(
  427. data->params, input1_shape,
  428. tflite::micro::GetTensorData<float>(input1), input2_shape,
  429. tflite::micro::GetTensorData<float>(input2), output_shape,
  430. output_data);
  431. break;
  432. case kTfLiteInt32:
  433. requires_broadcast
  434. ? reference_ops::Broadcast4DSlowLessNoScaling(
  435. data->params, input1_shape,
  436. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  437. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  438. output_data)
  439. : reference_ops::LessNoScaling(
  440. data->params, input1_shape,
  441. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  442. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  443. output_data);
  444. break;
  445. case kTfLiteInt64:
  446. requires_broadcast
  447. ? reference_ops::Broadcast4DSlowLessNoScaling(
  448. data->params, input1_shape,
  449. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  450. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  451. output_data)
  452. : reference_ops::LessNoScaling(
  453. data->params, input1_shape,
  454. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  455. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  456. output_data);
  457. break;
  458. case kTfLiteUInt8:
  459. requires_broadcast
  460. ? reference_ops::Broadcast4DSlowLessWithScaling(
  461. data->params, input1_shape,
  462. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  463. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  464. output_data)
  465. : reference_ops::LessWithScaling(
  466. data->params, input1_shape,
  467. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  468. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  469. output_data);
  470. break;
  471. case kTfLiteInt8:
  472. requires_broadcast
  473. ? reference_ops::Broadcast4DSlowLessWithScaling(
  474. data->params, input1_shape,
  475. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  476. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  477. output_data)
  478. : reference_ops::LessWithScaling(
  479. data->params, input1_shape,
  480. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  481. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  482. output_data);
  483. break;
  484. default:
  485. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  486. TfLiteTypeGetName(input1->type), input1->type);
  487. return kTfLiteError;
  488. }
  489. return kTfLiteOk;
  490. }
  491. TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
  492. TFLITE_DCHECK(node->user_data != nullptr);
  493. const OpData* data = static_cast<const OpData*>(node->user_data);
  494. const TfLiteEvalTensor* input1 =
  495. tflite::micro::GetEvalInput(context, node, kInputTensor1);
  496. const TfLiteEvalTensor* input2 =
  497. tflite::micro::GetEvalInput(context, node, kInputTensor2);
  498. TfLiteEvalTensor* output =
  499. tflite::micro::GetEvalOutput(context, node, kOutputTensor);
  500. RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  501. RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  502. RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  503. bool* output_data = tflite::micro::GetTensorData<bool>(output);
  504. bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  505. switch (input1->type) {
  506. case kTfLiteFloat32:
  507. requires_broadcast
  508. ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
  509. data->params, input1_shape,
  510. tflite::micro::GetTensorData<float>(input1), input2_shape,
  511. tflite::micro::GetTensorData<float>(input2), output_shape,
  512. output_data)
  513. : reference_ops::LessEqualNoScaling(
  514. data->params, input1_shape,
  515. tflite::micro::GetTensorData<float>(input1), input2_shape,
  516. tflite::micro::GetTensorData<float>(input2), output_shape,
  517. output_data);
  518. break;
  519. case kTfLiteInt32:
  520. requires_broadcast
  521. ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
  522. data->params, input1_shape,
  523. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  524. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  525. output_data)
  526. : reference_ops::LessEqualNoScaling(
  527. data->params, input1_shape,
  528. tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
  529. tflite::micro::GetTensorData<int32_t>(input2), output_shape,
  530. output_data);
  531. break;
  532. case kTfLiteInt64:
  533. requires_broadcast
  534. ? reference_ops::Broadcast4DSlowLessEqualNoScaling(
  535. data->params, input1_shape,
  536. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  537. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  538. output_data)
  539. : reference_ops::LessEqualNoScaling(
  540. data->params, input1_shape,
  541. tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
  542. tflite::micro::GetTensorData<int64_t>(input2), output_shape,
  543. output_data);
  544. break;
  545. case kTfLiteUInt8:
  546. requires_broadcast
  547. ? reference_ops::Broadcast4DSlowLessEqualWithScaling(
  548. data->params, input1_shape,
  549. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  550. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  551. output_data)
  552. : reference_ops::LessEqualWithScaling(
  553. data->params, input1_shape,
  554. tflite::micro::GetTensorData<uint8_t>(input1), input2_shape,
  555. tflite::micro::GetTensorData<uint8_t>(input2), output_shape,
  556. output_data);
  557. break;
  558. case kTfLiteInt8:
  559. requires_broadcast
  560. ? reference_ops::Broadcast4DSlowLessEqualWithScaling(
  561. data->params, input1_shape,
  562. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  563. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  564. output_data)
  565. : reference_ops::LessEqualWithScaling(
  566. data->params, input1_shape,
  567. tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
  568. tflite::micro::GetTensorData<int8_t>(input2), output_shape,
  569. output_data);
  570. break;
  571. default:
  572. TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
  573. TfLiteTypeGetName(input1->type), input1->type);
  574. return kTfLiteError;
  575. }
  576. return kTfLiteOk;
  577. }
  578. } // namespace
  579. void* Init(TfLiteContext* context, const char* buffer, size_t length) {
  580. TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  581. return context->AllocatePersistentBuffer(context, sizeof(OpData));
  582. }
  583. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  584. TFLITE_DCHECK(node->user_data != nullptr);
  585. OpData* data = static_cast<OpData*>(node->user_data);
  586. const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
  587. const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
  588. if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
  589. auto input1_offset = -input1->params.zero_point;
  590. auto input2_offset = -input2->params.zero_point;
  591. const int kLeftShift = 8;
  592. int32_t input1_multiplier;
  593. int input1_shift;
  594. QuantizeMultiplierSmallerThanOneExp(
  595. static_cast<double>(input1->params.scale), &input1_multiplier,
  596. &input1_shift);
  597. int32_t input2_multiplier;
  598. int input2_shift;
  599. QuantizeMultiplierSmallerThanOneExp(
  600. static_cast<double>(input2->params.scale), &input2_multiplier,
  601. &input2_shift);
  602. data->params.left_shift = kLeftShift;
  603. data->params.input1_offset = input1_offset;
  604. data->params.input1_multiplier = input1_multiplier;
  605. data->params.input1_shift = input1_shift;
  606. data->params.input2_offset = input2_offset;
  607. data->params.input2_multiplier = input2_multiplier;
  608. data->params.input2_shift = input2_shift;
  609. }
  610. return kTfLiteOk;
  611. }
  612. } // namespace comparisons
  613. TfLiteRegistration Register_EQUAL() {
  614. return {/*init=*/comparisons::Init,
  615. /*free=*/nullptr,
  616. /*prepare=*/comparisons::Prepare,
  617. /*invoke=*/comparisons::EqualEval,
  618. /*profiling_string=*/nullptr,
  619. /*builtin_code=*/0,
  620. /*custom_name=*/nullptr,
  621. /*version=*/0};
  622. }
  623. TfLiteRegistration Register_NOT_EQUAL() {
  624. return {/*init=*/comparisons::Init,
  625. /*free=*/nullptr,
  626. /*prepare=*/comparisons::Prepare,
  627. /*invoke=*/comparisons::NotEqualEval,
  628. /*profiling_string=*/nullptr,
  629. /*builtin_code=*/0,
  630. /*custom_name=*/nullptr,
  631. /*version=*/0};
  632. }
  633. TfLiteRegistration Register_GREATER() {
  634. return {/*init=*/comparisons::Init,
  635. /*free=*/nullptr,
  636. /*prepare=*/comparisons::Prepare,
  637. /*invoke=*/comparisons::GreaterEval,
  638. /*profiling_string=*/nullptr,
  639. /*builtin_code=*/0,
  640. /*custom_name=*/nullptr,
  641. /*version=*/0};
  642. }
  643. TfLiteRegistration Register_GREATER_EQUAL() {
  644. return {/*init=*/comparisons::Init,
  645. /*free=*/nullptr,
  646. /*prepare=*/comparisons::Prepare,
  647. /*invoke=*/comparisons::GreaterEqualEval,
  648. /*profiling_string=*/nullptr,
  649. /*builtin_code=*/0,
  650. /*custom_name=*/nullptr,
  651. /*version=*/0};
  652. }
  653. TfLiteRegistration Register_LESS() {
  654. return {/*init=*/comparisons::Init,
  655. /*free=*/nullptr,
  656. /*prepare=*/comparisons::Prepare,
  657. /*invoke=*/comparisons::LessEval,
  658. /*profiling_string=*/nullptr,
  659. /*builtin_code=*/0,
  660. /*custom_name=*/nullptr,
  661. /*version=*/0};
  662. }
  663. TfLiteRegistration Register_LESS_EQUAL() {
  664. return {/*init=*/comparisons::Init,
  665. /*free=*/nullptr,
  666. /*prepare=*/comparisons::Prepare,
  667. /*invoke=*/comparisons::LessEqualEval,
  668. /*profiling_string=*/nullptr,
  669. /*builtin_code=*/0,
  670. /*custom_name=*/nullptr,
  671. /*version=*/0};
  672. }
  673. } // namespace micro
  674. } // namespace ops
  675. } // namespace tflite