Pārlūkot izejas kodu

Detect string length overflows

Benoit Blanchon 2 gadi atpakaļ
vecāks
revīzija
6fe4b9c01d

+ 17 - 0
extras/tests/JsonDeserializer/errors.cpp

@@ -101,3 +101,20 @@ TEST_CASE("deserializeJson() returns EmptyInput") {
     REQUIRE(err == DeserializationError::EmptyInput);
   }
 }
+
+TEST_CASE("deserializeJson() returns NoMemory if string length overflows") {
+  JsonDocument doc;
+  auto maxLength = ArduinoJson::detail::StringNode::maxLength;
+
+  SECTION("max length should succeed") {
+    auto err = deserializeJson(doc, "\"" + std::string(maxLength, 'a') + "\"");
+
+    REQUIRE(err == DeserializationError::Ok);
+  }
+
+  SECTION("one above max length should fail") {
+    auto err =
+        deserializeJson(doc, "\"" + std::string(maxLength + 1, 'a') + "\"");
+    REQUIRE(err == DeserializationError::NoMemory);
+  }
+}

+ 12 - 0
extras/tests/JsonDocument/overflowed.cpp

@@ -80,4 +80,16 @@ TEST_CASE("JsonDocument::overflowed()") {
     doc.shrinkToFit();
     CHECK(doc.overflowed() == true);
   }
+
+  SECTION("returns false when string length doesn't overflow") {
+    auto maxLength = ArduinoJson::detail::StringNode::maxLength;
+    CHECK(doc.set(std::string(maxLength, 'a')) == true);
+    CHECK(doc.overflowed() == false);
+  }
+
+  SECTION("returns true when string length overflows") {
+    auto maxLength = ArduinoJson::detail::StringNode::maxLength;
+    CHECK(doc.set(std::string(maxLength + 1, 'a')) == false);
+    CHECK(doc.overflowed() == true);
+  }
 }

+ 23 - 0
extras/tests/MsgPackDeserializer/errors.cpp

@@ -240,3 +240,26 @@ TEST_CASE("deserializeMsgPack() replaces unsupported types by null") {
                           20) == "[null,42]");
   }
 }
+
+TEST_CASE("deserializeMsgPack() returns NoMemory is string length overflows") {
+  JsonDocument doc;
+  auto maxLength = ArduinoJson::detail::StringNode::maxLength;
+
+  SECTION("max length should succeed") {
+    auto len = maxLength;
+    std::string prefix = {'\xdb', char(len >> 24), char(len >> 16),
+                          char(len >> 8), char(len)};
+
+    auto err = deserializeMsgPack(doc, prefix + std::string(len, 'a'));
+    REQUIRE(err == DeserializationError::Ok);
+  }
+
+  SECTION("one above max length should fail") {
+    auto len = maxLength + 1;
+    std::string prefix = {'\xdb', char(len >> 24), char(len >> 16),
+                          char(len >> 8), char(len)};
+
+    auto err = deserializeMsgPack(doc, prefix + std::string(len, 'a'));
+    REQUIRE(err == DeserializationError::NoMemory);
+  }
+}

+ 11 - 2
src/ArduinoJson/Memory/StringNode.hpp

@@ -7,6 +7,7 @@
 #include <ArduinoJson/Memory/Allocator.hpp>
 #include <ArduinoJson/Namespace.hpp>
 #include <ArduinoJson/Polyfills/assert.hpp>
+#include <ArduinoJson/Polyfills/limits.hpp>
 
 #include <stddef.h>  // offsetof
 #include <stdint.h>  // uint16_t
@@ -19,11 +20,15 @@ struct StringNode {
   uint16_t references;
   char data[1];
 
+  static constexpr size_t maxLength = numeric_limits<uint16_t>::highest();
+
   static constexpr size_t sizeForLength(size_t n) {
     return n + 1 + offsetof(StringNode, data);
   }
 
   static StringNode* create(size_t length, Allocator* allocator) {
+    if (length > maxLength)
+      return nullptr;
     auto node = reinterpret_cast<StringNode*>(
         allocator->allocate(sizeForLength(length)));
     if (node) {
@@ -36,8 +41,12 @@ struct StringNode {
   static StringNode* resize(StringNode* node, size_t length,
                             Allocator* allocator) {
     ARDUINOJSON_ASSERT(node != nullptr);
-    auto newNode = reinterpret_cast<StringNode*>(
-        allocator->reallocate(node, sizeForLength(length)));
+    StringNode* newNode;
+    if (length <= maxLength)
+      newNode = reinterpret_cast<StringNode*>(
+          allocator->reallocate(node, sizeForLength(length)));
+    else
+      newNode = nullptr;
     if (newNode)
       newNode->length = uint16_t(length);
     else

+ 4 - 4
src/ArduinoJson/Polyfills/limits.hpp

@@ -19,10 +19,10 @@ struct numeric_limits;
 
 template <typename T>
 struct numeric_limits<T, typename enable_if<is_unsigned<T>::value>::type> {
-  static T lowest() {
+  static constexpr T lowest() {
     return 0;
   }
-  static T highest() {
+  static constexpr T highest() {
     return T(-1);
   }
 };
@@ -30,10 +30,10 @@ struct numeric_limits<T, typename enable_if<is_unsigned<T>::value>::type> {
 template <typename T>
 struct numeric_limits<
     T, typename enable_if<is_integral<T>::value && is_signed<T>::value>::type> {
-  static T lowest() {
+  static constexpr T lowest() {
     return T(T(1) << (sizeof(T) * 8 - 1));
   }
-  static T highest() {
+  static constexpr T highest() {
     return T(~lowest());
   }
 };