diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index c8bde7fea8ae8195001a7eccfd48baa4c48997ae..96221a1c43716c956a88f5749785f627750b4917 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -33,6 +33,8 @@ from onnx import TensorProto, helper from finn.core.modelwrapper import ModelWrapper from finn.custom_op.registry import getCustomOp +np_default_dtype = np.float32 + class QuantActBaseHandler(ABC): """Base class for converting quantized activation expressed in the QONNX dialect @@ -164,17 +166,16 @@ class QuantActBaseHandler(ABC): if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant": # Get Quant parameters mul_scale = np.atleast_1d(mul_scale) - # ONNX only accepts 64bit floats as attributes - mul_scale = mul_scale.astype(dtype=np.float64) adder_bias = np.atleast_1d(adder_bias) - adder_bias = adder_bias.astype(dtype=np.float64) # Set Bias and scale - mt_inst.set_nodeattr("out_scale", mul_scale[0]) + # note calls to .item() to get Python float instead of numpy float + # ONNX attribute setting fails otherwise + mt_inst.set_nodeattr("out_scale", mul_scale[0].item()) # FINN applies scale first then bias, # which is the other way around in Brevitas, # we thus need to adjust the bias in the MultiThreshold node - finn_bias = adder_bias[0] * mul_scale[0] + finn_bias = adder_bias[0].item() * mul_scale[0].item() mt_inst.set_nodeattr("out_bias", finn_bias) # Set the output data type @@ -190,8 +191,7 @@ class QuantActBaseHandler(ABC): zero_bias = False if bias_scalar: adder_bias = np.atleast_1d(adder_bias) - # ONNX only accepts 64bit floats as attributes - adder_bias = adder_bias.astype(dtype=np.float64)[0] + adder_bias = adder_bias[0] add_shape = tuple() if adder_bias == 0.0: zero_bias = True @@ -234,7 +234,7 @@ class QuantActBaseHandler(ABC): unity_scale = False if scale_scalar: mul_scale = np.atleast_1d(mul_scale) - mul_scale = mul_scale.astype(dtype=np.float64)[0] + mul_scale = mul_scale[0] mul_shape = tuple() if mul_scale == 1.0: unity_scale = True @@ -313,7 +313,7 @@ class QuantReluHandler(QuantActBaseHandler): # No bias allowed for Relu activations, see: https://github.com/Xilinx/ # brevitas/blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # export/onnx/finn/handler/act.py#L48 - bias = np.array([0.0]) + bias = np.array([0.0], dtype=np_default_dtype) return bias def _calculate_thresholds(self): @@ -339,7 +339,9 @@ class QuantReluHandler(QuantActBaseHandler): num_scale_channels = flat_scale.shape[0] step = np.abs(flat_scale).astype(np.float32) min_threshold = step / 2 - thresholds = np.empty((num_scale_channels, num_thresholds)).astype(np.float32) + thresholds = np.empty( + (num_scale_channels, num_thresholds), dtype=np_default_dtype + ) for c in range(num_scale_channels): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t @@ -438,13 +440,13 @@ class QuantIdentityHandler(QuantActBaseHandler): # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ # onnx/finn/handler/act.py#L64 if bit_width == 1.0: - bias = np.array([-0.5]) + bias = np.array([-0.5], dtype=np_default_dtype) else: if narrow: min_non_scaled_val = -(2 ** (bit_width - 1) - 1) else: min_non_scaled_val = -(2 ** (bit_width - 1)) - bias = np.array([min_non_scaled_val]) + bias = np.array([min_non_scaled_val], dtype=np_default_dtype) return bias def _calculate_thresholds(self): @@ -463,7 +465,7 @@ class QuantIdentityHandler(QuantActBaseHandler): # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # export/onnx/finn/handler/act.py#L76 if bit_width == 1.0: - thresholds = np.empty([1, 1]) + thresholds = np.empty([1, 1], dtype=np_default_dtype) thresholds[0] = 0 return thresholds else: @@ -477,7 +479,9 @@ class QuantIdentityHandler(QuantActBaseHandler): num_scale_channels = flat_scale.shape[0] step = np.abs(flat_scale) half_step = step / 2.0 - thresholds = np.empty((num_scale_channels, num_thresholds)) + thresholds = np.empty( + (num_scale_channels, num_thresholds), dtype=np_default_dtype + ) # compute the value of the smallest threshold, we'll neg-bias all # generated thresholds by this much min_threshold = -half_step - step * ((num_thresholds // 2) - 1)