Skip to content
Snippets Groups Projects
Commit 5ba522db authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

Merge upsteam changes in absorb.py

parent 58f1fe7f
No related branches found
No related tags found
No related merge requests found
......@@ -28,14 +28,81 @@
import numpy as np
from onnx import helper as oh
import warnings
from finn.core.datatype import DataType
import finn.core.data_layout as DataLayout
from finn.transformation import Transformation
from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
class AbsorbSignBiasIntoMultiThreshold(Transformation):
"""Absorb scalar bias originating from signed int export back into
MultiThreshold and re-evaluate the output datatype."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
# search for (MultiThreshold, Add) pair
node_ind += 1
if (
n.op_type == "MultiThreshold"
and not model.is_fork_node(n)
and not model.is_join_node(n)
):
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "Add":
mt_node = n
add_node = consumer
threshold_name = mt_node.input[1]
add_weight_name = add_node.input[1]
T = model.get_initializer(threshold_name)
A = model.get_initializer(add_weight_name)
if (A is None) or (T is None):
warnings.warn("Threshold or add bias not constant, skipping")
continue
end_name = add_node.output[0]
# we can only absorb scalar adds
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
if not is_scalar:
continue
bias = A.flatten()[0]
# set MultiThreshold bias property
mt_inst = getCustomOp(mt_node)
bias += mt_inst.get_nodeattr("out_bias")
mt_inst.set_nodeattr("out_bias", bias)
graph_modified = True
# compute new DataType for MultiThreshold output
steps = T.shape[-1]
new_min = bias
new_max = steps + bias
odt = DataType.get_smallest_possible(steps).name.replace(
"UINT", "INT"
)
odt = DataType[odt]
assert odt.allowed(new_max) and odt.allowed(
new_min
), """Could
not compute new MultiThreshold DataType (min = %d max = %d)""" % (
new_min,
new_max,
)
mt_inst.set_nodeattr("out_dtype", odt.name)
# remove Add node, rewire MultiThreshold
graph.node.remove(add_node)
mt_node.output[0] = end_name
# set datatype
model.set_tensor_datatype(end_name, odt)
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
class AbsorbAddIntoMultiThreshold(Transformation):
"""Absorb preceding Add ops into MultiThreshold by updating the threshold
values. Only scalar/1D add vectors can be absorbed."""
......@@ -292,6 +359,104 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
return (model, graph_modified)
class AbsorbTransposeIntoFlatten(Transformation):
"""Absorb transpose node into succeeding flatten node, if H=W=1 and the first
dimension stays the same. Can also be applied if flatten is implemented implicitly
by a reshape node with shape [1, -1] and the first input dimension is 1"""
def apply(self, model):
graph = model.graph
graph_modified = False
node_ind = 0
for n in graph.node:
node_ind += 1
if (
n.op_type == "Reshape"
and (model.get_initializer(n.input[1]) == [1, -1]).all()
) or n.op_type == "Flatten":
prod = model.find_producer(n.input[0])
if (
prod is not None
and prod.op_type == "Transpose"
# we ensure that the first dimension is not changed from the
# transpose operation
and get_by_name(prod.attribute, "perm").ints[0] == 0
):
data_layout = model.get_tensor_layout(prod.input[0])
# check for the data layout to interpret input shape correctly
if data_layout is None:
warnings.warn(
"""Data layout for input tensor of Transpose node is not set.
To use AbsorbTransposeIntoFlatten transformation
please set tensor data layout."""
)
continue
elif data_layout == DataLayout.NCHW:
(b, c, h, w) = model.get_tensor_shape(prod.input[0])
# if h=w=1 the transposition can be absorbed, otherwise
# the absorption would lead to an error in the behavior
if h != 1 or w != 1:
continue
# the flatten node from onnx keeps by default the first
# dim and flattens the rest, that is why this transformation
# can only work with b != 1 if the model contains already a
# flatten node and not a reshape node with shape = [1, -1].
# If the first dim of the input tensor is not 1, flatten and
# reshape (with shape = [1, -1]) would lead to different results
if n.op_type == "Reshape" and b != 1:
continue
elif data_layout == DataLayout.NHWC:
(b, h, w, c) = model.get_tensor_shape(prod.input[0])
if h != 1 or w != 1:
continue
if n.op_type == "Reshape" and b != 1:
continue
# create single flatten node and remove obsolete nodes
node = oh.make_node("Flatten", [prod.input[0]], [n.output[0]])
graph.node.remove(n)
graph.node.remove(prod)
graph.node.insert(node_ind, node)
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
class AbsorbScalarMulIntoTopK(Transformation):
"""Absorb a mul node into a suceeding topk node if the mul is scalar."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "TopK":
prod = model.find_producer(n.input[0])
if prod is not None and prod.op_type == "Mul":
prod_input = prod.input[0]
param_name = prod.input[1]
A = model.get_initializer(param_name)
if A is None:
warnings.warn("Param is not constant, skipping")
continue
if all(x == 1 for x in A.shape) and A > 0:
# if the mul is scalar and positive, we can just delete the
# mul node and rewire the top k node. Because the top k node
# works with probabilities and their relation to each other
# the relation doesn't change if every value is multiplied
# with a scalar
graph.node.remove(prod)
n.input[0] = prod_input
# to avoid error the dataype is set to float32
model.set_tensor_datatype(n.input[0], DataType.FLOAT32)
graph_modified = True
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
class AbsorbConsecutiveTransposes(Transformation):
"""Remove (Transpose -> Transpose) patterns when the input and output
of the pattern have the same layout."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment