Просмотр исходного кода

Merged feature/PROTO-6-generator-for-oneofs into develop

Bart Hertog 6 лет назад
Родитель
Сommit
3f545ce90a
6 измененных файлов с 549 добавлено и 90 удалено
  1. 3 1
      build_test.sh
  2. 270 85
      generator/Header_Template.h
  3. 49 3
      generator/protoc-gen-eams.py
  4. 23 0
      test/proto/oneof_fields.proto
  5. 181 0
      test/test_oneof_fields.cpp
  6. 23 1
      test_data.py

+ 3 - 1
build_test.sh

@@ -5,16 +5,18 @@ mkdir -p ./build/EAMS
 protoc --plugin=protoc-gen-eams=protoc-gen-eams -I./test/proto --eams_out=./build/EAMS ./test/proto/simple_types.proto
 protoc --plugin=protoc-gen-eams=protoc-gen-eams -I./test/proto --eams_out=./build/EAMS ./test/proto/nested_message.proto
 protoc --plugin=protoc-gen-eams=protoc-gen-eams -I./test/proto --eams_out=./build/EAMS ./test/proto/repeated_fields.proto
+protoc --plugin=protoc-gen-eams=protoc-gen-eams -I./test/proto --eams_out=./build/EAMS ./test/proto/oneof_fields.proto
 
 # For validation and testing generate the same message using python
 mkdir -p ./build/python
 protoc -I./test/proto --python_out=./build/python ./test/proto/simple_types.proto
 protoc -I./test/proto --python_out=./build/python ./test/proto/nested_message.proto
 protoc -I./test/proto --python_out=./build/python ./test/proto/repeated_fields.proto
+protoc -I./test/proto --python_out=./build/python ./test/proto/oneof_fields.proto
 
 
 # Build the tests
 mkdir -p build/test
 cd build/test/
 cmake -DCMAKE_BUILD_TYPE=Debug ../../
-make
+make -j16

+ 270 - 85
generator/Header_Template.h

@@ -7,7 +7,210 @@ enum {{ _enum.name }}
 };
 
 {% endmacro %}
-
+{# #}
+{# ------------------------------------------------------------------------------------------------------------------ #}
+{# #}
+{% macro field_get_set_macro(_field) %}
+{% if _field.is_repeated_field %}
+inline const {{_field.type}}& {{_field.name}}(uint32_t index) const { return {{_field.variable_full_name}}[index]; }
+{% if _field.which_oneof is defined %}
+inline void clear_{{_field.name}}()
+{
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
+  {
+    {{_field.which_oneof}} = id::NOT_SET;
+    {{_field.variable_full_name}}.clear();
+  }
+}
+inline void set_{{_field.name}}(uint32_t index, const {{_field.type}}& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}}.set(index, value);
+}
+inline void set_{{_field.name}}(uint32_t index, const {{_field.type}}&& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}}.set(index, value);
+}
+inline void add_{{_field.name}}(const {{_field.type}}& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}}.add(value);
+}
+inline {{_field.repeated_type}}& mutable_{{_field.name}}()
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  return {{_field.variable_full_name}};
+}
+{% else %}
+inline void clear_{{_field.name}}() { {{_field.variable_full_name}}.clear(); }
+inline void set_{{_field.name}}(uint32_t index, const {{_field.type}}& value) { {{_field.variable_full_name}}.set(index, value); }
+inline void set_{{_field.name}}(uint32_t index, const {{_field.type}}&& value) { {{_field.variable_full_name}}.set(index, value); }
+inline void add_{{_field.name}}(const {{_field.type}}& value) { {{_field.variable_full_name}}.add(value); }
+inline {{_field.repeated_type}}& mutable_{{_field.name}}() { return {{_field.variable_full_name}}; }
+{% endif %}
+inline const {{_field.repeated_type}}& get_{{_field.name}}() const { return {{_field.variable_full_name}}; }
+{% elif _field.of_type_message %}
+inline const {{_field.type}}& {{_field.name}}() const { return {{_field.variable_full_name}}; }
+{% if _field.which_oneof is defined %}
+inline void clear_{{_field.name}}()
+{
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
+  {
+    {{_field.which_oneof}} = id::NOT_SET;
+    {{_field.variable_full_name}}.clear();
+  }
+}
+inline void set_{{_field.name}}(const {{_field.type}}& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}} = value;
+}
+inline void set_{{_field.name}}(const {{_field.type}}&& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}} = value;
+}
+inline {{_field.type}}& mutable_{{_field.name}}()
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  return {{_field.variable_full_name}};
+}
+{% else %}
+inline void clear_{{_field.name}}() { {{_field.variable_full_name}}.clear(); }
+inline void set_{{_field.name}}(const {{_field.type}}& value) { {{_field.variable_full_name}} = value; }
+inline void set_{{_field.name}}(const {{_field.type}}&& value) { {{_field.variable_full_name}} = value; }
+inline {{_field.type}}& mutable_{{_field.name}}() { return {{_field.variable_full_name}}; }
+{% endif %}
+inline const {{_field.type}}& get_{{_field.name}}() const { return {{_field.variable_full_name}}; }
+{% elif _field.of_type_enum %}
+inline {{_field.type}} {{_field.name}}() const { return {{_field.variable_full_name}}; }
+{% if _field.which_oneof is defined %}
+inline void clear_{{_field.name}}()
+{
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
+  {
+    {{_field.which_oneof}} = id::NOT_SET;
+    {{_field.variable_full_name}} = static_cast<{{_field.type}}>({{_field.default_value}});
+  }
+}
+inline void set_{{_field.name}}(const {{_field.type}}& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}} = value;
+}
+inline void set_{{_field.name}}(const {{_field.type}}&& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}} = value;
+}
+{% else %}
+inline void clear_{{_field.name}}() { {{_field.variable_full_name}} = static_cast<{{_field.type}}>({{_field.default_value}}); }
+inline void set_{{_field.name}}(const {{_field.type}}& value) { {{_field.variable_full_name}} = value; }
+inline void set_{{_field.name}}(const {{_field.type}}&& value) { {{_field.variable_full_name}} = value; }
+{% endif %}    inline {{_field.type}} get_{{_field.name}}() const { return {{_field.variable_full_name}}; }
+{% else %}
+inline {{_field.type}}::FIELD_TYPE {{_field.name}}() const { return {{_field.variable_full_name}}.get(); }
+{% if _field.which_oneof is defined %}
+inline void clear_{{_field.name}}()
+{
+  if(id::{{_field.variable_id_name}} == {{_field.which_oneof}})
+  {
+    {{_field.which_oneof}} = id::NOT_SET;
+    {{_field.variable_full_name}}.set({{_field.default_value}});
+  }
+}
+inline void set_{{_field.name}}(const {{_field.type}}::FIELD_TYPE& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}}.set(value);
+}
+inline void set_{{_field.name}}(const {{_field.type}}::FIELD_TYPE&& value)
+{
+  {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  {{_field.variable_full_name}}.set(value);
+}
+{% else %}
+inline void clear_{{_field.name}}() { {{_field.variable_full_name}}.set({{_field.default_value}}); }
+inline void set_{{_field.name}}(const {{_field.type}}::FIELD_TYPE& value) { {{_field.variable_full_name}}.set(value); }
+inline void set_{{_field.name}}(const {{_field.type}}::FIELD_TYPE&& value) { {{_field.variable_full_name}}.set(value); }
+{% endif %}
+inline {{_field.type}}::FIELD_TYPE get_{{_field.name}}() const { return {{_field.variable_full_name}}.get(); }
+{% endif %}
+{% endmacro %}
+{# #}
+{# ------------------------------------------------------------------------------------------------------------------ #}
+{# #}
+{% macro field_serialize_macro(_field) %}
+{% if _field.is_repeated_field %}
+if(result)
+{
+  result = {{_field.variable_full_name}}.serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
+}
+{% elif _field.of_type_message %}
+if(result)
+{
+  const ::EmbeddedProto::MessageInterface* x = &{{_field.variable_full_name}};
+  result = x->serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
+}
+{% elif _field.of_type_enum %}
+if(({{_field.default_value}} != {{_field.variable_full_name}}) && result)
+{
+  EmbeddedProto::uint32 value;
+  value.set(static_cast<uint32_t>({{_field.variable_full_name}}));
+  result = value.serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
+}
+{% else %}
+if(({{_field.default_value}} != {{_field.variable_full_name}}.get()) && result)
+{
+  result = {{_field.variable_full_name}}.serialize(static_cast<uint32_t>(id::{{_field.variable_id_name}}), buffer);
+} {% endif %} {% endmacro %}
+{# #}
+{# ------------------------------------------------------------------------------------------------------------------ #}
+{# #}
+{% macro field_deserialize_macro(_field) %}
+{% if _field.is_repeated_field %}
+if(::EmbeddedProto::WireFormatter::WireType::LENGTH_DELIMITED == wire_type)
+{
+  result = {{_field.variable_full_name}}.deserialize(buffer);
+}
+{% else %}
+if(::EmbeddedProto::WireFormatter::WireType::{{_field.wire_type}} == wire_type)
+{
+  {% if _field.of_type_message %}
+  uint32_t size;
+  result = ::EmbeddedProto::WireFormatter::DeserializeVarint(buffer, size);
+  ::EmbeddedProto::ReadBufferSection bufferSection(buffer, size);
+  result = result && {{_field.variable_full_name}}.deserialize(bufferSection);
+  {% elif _field.of_type_enum %}
+  uint32_t value;
+  result = ::EmbeddedProto::WireFormatter::DeserializeVarint(buffer, value);
+  if(result)
+  {
+    {{_field.variable_full_name}} = static_cast<{{_field.type}}>(value);
+    {% if _field.which_oneof is defined %}
+    {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+    {% endif %}
+  }
+  {% else %}
+  result = {{_field.variable_full_name}}.deserialize(buffer);
+  {% if _field.which_oneof is defined %}
+  if(result)
+  {
+    {{_field.which_oneof}} = id::{{_field.variable_id_name}};
+  }
+  {% endif %}
+  {% endif %}
+}
+{% endif %}
+else
+{
+  // TODO Error wire type does not match field.
+  result = false;
+} {% endmacro %}
+{# #}
+{# ------------------------------------------------------------------------------------------------------------------ #}
+{# #}
 {% macro msg_macro(msg) %}
 {% if msg.templates is defined %}
 {% for template in msg.templates %}
@@ -20,11 +223,14 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
     {{ msg.name }}() :
     {% for field in msg.fields() %}
         {% if field.of_type_enum %}
-        {{field.variable_name}}({{field.default_value}}){{"," if not loop.last}}
+        {{field.variable_full_name}}({{field.default_value}}){{"," if not loop.last}}
         {% else %}
-        {{field.variable_name}}(){{"," if not loop.last}}
+        {{field.variable_full_name}}(){{"," if not loop.last}}{{"," if loop.last and msg.has_oneofs}}
         {% endif %}
     {% endfor %}
+    {% for oneof in msg.oneofs() %}
+        {{oneof.which_oneof}}(id::NOT_SET){{"," if not loop.last}}
+    {% endfor %}
     {
 
     };
@@ -34,67 +240,44 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
     {{ enum_macro(enum) }}
 
     {% endfor %}
+    enum class id
+    {
+      NOT_SET = 0,
+      {% for id_set in msg.field_ids %}
+      {{id_set[1]}} = {{id_set[0]}}{{ "," if not loop.last }}
+      {% endfor %}
+    };
+
     {% for field in msg.fields() %}
-    static const uint32_t {{field.variable_id_name}} = {{field.variable_id}};
-    {% if field.is_repeated_field %}
-    inline const {{field.type}}& {{field.name}}(uint32_t index) const { return {{field.variable_name}}[index]; }
-    inline void clear_{{field.name}}() { {{field.variable_name}}.clear(); }
-    inline void set_{{field.name}}(uint32_t index, const {{field.type}}& value) { {{field.variable_name}}.set(index, value); }
-    inline void set_{{field.name}}(uint32_t index, const {{field.type}}&& value) { {{field.variable_name}}.set(index, value); }
-    inline void add_{{field.name}}(const {{field.type}}& value) { {{field.variable_name}}.add(value); }
-    inline const {{field.repeated_type}}& get_{{field.name}}() const { return {{field.variable_name}}; }
-    inline {{field.repeated_type}}& mutable_{{field.name}}() { return {{field.variable_name}}; }
-    {% elif field.of_type_message %}
-    inline const {{field.type}}& {{field.name}}() const { return {{field.variable_name}}; }
-    inline void clear_{{field.name}}() { {{field.variable_name}}.clear(); }
-    inline void set_{{field.name}}(const {{field.type}}& value) { {{field.variable_name}} = value; }
-    inline void set_{{field.name}}(const {{field.type}}&& value) { {{field.variable_name}} = value; }
-    inline const {{field.type}}& get_{{field.name}}() const { return {{field.variable_name}}; }
-    inline {{field.type}}& mutable_{{field.name}}() { return {{field.variable_name}}; }
-    {% elif field.of_type_enum %}
-    inline {{field.type}} {{field.name}}() const { return {{field.variable_name}}; }
-    inline void clear_{{field.name}}() { {{field.variable_name}} = static_cast<{{field.type}}>({{field.default_value}}); }
-    inline void set_{{field.name}}(const {{field.type}}& value) { {{field.variable_name}} = value; }
-    inline void set_{{field.name}}(const {{field.type}}&& value) { {{field.variable_name}} = value; }
-    inline {{field.type}} get_{{field.name}}() const { return {{field.variable_name}}; }
-    {% else %}
-    inline {{field.type}}::FIELD_TYPE {{field.name}}() const { return {{field.variable_name}}.get(); }
-    inline void clear_{{field.name}}() { {{field.variable_name}}.set({{field.default_value}}); }
-    inline void set_{{field.name}}(const {{field.type}}::FIELD_TYPE& value) { {{field.variable_name}}.set(value); }
-    inline void set_{{field.name}}(const {{field.type}}::FIELD_TYPE&& value) { {{field.variable_name}}.set(value); }
-    inline {{field.type}}::FIELD_TYPE get_{{field.name}}() const { return {{field.variable_name}}.get(); }
-    {% endif %}
+    {{ field_get_set_macro(field)|indent(4) }}
+    {% endfor %}
+    {% for oneof in msg.oneofs() %}
+    id get_which_{{oneof.name}}() const { return {{oneof.which_oneof}}; }
 
+    {% for field in oneof.fields() %}
+    {{ field_get_set_macro(field)|indent(4) }}
+    {% endfor %}
     {% endfor %}
     bool serialize(::EmbeddedProto::WriteBufferInterface& buffer) const final
     {
       bool result = true;
 
       {% for field in msg.fields() %}
-      {% if field.is_repeated_field %}
-      if(result)
-      {
-        result = {{field.variable_name}}.serialize({{field.variable_id_name}}, buffer);
-      }
-      {% elif field.of_type_message %}
-      if(result)
-      {
-        const ::EmbeddedProto::MessageInterface* x = &{{field.variable_name}};
-        result = x->serialize({{field.variable_id_name}}, buffer);
-      }
-      {% elif field.of_type_enum %}
-      if(({{field.default_value}} != {{field.variable_name}}) && result)
-      {
-        EmbeddedProto::uint32 value;
-        value.set(static_cast<uint32_t>({{field.variable_name}}));
-        result = value.serialize({{field.variable_id_name}}, buffer);
-      }
-      {% else %}
-      if(({{field.default_value}} != {{field.variable_name}}.get()) && result)
+      {{ field_serialize_macro(field)|indent(6) }}
+
+      {% endfor %}
+      {% for oneof in msg.oneofs() %}
+      switch({{oneof.which_oneof}})
       {
-        result = {{field.variable_name}}.serialize({{field.variable_id_name}}, buffer);
+        {% for field in oneof.fields() %}
+        case id::{{field.variable_id_name}}:
+          {{ field_serialize_macro(field)|indent(12) }}
+          break;
+
+        {% endfor %}
+        default:
+          break;
       }
-      {% endif %}
 
       {% endfor %}
       return result;
@@ -111,41 +294,22 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
         switch(id_number)
         {
           {% for field in msg.fields() %}
-          case {{field.variable_id_name}}:
+          case static_cast<uint32_t>(id::{{field.variable_id_name}}):
+          {
+            {{ field_deserialize_macro(field)|indent(12) }}
+            break;
+          }
+
+          {% endfor %}
+          {% for oneof in msg.oneofs() %}
+          {% for field in oneof.fields() %}
+          case static_cast<uint32_t>(id::{{field.variable_id_name}}):
           {
-            {% if field.is_repeated_field %}
-            if(::EmbeddedProto::WireFormatter::WireType::LENGTH_DELIMITED == wire_type)
-            {
-              result = {{field.variable_name}}.deserialize(buffer);
-            }
-            {% else %}
-            if(::EmbeddedProto::WireFormatter::WireType::{{field.wire_type}} == wire_type)
-            {
-              {% if field.of_type_message %}
-              uint32_t size;
-              result = ::EmbeddedProto::WireFormatter::DeserializeVarint(buffer, size);
-              ::EmbeddedProto::ReadBufferSection bufferSection(buffer, size);
-              result = result && {{field.variable_name}}.deserialize(bufferSection);
-              {% elif field.of_type_enum %}
-              uint32_t value;
-              result = ::EmbeddedProto::WireFormatter::DeserializeVarint(buffer, value);
-              if(result)
-              {
-                {{field.variable_name}} = static_cast<{{field.type}}>(value);
-              }
-              {% else %}
-              result = {{field.variable_name}}.deserialize(buffer);
-              {% endif %}
-            }
-            {% endif %}
-            else
-            {
-              // TODO Error wire type does not match field.
-              result = false;
-            }
+            {{ field_deserialize_macro(field)|indent(12) }}
             break;
           }
 
+          {% endfor %}
           {% endfor %}
           default:
             break;
@@ -170,8 +334,29 @@ class {{ msg.name }} final: public ::EmbeddedProto::MessageInterface
     {{field.type}} {{field.variable_name}};
     {% endif %}
     {% endfor %}
+
+    {% for oneof in msg.oneofs() %}
+    id {{oneof.which_oneof}};
+    union {{oneof.name}}
+    {
+      {{oneof.name}}() {}
+      ~{{oneof.name}}() {}
+      {% for field in oneof.fields() %}
+      {% if field.is_repeated_field %}
+      {{field.repeated_type}} {{field.variable_name}};
+      {% else %}
+      {{field.type}} {{field.variable_name}};
+      {% endif %}
+      {% endfor %}
+    };
+    {{oneof.name}} {{oneof.name}}_;
+
+    {% endfor %}
 };
 {% endmacro %}
+{# #}
+{# ------------------------------------------------------------------------------------------------------------------ #}
+{# #}
 // This file is generated. Please do not edit!
 #ifndef _{{filename.upper()}}_H_
 #define _{{filename.upper()}}_H_

+ 49 - 3
generator/protoc-gen-eams.py

@@ -79,12 +79,20 @@ class FieldTemplateParameters:
                          FieldDescriptorProto.TYPE_FLOAT:    "FIXED32",
                          FieldDescriptorProto.TYPE_SFIXED32: "FIXED32"}
 
-    def __init__(self, field_proto):
+    def __init__(self, field_proto, oneof=None):
         self.name = field_proto.name
         self.variable_name = self.name + "_"
-        self.variable_id_name = self.name + "_id"
+        self.variable_id_name = self.name.upper()
         self.variable_id = field_proto.number
 
+        if oneof:
+            # When set this field is part of a oneof.
+            self.which_oneof = "which_" + oneof + "_"
+            self.variable_full_name = oneof + "_." + self.variable_name
+        else:
+            self.variable_full_name = self.variable_name
+
+
         self.of_type_message = FieldDescriptorProto.TYPE_MESSAGE == field_proto.type
         self.wire_type = self.type_to_wire_type[field_proto.type]
 
@@ -108,20 +116,58 @@ class FieldTemplateParameters:
 # -----------------------------------------------------------------------------
 
 
+class OneofTemplateParameters:
+    def __init__(self, name, index, msg_proto):
+        self.name = name
+        self.which_oneof = "which_" + name + "_"
+        self.index = index
+        self.msg_proto = msg_proto
+
+    def fields(self):
+        # Yield all the fields in this oneof
+        for f in self.msg_proto.field:
+            if f.HasField('oneof_index') and self.index == f.oneof_index:
+                yield FieldTemplateParameters(f, self.name)
+
+
+# -----------------------------------------------------------------------------
+
 class MessageTemplateParameters:
     def __init__(self, msg_proto):
         self.name = msg_proto.name
         self.msg_proto = msg_proto
+        self.has_fields = len(self.msg_proto.field) > 0
+        self.has_oneofs = len(self.msg_proto.oneof_decl) > 0
         self.templates = []
+        self.field_ids = []
+
         for field in self.fields():
+            self.field_ids.append((field.variable_id, field.variable_id_name))
             if field.is_repeated_field:
                 self.templates.append(field.variable_name)
 
+        for oneof in self.oneofs():
+            for field in oneof.fields():
+                self.field_ids.append((field.variable_id, field.variable_id_name))
+                if field.is_repeated_field:
+                    self.templates.append(field.variable_name)
+
+        # Sort the field id's such they will appear in order in the id enum.
+        self.field_ids.sort()
+
     def fields(self):
+        # Yield only the normal fields in this message.
         for f in self.msg_proto.field:
-            yield FieldTemplateParameters(f)
+            if not f.HasField('oneof_index'):
+                yield FieldTemplateParameters(f)
+
+    def oneofs(self):
+        # Yield all the oneofs in this message.
+        for index, oneof in enumerate(self.msg_proto.oneof_decl):
+            yield OneofTemplateParameters(oneof.name, index, self.msg_proto)
 
     def nested_enums(self):
+        # Yield all the enumerations defined in the scope of this message.
         for enum in self.msg_proto.enum_type:
             yield EnumTemplateParameters(enum)
 

+ 23 - 0
test/proto/oneof_fields.proto

@@ -0,0 +1,23 @@
+
+// This file is used to test oneof fields in messages.
+
+syntax = "proto3";
+
+message message_oneof 
+{
+  int32 a = 1;
+
+  oneof xyz {
+    int32 x = 5;
+    int32 y = 6;
+    int32 z = 7;
+  }
+
+  int32 b = 10;
+
+  oneof uvw {
+    float u = 15;
+    float v = 16;
+    float w = 17;
+  }
+}

+ 181 - 0
test/test_oneof_fields.cpp

@@ -0,0 +1,181 @@
+
+#include "gtest/gtest.h"
+
+#include <WireFormatter.h>
+#include <ReadBufferMock.h>
+#include <WriteBufferMock.h>
+
+#include <cstdint>    
+#include <limits> 
+
+// EAMS message definitions
+#include <oneof_fields.h>
+
+using ::testing::_;
+using ::testing::InSequence;
+using ::testing::Return;
+using ::testing::SetArgReferee;
+
+
+TEST(OneofField, construction) 
+{
+  message_oneof msg;
+}
+
+TEST(OneofField, serialize_zero) 
+{
+  message_oneof msg;
+  Mocks::WriteBufferMock buffer;
+  
+  EXPECT_CALL(buffer, push(_)).Times(0);
+  EXPECT_CALL(buffer, push(_,_)).Times(0);
+  EXPECT_CALL(buffer, get_available_size()).Times(0);
+
+  EXPECT_TRUE(msg.serialize(buffer));
+}
+
+TEST(OneofField, set_get_clear)
+{
+  message_oneof msg;
+  EXPECT_EQ(message_oneof::id::NOT_SET, msg.get_which_xyz());
+  msg.set_x(1);
+  EXPECT_EQ(1, msg.get_x());
+  EXPECT_EQ(message_oneof::id::X, msg.get_which_xyz());
+  msg.clear_x();
+
+  EXPECT_EQ(message_oneof::id::NOT_SET, msg.get_which_xyz());
+  msg.set_y(1);
+  EXPECT_EQ(1, msg.get_y());
+  EXPECT_EQ(message_oneof::id::Y, msg.get_which_xyz());
+  msg.clear_y();
+
+  EXPECT_EQ(message_oneof::id::NOT_SET, msg.get_which_xyz());
+  msg.set_z(1);
+  EXPECT_EQ(1, msg.get_z());
+  EXPECT_EQ(message_oneof::id::Z, msg.get_which_xyz());
+  msg.clear_z();
+}
+
+TEST(OneofField, serialize_ones) 
+{
+  InSequence s;
+  message_oneof msg;
+  Mocks::WriteBufferMock buffer;
+
+  // X
+  msg.set_a(1);
+  msg.set_b(1);
+  msg.set_x(1);
+  
+  uint8_t expected_x[] = {0x08, 0x01,  // a
+                          0x50, 0x01,  // b
+                          0x28, 0x01}; // x
+
+  for(auto e : expected_x) {
+    EXPECT_CALL(buffer, push(e)).Times(1).WillOnce(Return(true));
+  }
+
+  EXPECT_TRUE(msg.serialize(buffer));
+
+  // Y
+  msg.set_y(1);
+  uint8_t expected_y[] = {0x08, 0x01,  // a
+                          0x50, 0x01,  // b
+                          0x30, 0x01}; // y
+
+  for(auto e : expected_y) {
+    EXPECT_CALL(buffer, push(e)).Times(1).WillOnce(Return(true));
+  }
+
+  EXPECT_TRUE(msg.serialize(buffer));
+
+  // z
+  msg.set_z(1);
+  uint8_t expected_z[] = {0x08, 0x01,  // a
+                          0x50, 0x01,  // b
+                          0x38, 0x01}; // z
+
+  for(auto e : expected_z) {
+    EXPECT_CALL(buffer, push(e)).Times(1).WillOnce(Return(true));
+  }
+
+  EXPECT_TRUE(msg.serialize(buffer));
+}
+
+TEST(OneofField, serialize_second_oneof)
+{
+  InSequence s;
+  message_oneof msg;
+  Mocks::WriteBufferMock buffer;
+
+  // X and V
+  msg.set_a(1);
+  msg.set_b(1);
+  msg.set_x(1);
+  msg.set_v(1);
+
+  uint8_t expected_z[] = {0x08, 0x01,  // a
+                          0x50, 0x01,  // b
+                          0x28, 0x01,  // x
+                          0x85, 0x01, 0x00, 0x00, 0x80, 0x3f}; // v
+
+  for(auto e : expected_z) {
+    EXPECT_CALL(buffer, push(e)).Times(1).WillOnce(Return(true));
+  }
+
+  EXPECT_TRUE(msg.serialize(buffer));
+}
+
+TEST(OneofField, deserialize) 
+{
+  InSequence s;
+
+  message_oneof msg;
+  Mocks::ReadBufferMock buffer;
+
+  uint8_t referee[] = {0x08, 0x01,  // a
+                       0x50, 0x01,  // b
+                       0x30, 0x01}; // y
+
+  for(auto r: referee) {
+    EXPECT_CALL(buffer, pop(_)).Times(1).WillOnce(DoAll(SetArgReferee<0>(r), Return(true)));
+  }
+  EXPECT_CALL(buffer, pop(_)).Times(1).WillOnce(Return(false));
+
+  EXPECT_TRUE(msg.deserialize(buffer));
+
+  EXPECT_EQ(1, msg.get_a());
+  EXPECT_EQ(1, msg.get_b());
+  EXPECT_EQ(message_oneof::id::Y, msg.get_which_xyz());
+  EXPECT_EQ(1, msg.get_y());
+}
+
+TEST(OneofField, deserialize_second_oneof) 
+{
+  InSequence s;
+
+  message_oneof msg;
+  Mocks::ReadBufferMock buffer;
+
+  uint8_t referee[] = {0x08, 0x01,  // a
+                       0x50, 0x01,  // b
+                       0x28, 0x01,  // x
+                       0x85, 0x01, 0x00, 0x00, 0x80, 0x3f}; // v
+
+  for(auto r: referee) {
+    EXPECT_CALL(buffer, pop(_)).Times(1).WillOnce(DoAll(SetArgReferee<0>(r), Return(true)));
+  }
+  EXPECT_CALL(buffer, pop(_)).Times(1).WillOnce(Return(false));
+
+  EXPECT_TRUE(msg.deserialize(buffer));
+
+  EXPECT_EQ(1, msg.get_a());
+  EXPECT_EQ(1, msg.get_b());
+  EXPECT_EQ(message_oneof::id::X, msg.get_which_xyz());
+  EXPECT_EQ(1, msg.get_x());
+  EXPECT_EQ(message_oneof::id::V, msg.get_which_uvw());
+  EXPECT_EQ(1.0, msg.get_y());
+}
+
+
+

+ 23 - 1
test_data.py

@@ -1,6 +1,7 @@
 import build.python.simple_types_pb2 as st
 import build.python.nested_message_pb2 as nm
 import build.python.repeated_fields_pb2 as rf
+import build.python.oneof_fields_pb2 as of
 
 
 def test_simple_types():
@@ -155,6 +156,27 @@ def test_repeated_message():
     print(str)
     print()
 
+
+def test_oneof_fields():
+    msg = of.message_oneof()
+
+    msg.a = 1
+    msg.b = 1
+    msg.x = 1
+    msg.v = 1
+
+    str = ""
+    msg_str = msg.SerializeToString()
+    print(len(msg_str))
+    print(msg_str)
+    for x in msg_str:
+      str += "0x{:02x}, ".format(x)
+
+    print(str)
+    print()
+
+
 #test_repeated_fields()
-test_repeated_message()
+#test_repeated_message()
 #test_nested_message()
+test_oneof_fields()