Преглед на файлове

Merged feature/PROTO-30-pass-templates-on-for-nested-me into develop

Bart Hertog преди 6 години
родител
ревизия
e05829e300
променени са 5 файла, в които са добавени 110 реда и са изтрити 54 реда
  1. 1 1
      generator/Header_Template.h
  2. 77 26
      generator/protoc-gen-eams.py
  3. 8 3
      test/proto/repeated_fields.proto
  4. 0 0
      test/test_EmbeddedProto.cpp
  5. 24 24
      test/test_RepeatedFieldMessage.cpp

+ 1 - 1
generator/Header_Template.h

@@ -214,7 +214,7 @@ else
 {% macro msg_macro(msg) %}
 {% if msg.templates is defined %}
 {% for template in msg.templates %}
-{{"template<" if loop.first}}uint32_t {{template}}{{"SIZE, " if not loop.last}}{{"SIZE>" if loop.last}}
+{{"template<" if loop.first}}{{template['type']}} {{template['name']}}{{", " if not loop.last}}{{">" if loop.last}}
 {% endfor %}
 {% endif %}
 class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface

+ 77 - 26
generator/protoc-gen-eams.py

@@ -2,6 +2,7 @@ import io
 import sys
 import os
 import jinja2
+from copy import deepcopy
 
 from google.protobuf.compiler import plugin_pb2 as plugin
 from google.protobuf.descriptor_pb2 import DescriptorProto, FieldDescriptorProto, EnumDescriptorProto
@@ -92,7 +93,6 @@ class FieldTemplateParameters:
         else:
             self.variable_full_name = self.variable_name
 
-
         self.of_type_message = FieldDescriptorProto.TYPE_MESSAGE == field_proto.type
         self.wire_type = self.type_to_wire_type[field_proto.type]
 
@@ -102,16 +102,41 @@ class FieldTemplateParameters:
             self.type = self.type_to_cpp_type[field_proto.type]
 
         self.of_type_enum = FieldDescriptorProto.TYPE_ENUM == field_proto.type
+        self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
+
+        self.default_value = None
+        self.repeated_type = None
+        self.templates = []
+
+        self.field_proto = field_proto
+
+    def update_templates(self, messages):
+        if self.of_type_message:
+            for msg in messages:
+                if msg.name == self.type:
+                    msg_templates = deepcopy(msg.templates)
+                    for tmpl in msg_templates:
+                        tmpl["name"] = self.variable_name + tmpl["name"]
+                    self.templates.extend(msg_templates)
+
+                    if self.templates:
+                        self.type += "<"
+                        for tmpl in self.templates[:-1]:
+                            self.type += tmpl["name"] + ", "
+                        self.type += self.templates[-1]["name"] + ">"
+
+                    break
+
         if self.of_type_enum:
             self.default_value = "static_cast<" + self.type + ">(0)"
         else:
-            self.default_value = self.type_to_default_value[field_proto.type]
+            self.default_value = self.type_to_default_value[self.field_proto.type]
 
-        self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
         if self.is_repeated_field:
-            self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + self.variable_name + "SIZE>"
+            self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + self.variable_name \
+                                 + "SIZE>"
+            self.templates.append({"type": "uint32_t", "name": self.variable_name + "SIZE"})
 
-        self.field_proto = field_proto
 
 # -----------------------------------------------------------------------------
 
@@ -123,58 +148,78 @@ class OneofTemplateParameters:
         self.index = index
         self.msg_proto = msg_proto
 
-    def fields(self):
-        # Yield all the fields in this oneof
+        self.fields_array = []
+        # Loop over all the fields in this oneof
         for f in self.msg_proto.field:
             if f.HasField('oneof_index') and self.index == f.oneof_index:
-                yield FieldTemplateParameters(f, self.name)
+                self.fields_array.append(FieldTemplateParameters(f, self.name))
+
+    def fields(self):
+        for f in self.fields_array:
+            yield f
 
+    def update_templates(self, messages):
+        for f in self.fields_array:
+            f.update_templates(messages)
 
 # -----------------------------------------------------------------------------
 
