diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn
index 8e2d2b62f1d6b961111e2e8ab867aac96d2a55d8..d44f1676213e98eb2e079b40e86ef914a8838d09 100644
--- a/docker/Dockerfile.finn
+++ b/docker/Dockerfile.finn
@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
 
 # git-based Python repo dependencies
 # these are installed in editable mode for easier co-development
-ARG FINN_BASE_COMMIT="686639de94f96f02fef794a693a295591d3f25c6"
+ARG FINN_BASE_COMMIT="901d38414c7507b1eb5a7afa76e58a4368864455"
 ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0"
 ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
 ARG BREVITAS_COMMIT="5b22551871d25bf5e26917fe5900fcaa49406faf"
diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
index f563e885d86bb3e6a2ba1bccb7134e44df6a6189..faad31fa06e76b245f25b6f0aa583fec5c0da29a 100644
--- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
+++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
@@ -30,6 +30,7 @@
 import math
 from onnx import TensorProto, helper
 
+from finn.core.datatype import DataType
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
 from finn.transformation.infer_datatypes import InferDataTypes
@@ -37,6 +38,79 @@ from finn.transformation.infer_shapes import InferShapes
 from finn.util.basic import get_by_name
 
 
+def _get_signed_from_upstream(model, trunc_node):
+    """
+    Find out what the sign of the input to the trunc node is,
+    by looking at the upstream nodes.
+    """
+    node = trunc_node
+    # Check if the input of this node already has a FINN datatype
+    signed = None
+    inp_dt = model.get_tensor_datatype(node.input[0])
+    if inp_dt is not None and inp_dt is not DataType["FLOAT32"]:
+        signed = inp_dt.signed()
+    # Go further up the graph, since the datatype inference works top down
+    # these nodes should either be sign preserving ops or they already have a
+    # datatype defined for the output tensor.
+    curr_node = node
+    if signed is None:
+        while curr_node is not None:
+            if model.is_join_node(curr_node):
+                raise RuntimeError(
+                    "Datatype Inference for the Trunc node only supports "
+                    "linear nodes in the upstream path."
+                )
+            next_node = model.find_direct_predecessors(curr_node)
+            if next_node is None:
+                raise RuntimeError(
+                    "Could not infere the Datatype for the Trunc node due to "
+                    "missing upstream ndoes."
+                )
+            next_node = next_node[0]
+            out_dt = model.get_tensor_datatype(next_node.output[0])
+            if out_dt is not None and out_dt is not DataType["FLOAT32"]:
+                signed = out_dt.signed()
+                break
+            # Special cases where the node has an internal or intrinsic datatype.
+            if next_node.op_type == "MultiThreshold":
+                mt_inst = getCustomOp(next_node)
+                out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
+                if out_dt is not None and out_dt is not DataType["FLOAT32"]:
+                    signed = out_dt.signed()
+                    break
+            if next_node.op_type == "BipolarQuant":
+                signed = True
+                break
+            if next_node.op_type == "Quant":
+                q_inst = getCustomOp(next_node)
+                out_dt = q_inst.get_integer_datatype(model)
+                if out_dt is not None and out_dt is not DataType["FLOAT32"]:
+                    signed = out_dt.signed()
+                    break
+
+            # Check if we are allowed to move on to the next op
+            sign_preserving_ops = ["Add", "Mul", "AveragePool", "Pad"]
+            if next_node.op_type not in sign_preserving_ops:
+                raise RuntimeError(
+                    f"Could not infere the Datatype for the Trunc node, "
+                    f"because the sign of the input datatype could not be infered "
+                    f"from upstream nodes. And traversal further up the graph was "
+                    f"disallowed, since the next node type {next_node.op_type} "
+                    f"is not in the list of "
+                    f"sign preserving ops {sign_preserving_ops}."
+                )
+            curr_node = next_node
+
+    if signed is None:
+        raise RuntimeError(
+            "Could not infere the Datatype for the Trunc node, "
+            "because the sign of the input datatype could not be infered "
+            "from upstream nodes."
+        )
+
+    return signed
+
+
 class AvgPoolAndTruncToQuantAvgPool(Transformation):
     """
     Convert a section of nodes of the pattern:
@@ -159,8 +233,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
                             math.log(2 ** trunc_in_bits / (k_s * k_s), 2)
                         )
                         # Get sign
-                        t_inst = getCustomOp(t_node)
-                        signed = t_inst._get_signed_from_upstream(model)
+                        signed = _get_signed_from_upstream(model, t_node)
                         # ToDo: Change this to NHWC,
                         #  when the channels last layout comes around.
                         data_layout = "NCHW"