diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn
index 3856065f2fd1bcc52f557aaeb8feba58ad122041..1cba1538af07051b132cffd792eb0d1350b61b7c 100644
--- a/docker/Dockerfile.finn
+++ b/docker/Dockerfile.finn
@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
 
 # git-based Python repo dependencies
 # these are installed in editable mode for easier co-development
-ARG FINN_BASE_COMMIT="ec3997c3f4276f7746bfd08a1a9508bd02a132fa"
+ARG FINN_BASE_COMMIT="aefd5c8779db22d8c7a60c53cc32179f0c2ced67"
 ARG QONNX_COMMIT="834610ba3f668971fe2800fde7f8d0c10d825d5b"
 ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
 ARG BREVITAS_COMMIT="0eaff006407955153594254728baeb988edcd042"
diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py
index 7dd94230757abaf3120c0f2687d2807ccebccb54..5691f3c4cd080493a2bb98454804f86b6a8dd973 100644
--- a/src/finn/transformation/qonnx/fold_quant_weights.py
+++ b/src/finn/transformation/qonnx/fold_quant_weights.py
@@ -30,25 +30,11 @@ import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
-from finn.core.datatype import DataType
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
 from finn.transformation.infer_shapes import InferShapes
 
 
-def get_dtype(bit_width: int, signed: bool) -> DataType:
-    bit_width = int(bit_width)
-    signed = bool(signed)
-    if bit_width == 1.0:
-        finn_dt = DataType["BIPOLAR"]
-    else:
-        if signed:
-            finn_dt = DataType["INT" + str(bit_width)]
-        else:
-            finn_dt = DataType["UINT" + str(bit_width)]
-    return finn_dt
-
-
 class FoldQuantWeights(Transformation):
     """Merges Quant nodes, which are used as weights into the initializer
     of the weight tensor.
@@ -123,16 +109,8 @@ class FoldQuantWeights(Transformation):
                         # Round, to correct for floating point errors
                         new_initializer = np.round(new_initializer)
                         model.set_initializer(node_out, new_initializer)
-                        if n.op_type == "Quant":
-                            bit_width = model.get_initializer(n.input[3])
-                            q_inst = getCustomOp(n)
-                            signed = q_inst.get_nodeattr("signed")
-                        elif n.op_type == "BinaryQuant":
-                            bit_width = 1.0
-                            signed = True
-                        else:
-                            raise RuntimeError("Got an unexpected quantizer node type")
-                        new_dtype = get_dtype(bit_width, signed)
+                        q_inst = getCustomOp(n)
+                        new_dtype = q_inst.get_internal_dtype(model)
                         model.set_tensor_datatype(node_out, new_dtype)
 
                         if target_node.op_type == "Conv" and len(scale.shape) > 0:
diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
index 0814aecab407d37d7f87c28258674a3d13cba5b0..cbb94aa4846d8edb1456b559b2a4ca89deeaac47 100644
--- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py
+++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py
@@ -95,9 +95,9 @@ class QuantActBaseHandler(ABC):
 
     def _extract_output_datatype(self):
         """Get the output datatype for the MultiThreshold node."""
-        dtype = self._model.get_tensor_datatype(self._q_node.output[0]).name
-        if dtype is not None:
-            dtype = dtype.replace("SCALED", "")
+        q_inst = getCustomOp(self._q_node)
+        dtype = q_inst.get_internal_dtype(self._model)
+        dtype = dtype.name
         return dtype
 
     def calculate_node_parameters(self):