Skip to content
Snippets Groups Projects
Commit 223d2033 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[SameResize] ensure idt, PaddingStyle correctness

- SameResize_Batch hlslib op always pads with 0, i/odt must support
zero values
- Added assertion to check for PaddingStyle=2, the only tested
value
parent e1b6f89d
No related branches found
No related tags found
No related merge requests found
......@@ -86,6 +86,8 @@ class SameResize_Batch(HLSCustomOp):
node = self.onnx_node
# data type stays the same
dtype = model.get_tensor_datatype(node.input[0])
exp_idtype = self.get_input_datatype()
assert dtype == exp_idtype, "Unexpected datatype for SameResize_Batch"
model.set_tensor_datatype(node.output[0], dtype)
def verify_node(self):
......@@ -93,11 +95,15 @@ class SameResize_Batch(HLSCustomOp):
def get_input_datatype(self):
"""Returns FINN DataType of input."""
return DataType[self.get_nodeattr("inputDataType")]
ret = DataType[self.get_nodeattr("inputDataType")]
# the hlslib op always pads with zeroes, so ensure that the DataType
# is able to represent zeroes
assert ret.allowed(0), "SameResize_Batch DataType must support zero"
return ret
def get_output_datatype(self):
"""Returns FINN DataType of output. (Same as input datatype)"""
return DataType[self.get_nodeattr("inputDataType")]
return self.get_input_datatype()
def get_instream_width(self):
ibits = self.get_input_datatype().bitwidth()
......@@ -120,6 +126,7 @@ class SameResize_Batch(HLSCustomOp):
def defines(self, var):
numReps = 1
assert self.get_nodeattr("PaddingStyle") == 2, "Only PaddingStyle=2 supported"
self.code_gen_dict["$DEFINES$"] = [
"""#define ImgDim1 {}\n #define KernelDim1 {}\n
#define Stride1 {}\n #define NumChannels1 {}\n
......@@ -240,12 +247,7 @@ class SameResize_Batch(HLSCustomOp):
inp.shape == exp_ishape
), """Input shape doesn't
match expected shape (1, ImgDim, ImgDim, NumChannels)."""
if self.get_input_datatype() == DataType.BIPOLAR:
# store bipolar activations as binary
inp = (inp + 1) / 2
export_idt = DataType.BINARY
else:
export_idt = self.get_input_datatype()
export_idt = self.get_input_datatype()
# no reshaping for input since assuming no folding on input
# make copy before saving array
......@@ -290,11 +292,6 @@ class SameResize_Batch(HLSCustomOp):
mode
)
)
# binary -> bipolar if needed
if self.get_output_datatype() == DataType.BIPOLAR:
out = context[node.output[0]]
out = 2 * out - 1
context[node.output[0]] = out
assert (
context[node.output[0]].shape == exp_oshape
), """Output shape doesn't match expected shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment