From 6d7026eeaf7be834ba108e61e05a339e01f78b3d Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Thu, 14 Oct 2021 19:21:49 +0100
Subject: [PATCH] Removed reliance on ScaledInt data type.

---
 docker/Dockerfile.finn                        |  2 +-
 .../qonnx/fold_quant_weights.py               | 26 ++-----------------
 .../qonnx/qonnx_activation_handlers.py        |  6 ++---
 3 files changed, 6 insertions(+), 28 deletions(-)

diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn
index 3856065f2..1cba1538a 100644
--- a/docker/Dockerfile.finn
+++ b/docker/Dockerfile.finn
@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
 
 # git-based Python repo dependencies
 # these are installed in editable mode for easier co-development
-ARG FINN_BASE_COMMIT="ec3997c3f4276f7746bfd08a1a9508bd02a132fa"
+ARG FINN_BASE_COMMIT="aefd5c8779db22d8c7a60c53cc32179f0c2ced67"
 ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
 ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
 ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py
index 7dd942307..5691f3c4c 100644
--- a/src/finn/transformation/qonnx/fold_quant_weights.py
+++ b/src/finn/transformation/qonnx/fold_quant_weights.py
@@ -30,25 +30,11 @@ import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
-from finn.core.datatype import DataType
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
 from finn.transformation.infer_shapes import InferShapes
 
 
-def get_dtype(bit_width: int, signed: bool) -> DataType:
-    bit_width = int(bit_width)
-    signed = bool(signed)
-    if bit_width == 1.0:
-        finn_dt = DataType["BIPOLAR"]
-    else:
-        if signed:
-            finn_dt = DataType["INT" + str(bit_width)]
-        else:
-            finn_dt = DataType["UINT" + str(bit_width)]
-    return finn_dt
-
-
 class FoldQuantWeights(Transformation):
     """Merges Quant nodes, which are used as weights into the initializer
     of the weight tensor.
@@ -123,16 +109,8 @@ class FoldQuantWeights(Transformation):
                         # Round, to correct for floating point errors
                         new_initializer = np.round(new_initializer)
                         model.set_initializer(node_out, new_initializer)
-                        if n.op_type == "Quant":
-                            bit_width = model.get_initializer(n.input[3])
-                            q_inst = getCustomOp(n)
-                            signed = q_inst.get_nodeattr("signed")
-                        elif n.op_type == "BinaryQuant":
-                            bit_width = 1.0
-                            signed = True
-                        else:
-                            raise RuntimeError("Got an unexpected quantizer node type")
-                        new_dtype = get_dtype(bit_width, signed)
+                        q_inst = getCustomOp(n)
+                        new_dtype = q_inst.get_internal_dtype(model)
                         model.set_tensor_datatype(node_out, new_dtype)
 
                         if target_node.op_type == "Conv" and len(scale.shape) > 0:
diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
index 0814aecab..cbb94aa48 100644
--- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py
+++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
@@ -95,9 +95,9 @@ class QuantActBaseHandler(ABC):
 
     def _extract_output_datatype(self):
         """Get the output datatype for the MultiThreshold node."""
-        dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name
-        if dtype is not None:
-            dtype = dtype.replace("SCALED", "")
+        q_inst = getCustomOp(self._q_node)
+        dtype = q_inst.get_internal_dtype(self._model)
+        dtype = dtype.name
         return dtype
 
     def calculate_node_parameters(self):
-- 
GitLab