Skip to content
Snippets Groups Projects
Commit 96179ebe authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[fpgadataflow] Fix tensor element bitwidth missmatch for labelselect_batch and channelwise_op_batch

parent 357f1572
No related branches found
No related tags found
No related merge requests found
...@@ -41,6 +41,8 @@ from finn.util.data_packing import ( ...@@ -41,6 +41,8 @@ from finn.util.data_packing import (
) )
from . import templates from . import templates
import warnings
# ONNX i/o tensor shape assumptions for channelwise ops: # ONNX i/o tensor shape assumptions for channelwise ops:
# input 0 is the input tensor, shape (..., NumChannels) # input 0 is the input tensor, shape (..., NumChannels)
# input 1 is the channelwise parameter tensor, shape (NumChannels, params_per_channel) # input 1 is the channelwise parameter tensor, shape (NumChannels, params_per_channel)
...@@ -48,6 +50,37 @@ from . import templates ...@@ -48,6 +50,37 @@ from . import templates
# the ... here can be any shape (representing groups of vectors) # the ... here can be any shape (representing groups of vectors)
def get_smallest_possible(vals):
"""Returns smallest (fewest bits) possible DataType that can represent
value. Prefers unsigned integers where possible."""
vals = np.array(vals)
for v in vals:
assert int(v) == v, "Error float value"
for k in DataType.__members__:
dt = DataType[k]
if dt in [DataType.BIPOLAR, DataType.TERNARY, DataType.FLOAT32]:
# not currently supported
continue
if (dt.min() <= vals).all() and (vals <= dt.max()).all():
return dt
warnings.warn(
"""InferChannelwiseLinearLayer: Output values may not be
representable with supported data types.
Setting maximum width data type available.
This will lead to errors if there are no constrains on the input
"""
)
if (0 <= vals).all():
return DataType.UINT64
else:
return DataType.INT64
class ChannelwiseOp_Batch(HLSCustomOp): class ChannelwiseOp_Batch(HLSCustomOp):
"""Class that corresponds to finn-hls Thresholding_Batch function. """Class that corresponds to finn-hls Thresholding_Batch function.
It can implement a variety of channel-wise parametrized operations, It can implement a variety of channel-wise parametrized operations,
...@@ -109,10 +142,36 @@ class ChannelwiseOp_Batch(HLSCustomOp): ...@@ -109,10 +142,36 @@ class ChannelwiseOp_Batch(HLSCustomOp):
def infer_node_datatype(self, model): def infer_node_datatype(self, model):
node = self.onnx_node node = self.onnx_node
# check input datatype against property # check input datatype against property
idt_name = self.get_input_datatype().name idt = model.get_tensor_datatype(node.input[0])
exp_idt_name = self.get_nodeattr("inputDataType") exp_idt_name = self.get_nodeattr("inputDataType")
assert exp_idt_name == idt_name, "Bad input DataType for ChannelwiseOp layer" if exp_idt_name != idt.name:
# TODO: dynamically infer/update odt based on idt as done in ConvertToHLSLayers? func = self.get_nodeattr("Func")
assert func in ["add", "mul"], "Bad input DataType for ChannelwiseOp layer"
self.set_nodeattr("inputDataType", idt.name)
# update the func in ['add','mul'] cases
# get parameter ranges
param = model.get_initializer(node.input[1])
param_min = min(param.flatten())
param_max = max(param.flatten())
# set function and determine output data type
if func == "add":
out_min = idt.min() + param_min
out_max = idt.max() + param_max
odt = get_smallest_possible([out_min, out_max])
elif func == "mul":
possible_limits = []
possible_limits += [idt.min() * param_min]
possible_limits += [idt.min() * param_max]
possible_limits += [idt.max() * param_min]
possible_limits += [idt.max() * param_max]
odt = get_smallest_possible(possible_limits)
self.set_nodeattr("outputDataType", odt.name)
# set output datatype from property # set output datatype from property
odt = self.get_output_datatype() odt = self.get_output_datatype()
model.set_tensor_datatype(node.output[0], odt) model.set_tensor_datatype(node.output[0], odt)
......
...@@ -113,6 +113,11 @@ class LabelSelect_Batch(HLSCustomOp): ...@@ -113,6 +113,11 @@ class LabelSelect_Batch(HLSCustomOp):
) )
def infer_node_datatype(self, model): def infer_node_datatype(self, model):
node = self.onnx_node
# check input datatype against property
idt = model.get_tensor_datatype(node.input[0])
self.set_nodeattr("inputDataType", idt.name)
odt = self.get_output_datatype() odt = self.get_output_datatype()
model.set_tensor_datatype(self.onnx_node.output[0], odt) model.set_tensor_datatype(self.onnx_node.output[0], odt)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment