Преглед изворни кода

Enabled String and Byte arrays to be repeated fields.

Bart Hertog пре 5 година
родитељ
комит
709782936f

+ 1 - 0
generator/templates/FieldRepeated_GetSet.h

@@ -84,6 +84,7 @@ inline void set_{{field.get_name()}}(uint32_t index, const {{field.get_base_type
 inline void set_{{field.get_name()}}(const {{field.get_type()}}& values) { {{field.get_variable_name()}} = values; }
 inline void add_{{field.get_name()}}(const {{field.get_base_type()}}& value) { {{field.get_variable_name()}}.add(value); }
 inline {{field.get_type()}}& mutable_{{field.get_name()}}() { return {{field.get_variable_name()}}; }
+inline {{field.get_base_type()}}& mutable_{{field.get_name()}}(uint32_t index) { return {{field.get_variable_name()}}[index]; }
 {% endif %}
 inline const {{field.get_type()}}& get_{{field.get_name()}}() const { return {{field.get_variable_name()}}; }
 inline const {{field.get_type()}}& {{field.get_name()}}() const { return {{field.get_variable_name()}}; }

+ 12 - 9
src/FieldStringBytes.h

@@ -45,8 +45,10 @@ namespace EmbeddedProto
   namespace internal
   {
 
+    class BaseStringBytes : public Field {};
+
     template<uint32_t MAX_LENGTH, class DATA_TYPE>
-    class FieldStringBytes : public Field
+    class FieldStringBytes : public BaseStringBytes
     {
       static_assert(std::is_same<uint8_t, DATA_TYPE>::value || std::is_same<char, DATA_TYPE>::value, 
                     "This class only supports unit8_t or chars.");
@@ -151,6 +153,10 @@ namespace EmbeddedProto
                                                     WireFormatter::WireType::LENGTH_DELIMITED);
               return_value = WireFormatter::SerializeVarint(tag, buffer);
               if(Error::NO_ERRORS == return_value) 
+              {
+                return_value = WireFormatter::SerializeVarint(current_length_, buffer);
+              }
+              if(Error::NO_ERRORS == return_value) 
               {
                 return_value = serialize(buffer);
               }
@@ -166,15 +172,12 @@ namespace EmbeddedProto
 
         Error serialize(WriteBufferInterface& buffer) const override 
         { 
-          Error return_value = WireFormatter::SerializeVarint(current_length_, buffer);
-          if(Error::NO_ERRORS == return_value) 
+          Error return_value = Error::NO_ERRORS;
+          const void* void_pointer = static_cast<const void*>(&(data_[0]));
+          const uint8_t* byte_pointer = static_cast<const uint8_t*>(void_pointer);
+          if(!buffer.push(byte_pointer, current_length_))
           {
-            const void* void_pointer = static_cast<const void*>(&(data_[0]));
-            const uint8_t* byte_pointer = static_cast<const uint8_t*>(void_pointer);
-            if(!buffer.push(byte_pointer, current_length_))
-            {
-              return_value = Error::BUFFER_FULL;
-            }
+            return_value = Error::BUFFER_FULL;
           }
           return return_value;
         }

+ 13 - 5
src/RepeatedField.h

@@ -34,7 +34,8 @@
 #include "Fields.h"
 #include "MessageInterface.h"
 #include "MessageSizeCalculator.h"
-#include "ReadBufferSection.h" 
+#include "ReadBufferSection.h"
+#include "FieldStringBytes.h"
 #include "Errors.h"
 
 #include <cstdint>
@@ -51,8 +52,9 @@ namespace EmbeddedProto
     static_assert(std::is_base_of<::EmbeddedProto::Field, DATA_TYPE>::value, "A Field can only be used as template paramter.");
 
     //! Check how this field shoeld be serialized, packed or not.
-    static constexpr bool REPEATED_FIELD_IS_PACKED = !std::is_base_of<MessageInterface, 
-                                                                      DATA_TYPE>::value;
+    static constexpr bool REPEATED_FIELD_IS_PACKED = 
+          !(std::is_base_of<MessageInterface, DATA_TYPE>::value 
+            || std::is_base_of<internal::BaseStringBytes, DATA_TYPE>::value);
 
     public:
 
@@ -165,8 +167,14 @@ namespace EmbeddedProto
         else 
         {
           const uint32_t size_x = this->serialized_size_unpacked(field_number);
-          return_value = (size_x <= buffer.get_available_size()) ? serialize_unpacked(field_number, buffer) 
-                                                                 : Error::BUFFER_FULL;
+          if(size_x <= buffer.get_available_size()) 
+          {
+            return_value = serialize_unpacked(field_number, buffer);
+          }
+          else 
+          {
+            return_value = Error::BUFFER_FULL;
+          }
         }
 
         return return_value;

+ 6 - 0
test/proto/string_bytes.proto

@@ -50,3 +50,9 @@ message string_or_bytes
     bytes b = 2;
   }
 }
+
+message repeated_string_bytes
+{
+  repeated string array_of_txt = 1;
+  repeated bytes array_of_bytes = 2;
+}

+ 70 - 0
test/test_string_bytes.cpp

@@ -362,4 +362,74 @@ TEST(FieldBytes, oneof_deserialize)
   EXPECT_EQ(0, msg.b()[3]);
 }
 
+TEST(RepeatedStringBytes, empty) 
+{ 
+  repeated_string_bytes<3, 10, 3, 10> msg;
+  Mocks::WriteBufferMock buffer;
+  EXPECT_CALL(buffer, get_available_size()).Times(2).WillRepeatedly(Return(99));  
+  EXPECT_EQ(::EmbeddedProto::Error::NO_ERRORS, msg.serialize(buffer));
+}
+
+TEST(RepeatedStringBytes, get_set) 
+{ 
+  repeated_string_bytes<3, 15, 3, 15> msg;
+
+  ::EmbeddedProto::FieldString<15> str;
+  msg.add_array_of_txt(str);
+  msg.mutable_array_of_txt(0) = "Foo bar 1";
+  msg.add_array_of_txt(str);
+  msg.mutable_array_of_txt(1) = "Foo bar 2";
+
+  str = "Foo bar 3";
+  msg.add_array_of_txt(str);
+  
+  EXPECT_EQ(3, msg.array_of_txt().get_length());
+  EXPECT_EQ(0, msg.array_of_bytes().get_length());
+  EXPECT_STREQ(msg.array_of_txt(0).get_const(), "Foo bar 1");
+  EXPECT_STREQ(msg.array_of_txt(1).get_const(), "Foo bar 2");
+  EXPECT_STREQ(msg.array_of_txt(2).get_const(), "Foo bar 3");
+}
+
+
+TEST(RepeatedStringBytes, serialize) 
+{ 
+  InSequence s;
+
+  repeated_string_bytes<3, 15, 3, 15> msg;
+  Mocks::WriteBufferMock buffer;
+
+  ::EmbeddedProto::FieldString<15> str;
+  msg.add_array_of_txt(str);
+  msg.mutable_array_of_txt(0) = "Foo bar 1";
+  msg.add_array_of_txt(str);
+  msg.mutable_array_of_txt(1) = "";
+  msg.add_array_of_txt(str);
+  msg.mutable_array_of_txt(2) = "Foo bar 3";
+
+  // We need 24 bytes to serialze the strings above.
+  EXPECT_CALL(buffer, get_available_size()).Times(1).WillOnce(Return(24));
+
+  // The first string.
+  // Id and size of array of txt.
+  EXPECT_CALL(buffer, push(0x0a)).Times(1).WillOnce(Return(true));
+  EXPECT_CALL(buffer, push(0x09)).Times(1).WillOnce(Return(true));
+  // The string is pushed as an array, we do not know the pointer value so use _, but we do know 
+  // the size.
+  EXPECT_CALL(buffer, push(_, 9)).Times(1).WillOnce(Return(true));
+  
+
+  // The empty string
+  EXPECT_CALL(buffer, push(0x0a)).Times(1).WillOnce(Return(true));
+  EXPECT_CALL(buffer, push(0x00)).Times(1).WillOnce(Return(true));
+  
+  // The last string
+  EXPECT_CALL(buffer, push(0x0a)).Times(1).WillOnce(Return(true));
+  EXPECT_CALL(buffer, push(0x09)).Times(1).WillOnce(Return(true));
+  EXPECT_CALL(buffer, push(_, 9)).Times(1).WillOnce(Return(true));
+
+  EXPECT_CALL(buffer, get_available_size()).Times(1).WillOnce(Return(0));
+
+  EXPECT_EQ(::EmbeddedProto::Error::NO_ERRORS, msg.serialize(buffer));
+}
+
 } // End of namespace test_EmbeddedAMS_string_bytes

+ 24 - 4
test_data.py

@@ -36,7 +36,8 @@ import nested_message_pb2 as nm
 import repeated_fields_pb2 as rf
 import oneof_fields_pb2 as of
 import file_to_include_pb2 as fti
-import include_other_files_pb2 as iof
+#import include_other_files_pb2 as iof
+import string_bytes_pb2 as sb
 
 
 def test_simple_types():
@@ -192,7 +193,7 @@ def test_repeated_message():
 
 
 def test_string():
-    msg = rf.text()
+    msg = sb.text()
 
     msg.txt = "Foo bar"
 
@@ -208,7 +209,7 @@ def test_string():
 
 
 def test_bytes():
-    msg = rf.raw_bytes()
+    msg = sb.raw_bytes()
 
     msg.b = b'\x01\x02\x03\x00'
 
@@ -223,6 +224,24 @@ def test_bytes():
     print()
 
 
+def test_repeated_string_bytes():
+    msg = sb.repeated_string_bytes()
+
+    msg.array_of_txt.append("Foo bar 1")
+    msg.array_of_txt.append("")
+    msg.array_of_txt.append("Foo bar 3")
+
+    str = ""
+    msg_str = msg.SerializeToString()
+    print(len(msg_str))
+    print(msg_str)
+    for x in msg_str:
+        str += "0x{:02x}, ".format(x)
+
+    print(str)
+    print()
+
+
 def test_oneof_fields():
     msg = of.message_oneof()
 
@@ -270,7 +289,8 @@ def test_included_proto():
 #test_repeated_message()
 #test_string()
 #test_bytes()
-test_nested_message()
+test_repeated_string_bytes()
+#test_nested_message()
 #test_oneof_fields()
 #test_included_proto()