Jelajahi Sumber

Support NUL inside string values (issue #1646)

Benoit Blanchon 4 tahun lalu
induk
melakukan
be70f6ddd7

+ 1 - 0
CHANGELOG.md

@@ -14,6 +14,7 @@ HEAD
 * Fix `JsonVariant::memoryUsage()` for raw strings
 * Fix `call of overloaded 'swap(BasicJsonDocument&, BasicJsonDocument&)' is ambiguous` (issue #1678)
 * Fix inconsistent pool size in `BasicJsonDocument`'s copy constructor
+* Support NUL in string values (issue #1646)
 
 v6.18.5 (2021-09-28)
 -------

+ 13 - 1
extras/tests/Cpp17/string_view.cpp

@@ -8,7 +8,7 @@
 #endif
 
 TEST_CASE("string_view") {
-  StaticJsonDocument<128> doc;
+  StaticJsonDocument<256> doc;
   JsonVariant variant = doc.to<JsonVariant>();
 
   SECTION("deserializeJson()") {
@@ -57,6 +57,12 @@ TEST_CASE("string_view") {
 
     doc.add(std::string_view("example two", 7));
     REQUIRE(doc.memoryUsage() == JSON_ARRAY_SIZE(2) + 8);
+
+    doc.add(std::string_view("example\0tree", 12));
+    REQUIRE(doc.memoryUsage() == JSON_ARRAY_SIZE(3) + 21);
+
+    doc.add(std::string_view("example\0tree and a half", 12));
+    REQUIRE(doc.memoryUsage() == JSON_ARRAY_SIZE(4) + 21);
   }
 
   SECTION("as<std::string_view>()") {
@@ -72,6 +78,12 @@ TEST_CASE("string_view") {
     REQUIRE(doc["s"].is<std::string_view>() == true);
     REQUIRE(doc["i"].is<std::string_view>() == false);
   }
+
+  SECTION("String containing NUL") {
+    doc.set(std::string("hello\0world", 11));
+    REQUIRE(doc.as<std::string_view>().size() == 11);
+    REQUIRE(doc.as<std::string_view>() == std::string_view("hello\0world", 11));
+  }
 }
 
 using ARDUINOJSON_NAMESPACE::adaptString;

+ 2 - 3
extras/tests/JsonDeserializer/string.cpp

@@ -60,9 +60,8 @@ TEST_CASE("\\u0000") {
   CHECK(result[4] == 'z');
   CHECK(result[5] == 0);
 
-  // ArduinoJson strings doesn't store string length, so the following returns 2
-  // instead of 5 (issue #1646)
-  CHECK(doc.as<std::string>().size() == 2);
+  CHECK(doc.as<JsonString>().size() == 5);
+  CHECK(doc.as<std::string>().size() == 5);
 }
 
 TEST_CASE("Truncated JSON string") {

+ 4 - 0
extras/tests/JsonSerializer/JsonVariant.cpp

@@ -63,6 +63,10 @@ TEST_CASE("serializeJson(JsonVariant)") {
     SECTION("Escape tab") {
       check(std::string("hello\tworld"), "\"hello\\tworld\"");
     }
+
+    SECTION("NUL char") {
+      check(std::string("hello\0world", 11), "\"hello\\u0000world\"");
+    }
   }
 
   SECTION("SerializedValue<const char*>") {

+ 10 - 0
extras/tests/JsonSerializer/std_string.cpp

@@ -45,3 +45,13 @@ TEST_CASE("serialize JsonObject to std::string") {
     REQUIRE("{\r\n  \"key\": \"value\"\r\n}" == json);
   }
 }
+
+TEST_CASE("serialize an std::string containing a NUL") {
+  StaticJsonDocument<256> doc;
+  doc.set(std::string("hello\0world", 11));
+  CHECK(doc.memoryUsage() == 12);
+
+  std::string json;
+  serializeJson(doc, json);
+  CHECK("\"hello\\u0000world\"" == json);
+}

+ 25 - 0
extras/tests/MemoryPool/saveString.cpp

@@ -12,6 +12,10 @@ static const char *saveString(MemoryPool &pool, const char *s) {
   return pool.saveString(adaptString(const_cast<char *>(s)));
 }
 
+static const char *saveString(MemoryPool &pool, const char *s, size_t n) {
+  return pool.saveString(adaptString(s, n));
+}
+
 TEST_CASE("MemoryPool::saveString()") {
   char buffer[32];
   MemoryPool pool(buffer, 32);
@@ -30,6 +34,27 @@ TEST_CASE("MemoryPool::saveString()") {
     REQUIRE(pool.size() == 6);
   }
 
+  SECTION("Deduplicates identical strings that contain NUL") {
+    const char *a = saveString(pool, "hello\0world", 11);
+    const char *b = saveString(pool, "hello\0world", 11);
+    REQUIRE(a == b);
+    REQUIRE(pool.size() == 12);
+  }
+
+  SECTION("Reuse part of a string if it ends with NUL") {
+    const char *a = saveString(pool, "hello\0world", 11);
+    const char *b = saveString(pool, "hello");
+    REQUIRE(a == b);
+    REQUIRE(pool.size() == 12);
+  }
+
+  SECTION("Don't stop on first NUL") {
+    const char *a = saveString(pool, "hello");
+    const char *b = saveString(pool, "hello\0world", 11);
+    REQUIRE(a != b);
+    REQUIRE(pool.size() == 18);
+  }
+
   SECTION("Returns NULL when full") {
     REQUIRE(pool.capacity() == 32);
 

+ 5 - 0
src/ArduinoJson/Json/JsonSerializer.hpp

@@ -68,6 +68,11 @@ class JsonSerializer : public Visitor<size_t> {
     return bytesWritten();
   }
 
+  size_t visitString(const char *value, size_t n) {
+    _formatter.writeString(value, n);
+    return bytesWritten();
+  }
+
   size_t visitRawJson(const char *data, size_t n) {
     _formatter.writeRaw(data, n);
     return bytesWritten();

+ 10 - 1
src/ArduinoJson/Json/TextFormatter.hpp

@@ -41,13 +41,22 @@ class TextFormatter {
     writeRaw('\"');
   }
 
+  void writeString(const char *value, size_t n) {
+    ARDUINOJSON_ASSERT(value != NULL);
+    writeRaw('\"');
+    while (n--) writeChar(*value++);
+    writeRaw('\"');
+  }
+
   void writeChar(char c) {
     char specialChar = EscapeSequence::escapeChar(c);
     if (specialChar) {
       writeRaw('\\');
       writeRaw(specialChar);
-    } else {
+    } else if (c) {
       writeRaw(c);
+    } else {
+      writeRaw("\\u0000");
     }
   }
 

+ 5 - 5
src/ArduinoJson/Memory/MemoryPool.hpp

@@ -62,12 +62,12 @@ class MemoryPool {
   template <typename TAdaptedString>
   const char* saveString(const TAdaptedString& str) {
     if (str.isNull())
-      return 0;
+      return CopiedString();
 
 #if ARDUINOJSON_ENABLE_STRING_DEDUPLICATION
     const char* existingCopy = findString(str);
     if (existingCopy)
-      return existingCopy;
+      return CopiedString(existingCopy, str.size());
 #endif
 
     size_t n = str.size();
@@ -77,7 +77,7 @@ class MemoryPool {
       str.copyTo(newCopy, n);
       newCopy[n] = 0;  // force null-terminator
     }
-    return newCopy;
+    return CopiedString(newCopy, n);
   }
 
   void getFreeZone(char** zoneStart, size_t* zoneSize) const {
@@ -89,14 +89,14 @@ class MemoryPool {
 #if ARDUINOJSON_ENABLE_STRING_DEDUPLICATION
     const char* dup = findString(adaptString(_left, len));
     if (dup)
-      return dup;
+      return CopiedString(dup, len);
 #endif
 
     const char* str = _left;
     _left += len;
     *_left++ = 0;
     checkInvariants();
-    return str;
+    return CopiedString(str, len);
   }
 
   void markAsOverflowed() {

+ 4 - 2
src/ArduinoJson/MsgPack/MsgPackSerializer.hpp

@@ -78,9 +78,11 @@ class MsgPackSerializer : public Visitor<size_t> {
   }
 
   size_t visitString(const char* value) {
-    ARDUINOJSON_ASSERT(value != NULL);
+    return visitString(value, strlen(value));
+  }
 
-    size_t n = strlen(value);
+  size_t visitString(const char* value, size_t n) {
+    ARDUINOJSON_ASSERT(value != NULL);
 
     if (n < 0x20) {
       writeByte(uint8_t(0xA0 + n));

+ 2 - 2
src/ArduinoJson/StringStorage/StringCopier.hpp

@@ -24,7 +24,7 @@ class StringCopier {
   string_type save() {
     ARDUINOJSON_ASSERT(_ptr);
     ARDUINOJSON_ASSERT(_size < _capacity);  // needs room for the terminator
-    return _pool->saveStringFromFreeZone(_size);
+    return string_type(_pool->saveStringFromFreeZone(_size), _size);
   }
 
   void append(const char* s) {
@@ -54,7 +54,7 @@ class StringCopier {
     ARDUINOJSON_ASSERT(_ptr);
     ARDUINOJSON_ASSERT(_size < _capacity);
     _ptr[_size] = 0;
-    return _ptr;
+    return string_type(_ptr, _size);
   }
 
  private:

+ 1 - 1
src/ArduinoJson/StringStorage/StringMover.hpp

@@ -35,7 +35,7 @@ class StringMover {
   }
 
   string_type str() const {
-    return string_type(_startPtr);
+    return string_type(_startPtr, size());
   }
 
   size_t size() const {

+ 7 - 1
src/ArduinoJson/Strings/StoredString.hpp

@@ -12,7 +12,8 @@ namespace ARDUINOJSON_NAMESPACE {
 template <typename TStoragePolicy>
 class StoredString {
  public:
-  StoredString(const char* p) : _data(p) {}
+  StoredString() : _data(0), _size(0) {}
+  StoredString(const char* p, size_t n) : _data(p), _size(n) {}
 
   operator const char*() const {
     return _data;
@@ -22,8 +23,13 @@ class StoredString {
     return _data;
   }
 
+  size_t size() const {
+    return _size;
+  }
+
  private:
   const char* _data;
+  size_t _size;
 };
 
 typedef StoredString<storage_policies::store_by_address> LinkedString;

+ 13 - 2
src/ArduinoJson/Strings/String.hpp

@@ -10,9 +10,15 @@ namespace ARDUINOJSON_NAMESPACE {
 
 class String : public SafeBoolIdom<String> {
  public:
-  String() : _data(0), _isStatic(true) {}
+  String() : _data(0), _size(0), _isStatic(true) {}
+
   String(const char* data, bool isStaticData = true)
-      : _data(data), _isStatic(isStaticData) {}
+      : _data(data),
+        _size(data ? ::strlen(data) : 0),
+        _isStatic(isStaticData) {}
+
+  String(const char* data, size_t sz, bool isStaticData = true)
+      : _data(data), _size(sz), _isStatic(isStaticData) {}
 
   const char* c_str() const {
     return _data;
@@ -26,6 +32,10 @@ class String : public SafeBoolIdom<String> {
     return _isStatic;
   }
 
+  size_t size() const {
+    return _size;
+  }
+
   // safe bool idiom
   operator bool_type() const {
     return _data ? safe_true() : safe_false();
@@ -53,6 +63,7 @@ class String : public SafeBoolIdom<String> {
 
  private:
   const char* _data;
+  size_t _size;
   bool _isStatic;
 };
 

+ 9 - 13
src/ArduinoJson/Variant/ConverterImpl.hpp

@@ -205,7 +205,7 @@ class MemoryPoolPrint : public Print {
 
   CopiedString str() {
     ARDUINOJSON_ASSERT(_size < _capacity);
-    return _pool->saveStringFromFreeZone(_size);
+    return CopiedString(_pool->saveStringFromFreeZone(_size), _size);
   }
 
   size_t write(uint8_t c) {
@@ -257,8 +257,7 @@ inline void convertToJson(const ::Printable& src, VariantRef dst) {
 #if ARDUINOJSON_ENABLE_ARDUINO_STRING
 
 inline void convertFromJson(VariantConstRef src, ::String& dst) {
-  const VariantData* data = getData(src);
-  String str = data != 0 ? data->asString() : String();
+  String str = src.as<String>();
   if (str)
     dst = str.c_str();
   else
@@ -266,8 +265,7 @@ inline void convertFromJson(VariantConstRef src, ::String& dst) {
 }
 
 inline bool canConvertFromJson(VariantConstRef src, const ::String&) {
-  const VariantData* data = getData(src);
-  return data && data->isString();
+  return src.is<String>();
 }
 
 #endif
@@ -275,17 +273,15 @@ inline bool canConvertFromJson(VariantConstRef src, const ::String&) {
 #if ARDUINOJSON_ENABLE_STD_STRING
 
 inline void convertFromJson(VariantConstRef src, std::string& dst) {
-  const VariantData* data = getData(src);
-  String str = data != 0 ? data->asString() : String();
+  String str = src.as<String>();
   if (str)
-    dst.assign(str.c_str());
+    dst.assign(str.c_str(), str.size());
   else
     serializeJson(src, dst);
 }
 
 inline bool canConvertFromJson(VariantConstRef src, const std::string&) {
-  const VariantData* data = getData(src);
-  return data && data->isString();
+  return src.is<String>();
 }
 
 #endif
@@ -293,13 +289,13 @@ inline bool canConvertFromJson(VariantConstRef src, const std::string&) {
 #if ARDUINOJSON_ENABLE_STRING_VIEW
 
 inline void convertFromJson(VariantConstRef src, std::string_view& dst) {
-  const char* str = src.as<const char*>();
+  String str = src.as<String>();
   if (str)  // the standard doesn't allow passing null to the constructor
-    dst = std::string_view(str);
+    dst = std::string_view(str.c_str(), str.size());
 }
 
 inline bool canConvertFromJson(VariantConstRef src, const std::string_view&) {
-  return src.is<const char*>();
+  return src.is<String>();
 }
 
 #endif

+ 2 - 2
src/ArduinoJson/Variant/SlotFunctions.hpp

@@ -30,14 +30,14 @@ template <typename TAdaptedString>
 inline bool slotSetKey(VariantSlot* var, TAdaptedString key, MemoryPool*,
                        storage_policies::store_by_address) {
   ARDUINOJSON_ASSERT(var);
-  var->setKey(LinkedString(key.data()));
+  var->setKey(LinkedString(key.data(), key.size()));
   return true;
 }
 
 template <typename TAdaptedString>
 inline bool slotSetKey(VariantSlot* var, TAdaptedString key, MemoryPool* pool,
                        storage_policies::store_by_copy) {
-  CopiedString dup = pool->saveString(key);
+  CopiedString dup(pool->saveString(key), key.size());
   if (!dup)
     return false;
   ARDUINOJSON_ASSERT(var);

+ 2 - 2
src/ArduinoJson/Variant/VariantCompare.hpp

@@ -27,7 +27,7 @@ struct Comparer<T, typename enable_if<IsString<T>::value>::type>
 
   explicit Comparer(T value) : rhs(value) {}
 
-  CompareResult visitString(const char *lhs) {
+  CompareResult visitString(const char *lhs, size_t) {
     int i = adaptString(rhs).compare(lhs);
     if (i < 0)
       return COMPARE_RESULT_GREATER;
@@ -150,7 +150,7 @@ struct Comparer<T, typename enable_if<IsVisitable<T>::value>::type>
     return accept(comparer);
   }
 
-  CompareResult visitString(const char *lhs) {
+  CompareResult visitString(const char *lhs, size_t) {
     Comparer<const char *> comparer(lhs);
     return accept(comparer);
   }

+ 1 - 2
src/ArduinoJson/Variant/VariantContent.hpp

@@ -49,10 +49,9 @@ union VariantContent {
   UInt asUnsignedInteger;
   Integer asSignedInteger;
   CollectionData asCollection;
-  const char *asString;
   struct {
     const char *data;
     size_t size;
-  } asRaw;
+  } asString;
 };
 }  // namespace ARDUINOJSON_NAMESPACE

+ 21 - 15
src/ArduinoJson/Variant/VariantData.hpp

@@ -51,11 +51,13 @@ class VariantData {
 
       case VALUE_IS_LINKED_STRING:
       case VALUE_IS_OWNED_STRING:
-        return visitor.visitString(_content.asString);
+        return visitor.visitString(_content.asString.data,
+                                   _content.asString.size);
 
       case VALUE_IS_OWNED_RAW:
       case VALUE_IS_LINKED_RAW:
-        return visitor.visitRawJson(_content.asRaw.data, _content.asRaw.size);
+        return visitor.visitRawJson(_content.asString.data,
+                                    _content.asString.size);
 
       case VALUE_IS_SIGNED_INTEGER:
         return visitor.visitSignedInteger(_content.asSignedInteger);
@@ -105,10 +107,13 @@ class VariantData {
         return toObject().copyFrom(src._content.asCollection, pool);
       case VALUE_IS_OWNED_STRING:
         return storeString(
-            adaptString(const_cast<char *>(src._content.asString)), pool);
+            adaptString(const_cast<char *>(src._content.asString.data),
+                        src._content.asString.size),
+            pool);
       case VALUE_IS_OWNED_RAW:
         return storeOwnedRaw(
-            serialized(src._content.asRaw.data, src._content.asRaw.size), pool);
+            serialized(src._content.asString.data, src._content.asString.size),
+            pool);
       default:
         setType(src.type());
         _content = src._content;
@@ -186,8 +191,8 @@ class VariantData {
   void setLinkedRaw(SerializedValue<const char *> value) {
     if (value.data()) {
       setType(VALUE_IS_LINKED_RAW);
-      _content.asRaw.data = value.data();
-      _content.asRaw.size = value.size();
+      _content.asString.data = value.data();
+      _content.asString.size = value.size();
     } else {
       setType(VALUE_IS_NULL);
     }
@@ -198,8 +203,8 @@ class VariantData {
     const char *dup = pool->saveString(adaptString(value.data(), value.size()));
     if (dup) {
       setType(VALUE_IS_OWNED_RAW);
-      _content.asRaw.data = dup;
-      _content.asRaw.size = value.size();
+      _content.asString.data = dup;
+      _content.asString.size = value.size();
       return true;
     } else {
       setType(VALUE_IS_NULL);
@@ -226,13 +231,15 @@ class VariantData {
   void setString(CopiedString s) {
     ARDUINOJSON_ASSERT(s);
     setType(VALUE_IS_OWNED_STRING);
-    _content.asString = s.c_str();
+    _content.asString.data = s.c_str();
+    _content.asString.size = s.size();
   }
 
   void setString(LinkedString s) {
     ARDUINOJSON_ASSERT(s);
     setType(VALUE_IS_LINKED_STRING);
-    _content.asString = s.c_str();
+    _content.asString.data = s.c_str();
+    _content.asString.size = s.size();
   }
 
   template <typename TAdaptedString>
@@ -255,11 +262,10 @@ class VariantData {
   size_t memoryUsage() const {
     switch (type()) {
       case VALUE_IS_OWNED_STRING:
-        return strlen(_content.asString) + 1;
       case VALUE_IS_OWNED_RAW:
         // We always add a zero at the end: the deduplication function uses it
         // to detect the beginning of the next string.
-        return _content.asRaw.size + 1;
+        return _content.asString.size + 1;
       case VALUE_IS_OBJECT:
       case VALUE_IS_ARRAY:
         return _content.asCollection.memoryUsage();
@@ -312,7 +318,7 @@ class VariantData {
 
   void movePointers(ptrdiff_t stringDistance, ptrdiff_t variantDistance) {
     if (_flags & OWNED_VALUE_BIT)
-      _content.asString += stringDistance;
+      _content.asString.data += stringDistance;
     if (_flags & COLLECTION_MASK)
       _content.asCollection.movePointers(stringDistance, variantDistance);
   }
@@ -342,7 +348,7 @@ class VariantData {
     if (value.isNull())
       setNull();
     else
-      setString(LinkedString(value.data()));
+      setString(LinkedString(value.data(), value.size()));
     return true;
   }
 
@@ -358,7 +364,7 @@ class VariantData {
       setNull();
       return false;
     }
-    setString(CopiedString(copy));
+    setString(CopiedString(copy, value.size()));
     return true;
   }
 };

+ 5 - 5
src/ArduinoJson/Variant/VariantImpl.hpp

@@ -26,7 +26,7 @@ inline T VariantData::asIntegral() const {
       return convertNumber<T>(_content.asSignedInteger);
     case VALUE_IS_LINKED_STRING:
     case VALUE_IS_OWNED_STRING:
-      return parseNumber<T>(_content.asString);
+      return parseNumber<T>(_content.asString.data);
     case VALUE_IS_FLOAT:
       return convertNumber<T>(_content.asFloat);
     default:
@@ -62,7 +62,7 @@ inline T VariantData::asFloat() const {
       return static_cast<T>(_content.asSignedInteger);
     case VALUE_IS_LINKED_STRING:
     case VALUE_IS_OWNED_STRING:
-      return parseNumber<T>(_content.asString);
+      return parseNumber<T>(_content.asString.data);
     case VALUE_IS_FLOAT:
       return static_cast<T>(_content.asFloat);
     default:
@@ -73,11 +73,11 @@ inline T VariantData::asFloat() const {
 inline String VariantData::asString() const {
   switch (type()) {
     case VALUE_IS_LINKED_STRING:
-      return String(_content.asString, true);
+      return String(_content.asString.data, _content.asString.size, true);
     case VALUE_IS_OWNED_STRING:
-      return String(_content.asString, false);
+      return String(_content.asString.data, _content.asString.size, false);
     default:
-      return 0;
+      return String();
   }
 }
 

+ 1 - 1
src/ArduinoJson/Variant/VariantSlot.hpp

@@ -108,7 +108,7 @@ class VariantSlot {
     if (_flags & OWNED_KEY_BIT)
       _key += stringDistance;
     if (_flags & OWNED_VALUE_BIT)
-      _content.asString += stringDistance;
+      _content.asString.data += stringDistance;
     if (_flags & COLLECTION_MASK)
       _content.asCollection.movePointers(stringDistance, variantDistance);
   }

+ 1 - 1
src/ArduinoJson/Variant/Visitor.hpp

@@ -46,7 +46,7 @@ struct Visitor {
     return TResult();
   }
 
-  TResult visitString(const char *) {
+  TResult visitString(const char *, size_t) {
     return TResult();
   }
 };