|
|
@@ -97,7 +97,7 @@ class TestSettings(ABC):
|
|
|
UINT8_MIN = 0
|
|
|
|
|
|
def __init__(self, dataset, testtype, args, in_ch, out_ch, x_in, y_in, w_x, w_y, stride_x, stride_y, pad, randmin,
|
|
|
- randmax, outminrange=-128, outmaxrange=127, batches=1):
|
|
|
+ randmax, outminrange=-128, outmaxrange=127, batches=1, generate_bias=True):
|
|
|
|
|
|
self.tensor_flow_reference_version = ("// Generated by {} using TFL version {} as reference.\n".
|
|
|
format(os.path.basename(__file__), tf.__version__))
|
|
|
@@ -128,6 +128,9 @@ class TestSettings(ABC):
|
|
|
# which may cause output to differ.
|
|
|
self.output_scale = 1.0
|
|
|
|
|
|
+ # Bias is optional.
|
|
|
+ self.generate_bias = generate_bias
|
|
|
+
|
|
|
self.generated_header_files = []
|
|
|
self.pregenerated_data_dir = self.PREGEN
|
|
|
self.testdataset = DEFAULT_TESTDATA_SET
|
|
|
@@ -269,7 +272,9 @@ class TestSettings(ABC):
|
|
|
|
|
|
def get_randomized_bias_data(self, biases):
|
|
|
# Generate or load saved bias data unless hardcoded data provided
|
|
|
- if biases is not None:
|
|
|
+ if not self.generate_bias:
|
|
|
+ biases = tf.reshape(np.full([self.output_ch], 0), [self.output_ch])
|
|
|
+ elif biases is not None:
|
|
|
biases = tf.reshape(biases, [self.output_ch])
|
|
|
else:
|
|
|
biases = self.get_randomized_data([self.output_ch],
|
|
|
@@ -442,9 +447,9 @@ class TestSettings(ABC):
|
|
|
class ConvSettings(TestSettings):
|
|
|
|
|
|
def __init__(self, dataset, testtype, args, in_ch=1, out_ch=1, x_in=7, y_in=7, w_x=3, w_y=3, stride_x=2, stride_y=2,
|
|
|
- pad=True, randmin=-7, randmax=7, outminrange=-128, outmaxrange=127, batches=1):
|
|
|
+ pad=True, randmin=-7, randmax=7, outminrange=-128, outmaxrange=127, batches=1, generate_bias=True):
|
|
|
super().__init__(dataset, testtype, args, in_ch, out_ch, x_in, y_in, w_x, w_y, stride_x, stride_y, pad,
|
|
|
- randmin, randmax, outminrange, outmaxrange, batches)
|
|
|
+ randmin, randmax, outminrange, outmaxrange, batches, generate_bias=generate_bias)
|
|
|
|
|
|
self.scaling_factors = []
|
|
|
|
|
|
@@ -669,9 +674,9 @@ class FullyConnectedSettings(TestSettings):
|
|
|
def __init__(self, dataset, testtype, args, in_ch=1, out_ch=1, x_in=1, y_in=1, w_x=1, w_y=1, stride_x=1, stride_y=1,
|
|
|
pad=False, randmin=-7, randmax=7, outminrange=-128, outmaxrange=127, batches=1, input_scale=1.0,
|
|
|
input_zero_point=0, weights_scale=1.0, weights_zero_point=0, bias_scale=1.0, output_scale=1.0,
|
|
|
- output_zero_point=0):
|
|
|
+ output_zero_point=0, generate_bias=True):
|
|
|
super().__init__(dataset, testtype, args, in_ch, out_ch, x_in, y_in, w_x, w_y, stride_x, stride_y, pad, randmin,
|
|
|
- randmax, outminrange, outmaxrange, batches)
|
|
|
+ randmax, outminrange, outmaxrange, batches, generate_bias=generate_bias)
|
|
|
|
|
|
if not self.test_type == 'fully_connected':
|
|
|
raise RuntimeError("Invalid test type {}".format(self.test_type))
|
|
|
@@ -801,12 +806,16 @@ def load_all_testdatasets():
|
|
|
output_zero_point=-2)
|
|
|
dataset = 'fully_connected_mve_0'
|
|
|
ALL_TESTDATA_SETS[dataset] = FullyConnectedSettings(dataset, type_of_test, args, in_ch=16, out_ch=9, x_in=1, y_in=1,
|
|
|
- input_zero_point=-3, w_x=1, w_y=1, randmin=-4, randmax=4, batches=1,
|
|
|
- output_zero_point=-2)
|
|
|
+ input_zero_point=-3, w_x=1, w_y=1, randmin=-4, randmax=4,
|
|
|
+ batches=1, output_zero_point=-2)
|
|
|
dataset = 'fully_connected_mve_1'
|
|
|
ALL_TESTDATA_SETS[dataset] = FullyConnectedSettings(dataset, type_of_test, args, in_ch=20, out_ch=4, x_in=1, y_in=1,
|
|
|
input_zero_point=-1, weights_zero_point=3, w_x=1, w_y=1,
|
|
|
randmin=-4, randmax=4, batches=1, output_zero_point=3)
|
|
|
+ dataset = 'fully_connected_null_bias_0'
|
|
|
+ ALL_TESTDATA_SETS[dataset] = FullyConnectedSettings(dataset, type_of_test, args, in_ch=33, out_ch=5,
|
|
|
+ input_zero_point=-1, weights_zero_point=3,
|
|
|
+ randmin=-4, randmax=4, batches=2, generate_bias=False)
|
|
|
type_of_test = 'avgpool'
|
|
|
dataset = 'avgpooling'
|
|
|
ALL_TESTDATA_SETS[dataset] = PoolingSettings(dataset, type_of_test, args, channels=8, x_in=22, y_in=12, stride_x=9,
|