Эх сурвалжийг харах

Git work in progress on the template parameters pass on.

Bart Hertog 6 жил өмнө
parent
commit
f3943de457

+ 71 - 16
generator/protoc-gen-eams.py

@@ -85,6 +85,8 @@ class FieldTemplateParameters:
         self.variable_id_name = self.name.upper()
         self.variable_id = field_proto.number
 
+        self.templates = []
+
         if oneof:
             # When set this field is part of a oneof.
             self.which_oneof = "which_" + oneof + "_"
@@ -92,26 +94,48 @@ class FieldTemplateParameters:
         else:
             self.variable_full_name = self.variable_name
 
-
         self.of_type_message = FieldDescriptorProto.TYPE_MESSAGE == field_proto.type
         self.wire_type = self.type_to_wire_type[field_proto.type]
+        self.type = None
+
+        self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
+        self.repeated_type = None
+        self.repeated_size_template = None
+
+        self.field_proto = field_proto
+
+        self.of_type_enum = False
+        self.default_value = None
+
+    def set_templates(self, templates):
+
+        self.templates = templates
+
+        if FieldDescriptorProto.TYPE_MESSAGE == self.field_proto.type or \
+                FieldDescriptorProto.TYPE_ENUM == self.field_proto.type:
+            self.type = self.field_proto.type_name if "." != self.field_proto.type_name[0] else \
+                self.field_proto.type_name[1:]
+
+            if self.templates:
+                self.type += "<"
+                for tmpl in self.templates[:-1]:
+                    self.type += tmpl + ", "
+                self.type += self.templates[-1] + ">"
 
-        if FieldDescriptorProto.TYPE_MESSAGE == field_proto.type or FieldDescriptorProto.TYPE_ENUM == field_proto.type:
-            self.type = field_proto.type_name if "." != field_proto.type_name[0] else field_proto.type_name[1:]
         else:
-            self.type = self.type_to_cpp_type[field_proto.type]
+            self.type = self.type_to_cpp_type[self.field_proto.type]
 
-        self.of_type_enum = FieldDescriptorProto.TYPE_ENUM == field_proto.type
+        self.of_type_enum = FieldDescriptorProto.TYPE_ENUM == self.field_proto.type
         if self.of_type_enum:
             self.default_value = "static_cast<" + self.type + ">(0)"
         else:
-            self.default_value = self.type_to_default_value[field_proto.type]
+            self.default_value = self.type_to_default_value[self.field_proto.type]
 
-        self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
         if self.is_repeated_field:
-            self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + self.variable_name + "SIZE>"
+            self.repeated_size_template = self.variable_name + "SIZE"
+            self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + \
+                                 self.repeated_size_template + ">"
 
-        self.field_proto = field_proto
 
 # -----------------------------------------------------------------------------
 
@@ -123,6 +147,8 @@ class OneofTemplateParameters:
         self.index = index
         self.msg_proto = msg_proto
 
+        self.templates = []
+
         self.fields_array = []
         # Loop over all the fields in this oneof
         for f in self.msg_proto.field:
@@ -133,7 +159,6 @@ class OneofTemplateParameters:
         for f in self.fields_array:
             yield f
 
-
 # -----------------------------------------------------------------------------
 
 
@@ -160,14 +185,10 @@ class MessageTemplateParameters:
 
         for field in self.fields_array:
             self.field_ids.append((field.variable_id, field.variable_id_name))
-            if field.is_repeated_field:
-                self.templates.append(field.variable_name)
 
         for oneof in self.oneof_fields:
             for field in oneof.fields():
                 self.field_ids.append((field.variable_id, field.variable_id_name))
-                if field.is_repeated_field:
-                    self.templates.append(field.variable_name)
 
         # Sort the field id's such they will appear in order in the id enum.
         self.field_ids.sort()
@@ -185,6 +206,34 @@ class MessageTemplateParameters:
         for enum in self.msg_proto.enum_type:
             yield EnumTemplateParameters(enum)
 
+    def set_templates(self, messages):
+        for field in self.fields_array:
+            if field.is_repeated_field:
+                self.templates.append(field.variable_name)
+
+            # Loop over messages and find relevant templates.
+            for msg in messages:
+                if (field.name == msg.name) and msg.templates:
+                    templates = [msg.name.capitalize() + "_" + tmpl for tmpl in msg.templates]
+                    field.extend(templates)
+                    self.templates.extend(templates)
+
+            field.set_templates(self.templates)
+
+        for oneof in self.oneof_fields:
+            for field in self.oneof_fields:
+                if field.is_repeated_field:
+                    self.templates.append(field.variable_name)
+
+                # Loop over messages and find relevant templates.
+                for msg in messages:
+                    if (oneof.name == msg.name) and msg.templates:
+                        templates = [msg.name.capitalize() + "_" + tmpl for tmpl in msg.templates]
+                        field.templates.extend(templates)
+                        self.templates.extend(templates)
+
+                field.set_templates(self.templates)
+
 
 def generate_messages(message_types):
     for msg in message_types:
@@ -207,13 +256,19 @@ def generate_code(request, respones):
         if "proto2" == proto_file.syntax:
             raise Exception(proto_file.name + ": Sorry, proto2 is not supported, please use proto3.")
 
-        messages_generator = generate_messages(proto_file.message_type)
+        message_array = []
+        for msg in proto_file.message_type:
+            msg = MessageTemplateParameters(msg)
+            msg.set_templates(message_array)
+            message_array.append(msg)
+
+        #messages_generator = generate_messages(proto_file.message_type)
         enums_generator = generate_enums(proto_file.enum_type)
 
         filename_str = os.path.splitext(proto_file.name)[0]
 
         try:
-            file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_generator,
+            file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=message_array,
                                        enums=enums_generator)
         except jinja2.TemplateError as e:
             print("TemplateError exception: " + str(e))