Просмотр исходного кода

Rewrote the ID's into an enum class to avoid the initialization of static const members in c++14.

Bart Hertog 6 лет назад
Родитель
Сommit
b4acb6f9b3
3 измененных файлов с 57 добавлено и 40 удалено
  1. 39 32
      generator/Header_Template.h
  2. 11 1
      generator/protoc-gen-eams.py
  3. 7 7
      test/test_oneof_fields.cpp

+ 39 - 32
generator/Header_Template.h

@@ -11,36 +11,35 @@ enum {{ _enum.name }}
 {# ------------------------------------------------------------------------------------------------------------------ #}
 {# #}
 {% macro field_get_set_macro(_field) %}
-static constexpr uint32_t {{_field.variable_id_name}} = {{_field.variable_id}};
 {% if _field.is_repeated_field %}
 inline const {{_field.type}}& {{_field.name}}(uint32_t index) const { return {{_field.variable_name}}[index]; }
 {% if _field.which_oneof is defined %}
 inline void clear_{{_field.name}}()
 {
-  if({{_field.variable_id_name}} == {{_field.which_oneof}})
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
   {
-    {{_field.which_oneof}} = 0;
+    {{_field.which_oneof}} = id::NOT_SET;
     {{_field.variable_name}}.clear();
   }
 }
 inline void set_{{_field.name}}(uint32_t index, const {{_field.type}}& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}}.set(index, value);
 }
 inline void set_{{_field.name}}(uint32_t index, const {{_field.type}}&& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}}.set(index, value);
 }
 inline void add_{{_field.name}}(const {{_field.type}}& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}}.add(value);
 }
 inline {{_field.repeated_type}}& mutable_{{_field.name}}()
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   return {{_field.variable_name}};
 }
 {% else %}
@@ -56,25 +55,25 @@ inline const {{_field.type}}& {{_field.name}}() const { return {{_field.variable
 {% if _field.which_oneof is defined %}
 inline void clear_{{_field.name}}()
 {
-  if({{_field.variable_id_name}} == {{_field.which_oneof}})
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
   {
-    {{_field.which_oneof}} = 0;
+    {{_field.which_oneof}} = id::NOT_SET;
     {{_field.variable_name}}.clear();
   }
 }
 inline void set_{{_field.name}}(const {{_field.type}}& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}} = value;
 }
 inline void set_{{_field.name}}(const {{_field.type}}&& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}} = value;
 }
 inline {{_field.type}}& mutable_{{_field.name}}()
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   return {{_field.variable_name}};
 }
 {% else %}
@@ -89,20 +88,20 @@ inline {{_field.type}} {{_field.name}}() const { return {{_field.variable_name}}
 {% if _field.which_oneof is defined %}
 inline void clear_{{_field.name}}()
 {
-  if({{_field.variable_id_name}} == {{_field.which_oneof}})
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
   {
-    {{_field.which_oneof}} = 0;
+    {{_field.which_oneof}} = id::NOT_SET;
     {{_field.variable_name}} = static_cast<{{_field.type}}>({{_field.default_value}});
   }
 }
 inline void set_{{_field.name}}(const {{_field.type}}& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}} = value;
 }
 inline void set_{{_field.name}}(const {{_field.type}}&& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}} = value;
 }
 {% else %}
@@ -115,20 +114,20 @@ inline {{_field.type}}::FIELD_TYPE {{_field.name}}() const { return {{_field.var
 {% if _field.which_oneof is defined %}
 inline void clear_{{_field.name}}()
 {
-  if({{_field.variable_id_name}} == {{_field.which_oneof}})
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
   {
-    {{_field.which_oneof}} = 0;
+    {{_field.which_oneof}} = id::NOT_SET;
     {{_field.variable_name}}.set({{_field.default_value}});
   }
 }
 inline void set_{{_field.name}}(const {{_field.type}}::FIELD_TYPE& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}}.set(value);
 }
 inline void set_{{_field.name}}(const {{_field.type}}::FIELD_TYPE&& value)
 {
-  {{_field.which_oneof}} = {{_field.variable_id}};
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   {{_field.variable_name}}.set(value);
 }
 {% else %}
@@ -146,25 +145,25 @@ inline {{_field.type}}::FIELD_TYPE get_{{_field.name}}() const { return {{_field
 {% if _field.is_repeated_field %}
 if(result)
 {
-  result = {{_field.variable_name}}.serialize({{_field.variable_id_name}}, buffer);
+  result = {{_field.variable_name}}.serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
 }
 {% elif _field.of_type_message %}
 if(result)
 {
   const ::EmbeddedProto::MessageInterface* x = &{{_field.variable_name}};
-  result = x->serialize({{_field.variable_id_name}}, buffer);
+  result = x->serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
 }
 {% elif _field.of_type_enum %}
 if(({{_field.default_value}} != {{_field.variable_name}}) && result)
 {
   EmbeddedProto::uint32 value;
   value.set(static_cast<uint32_t>({{_field.variable_name}}));
-  result = value.serialize({{_field.variable_id_name}}, buffer);
+  result = value.serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
 }
 {% else %}
 if(({{_field.default_value}} != {{_field.variable_name}}.get()) && result)
 {
-  result = {{_field.variable_name}}.serialize({{_field.variable_id_name}}, buffer);
+  result = {{_field.variable_name}}.serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
 } {% endif %} {% endmacro %}
 {# #}
 {# ------------------------------------------------------------------------------------------------------------------ #}
@@ -190,7 +189,7 @@ if(::EmbeddedProto::WireFormatter::WireType::{{_field.wire_type}} == wire_type)
   {
     {{_field.variable_name}} = static_cast<{{_field.type}}>(value);
     {% if _field.which_oneof is defined %}
-    {{_field.which_oneof}} = {{_field.variable_id}};
+    {{_field.which_oneof}} = id::{{_field.variable_id_name}};
     {% endif %}
   }
   {% else %}
@@ -198,7 +197,7 @@ if(::EmbeddedProto::WireFormatter::WireType::{{_field.wire_type}} == wire_type)
   {% if _field.which_oneof is defined %}
   if(result)
   {
-    {{_field.which_oneof}} = {{_field.variable_id}};
+    {{_field.which_oneof}} = id::{{_field.variable_id_name}};
   }
   {% endif %}
   {% endif %}
@@ -240,11 +239,19 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
     {{ enum_macro(enum) }}
 
     {% endfor %}
+    enum class id
+    {
+      NOT_SET = 0,
+      {% for id_set in msg.field_ids %}
+      {{id_set[1]}} = {{id_set[0]}}{{ "," if not loop.last }}
+      {% endfor %}
+    };
+
     {% for field in msg.fields() %}
     {{ field_get_set_macro(field)|indent(4) }}
     {% endfor %}
     {% for oneof in msg.oneofs() %}
-    uint32_t get_which_{{oneof.name}}() const { return {{oneof.which_oneof}}; }
+    id get_which_{{oneof.name}}() const { return {{oneof.which_oneof}}; }
 
     {% for field in oneof.fields() %}
     {{ field_get_set_macro(field)|indent(4) }}
@@ -259,12 +266,12 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
 
       {% endfor %}
       {% for oneof in msg.oneofs() %}
-      if((0 != {{oneof.which_oneof}}) && result)
+      if(result)
       {
         switch({{oneof.which_oneof}})
         {
           {% for field in oneof.fields() %}
-          case {{field.variable_id_name}}:
+          case id::{{field.variable_id_name}}:
             {{ field_serialize_macro(field)|indent(12) }}
             break;
 
@@ -289,7 +296,7 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
         switch(id_number)
         {
           {% for field in msg.fields() %}
-          case {{field.variable_id_name}}:
+          case static_cast<uint32_t>(id::{{field.variable_id_name}}):
           {
             {{ field_deserialize_macro(field)|indent(12) }}
             break;
@@ -298,7 +305,7 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
           {% endfor %}
           {% for oneof in msg.oneofs() %}
           {% for field in oneof.fields() %}
-          case {{field.variable_id_name}}:
+          case static_cast<uint32_t>(id::{{field.variable_id_name}}):
           {
             {{ field_deserialize_macro(field)|indent(12) }}
             break;
@@ -331,7 +338,7 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
     {% endfor %}
 
     {% for oneof in msg.oneofs() %}
-    ::EmbeddedProto::uint32 {{oneof.which_oneof}};
+    id {{oneof.which_oneof}};
     union
     {
       {% for field in oneof.fields() %}

+ 11 - 1
generator/protoc-gen-eams.py

@@ -82,7 +82,7 @@ class FieldTemplateParameters:
     def __init__(self, field_proto, which_oneof=None):
         self.name = field_proto.name
         self.variable_name = self.name + "_"
-        self.variable_id_name = self.name + "_id"
+        self.variable_id_name = self.name.upper()
         self.variable_id = field_proto.number
 
         if which_oneof:
@@ -135,12 +135,22 @@ class MessageTemplateParameters:
         self.has_fields = len(self.msg_proto.field) > 0
         self.has_oneofs = len(self.msg_proto.oneof_decl) > 0
         self.templates = []
+        self.field_ids = []
 
         #TODO this creates a bug if a oneof field is also a repeated_field.
         for field in self.fields():
+            self.field_ids.append((field.variable_id, field.variable_id_name))
             if field.is_repeated_field:
                 self.templates.append(field.variable_name)
 
+        for oneof in self.oneofs():
+            for field in oneof.fields():
+                self.field_ids.append((field.variable_id, field.variable_id_name))
+
+        # Sort the field id's such they will appear in order in the id enum.
+        self.field_ids.sort()
+
+
     def fields(self):
         # Yield only the normal fields in this message.
         for f in self.msg_proto.field:

+ 7 - 7
test/test_oneof_fields.cpp

@@ -37,22 +37,22 @@ TEST(OneofField, serialize_zero)
 TEST(OneofField, set_get_clear)
 {
   message_oneof msg;
-  EXPECT_EQ(0, msg.get_which_xyz());
+  EXPECT_EQ(message_oneof::id::NOT_SET, msg.get_which_xyz());
   msg.set_x(1);
   EXPECT_EQ(1, msg.get_x());
-  EXPECT_EQ(5, msg.get_which_xyz());
+  EXPECT_EQ(message_oneof::id::X, msg.get_which_xyz());
   msg.clear_x();
 
-  EXPECT_EQ(0, msg.get_which_xyz());
+  EXPECT_EQ(message_oneof::id::NOT_SET, msg.get_which_xyz());
   msg.set_y(1);
   EXPECT_EQ(1, msg.get_y());
-  EXPECT_EQ(6, msg.get_which_xyz());
+  EXPECT_EQ(message_oneof::id::Y, msg.get_which_xyz());
   msg.clear_y();
 
-  EXPECT_EQ(0, msg.get_which_xyz());
+  EXPECT_EQ(message_oneof::id::NOT_SET, msg.get_which_xyz());
   msg.set_z(1);
   EXPECT_EQ(1, msg.get_z());
-  EXPECT_EQ(7, msg.get_which_xyz());
+  EXPECT_EQ(message_oneof::id::Z, msg.get_which_xyz());
   msg.clear_z();
 }
 
@@ -122,7 +122,7 @@ TEST(OneofField, deserialize)
 
   EXPECT_EQ(1, msg.get_a());
   EXPECT_EQ(1, msg.get_b());
-  EXPECT_EQ(6, msg.get_which_xyz());
+  EXPECT_EQ(message_oneof::id::Y, msg.get_which_xyz());
   EXPECT_EQ(1, msg.get_y());
 }