Skip to content
Snippets Groups Projects
Commit 21e19d02 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Updated finn-base commit and moved _get_signed_from_upstream function from finn-base to finn.

parent 0fda042a
No related branches found
No related tags found
No related merge requests found
......@@ -86,7 +86,7 @@ RUN pip install -e git+https://github.com/fbcotter/dataset_loading.git@0.0.4#egg
# git-based Python repo dependencies
# these are installed in editable mode for easier co-development
ARG FINN_BASE_COMMIT="686639de94f96f02fef794a693a295591d3f25c6"
ARG FINN_BASE_COMMIT="901d38414c7507b1eb5a7afa76e58a4368864455"
ARG QONNX_COMMIT="6d55dce220c7f744ef23585686460b9370b672a0"
ARG FINN_EXP_COMMIT="f82c0d9868bb88ea045dfadb28508d327d287221"
ARG BREVITAS_COMMIT="5b22551871d25bf5e26917fe5900fcaa49406faf"
......
......@@ -30,6 +30,7 @@
import math
from onnx import TensorProto, helper
from finn.core.datatype import DataType
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation
from finn.transformation.infer_datatypes import InferDataTypes
......@@ -37,6 +38,79 @@ from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import get_by_name
def _get_signed_from_upstream(model, trunc_node):
"""
Find out what the sign of the input to the trunc node is,
by looking at the upstream nodes.
"""
node = trunc_node
# Check if the input of this node already has a FINN datatype
signed = None
inp_dt = model.get_tensor_datatype(node.input[0])
if inp_dt is not None and inp_dt is not DataType["FLOAT32"]:
signed = inp_dt.signed()
# Go further up the graph, since the datatype inference works top down
# these nodes should either be sign preserving ops or they already have a
# datatype defined for the output tensor.
curr_node = node
if signed is None:
while curr_node is not None:
if model.is_join_node(curr_node):
raise RuntimeError(
"Datatype Inference for the Trunc node only supports "
"linear nodes in the upstream path."
)
next_node = model.find_direct_predecessors(curr_node)
if next_node is None:
raise RuntimeError(
"Could not infere the Datatype for the Trunc node due to "
"missing upstream ndoes."
)
next_node = next_node[0]
out_dt = model.get_tensor_datatype(next_node.output[0])
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
signed = out_dt.signed()
break
# Special cases where the node has an internal or intrinsic datatype.
if next_node.op_type == "MultiThreshold":
mt_inst = getCustomOp(next_node)
out_dt = DataType[mt_inst.get_nodeattr("out_dtype")]
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
signed = out_dt.signed()
break
if next_node.op_type == "BipolarQuant":
signed = True
break
if next_node.op_type == "Quant":
q_inst = getCustomOp(next_node)
out_dt = q_inst.get_integer_datatype(model)
if out_dt is not None and out_dt is not DataType["FLOAT32"]:
signed = out_dt.signed()
break
# Check if we are allowed to move on to the next op
sign_preserving_ops = ["Add", "Mul", "AveragePool", "Pad"]
if next_node.op_type not in sign_preserving_ops:
raise RuntimeError(
f"Could not infere the Datatype for the Trunc node, "
f"because the sign of the input datatype could not be infered "
f"from upstream nodes. And traversal further up the graph was "
f"disallowed, since the next node type {next_node.op_type} "
f"is not in the list of "
f"sign preserving ops {sign_preserving_ops}."
)
curr_node = next_node
if signed is None:
raise RuntimeError(
"Could not infere the Datatype for the Trunc node, "
"because the sign of the input datatype could not be infered "
"from upstream nodes."
)
return signed
class AvgPoolAndTruncToQuantAvgPool(Transformation):
"""
Convert a section of nodes of the pattern:
......@@ -159,8 +233,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
math.log(2 ** trunc_in_bits / (k_s * k_s), 2)
)
# Get sign
t_inst = getCustomOp(t_node)
signed = t_inst._get_signed_from_upstream(model)
signed = _get_signed_from_upstream(model, t_node)
# ToDo: Change this to NHWC,
# when the channels last layout comes around.
data_layout = "NCHW"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment