Bart Hertog 6 лет назад
Родитель
Сommit
229a60c8aa
2 измененных файлов с 48 добавлено и 15 удалено
  1. 1 1
      generator/Header_Template.h
  2. 47 14
      generator/protoc-gen-eams.py

+ 1 - 1
generator/Header_Template.h

@@ -214,7 +214,7 @@ else
 {% macro msg_macro(msg) %}
 {% if msg.templates is defined %}
 {% for template in msg.templates %}
-{{"template<" if loop.first}}uint32_t {{template}}{{"SIZE, " if not loop.last}}{{"SIZE>" if loop.last}}
+{{"template<" if loop.first}}{{template["type"]}} {{template["name"]}}{{", " if not loop.last}}{{">" if loop.last}}
 {% endfor %}
 {% endif %}
 class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface

+ 47 - 14
generator/protoc-gen-eams.py

@@ -92,7 +92,6 @@ class FieldTemplateParameters:
         else:
             self.variable_full_name = self.variable_name
 
-
         self.of_type_message = FieldDescriptorProto.TYPE_MESSAGE == field_proto.type
         self.wire_type = self.type_to_wire_type[field_proto.type]
 
@@ -102,16 +101,38 @@ class FieldTemplateParameters:
             self.type = self.type_to_cpp_type[field_proto.type]
 
         self.of_type_enum = FieldDescriptorProto.TYPE_ENUM == field_proto.type
+        self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
+
+        self.default_value = None
+        self.repeated_type = None
+        self.templates = []
+
+        self.field_proto = field_proto
+
+    def update_templates(self, messages):
+        if self.of_type_message:
+            for msg in messages:
+                if msg.name == self.type:
+                    self.templates.extend(msg.templates)
+
+                    if self.templates:
+                        self.type += "<"
+                        for tmpl in self.templates[:-1]["name"]:
+                            self.type += tmpl + ", "
+                        self.type += self.templates[-1]["name"] + ">"
+
+                    break
+
         if self.of_type_enum:
             self.default_value = "static_cast<" + self.type + ">(0)"
         else:
-            self.default_value = self.type_to_default_value[field_proto.type]
+            self.default_value = self.type_to_default_value[self.field_proto.type]
 
-        self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
         if self.is_repeated_field:
-            self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + self.variable_name + "SIZE>"
+            self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " +  self.variable_name \
+                                 + "SIZE>"
+            self.templates.append({"type": "uint32_t", "name": self.variable_name + "SIZE>"})
 
-        self.field_proto = field_proto
 
 # -----------------------------------------------------------------------------
 
@@ -133,6 +154,9 @@ class OneofTemplateParameters:
         for f in self.fields_array:
             yield f
 
+    def update_templates(self, messages):
+        for f in self.fields_array:
+            f.update_templates(messages)
 
 # -----------------------------------------------------------------------------
 
@@ -160,14 +184,10 @@ class MessageTemplateParameters:
 
         for field in self.fields_array:
             self.field_ids.append((field.variable_id, field.variable_id_name))
-            if field.is_repeated_field:
-                self.templates.append(field.variable_name)
 
         for oneof in self.oneof_fields:
             for field in oneof.fields():
                 self.field_ids.append((field.variable_id, field.variable_id_name))
-                if field.is_repeated_field:
-                    self.templates.append(field.variable_name)
 
         # Sort the field id's such they will appear in order in the id enum.
         self.field_ids.sort()
@@ -185,10 +205,17 @@ class MessageTemplateParameters:
         for enum in self.msg_proto.enum_type:
             yield EnumTemplateParameters(enum)
 
+    def update_templates(self, messages):
+        for field in self.fields_array:
+            field.update_templates(messages)
+
+            self.templates.extend(field.templates)
 
-def generate_messages(message_types):
-    for msg in message_types:
-        yield MessageTemplateParameters(msg)
+        for oneof in self.oneof_fields:
+            for field in oneof.fields():
+                field.update_templates(messages)
+
+                self.templates.extend(field.templates)
 
 # -----------------------------------------------------------------------------
 
@@ -201,19 +228,25 @@ def generate_code(request, respones):
     template_file = "Header_Template.h"
     template = template_env.get_template(template_file)
 
+    messages_array = []
+
     # Loop over all proto files in the request
     for proto_file in request.proto_file:
 
         if "proto2" == proto_file.syntax:
             raise Exception(proto_file.name + ": Sorry, proto2 is not supported, please use proto3.")
 
-        messages_generator = generate_messages(proto_file.message_type)
+        for msg_type in proto_file.message_type:
+            msg = MessageTemplateParameters(msg_type)
+            msg.update_templates(messages_array)
+            messages_array.append(msg)
+
         enums_generator = generate_enums(proto_file.enum_type)
 
         filename_str = os.path.splitext(proto_file.name)[0]
 
         try:
-            file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_generator,
+            file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_array,
                                        enums=enums_generator)
         except jinja2.TemplateError as e:
             print("TemplateError exception: " + str(e))