|
|
@@ -2,6 +2,7 @@ import io
|
|
|
import sys
|
|
|
import os
|
|
|
import jinja2
|
|
|
+from copy import deepcopy
|
|
|
|
|
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
|
|
from google.protobuf.descriptor_pb2 import DescriptorProto, FieldDescriptorProto, EnumDescriptorProto
|
|
|
@@ -92,7 +93,6 @@ class FieldTemplateParameters:
|
|
|
else:
|
|
|
self.variable_full_name = self.variable_name
|
|
|
|
|
|
-
|
|
|
self.of_type_message = FieldDescriptorProto.TYPE_MESSAGE == field_proto.type
|
|
|
self.wire_type = self.type_to_wire_type[field_proto.type]
|
|
|
|
|
|
@@ -102,16 +102,41 @@ class FieldTemplateParameters:
|
|
|
self.type = self.type_to_cpp_type[field_proto.type]
|
|
|
|
|
|
self.of_type_enum = FieldDescriptorProto.TYPE_ENUM == field_proto.type
|
|
|
+ self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
|
|
|
+
|
|
|
+ self.default_value = None
|
|
|
+ self.repeated_type = None
|
|
|
+ self.templates = []
|
|
|
+
|
|
|
+ self.field_proto = field_proto
|
|
|
+
|
|
|
+ def update_templates(self, messages):
|
|
|
+ if self.of_type_message:
|
|
|
+ for msg in messages:
|
|
|
+ if msg.name == self.type:
|
|
|
+ msg_templates = deepcopy(msg.templates)
|
|
|
+ for tmpl in msg_templates:
|
|
|
+ tmpl["name"] = self.variable_name + tmpl["name"]
|
|
|
+ self.templates.extend(msg_templates)
|
|
|
+
|
|
|
+ if self.templates:
|
|
|
+ self.type += "<"
|
|
|
+ for tmpl in self.templates[:-1]:
|
|
|
+ self.type += tmpl["name"] + ", "
|
|
|
+ self.type += self.templates[-1]["name"] + ">"
|
|
|
+
|
|
|
+ break
|
|
|
+
|
|
|
if self.of_type_enum:
|
|
|
self.default_value = "static_cast<" + self.type + ">(0)"
|
|
|
else:
|
|
|
- self.default_value = self.type_to_default_value[field_proto.type]
|
|
|
+ self.default_value = self.type_to_default_value[self.field_proto.type]
|
|
|
|
|
|
- self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
|
|
|
if self.is_repeated_field:
|
|
|
- self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + self.variable_name + "SIZE>"
|
|
|
+ self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + self.variable_name \
|
|
|
+ + "SIZE>"
|
|
|
+ self.templates.append({"type": "uint32_t", "name": self.variable_name + "SIZE"})
|
|
|
|
|
|
- self.field_proto = field_proto
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
@@ -123,58 +148,78 @@ class OneofTemplateParameters:
|
|
|
self.index = index
|
|
|
self.msg_proto = msg_proto
|
|
|
|
|
|
- def fields(self):
|
|
|
- # Yield all the fields in this oneof
|
|
|
+ self.fields_array = []
|
|
|
+ # Loop over all the fields in this oneof
|
|
|
for f in self.msg_proto.field:
|
|
|
if f.HasField('oneof_index') and self.index == f.oneof_index:
|
|
|
- yield FieldTemplateParameters(f, self.name)
|
|
|
+ self.fields_array.append(FieldTemplateParameters(f, self.name))
|
|
|
+
|
|
|
+ def fields(self):
|
|
|
+ for f in self.fields_array:
|
|
|
+ yield f
|
|
|
|
|
|
+ def update_templates(self, messages):
|
|
|
+ for f in self.fields_array:
|
|
|
+ f.update_templates(messages)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
+
|
|
|
class MessageTemplateParameters:
|
|
|
def __init__(self, msg_proto):
|
|
|
self.name = msg_proto.name
|
|
|
self.msg_proto = msg_proto
|
|
|
self.has_fields = len(self.msg_proto.field) > 0
|
|
|
self.has_oneofs = len(self.msg_proto.oneof_decl) > 0
|
|
|
+
|
|
|
+ self.fields_array = []
|
|
|
+ # Loop over only the normal fields in this message.
|
|
|
+ for f in self.msg_proto.field:
|
|
|
+ if not f.HasField('oneof_index'):
|
|
|
+ self.fields_array.append(FieldTemplateParameters(f))
|
|
|
+
|
|
|
+ self.oneof_fields = []
|
|
|
+ # Loop over all the oneofs in this message.
|
|
|
+ for index, oneof in enumerate(self.msg_proto.oneof_decl):
|
|
|
+ self.oneof_fields.append(OneofTemplateParameters(oneof.name, index, self.msg_proto))
|
|
|
+
|
|
|
self.templates = []
|
|
|
self.field_ids = []
|
|
|
|
|
|
- for field in self.fields():
|
|
|
+ for field in self.fields_array:
|
|
|
self.field_ids.append((field.variable_id, field.variable_id_name))
|
|
|
- if field.is_repeated_field:
|
|
|
- self.templates.append(field.variable_name)
|
|
|
|
|
|
- for oneof in self.oneofs():
|
|
|
+ for oneof in self.oneof_fields:
|
|
|
for field in oneof.fields():
|
|
|
self.field_ids.append((field.variable_id, field.variable_id_name))
|
|
|
- if field.is_repeated_field:
|
|
|
- self.templates.append(field.variable_name)
|
|
|
|
|
|
# Sort the field id's such they will appear in order in the id enum.
|
|
|
self.field_ids.sort()
|
|
|
|
|
|
def fields(self):
|
|
|
- # Yield only the normal fields in this message.
|
|
|
- for f in self.msg_proto.field:
|
|
|
- if not f.HasField('oneof_index'):
|
|
|
- yield FieldTemplateParameters(f)
|
|
|
+ for f in self.fields_array:
|
|
|
+ yield f
|
|
|
|
|
|
def oneofs(self):
|
|
|
- # Yield all the oneofs in this message.
|
|
|
- for index, oneof in enumerate(self.msg_proto.oneof_decl):
|
|
|
- yield OneofTemplateParameters(oneof.name, index, self.msg_proto)
|
|
|
+ for o in self.oneof_fields:
|
|
|
+ yield o
|
|
|
|
|
|
def nested_enums(self):
|
|
|
# Yield all the enumerations defined in the scope of this message.
|
|
|
for enum in self.msg_proto.enum_type:
|
|
|
yield EnumTemplateParameters(enum)
|
|
|
|
|
|
+ def update_templates(self, messages):
|
|
|
+ for field in self.fields_array:
|
|
|
+ field.update_templates(messages)
|
|
|
|
|
|
-def generate_messages(message_types):
|
|
|
- for msg in message_types:
|
|
|
- yield MessageTemplateParameters(msg)
|
|
|
+ self.templates.extend(field.templates)
|
|
|
+
|
|
|
+ for oneof in self.oneof_fields:
|
|
|
+ for field in oneof.fields():
|
|
|
+ field.update_templates(messages)
|
|
|
+
|
|
|
+ self.templates.extend(field.templates)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
@@ -187,19 +232,25 @@ def generate_code(request, respones):
|
|
|
template_file = "Header_Template.h"
|
|
|
template = template_env.get_template(template_file)
|
|
|
|
|
|
+ messages_array = []
|
|
|
+
|
|
|
# Loop over all proto files in the request
|
|
|
for proto_file in request.proto_file:
|
|
|
|
|
|
if "proto2" == proto_file.syntax:
|
|
|
raise Exception(proto_file.name + ": Sorry, proto2 is not supported, please use proto3.")
|
|
|
|
|
|
- messages_generator = generate_messages(proto_file.message_type)
|
|
|
+ for msg_type in proto_file.message_type:
|
|
|
+ msg = MessageTemplateParameters(msg_type)
|
|
|
+ msg.update_templates(messages_array)
|
|
|
+ messages_array.append(msg)
|
|
|
+
|
|
|
enums_generator = generate_enums(proto_file.enum_type)
|
|
|
|
|
|
filename_str = os.path.splitext(proto_file.name)[0]
|
|
|
|
|
|
try:
|
|
|
- file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_generator,
|
|
|
+ file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_array,
|
|
|
enums=enums_generator)
|
|
|
except jinja2.TemplateError as e:
|
|
|
print("TemplateError exception: " + str(e))
|