diff --git a/src/finn/transformation/convert_qonnx_to_finn.py b/src/finn/transformation/convert_qonnx_to_finn.py index 9d0116170846bbada0ce4482667575af717d3b4a..58127d286ec2f0549b4f68394aa0c8d2bce873b0 100644 --- a/src/finn/transformation/convert_qonnx_to_finn.py +++ b/src/finn/transformation/convert_qonnx_to_finn.py @@ -198,6 +198,10 @@ class QuantActBaseHandler(ABC): def _remove_activation_node(self): pass + @abstractmethod + def _check_compatibility(self): + pass + def _extract_output_datatype(self): dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name if "SCALED" in dtype: @@ -214,6 +218,8 @@ class QuantActBaseHandler(ABC): } def replace_quant_node(self): + # Check that we actually support what the user is trying to do + self._check_compatibility() # Shorten instance variables model = self._model graph = model.graph @@ -334,6 +340,16 @@ class QuantReluHandler(QuantActBaseHandler): # zero_pt = model.get_initializer(n.input[2]) # signed = q_inst.get_nodeattr("signed") + def _check_compatibility(self): + q_inst = getCustomOp(self._q_node) + narrow = q_inst.get_nodeattr("narrow") + signed = q_inst.get_nodeattr("signed") + if signed or narrow: + raise ValueError( + "FINN only supports unsigned and non-narrow Quant nodes " + "for Relu activations." + ) + def _calculate_act_bias(self): # No bias allowed for Relu activations, see: https://github.com/Xilinx/ # brevitas/blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ @@ -413,6 +429,15 @@ class QuantIdentityHandler(QuantActBaseHandler): # zero_pt = model.get_initializer(n.input[2]) # signed = q_inst.get_nodeattr("signed") + def _check_compatibility(self): + # Gather parameters to check + q_inst = getCustomOp(self._q_node) + signed = q_inst.get_nodeattr("signed") + if not signed: + raise ValueError( + "FINN only supports signed Quant nodes for identity activations." + ) + def _calculate_act_bias(self): # Gather parameters bit_width = self._model.get_initializer(self._q_node.input[3])