+
 class MessageTemplateParameters:
     def __init__(self, msg_proto):
         self.name = msg_proto.name
         self.msg_proto = msg_proto
         self.has_fields = len(self.msg_proto.field) > 0
         self.has_oneofs = len(self.msg_proto.oneof_decl) > 0
+
+        self.fields_array = []
+        # Loop over only the normal fields in this message.
+        for f in self.msg_proto.field:
+            if not f.HasField('oneof_index'):
+                self.fields_array.append(FieldTemplateParameters(f))
+
+        self.oneof_fields = []
+        # Loop over all the oneofs in this message.
+        for index, oneof in enumerate(self.msg_proto.oneof_decl):
+            self.oneof_fields.append(OneofTemplateParameters(oneof.name, index, self.msg_proto))
+
         self.templates = []
         self.field_ids = []
 
-        for field in self.fields():
+        for field in self.fields_array:
             self.field_ids.append((field.variable_id, field.variable_id_name))
-            if field.is_repeated_field:
-                self.templates.append(field.variable_name)
 
-        for oneof in self.oneofs():
+        for oneof in self.oneof_fields:
             for field in oneof.fields():
                 self.field_ids.append((field.variable_id, field.variable_id_name))
-                if field.is_repeated_field:
-                    self.templates.append(field.variable_name)
 
         # Sort the field id's such they will appear in order in the id enum.
         self.field_ids.sort()
 
     def fields(self):
-        # Yield only the normal fields in this message.
-        for f in self.msg_proto.field:
-            if not f.HasField('oneof_index'):
-                yield FieldTemplateParameters(f)
+        for f in self.fields_array:
+            yield f
 
     def oneofs(self):
-        # Yield all the oneofs in this message.
-        for index, oneof in enumerate(self.msg_proto.oneof_decl):
-            yield OneofTemplateParameters(oneof.name, index, self.msg_proto)
+        for o in self.oneof_fields:
+            yield o
 
     def nested_enums(self):
         # Yield all the enumerations defined in the scope of this message.
         for enum in self.msg_proto.enum_type:
             yield EnumTemplateParameters(enum)
 
+    def update_templates(self, messages):
+        for field in self.fields_array:
+            field.update_templates(messages)
 
-def generate_messages(message_types):
-    for msg in message_types:
-        yield MessageTemplateParameters(msg)
+            self.templates.extend(field.templates)
+
+        for oneof in self.oneof_fields:
+            for field in oneof.fields():
+                field.update_templates(messages)
+
+                self.templates.extend(field.templates)
 
 # -----------------------------------------------------------------------------
 
@@ -187,19 +232,25 @@ def generate_code(request, respones):
     template_file = "Header_Template.h"
     template = template_env.get_template(template_file)
 
+    messages_array = []
+
     # Loop over all proto files in the request
     for proto_file in request.proto_file:
 
         if "proto2" == proto_file.syntax:
             raise Exception(proto_file.name + ": Sorry, proto2 is not supported, please use proto3.")
 
-        messages_generator = generate_messages(proto_file.message_type)
+        for msg_type in proto_file.message_type:
+            msg = MessageTemplateParameters(msg_type)
+            msg.update_templates(messages_array)
+            messages_array.append(msg)
+
         enums_generator = generate_enums(proto_file.enum_type)
 
         filename_str = os.path.splitext(proto_file.name)[0]
 
         try:
-            file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_generator,
+            file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_array,
                                        enums=enums_generator)
         except jinja2.TemplateError as e:
             print("TemplateError exception: " + str(e))

+ 8 - 3
test/proto/repeated_fields.proto

@@ -17,7 +17,12 @@ message repeated_nested_message
 
 message repeated_message
 {
-  uint32 x                            = 1;
-  repeated repeated_nested_message y  = 2;
-  uint32 z                            = 3;
+  uint32 a                            = 1;
+  repeated repeated_nested_message b  = 2;
+  uint32 c                            = 3;
+}
+
+message nested_repeated_message
+{
+    repeated_message rm = 1;
 }

+ 0 - 0
test/test_main.cpp → test/test_EmbeddedProto.cpp


+ 24 - 24
test/test_RepeatedFieldMessage.cpp

@@ -89,9 +89,9 @@ TEST(RepeatedFieldMessage, serialize_array_zero_messages)
   rnm.set_u(0);
   rnm.set_v(0);
 
-  msg.add_y(rnm);
-  msg.add_y(rnm);
-  msg.add_y(rnm);
+  msg.add_b(rnm);
+  msg.add_b(rnm);
+  msg.add_b(rnm);
 
   EXPECT_CALL(buffer, get_available_size()).Times(1).WillOnce(Return(6));
 
@@ -140,15 +140,15 @@ TEST(RepeatedFieldMessage, serialize_array_zero_one_zero_messages)
   
   rnm.set_u(0);
   rnm.set_v(0);
-  msg.add_y(rnm);
+  msg.add_b(rnm);
 
   rnm.set_u(1);
   rnm.set_v(1);
-  msg.add_y(rnm);
+  msg.add_b(rnm);
   
   rnm.set_u(0);
   rnm.set_v(0);
-  msg.add_y(rnm);
+  msg.add_b(rnm);
 
   EXPECT_CALL(buffer, get_available_size()).Times(1).WillOnce(Return(10));
 
@@ -349,15 +349,15 @@ TEST(RepeatedFieldMessage, deserialize_one_message_array)
 
   EXPECT_TRUE(msg.deserialize(buffer));
 
-  EXPECT_EQ(1, msg.get_x());
-  EXPECT_EQ(3, msg.get_y().get_length());
-  EXPECT_EQ(0, msg.y(0).u());
-  EXPECT_EQ(0, msg.y(0).v());
-  EXPECT_EQ(1, msg.y(1).u());
-  EXPECT_EQ(1, msg.y(1).v());
-  EXPECT_EQ(0, msg.y(2).u());
-  EXPECT_EQ(0, msg.y(2).v());
-  EXPECT_EQ(1, msg.get_z());
+  EXPECT_EQ(1, msg.get_a());
+  EXPECT_EQ(3, msg.get_b().get_length());
+  EXPECT_EQ(0, msg.b(0).u());
+  EXPECT_EQ(0, msg.b(0).v());
+  EXPECT_EQ(1, msg.b(1).u());
+  EXPECT_EQ(1, msg.b(1).v());
+  EXPECT_EQ(0, msg.b(2).u());
+  EXPECT_EQ(0, msg.b(2).v());
+  EXPECT_EQ(1, msg.get_c());
 }
 
 TEST(RepeatedFieldMessage, deserialize_mixed_message_array) 
@@ -384,15 +384,15 @@ TEST(RepeatedFieldMessage, deserialize_mixed_message_array)
 
   EXPECT_TRUE(msg.deserialize(buffer));
 
-  EXPECT_EQ(1, msg.get_x());
-  EXPECT_EQ(3, msg.get_y().get_length());
-  EXPECT_EQ(0, msg.y(0).u());
-  EXPECT_EQ(0, msg.y(0).v());
-  EXPECT_EQ(1, msg.y(1).u());
-  EXPECT_EQ(1, msg.y(1).v());
-  EXPECT_EQ(0, msg.y(2).u());
-  EXPECT_EQ(0, msg.y(2).v());
-  EXPECT_EQ(1, msg.get_z());
+  EXPECT_EQ(1, msg.get_a());
+  EXPECT_EQ(3, msg.get_b().get_length());
+  EXPECT_EQ(0, msg.b(0).u());
+  EXPECT_EQ(0, msg.b(0).v());
+  EXPECT_EQ(1, msg.b(1).u());
+  EXPECT_EQ(1, msg.b(1).v());
+  EXPECT_EQ(0, msg.b(2).u());
+  EXPECT_EQ(0, msg.b(2).v());
+  EXPECT_EQ(1, msg.get_c());
 }
 
 TEST(RepeatedFieldMessage, deserialize_max)