From 21e19d0241568fe3c4c68a4c56046fbf4fdf2655 Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Fri, 29 Oct 2021 18:04:52 +0100 Subject: [PATCH] Updated finn-base commit and moved _get_signed_from_upstream function from finn-base to finn. --- docker/Dockerfile.finn | 2 +- .../qonnx/infer_quant_avg_pool_2d.py | 77 ++++++++++++++++++- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile.finn b/docker/Dockerfile.finn index 8e2d2b62f..d44f16762 100644 --- a/docker/Dockerfile.finn +++ b/docker/Dockerfile.finn @@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg # git-based Python repo dependencies # these are installed in editable mode for easier co-development -ARG FINN_BASE_COMMIT="686639de94f96f02fef794a693a295591d3f25c6" +ARG FINN_BASE_COMMIT="901d38414c7507b1eb5a7afa76e58a4368864455" ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0" ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221" ARG BREVITAS_COMMIT="5b22551871d25bf5e26917fe5900fcaa49406faf" diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py index f563e885d..faad31fa0 100644 --- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py +++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py @@ -30,6 +30,7 @@ import math from onnx import TensorProto, helper +from finn.core.datatype import DataType from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation from finn.transformation.infer_datatypes import InferDataTypes @@ -37,6 +38,79 @@ from finn.transformation.infer_shapes import InferShapes from finn.util.basic import get_by_name +def _get_signed_from_upstream(model, trunc_node): + """ + Find out what the sign of the input to the trunc node is, + by looking at the upstream nodes. + """ + node = trunc_node + # Check if the input of this node already has a FINN datatype + signed = None + inp_dt = model.get_tensor_datatype(node.input[0]) + if inp_dt is not None and inp_dt is not DataType["FLOAT32"]: + signed = inp_dt.signed() + # Go further up the graph, since the datatype inference works top down + # these nodes should either be sign preserving ops or they already have a + # datatype defined for the output tensor. + curr_node = node + if signed is None: + while curr_node is not None: + if model.is_join_node(curr_node): + raise RuntimeError( + "Datatype Inference for the Trunc node only supports " + "linear nodes in the upstream path." + ) + next_node = model.find_direct_predecessors(curr_node) + if next_node is None: + raise RuntimeError( + "Could not infere the Datatype for the Trunc node due to " + "missing upstream ndoes." + ) + next_node = next_node[0] + out_dt = model.get_tensor_datatype(next_node.output[0]) + if out_dt is not None and out_dt is not DataType["FLOAT32"]: + signed = out_dt.signed() + break + # Special cases where the node has an internal or intrinsic datatype. + if next_node.op_type == "MultiThreshold": + mt_inst = getCustomOp(next_node) + out_dt = DataType[mt_inst.get_nodeattr("out_dtype")] + if out_dt is not None and out_dt is not DataType["FLOAT32"]: + signed = out_dt.signed() + break + if next_node.op_type == "BipolarQuant": + signed = True + break + if next_node.op_type == "Quant": + q_inst = getCustomOp(next_node) + out_dt = q_inst.get_integer_datatype(model) + if out_dt is not None and out_dt is not DataType["FLOAT32"]: + signed = out_dt.signed() + break + + # Check if we are allowed to move on to the next op + sign_preserving_ops = ["Add", "Mul", "AveragePool", "Pad"] + if next_node.op_type not in sign_preserving_ops: + raise RuntimeError( + f"Could not infere the Datatype for the Trunc node, " + f"because the sign of the input datatype could not be infered " + f"from upstream nodes. And traversal further up the graph was " + f"disallowed, since the next node type {next_node.op_type} " + f"is not in the list of " + f"sign preserving ops {sign_preserving_ops}." + ) + curr_node = next_node + + if signed is None: + raise RuntimeError( + "Could not infere the Datatype for the Trunc node, " + "because the sign of the input datatype could not be infered " + "from upstream nodes." + ) + + return signed + + class AvgPoolAndTruncToQuantAvgPool(Transformation): """ Convert a section of nodes of the pattern: @@ -159,8 +233,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation): math.log(2 ** trunc_in_bits / (k_s * k_s), 2) ) # Get sign - t_inst = getCustomOp(t_node) - signed = t_inst._get_signed_from_upstream(model) + signed = _get_signed_from_upstream(model, t_node) # ToDo: Change this to NHWC, # when the channels last layout comes around. data_layout = "NCHW" -- GitLab