Skip to content
Snippets Groups Projects
Commit 987a03da authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/tensor_layout_annotation' into dev

parents 9ddad020 9465bbb9
No related merge requests found
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# predefined lists of strings to have a cannonical way of expresing data layout
# annotations
NHWC = ["N", "H", "W", "C"]
NCHW = ["N", "C", "H", "W"]
NC = ["N", "C"]
UNKNOWN = []
......@@ -137,11 +137,16 @@ class ModelWrapper:
qnt_annotations = graph.quantization_annotation
ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
if ret is not None:
ret = util.get_by_name(
ret_dt = util.get_by_name(
ret.quant_parameter_tensor_names, "finn_datatype", "key"
)
if ret is not None:
ret.value = datatype.name
if ret_dt is not None:
ret_dt.value = datatype.name
else:
dt = onnx.StringStringEntryProto()
dt.key = "finn_datatype"
dt.value = datatype.name
ret.quant_parameter_tensor_names.append(dt)
else:
qa = onnx.TensorAnnotation()
dt = onnx.StringStringEntryProto()
......@@ -434,3 +439,58 @@ class ModelWrapper:
n_ind += 1
except ValueError:
return None
def get_tensor_layout(self, tensor_name):
"""Returns the data layout annotation of tensor with given name.
The data layout is expressed as a list of strings with as many
elements as the number of dimensions in the tensor shape. Each
string annotates what is contained in that dimension. If there is no
data layout annotation, None will be returned.
Examples of data layout annotations:
["N", "C"] is tensor[batch][channel]
["N", "C", "H", "W"] is tensor[batch][channel][height][width]
["N", "H", "W", "C"] is tensor[batch][height][width][channel]
"""
graph = self._model_proto.graph
qnt_annotations = graph.quantization_annotation
ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
if ret is not None:
ret = util.get_by_name(
ret.quant_parameter_tensor_names, "tensor_layout", "key"
)
if ret is not None:
return eval(ret.value)
return None
def set_tensor_layout(self, tensor_name, data_layout):
"""Sets the data layout annotation of tensor with given name. See
get_tensor_layout for examples."""
tensor_shape = self.get_tensor_shape(tensor_name)
assert type(data_layout) == list, "data_layout must be a list"
if tensor_shape is not None:
assert len(tensor_shape) == len(
data_layout
), """Mismatch between number
of dimensions of tensor shape and data layout annotation."""
graph = self._model_proto.graph
qnt_annotations = graph.quantization_annotation
ret = util.get_by_name(qnt_annotations, tensor_name, "tensor_name")
if ret is not None:
ret_tl = util.get_by_name(
ret.quant_parameter_tensor_names, "tensor_layout", "key"
)
if ret_tl is not None:
ret_tl.value = str(data_layout)
else:
tl = onnx.StringStringEntryProto()
tl.key = "tensor_layout"
tl.value = str(data_layout)
ret.quant_parameter_tensor_names.append(tl)
else:
qa = onnx.TensorAnnotation()
dt = onnx.StringStringEntryProto()
dt.key = "tensor_layout"
dt.value = str(data_layout)
qa.tensor_name = tensor_name
qa.quant_parameter_tensor_names.append(dt)
qnt_annotations.append(qa)
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import finn.custom_op.registry as registry
import finn.core.data_layout as DataLayout
from finn.transformation import Transformation
import warnings
from finn.util.basic import get_by_name
def _dims_to_layout(model, node, ndims):
if ndims == 2:
return DataLayout.NC
else:
if node.domain == "finn":
if node.op_type == "MultiThreshold":
mt_inst = registry.getCustomOp(node)
layout = mt_inst.get_nodeattr("data_layout")
if layout == "NHWC" and ndims == 4:
return DataLayout.NHWC
elif layout == "NCHW" and ndims == 4:
return DataLayout.NCHW
else:
return DataLayout.UNKNOWN
else:
if ndims == 4:
return DataLayout.NHWC
else:
return DataLayout.UNKNOWN
else:
# propagate input layout to output
# TODO this won't work for concat, squeeze/unsqueeze/reshape...
return model.get_tensor_layout(node.input[0])
def _infer_node_data_layout(model, node):
"""Infer output data layout annotation(s) for a particular node.
Returns True if any changes were made."""
old_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output))
if node.domain == "finn":
# try to guess based on number of output dims
for o in node.output:
ndims = len(model.get_tensor_shape(o))
new_layout = _dims_to_layout(model, node, ndims)
model.set_tensor_layout(o, new_layout)
else:
if node.op_type == "Transpose":
# grab input annotation and switch it around using perm
perm = get_by_name(node.attribute, "perm").ints
inp_layout = model.get_tensor_layout(node.input[0])
out_layout = [inp_layout[i] for i in perm]
model.set_tensor_layout(node.output[0], out_layout)
else:
# try to guess based on number of output dims
for o in node.output:
ndims = len(model.get_tensor_shape(o))
model.set_tensor_layout(o, _dims_to_layout(model, node, ndims))
# compare old and new output dtypes to see if anything changed
new_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output))
graph_modified = new_layouts != old_layouts
return graph_modified
class InferDataLayouts(Transformation):
"""Try to infer data layout annotations info for all input/intermediate/output
tensors based on inputs and node type."""
def apply(self, model):
graph = model.graph
graph_modified = False
# first, make sure that the global input has an annotation
# this is really hard to do in general, so we do some bad guesswork
inp_name = graph.input[0].name
if model.get_tensor_layout(inp_name) is None:
inp_shape = model.get_tensor_shape(inp_name)
if len(inp_shape) == 4:
warnings.warn("Assuming 4D input is NCHW")
model.set_tensor_layout(inp_name, DataLayout.NCHW)
graph_modified = True
elif len(inp_shape) == 2:
graph_modified = True
warnings.warn("Assuming 2D input is NC")
model.set_tensor_layout(inp_name, DataLayout.NC)
else:
raise Exception(
"""Unknown number of dims for input, don't know
how to annotate"""
)
for node in graph.node:
graph_modified |= _infer_node_data_layout(model, node)
return (model, graph_modified)
......@@ -31,6 +31,7 @@ import onnx
from collections import Counter
import brevitas.onnx as bo
import numpy as np
import finn.core.data_layout as DataLayout
from finn.core.modelwrapper import ModelWrapper
from finn.util.test import get_test_model_trained
......@@ -67,6 +68,11 @@ def test_modelwrapper():
assert inp_cons.op_type == "MatMul"
out_prod = model.find_producer(l0_inp_tensor_name)
assert out_prod.op_type == "MultiThreshold"
inp_layout = model.get_tensor_layout(inp_name)
assert inp_layout is None
inp_layout = DataLayout.NCHW
model.set_tensor_layout(inp_name, inp_layout)
assert model.get_tensor_layout(inp_name) == inp_layout
os.remove(export_onnx_path)
......
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import brevitas.onnx as bo
import finn.transformation.streamline.absorb as absorb
from finn.transformation.streamline.reorder import MakeMaxPoolNHWC
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import Streamline
from finn.util.test import get_test_model_trained
from finn.transformation.double_to_single_float import DoubleToSingleFloat
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
from finn.transformation.infer_data_layouts import InferDataLayouts
import finn.core.data_layout as DataLayout
export_onnx_path_cnv = "test_output_cnv.onnx"
def test_infer_data_layouts():
cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv)
model = ModelWrapper(export_onnx_path_cnv)
model = model.transform(DoubleToSingleFloat())
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(Streamline())
model = model.transform(InferDataLayouts())
assert model.get_tensor_layout("global_in") == DataLayout.NCHW
assert model.get_tensor_layout("Conv_0_out0") == DataLayout.NCHW
assert model.get_tensor_layout("MaxPool_0_out0") == DataLayout.NCHW
assert model.get_tensor_layout("MultiThreshold_6_out0") == DataLayout.NCHW
assert model.get_tensor_layout("Reshape_0_out0") == DataLayout.NC
assert model.get_tensor_layout("MatMul_0_out0") == DataLayout.NC
assert model.get_tensor_layout("global_out") == DataLayout.NC
model = model.transform(LowerConvsToMatMul())
model = model.transform(MakeMaxPoolNHWC())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataLayouts())
assert model.get_tensor_layout("global_in") == DataLayout.NCHW
assert model.get_tensor_layout("Transpose_0_out0") == DataLayout.NHWC
assert model.get_tensor_layout("Im2Col_0_out0") == DataLayout.NHWC
# note: im2col output isn't really NHWC or any other common layout
# since the concept of channels changes with lowering... but it is
# conceptually close to NHWC since the innermost dim gets multiplied
assert model.get_tensor_layout("MatMul_0_out0") == DataLayout.NHWC
assert model.get_tensor_layout("Transpose_1_out0") == DataLayout.NCHW
assert model.get_tensor_layout("Transpose_2_out0") == DataLayout.NHWC
assert model.get_tensor_layout("MaxPoolNHWC_0_out0") == DataLayout.NHWC
assert model.get_tensor_layout("Reshape_0_out0") == DataLayout.NC
assert model.get_tensor_layout("global_out") == DataLayout.NC
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(Streamline())
model = model.transform(to_hls.InferBinaryStreamingFCLayer())
model = model.transform(to_hls.InferQuantizedStreamingFCLayer())
model = model.transform(to_hls.InferConvInpGen())
model = model.transform(to_hls.InferStreamingMaxPool())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataLayouts())
assert model.get_tensor_layout("global_in") == DataLayout.NCHW
assert model.get_tensor_layout("Transpose_0_out0") == DataLayout.NHWC
# note: im2col output isn't really NHWC or any other common layout
# since the concept of channels changes with lowering... but it is
# conceptually close to NHWC since the innermost dim gets multiplied
assert (
model.get_tensor_layout("ConvolutionInputGenerator_0_out0") == DataLayout.NHWC
)
assert model.get_tensor_layout("StreamingFCLayer_Batch_3_out0") == DataLayout.NHWC
assert model.get_tensor_layout("Reshape_0_out0") == DataLayout.NC
assert model.get_tensor_layout("StreamingFCLayer_Batch_6_out0") == DataLayout.NC
assert model.get_tensor_layout("global_out") == DataLayout.NC
os.remove(export_onnx_path_cnv)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment