Эх сурвалжийг харах

Work around the repeated fields.

Bart Hertog 6 жил өмнө
parent
commit
c9f7e852f6

+ 4 - 0
build_test.sh

@@ -6,10 +6,14 @@ protoc --plugin=protoc-gen-eams=protoc-gen-eams -I./test/proto --eams_out=./buil
 protoc --plugin=protoc-gen-eams=protoc-gen-eams -I./test/proto --eams_out=./build/EAMS ./test/proto/nested_message.proto
 protoc --plugin=protoc-gen-eams=protoc-gen-eams -I./test/proto --eams_out=./build/EAMS ./test/proto/repeated_fields.proto
 
+mkdir -p ./build/google
+protoc -I./test/proto --cpp_out=./build/google ./test/proto/repeated_fields.proto
+
 # For validation and testing generate the same message using python
 mkdir -p ./build/python
 protoc -I./test/proto --python_out=./build/python ./test/proto/simple_types.proto
 protoc -I./test/proto --python_out=./build/python ./test/proto/nested_message.proto
+protoc -I./test/proto --python_out=./build/python ./test/proto/repeated_fields.proto
 
 
 # Build the tests

+ 10 - 0
generator/Header_Template.h

@@ -9,6 +9,11 @@ enum {{ _enum.name }}
 {% endmacro %}
 
 {% macro msg_macro(msg) %}
+{% if msg.templates is defined %}
+{% for template in msg.templates %}
+{{"template<" if loop.first}}uint32_t {{template}}{{"SIZE, " if not loop.last}}{{"SIZE>" if loop.last}}
+{% endfor %}
+{% endif %}
 class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
 {
   public:
@@ -83,6 +88,11 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
         value.set(static_cast<uint32_t>({{field.variable_name}}));
         result = ::EmbeddedProto::{{field.serialization_func}}({{field.variable_id_name}}, value, buffer);
       }
+      {% elif field.is_repeated_field %}
+      if(0 != {{field.variable_name}}.get_length() && result)
+      {
+        // TODO
+      }
       {% else %}
       if(({{field.default_value}} != {{field.variable_name}}.get()) && result)
       {

+ 5 - 1
generator/protoc-gen-eams.py

@@ -138,7 +138,7 @@ class FieldTemplateParameters:
 
         self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
         if self.is_repeated_field:
-            self.repeated_type = "DynamicArraySize<" + self.type + ", " + self.variable_name + "SIZE>"
+            self.repeated_type = "::EmbeddedProto::DynamicArraySize<" + self.type + ", " + self.variable_name + "SIZE>"
 
         self.field_proto = field_proto
 
@@ -149,6 +149,10 @@ class MessageTemplateParameters:
     def __init__(self, msg_proto):
         self.name = msg_proto.name
         self.msg_proto = msg_proto
+        self.templates = []
+        for field in self.fields():
+            if field.is_repeated_field:
+                self.templates.append(field.variable_name)
 
     def fields(self):
         for f in self.msg_proto.field:

+ 16 - 2
src/DynamicArray.h

@@ -3,6 +3,7 @@
 #define _DYNAMIC_BUFFER_H_
 
 #include <cstring>
+#include <algorithm>    // std::min
 
 namespace EmbeddedProto
 {
@@ -38,6 +39,13 @@ namespace EmbeddedProto
       */
       virtual DATA_TYPE& get(uint32_t index) = 0;
 
+      //! Get a constatnt refernce to the value at the given index. 
+      /*!
+        \param[in] index The desired index to return.
+        \return The constant reference to the value at the given index.
+      */
+      virtual const DATA_TYPE& get(uint32_t index) const = 0;
+
       //! Get a refernce to the value at the given index. 
       /*!
         \param[in] index The desired index to return.
@@ -91,7 +99,8 @@ namespace EmbeddedProto
     public:
 
       DynamicArraySize()
-        : current_size_(0)
+        : current_size_(0),
+          data_{0}
       {
 
       }  
@@ -107,8 +116,13 @@ namespace EmbeddedProto
       DATA_TYPE* get_data() { return data_; }
 
       DATA_TYPE& get(uint32_t index) override { return data_[index]; }
+      const DATA_TYPE& get(uint32_t index) const override { return data_[index]; }
 
-      void set(uint32_t index, const DATA_TYPE& value) override { data_[index] = value; }
+      void set(uint32_t index, const DATA_TYPE& value) override 
+      { 
+        data_[index] = value;
+        current_size_ = std::max(index+1, current_size_); 
+      }
 
       bool set_data(const DATA_TYPE* data, const uint32_t length) override 
       {

+ 2 - 2
src/Fields.cpp

@@ -36,7 +36,7 @@ namespace EmbeddedProto
     return WireFormatter::SerializeVarint(WireFormatter::MakeTag(field_number, WireFormatter::WireType::VARINT), buffer) && serialize(x, buffer);
   }
 
-  bool serialize(uint32_t field_number, const boolean x, WriteBufferInterface& buffer) 
+  bool serialize(uint32_t field_number, const boolean& x, WriteBufferInterface& buffer) 
   { 
     return WireFormatter::SerializeVarint(WireFormatter::MakeTag(field_number, WireFormatter::WireType::VARINT), buffer) && serialize(x, buffer);
   }
@@ -103,7 +103,7 @@ namespace EmbeddedProto
     return WireFormatter::SerializeVarint(WireFormatter::ZigZagEncode(x.get()), buffer);
   }
 
-  bool serialize(const boolean x, WriteBufferInterface& buffer)
+  bool serialize(const boolean& x, WriteBufferInterface& buffer)
   {
     return buffer.push(x.get() ? 0x01 : 0x00);
   }

+ 39 - 16
src/Fields.h

@@ -35,24 +35,47 @@ namespace EmbeddedProto
       const TYPE& get() const { return value_; }
       TYPE& get() { return value_; }
 
+      operator TYPE() const { return value_; }
+
+      bool operator==(const TYPE& rhs) { return value_ == rhs; }
+      bool operator!=(const TYPE& rhs) { return value_ != rhs; }
+      bool operator>(const TYPE& rhs) { return value_ > rhs; }
+      bool operator<(const TYPE& rhs) { return value_ < rhs; }
+      bool operator>=(const TYPE& rhs) { return value_ >= rhs; }
+      bool operator<=(const TYPE& rhs) { return value_ <= rhs; }
+
+      template<class TYPE_RHS>
+      bool operator==(const FieldTemplate<TYPE_RHS>& rhs) { return value_ == rhs.get(); }
+      template<class TYPE_RHS>
+      bool operator!=(const FieldTemplate<TYPE_RHS>& rhs) { return value_ != rhs.get(); }
+      template<class TYPE_RHS>
+      bool operator>(const FieldTemplate<TYPE_RHS>& rhs) { return value_ > rhs.get(); }
+      template<class TYPE_RHS>
+      bool operator<(const FieldTemplate<TYPE_RHS>& rhs) { return value_ < rhs.get(); }
+      template<class TYPE_RHS>
+      bool operator>=(const FieldTemplate<TYPE_RHS>& rhs) { return value_ >= rhs.get(); }
+      template<class TYPE_RHS>
+      bool operator<=(const FieldTemplate<TYPE_RHS>& rhs) { return value_ <= rhs.get(); }
+
     private:
+
       TYPE value_;
   };
 
-  class int32 : public FieldTemplate<int32_t> { public: int32() : FieldTemplate<int32_t>(0) {}; };     
-  class int64 : public FieldTemplate<int64_t> { public: int64() : FieldTemplate<int64_t>(0) {}; };
-  class uint32 : public FieldTemplate<uint32_t> { public: uint32() : FieldTemplate<uint32_t>(0) {}; };
-  class uint64 : public FieldTemplate<uint64_t> { public: uint64() : FieldTemplate<uint64_t>(0) {}; };
-  class sint32 : public FieldTemplate<int32_t> { public: sint32() : FieldTemplate<int32_t>(0) {}; };
-  class sint64 : public FieldTemplate<int64_t> { public: sint64() : FieldTemplate<int64_t>(0) {}; };
-  class boolean : public FieldTemplate<bool> { public: boolean() : FieldTemplate<bool>(false) {}; };
-  // TODO enum
-  class fixed32 : public FieldTemplate<uint32_t> { public: fixed32() : FieldTemplate<uint32_t>(0) {}; };
-  class fixed64 : public FieldTemplate<uint64_t> { public: fixed64() : FieldTemplate<uint64_t>(0) {}; };
-  class sfixed32 : public FieldTemplate<int32_t> { public: sfixed32() : FieldTemplate<int32_t>(0) {}; };
-  class sfixed64 : public FieldTemplate<int64_t> { public: sfixed64() : FieldTemplate<int64_t>(0) {}; };
-  class floatfixed : public FieldTemplate<float> { public: floatfixed() : FieldTemplate<float>(0.0) {}; };
-  class doublefixed : public FieldTemplate<double> { public: doublefixed() : FieldTemplate<double>(0.0) {}; };
+
+  class int32 : public FieldTemplate<int32_t> { public: int32() : FieldTemplate<int32_t>(0) {}; int32(const int32_t& v) : FieldTemplate<int32_t>(v) {}; int32(const int32_t&& v) : FieldTemplate<int32_t>(v) {}; }; 
+  class int64 : public FieldTemplate<int64_t> { public: int64() : FieldTemplate<int64_t>(0) {}; int64(const int64_t& v) : FieldTemplate<int64_t>(v) {}; int64(const int64_t&& v) : FieldTemplate<int64_t>(v) {}; };
+  class uint32 : public FieldTemplate<uint32_t> { public: uint32() : FieldTemplate<uint32_t>(0) {}; uint32(const uint32_t& v) : FieldTemplate<uint32_t>(v) {}; uint32(const uint32_t&& v) : FieldTemplate<uint32_t>(v) {}; };
+  class uint64 : public FieldTemplate<uint64_t> { public: uint64() : FieldTemplate<uint64_t>(0) {}; uint64(const uint64_t& v) : FieldTemplate<uint64_t>(v) {}; uint64(const uint64_t&& v) : FieldTemplate<uint64_t>(v) {}; };
+  class sint32 : public FieldTemplate<int32_t> { public: sint32() : FieldTemplate<int32_t>(0) {}; sint32(const int32_t& v) : FieldTemplate<int32_t>(v) {}; sint32(const int32_t&& v) : FieldTemplate<int32_t>(v) {}; };
+  class sint64 : public FieldTemplate<int64_t> { public: sint64() : FieldTemplate<int64_t>(0) {}; sint64(const int64_t& v) : FieldTemplate<int64_t>(v) {}; sint64(const int64_t&& v) : FieldTemplate<int64_t>(v) {};};
+  class boolean : public FieldTemplate<bool> { public: boolean() : FieldTemplate<bool>(false) {}; boolean(const bool& v) : FieldTemplate<bool>(v) {}; boolean(const boolean&& v) : FieldTemplate<bool>(v) {}; };
+  class fixed32 : public FieldTemplate<uint32_t> { public: fixed32() : FieldTemplate<uint32_t>(0) {}; fixed32(const uint32_t& v) : FieldTemplate<uint32_t>(v) {}; fixed32(const uint32_t&& v) : FieldTemplate<uint32_t>(v) {}; };
+  class fixed64 : public FieldTemplate<uint64_t> { public: fixed64() : FieldTemplate<uint64_t>(0) {}; fixed64(const uint64_t& v) : FieldTemplate<uint64_t>(v) {}; fixed64(const uint64_t&& v) : FieldTemplate<uint64_t>(v) {}; };
+  class sfixed32 : public FieldTemplate<int32_t> { public: sfixed32() : FieldTemplate<int32_t>(0) {}; sfixed32(const int32_t& v) : FieldTemplate<int32_t>(v) {}; sfixed32(const int32_t&& v) : FieldTemplate<int32_t>(v) {}; };
+  class sfixed64 : public FieldTemplate<int64_t> { public: sfixed64() : FieldTemplate<int64_t>(0) {}; sfixed64(const int64_t& v) : FieldTemplate<int64_t>(v) {}; sfixed64(const int64_t&& v) : FieldTemplate<int64_t>(v) {}; };
+  class floatfixed : public FieldTemplate<float> { public: floatfixed() : FieldTemplate<float>(0.0F) {}; floatfixed(const float& v) : FieldTemplate<float>(v) {}; floatfixed(const float&& v) : FieldTemplate<float>(v) {}; };
+  class doublefixed : public FieldTemplate<double> { public: doublefixed() : FieldTemplate<double>(0.0) {}; doublefixed(const double& v) : FieldTemplate<double>(v) {}; doublefixed(const double&& v) : FieldTemplate<double>(v) {}; };
 
   bool serialize(uint32_t field_number, const int32& x, WriteBufferInterface& buffer);
   bool serialize(uint32_t field_number, const int64& x, WriteBufferInterface& buffer);
@@ -60,7 +83,7 @@ namespace EmbeddedProto
   bool serialize(uint32_t field_number, const uint64& x, WriteBufferInterface& buffer);
   bool serialize(uint32_t field_number, const sint32& x, WriteBufferInterface& buffer);
   bool serialize(uint32_t field_number, const sint64& x, WriteBufferInterface& buffer);
-  bool serialize(uint32_t field_number, const boolean x, WriteBufferInterface& buffer);
+  bool serialize(uint32_t field_number, const boolean& x, WriteBufferInterface& buffer);
   bool serialize(uint32_t field_number, const fixed32& x, WriteBufferInterface& buffer);
   bool serialize(uint32_t field_number, const fixed64& x, WriteBufferInterface& buffer);
   bool serialize(uint32_t field_number, const sfixed32& x, WriteBufferInterface& buffer); 
@@ -74,7 +97,7 @@ namespace EmbeddedProto
   bool serialize(const uint64& x, WriteBufferInterface& buffer);
   bool serialize(const sint32& x, WriteBufferInterface& buffer);
   bool serialize(const sint64& x, WriteBufferInterface& buffer);
-  bool serialize(const boolean x, WriteBufferInterface& buffer);
+  bool serialize(const boolean& x, WriteBufferInterface& buffer);
   bool serialize(const fixed32& x, WriteBufferInterface& buffer);
   bool serialize(const fixed64& x, WriteBufferInterface& buffer);
   bool serialize(const sfixed32& x, WriteBufferInterface& buffer); 

+ 1 - 1
test/proto/repeated_fields.proto

@@ -2,7 +2,7 @@
 
 syntax = "proto3";
 
-message message_a 
+message repeated_fields 
 {
   uint32 x            = 1;
   repeated uint32 y   = 2;

+ 79 - 0
test/test_DynamicArray.cpp

@@ -0,0 +1,79 @@
+#include <gtest/gtest.h>
+
+#include <DynamicArray.h>
+
+namespace test_EmbeddedAMS_DynamicArray
+{
+
+TEST(DynamicArray, construction) 
+{
+  static constexpr uint32_t SIZE = 3;
+  EmbeddedProto::DynamicArraySize<uint8_t, SIZE> x;
+}
+
+TEST(DynamicArray, size_uint8_t) 
+{
+  static constexpr uint32_t SIZE = 3;
+  EmbeddedProto::DynamicArraySize<uint8_t, SIZE> x;
+
+  EXPECT_EQ(0, x.get_size());
+  EXPECT_EQ(SIZE, x.get_max_size());
+  EXPECT_EQ(0, x.get_length());
+  EXPECT_EQ(SIZE, x.get_max_length());
+
+  x.add(1);
+  x.add(2);
+  EXPECT_EQ(2, x.get_size());
+  EXPECT_EQ(2, x.get_length());
+
+  x.add(3);
+
+  EXPECT_EQ(SIZE, x.get_size());
+  EXPECT_EQ(SIZE, x.get_max_size());
+  EXPECT_EQ(SIZE, x.get_length());
+  EXPECT_EQ(SIZE, x.get_max_length());
+}
+
+TEST(DynamicArray, size_uint32_t) 
+{  
+  static constexpr uint32_t SIZE = 3;
+  EmbeddedProto::DynamicArraySize<uint32_t, SIZE> x;
+
+  EXPECT_EQ(0, x.get_size());
+  EXPECT_EQ(SIZE*4, x.get_max_size());
+  EXPECT_EQ(0, x.get_length());
+  EXPECT_EQ(SIZE, x.get_max_length());
+
+  x.add(1);
+  x.add(2);
+  EXPECT_EQ(2*4, x.get_size());
+  EXPECT_EQ(2, x.get_length());
+
+  x.add(3);
+
+  EXPECT_EQ(SIZE*4, x.get_size());
+  EXPECT_EQ(SIZE*4, x.get_max_size());
+  EXPECT_EQ(SIZE, x.get_length());
+  EXPECT_EQ(SIZE, x.get_max_length());
+}
+
+TEST(DynamicArray, set) 
+{
+  static constexpr uint32_t SIZE = 3;
+  EmbeddedProto::DynamicArraySize<uint8_t, SIZE> x;
+
+  // First add a value in the middle and see if we have a size of two.
+  x.set(1, 2);
+  EXPECT_EQ(2, x.get(1));
+  EXPECT_EQ(2, x.get_length());
+
+  x.set(0, 1);
+  EXPECT_EQ(1, x.get(0));
+
+  x.set(2, 3);
+  EXPECT_EQ(3, x.get(2));
+
+
+}
+
+} // End namespace test_EmbeddedAMS_DynamicArray

+ 0 - 1
test/test_NestedMessage.cpp

@@ -1,5 +1,4 @@
 
-
 #include "gtest/gtest.h"
 
 #include <WireFormatter.h>

+ 24 - 0
test/test_RepeatedFields.cpp

@@ -0,0 +1,24 @@
+
+#include "gtest/gtest.h"
+
+#include <WireFormatter.h>
+#include <ReadBufferMock.h>
+#include <WriteBufferMock.h>
+
+#include <cstdint>    
+#include <limits> 
+
+// EAMS message definitions
+#include <repeated_fields.h>
+
+namespace test_EmbeddedAMS_RepeatedFields
+{
+
+TEST(NestedMessage, construction) 
+{
+  static constexpr uint32_t Y_SIZE = 3;
+  repeated_fields<Y_SIZE> a;
+
+}
+
+} // End of namespace test_EmbeddedAMS_RepeatedFields

+ 54 - 0
test/test_getters_setters_fields.cpp

@@ -0,0 +1,54 @@
+
+#include <gtest/gtest.h>
+
+#include <Fields.h>
+
+namespace test_EmbeddedAMS_Getters_Setters_Fields
+{
+
+
+TEST(getters_setters_fields, construction) 
+{
+  EmbeddedProto::int32 a;
+  EmbeddedProto::int32 b(1);
+  EmbeddedProto::int32 c = 1;
+  int32_t dd = 1;
+  EmbeddedProto::int32 d(dd);
+  EmbeddedProto::int32 e = dd;
+
+  EXPECT_EQ(0, a);
+  EXPECT_EQ(1, b);
+  EXPECT_EQ(1, c);
+  EXPECT_EQ(1, d);
+  EXPECT_EQ(1, e);
+}
+
+TEST(getters_setters_fields, comparison) 
+{
+  EmbeddedProto::int32 a(1);
+  EmbeddedProto::uint32 b(1);
+  EmbeddedProto::floatfixed c(0.5F);
+
+  EXPECT_TRUE(a == 1);
+  EXPECT_TRUE(a != 0);
+  EXPECT_TRUE(a > 0);
+  EXPECT_TRUE(a < 2);
+  EXPECT_TRUE(a >= 0);
+  EXPECT_TRUE(a >= 1);
+  EXPECT_FALSE(a >= 2);
+  EXPECT_TRUE(a <= 1);  
+  EXPECT_TRUE(a <= 2);
+  EXPECT_FALSE(a <= 0); 
+
+
+  EXPECT_TRUE(a == b);
+  EXPECT_FALSE(a != b);
+
+  EXPECT_TRUE(a > c);
+  EXPECT_TRUE(c < b);
+
+  EXPECT_TRUE(a >= b);
+  EXPECT_FALSE(a <= c);
+}
+
+} // End of namespace test_EmbeddedAMS_Getters_Setters_Fields

+ 20 - 1
test_data.py

@@ -1,5 +1,6 @@
 import build.python.simple_types_pb2 as st
 import build.python.nested_message_pb2 as nm
+import build.python.repeated_fields_pb2 as rf
 
 def test_simple_types():
     # A test function used to generate encoded data to test the implementation of the wireformatter
@@ -93,4 +94,22 @@ def test_nested_message():
     print(msg2)
 
 
-test_nested_message()
+def test_repeated_fields():
+    msg = rf.repeated_fields()
+
+    y = msg.y.append(1)
+    y = msg.y.append(1)
+    y = msg.y.append(255)
+
+    str = ""
+    msg_str = msg.SerializeToString()
+    print(len(msg_str))
+    print(msg_str)
+    for x in msg_str:
+      str += "0x{:02x}, ".format(x)
+
+    print(str)
+    print()
+
+
+test_repeated_fields()