From 078a4d745a989626107521baa6d22d2ad4f96a7d Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Fri, 15 Oct 2021 11:04:42 +0100 Subject: [PATCH] Fixed a bug in the QuantReluHandler, where the export would fail if the activation was at the top of the network. --- docker/Dockerfile.finn | 2 +- .../qonnx/qonnx_activation_handlers.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn index 1cba1538a..3c165e7f8 100644 --- a/docker/Dockerfile.finn +++ b/docker/Dockerfile.finn @@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg # git-based Python repo dependencies # these are installed in editable mode for easier co-development -ARG FINN_BASE_COMMIT="aefd5c8779db22d8c7a60c53cc32179f0c2ced67" +ARG FINN_BASE_COMMIT="22886049c96048d7150efe261a82b9c3f469c100" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index a64ffb538..068457512 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -89,7 +89,7 @@ class QuantActBaseHandler(ABC): raise NotImplementedError() @abstractmethod - def _remove_activation_node(self): + def _remove_activation_node(self, multi_threshold_node): """Remove the activation node in front of the Quant node.""" raise NotImplementedError() @@ -273,7 +273,7 @@ class QuantActBaseHandler(ABC): up_stream_node = mul_node # Remove activation node - self._remove_activation_node() + self._remove_activation_node(mt_node) # Remove the Quant node graph.node.remove(n) @@ -360,7 +360,7 @@ class QuantReluHandler(QuantActBaseHandler): scale = quant_scale return scale - def _remove_activation_node(self): + def _remove_activation_node(self, multi_threshold_node): # Find the activation node act_node = self._model.find_direct_predecessors(self._q_node) if act_node is None: @@ -375,11 +375,8 @@ class QuantReluHandler(QuantActBaseHandler): "of Relu activations." ) - # Reroute possible predecessors - act_predecessors = self._model.find_direct_predecessors(act_node) - if act_node is not None: - for act_pre in act_predecessors: - act_pre.output[0] = act_node.output[0] + # Reroute upstream tensor + multi_threshold_node.input[0] = act_node.input[0] # Remove the activation node self._model.graph.node.remove(act_node) @@ -517,6 +514,6 @@ class QuantIdentityHandler(QuantActBaseHandler): scale = quant_scale * 2 return scale - def _remove_activation_node(self): + def _remove_activation_node(self, multi_threshold_node): # The Quant identity activation has per definition no explicit activation node return -- GitLab