diff --git a/src/finn/custom_op/fpgadataflow/sameresize_batch.py b/src/finn/custom_op/fpgadataflow/sameresize_batch.py index cf279dcc889d3afaa4da96833067e36371e6fc01..c459cac1e9c17336200a1fc85aad2af5e14e2c61 100644 --- a/src/finn/custom_op/fpgadataflow/sameresize_batch.py +++ b/src/finn/custom_op/fpgadataflow/sameresize_batch.py @@ -86,6 +86,8 @@ class SameResize_Batch(HLSCustomOp): node = self.onnx_node # data type stays the same dtype = model.get_tensor_datatype(node.input[0]) + exp_idtype = self.get_input_datatype() + assert dtype == exp_idtype, "Unexpected datatype for SameResize_Batch" model.set_tensor_datatype(node.output[0], dtype) def verify_node(self): @@ -93,11 +95,15 @@ class SameResize_Batch(HLSCustomOp): def get_input_datatype(self): """Returns FINN DataType of input.""" - return DataType[self.get_nodeattr("inputDataType")] + ret = DataType[self.get_nodeattr("inputDataType")] + # the hlslib op always pads with zeroes, so ensure that the DataType + # is able to represent zeroes + assert ret.allowed(0), "SameResize_Batch DataType must support zero" + return ret def get_output_datatype(self): """Returns FINN DataType of output. (Same as input datatype)""" - return DataType[self.get_nodeattr("inputDataType")] + return self.get_input_datatype() def get_instream_width(self): ibits = self.get_input_datatype().bitwidth() @@ -120,6 +126,7 @@ class SameResize_Batch(HLSCustomOp): def defines(self, var): numReps = 1 + assert self.get_nodeattr("PaddingStyle") == 2, "Only PaddingStyle=2 supported" self.code_gen_dict["$DEFINES$"] = [ """#define ImgDim1 {}\n #define KernelDim1 {}\n #define Stride1 {}\n #define NumChannels1 {}\n @@ -240,12 +247,7 @@ class SameResize_Batch(HLSCustomOp): inp.shape == exp_ishape ), """Input shape doesn't match expected shape (1, ImgDim, ImgDim, NumChannels).""" - if self.get_input_datatype() == DataType.BIPOLAR: - # store bipolar activations as binary - inp = (inp + 1) / 2 - export_idt = DataType.BINARY - else: - export_idt = self.get_input_datatype() + export_idt = self.get_input_datatype() # no reshaping for input since assuming no folding on input # make copy before saving array @@ -290,11 +292,6 @@ class SameResize_Batch(HLSCustomOp): mode ) ) - # binary -> bipolar if needed - if self.get_output_datatype() == DataType.BIPOLAR: - out = context[node.output[0]] - out = 2 * out - 1 - context[node.output[0]] = out assert ( context[node.output[0]].shape == exp_oshape ), """Output shape doesn't match expected shape