|
|
@@ -85,6 +85,8 @@ class FieldTemplateParameters:
|
|
|
self.variable_id_name = self.name.upper()
|
|
|
self.variable_id = field_proto.number
|
|
|
|
|
|
+ self.templates = []
|
|
|
+
|
|
|
if oneof:
|
|
|
# When set this field is part of a oneof.
|
|
|
self.which_oneof = "which_" + oneof + "_"
|
|
|
@@ -92,26 +94,48 @@ class FieldTemplateParameters:
|
|
|
else:
|
|
|
self.variable_full_name = self.variable_name
|
|
|
|
|
|
-
|
|
|
self.of_type_message = FieldDescriptorProto.TYPE_MESSAGE == field_proto.type
|
|
|
self.wire_type = self.type_to_wire_type[field_proto.type]
|
|
|
+ self.type = None
|
|
|
+
|
|
|
+ self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
|
|
|
+ self.repeated_type = None
|
|
|
+ self.repeated_size_template = None
|
|
|
+
|
|
|
+ self.field_proto = field_proto
|
|
|
+
|
|
|
+ self.of_type_enum = False
|
|
|
+ self.default_value = None
|
|
|
+
|
|
|
+ def set_templates(self, templates):
|
|
|
+
|
|
|
+ self.templates = templates
|
|
|
+
|
|
|
+ if FieldDescriptorProto.TYPE_MESSAGE == self.field_proto.type or \
|
|
|
+ FieldDescriptorProto.TYPE_ENUM == self.field_proto.type:
|
|
|
+ self.type = self.field_proto.type_name if "." != self.field_proto.type_name[0] else \
|
|
|
+ self.field_proto.type_name[1:]
|
|
|
+
|
|
|
+ if self.templates:
|
|
|
+ self.type += "<"
|
|
|
+ for tmpl in self.templates[:-1]:
|
|
|
+ self.type += tmpl + ", "
|
|
|
+ self.type += self.templates[-1] + ">"
|
|
|
|
|
|
- if FieldDescriptorProto.TYPE_MESSAGE == field_proto.type or FieldDescriptorProto.TYPE_ENUM == field_proto.type:
|
|
|
- self.type = field_proto.type_name if "." != field_proto.type_name[0] else field_proto.type_name[1:]
|
|
|
else:
|
|
|
- self.type = self.type_to_cpp_type[field_proto.type]
|
|
|
+ self.type = self.type_to_cpp_type[self.field_proto.type]
|
|
|
|
|
|
- self.of_type_enum = FieldDescriptorProto.TYPE_ENUM == field_proto.type
|
|
|
+ self.of_type_enum = FieldDescriptorProto.TYPE_ENUM == self.field_proto.type
|
|
|
if self.of_type_enum:
|
|
|
self.default_value = "static_cast<" + self.type + ">(0)"
|
|
|
else:
|
|
|
- self.default_value = self.type_to_default_value[field_proto.type]
|
|
|
+ self.default_value = self.type_to_default_value[self.field_proto.type]
|
|
|
|
|
|
- self.is_repeated_field = field_proto.label == FieldDescriptorProto.LABEL_REPEATED
|
|
|
if self.is_repeated_field:
|
|
|
- self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + self.variable_name + "SIZE>"
|
|
|
+ self.repeated_size_template = self.variable_name + "SIZE"
|
|
|
+ self.repeated_type = "::EmbeddedProto::RepeatedFieldSize<" + self.type + ", " + \
|
|
|
+ self.repeated_size_template + ">"
|
|
|
|
|
|
- self.field_proto = field_proto
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
@@ -123,6 +147,8 @@ class OneofTemplateParameters:
|
|
|
self.index = index
|
|
|
self.msg_proto = msg_proto
|
|
|
|
|
|
+ self.templates = []
|
|
|
+
|
|
|
self.fields_array = []
|
|
|
# Loop over all the fields in this oneof
|
|
|
for f in self.msg_proto.field:
|
|
|
@@ -133,7 +159,6 @@ class OneofTemplateParameters:
|
|
|
for f in self.fields_array:
|
|
|
yield f
|
|
|
|
|
|
-
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
@@ -160,14 +185,10 @@ class MessageTemplateParameters:
|
|
|
|
|
|
for field in self.fields_array:
|
|
|
self.field_ids.append((field.variable_id, field.variable_id_name))
|
|
|
- if field.is_repeated_field:
|
|
|
- self.templates.append(field.variable_name)
|
|
|
|
|
|
for oneof in self.oneof_fields:
|
|
|
for field in oneof.fields():
|
|
|
self.field_ids.append((field.variable_id, field.variable_id_name))
|
|
|
- if field.is_repeated_field:
|
|
|
- self.templates.append(field.variable_name)
|
|
|
|
|
|
# Sort the field id's such they will appear in order in the id enum.
|
|
|
self.field_ids.sort()
|
|
|
@@ -185,6 +206,34 @@ class MessageTemplateParameters:
|
|
|
for enum in self.msg_proto.enum_type:
|
|
|
yield EnumTemplateParameters(enum)
|
|
|
|
|
|
+ def set_templates(self, messages):
|
|
|
+ for field in self.fields_array:
|
|
|
+ if field.is_repeated_field:
|
|
|
+ self.templates.append(field.variable_name)
|
|
|
+
|
|
|
+ # Loop over messages and find relevant templates.
|
|
|
+ for msg in messages:
|
|
|
+ if (field.name == msg.name) and msg.templates:
|
|
|
+ templates = [msg.name.capitalize() + "_" + tmpl for tmpl in msg.templates]
|
|
|
+ field.extend(templates)
|
|
|
+ self.templates.extend(templates)
|
|
|
+
|
|
|
+ field.set_templates(self.templates)
|
|
|
+
|
|
|
+ for oneof in self.oneof_fields:
|
|
|
+ for field in self.oneof_fields:
|
|
|
+ if field.is_repeated_field:
|
|
|
+ self.templates.append(field.variable_name)
|
|
|
+
|
|
|
+ # Loop over messages and find relevant templates.
|
|
|
+ for msg in messages:
|
|
|
+ if (oneof.name == msg.name) and msg.templates:
|
|
|
+ templates = [msg.name.capitalize() + "_" + tmpl for tmpl in msg.templates]
|
|
|
+ field.templates.extend(templates)
|
|
|
+ self.templates.extend(templates)
|
|
|
+
|
|
|
+ field.set_templates(self.templates)
|
|
|
+
|
|
|
|
|
|
def generate_messages(message_types):
|
|
|
for msg in message_types:
|
|
|
@@ -207,13 +256,19 @@ def generate_code(request, respones):
|
|
|
if "proto2" == proto_file.syntax:
|
|
|
raise Exception(proto_file.name + ": Sorry, proto2 is not supported, please use proto3.")
|
|
|
|
|
|
- messages_generator = generate_messages(proto_file.message_type)
|
|
|
+ message_array = []
|
|
|
+ for msg in proto_file.message_type:
|
|
|
+ msg = MessageTemplateParameters(msg)
|
|
|
+ msg.set_templates(message_array)
|
|
|
+ message_array.append(msg)
|
|
|
+
|
|
|
+ #messages_generator = generate_messages(proto_file.message_type)
|
|
|
enums_generator = generate_enums(proto_file.enum_type)
|
|
|
|
|
|
filename_str = os.path.splitext(proto_file.name)[0]
|
|
|
|
|
|
try:
|
|
|
- file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=messages_generator,
|
|
|
+ file_str = template.render(filename=filename_str, namespace=proto_file.package, messages=message_array,
|
|
|
enums=enums_generator)
|
|
|
except jinja2.TemplateError as e:
|
|
|
print("TemplateError exception: " + str(e))
|