Skip to content
Snippets Groups Projects
Commit 078a4d74 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Fixed a bug in the QuantReluHandler, where the export would fail if the...

Fixed a bug in the QuantReluHandler, where the export would fail if the activation was at the top of the network.
parent 9be84cec
No related branches found
No related tags found
No related merge requests found
...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg ...@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies # git-based Python repo dependencies
# these are installed in editable mode for easier co-development # these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="aefd5c8779db22d8c7a60c53cc32179f0c2ced67" ARG FINN_BASE_COMMIT="22886049c96048d7150efe261a82b9c3f469c100"
ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
......
...@@ -89,7 +89,7 @@ class QuantActBaseHandler(ABC): ...@@ -89,7 +89,7 @@ class QuantActBaseHandler(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def _remove_activation_node(self): def _remove_activation_node(self, multi_threshold_node):
"""Remove the activation node in front of the Quant node.""" """Remove the activation node in front of the Quant node."""
raise NotImplementedError() raise NotImplementedError()
...@@ -273,7 +273,7 @@ class QuantActBaseHandler(ABC): ...@@ -273,7 +273,7 @@ class QuantActBaseHandler(ABC):
up_stream_node = mul_node up_stream_node = mul_node
# Remove activation node # Remove activation node
self._remove_activation_node() self._remove_activation_node(mt_node)
# Remove the Quant node # Remove the Quant node
graph.node.remove(n) graph.node.remove(n)
...@@ -360,7 +360,7 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -360,7 +360,7 @@ class QuantReluHandler(QuantActBaseHandler):
scale = quant_scale scale = quant_scale
return scale return scale
def _remove_activation_node(self): def _remove_activation_node(self, multi_threshold_node):
# Find the activation node # Find the activation node
act_node = self._model.find_direct_predecessors(self._q_node) act_node = self._model.find_direct_predecessors(self._q_node)
if act_node is None: if act_node is None:
...@@ -375,11 +375,8 @@ class QuantReluHandler(QuantActBaseHandler): ...@@ -375,11 +375,8 @@ class QuantReluHandler(QuantActBaseHandler):
"of Relu activations." "of Relu activations."
) )
# Reroute possible predecessors # Reroute upstream tensor
act_predecessors = self._model.find_direct_predecessors(act_node) multi_threshold_node.input[0] = act_node.input[0]
if act_node is not None:
for act_pre in act_predecessors:
act_pre.output[0] = act_node.output[0]
# Remove the activation node # Remove the activation node
self._model.graph.node.remove(act_node) self._model.graph.node.remove(act_node)
...@@ -517,6 +514,6 @@ class QuantIdentityHandler(QuantActBaseHandler): ...@@ -517,6 +514,6 @@ class QuantIdentityHandler(QuantActBaseHandler):
scale = quant_scale * 2 scale = quant_scale * 2
return scale return scale
def _remove_activation_node(self): def _remove_activation_node(self, multi_threshold_node):
# The Quant identity activation has per definition no explicit activation node # The Quant identity activation has per definition no explicit activation node
return return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment