diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn index 3856065f2fd1bcc52f557aaeb8feba58ad122041..1cba1538af07051b132cffd792eb0d1350b61b7c 100644 --- a/docker/Dockerfile.finn +++ b/docker/Dockerfile.finn @@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg # git-based Python repo dependencies # these are installed in editable mode for easier co-development -ARG FINN_BASE_COMMIT="ec3997c3f4276f7746bfd08a1a9508bd02a132fa" +ARG FINN_BASE_COMMIT="aefd5c8779db22d8c7a60c53cc32179f0c2ced67" ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042" diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py index 7dd94230757abaf3120c0f2687d2807ccebccb54..5691f3c4cd080493a2bb98454804f86b6a8dd973 100644 --- a/src/finn/transformation/qonnx/fold_quant_weights.py +++ b/src/finn/transformation/qonnx/fold_quant_weights.py @@ -30,25 +30,11 @@ import numpy as np from onnx import TensorProto, helper import finn.core.onnx_exec as oxe -from finn.core.datatype import DataType from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation from finn.transformation.infer_shapes import InferShapes -def get_dtype(bit_width: int, signed: bool) -> DataType: - bit_width = int(bit_width) - signed = bool(signed) - if bit_width == 1.0: - finn_dt = DataType["BIPOLAR"] - else: - if signed: - finn_dt = DataType["INT" + str(bit_width)] - else: - finn_dt = DataType["UINT" + str(bit_width)] - return finn_dt - - class FoldQuantWeights(Transformation): """Merges Quant nodes, which are used as weights into the initializer of the weight tensor. @@ -123,16 +109,8 @@ class FoldQuantWeights(Transformation): # Round, to correct for floating point errors new_initializer = np.round(new_initializer) model.set_initializer(node_out, new_initializer) - if n.op_type == "Quant": - bit_width = model.get_initializer(n.input[3]) - q_inst = getCustomOp(n) - signed = q_inst.get_nodeattr("signed") - elif n.op_type == "BinaryQuant": - bit_width = 1.0 - signed = True - else: - raise RuntimeError("Got an unexpected quantizer node type") - new_dtype = get_dtype(bit_width, signed) + q_inst = getCustomOp(n) + new_dtype = q_inst.get_internal_dtype(model) model.set_tensor_datatype(node_out, new_dtype) if target_node.op_type == "Conv" and len(scale.shape) > 0: diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 0814aecab407d37d7f87c28258674a3d13cba5b0..cbb94aa4846d8edb1456b559b2a4ca89deeaac47 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -95,9 +95,9 @@ class QuantActBaseHandler(ABC): def _extract_output_datatype(self): """Get the output datatype for the MultiThreshold node.""" - dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name - if dtype is not None: - dtype = dtype.replace("SCALED", "") + q_inst = getCustomOp(self._q_node) + dtype = q_inst.get_internal_dtype(self._model) + dtype = dtype.name return dtype def calculate_node_parameters(self):