Skip to content
Snippets Groups Projects
Commit dd0619ec authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

Merge remote-tracking branch 'upstream/dev' into feature/fix_globalaccpool_batch_shape

parents cd72c1cc cf54f617
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ Requirements
* A working Vivado 2019.1 installation
* A `VIVADO_PATH` environment variable pointing to the Vivado installation directory (e.g. the directory where settings64.sh is located)
* (optional) A PYNQ board with a network connection
* the ``bitstring`` package must be installed on the PYNQ: ``sudo pip install bitstring``
* the ``bitstring`` package must be installed on the PYNQ: ``sudo pip3 install bitstring``
Running FINN in Docker
======================
......
......@@ -290,12 +290,15 @@ class StreamingFCLayer_Batch(HLSCustomOp):
return out_width
def get_weightstream_width(self):
"""Returns weight stream width. Used in decoupled mode."""
pe = self.get_nodeattr("PE")
simd = self.get_nodeattr("SIMD")
wp = self.get_weight_datatype().bitwidth()
w_width = pe * simd * wp
return w_width
"""Returns weight stream width. Used only in decoupled mode."""
if self.get_nodeattr("mem_mode") == "decoupled":
pe = self.get_nodeattr("PE")
simd = self.get_nodeattr("SIMD")
wp = self.get_weight_datatype().bitwidth()
w_width = pe * simd * wp
return w_width
else:
return 0
def get_weightstream_width_padded(self):
"""Returns weight stream width padded to a multiple of 8. This is required
......
......@@ -99,6 +99,7 @@ set_top $config_toplevelfxn
open_solution sol1
set_part $config_proj_part
config_compile -ignore_long_run_time -disable_unroll_code_size_check
config_interface -m_axi_addr64
config_rtl -auto_prefix
$EXTRA_DIRECTIVES$
......
......@@ -4,6 +4,7 @@ import onnxruntime as rt
from finn.custom_op import CustomOp
from finn.core.datatype import DataType
from finn.custom_op.maxpoolnhwc import compute_pool_output_dim
class QuantAvgPool2d(CustomOp):
......@@ -16,20 +17,51 @@ class QuantAvgPool2d(CustomOp):
"kernel": ("i", True, 1),
"ibits": ("i", True, 1),
"obits": ("i", True, 1),
# determines if values are signed (set to "1") or unsigned ("0")
"signed": ("i", True, 0),
# data layout attribute can be set to "NCHW" or "NHWC"
"data_layout": ("s", False, "NCHW"),
}
def make_shape_compatible_op(self, model):
node = self.onnx_node
k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride")
return helper.make_node(
"AveragePool",
inputs=[node.input[0]],
outputs=[node.output[0]],
kernel_shape=[k, k],
strides=[s, s],
)
data_layout = self.get_nodeattr("data_layout")
if data_layout == "NCHW":
return helper.make_node(
"AveragePool",
inputs=[node.input[0]],
outputs=[node.output[0]],
kernel_shape=[k, k],
strides=[s, s],
)
elif data_layout == "NHWC":
iname = node.input[0]
ishape = model.get_tensor_shape(iname)
(n, hi, wi, c) = ishape
ho = compute_pool_output_dim(hi, k, s)
wo = compute_pool_output_dim(wi, k, s)
oshape = (n, ho, wo, c)
# implement tensor with correct shape
values = np.random.randn(*oshape).astype(np.float32)
return helper.make_node(
"Constant",
inputs=[],
outputs=[node.output[0]],
value=helper.make_tensor(
name="const_tensor",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(float),
),
)
else:
raise Exception(
"""Datalayout for QuantAvgPool2d is set to an invalid value.
Has to be set to "NCHW" or "NHWC"."""
)
def infer_node_datatype(self, model):
node = self.onnx_node
......@@ -48,8 +80,12 @@ class QuantAvgPool2d(CustomOp):
node = self.onnx_node
k = self.get_nodeattr("kernel")
s = self.get_nodeattr("stride")
ishape = context[node.input[0]].shape
inp_values = context[node.input[0]]
oshape = context[node.output[0]].shape
if self.get_nodeattr("data_layout") == "NHWC":
inp_values = inp_values.transpose(0, 3, 1, 2)
oshape = (context[node.output[0]]).transpose(0, 3, 1, 2).shape
ishape = inp_values.shape
inp = helper.make_tensor_value_info(node.input[0], TensorProto.FLOAT, ishape)
outp = helper.make_tensor_value_info(node.output[0], TensorProto.FLOAT, oshape)
node_avgpool = helper.make_node(
......@@ -66,7 +102,7 @@ class QuantAvgPool2d(CustomOp):
outputs=[outp],
)
model_avgpool = helper.make_model(graph_avgpool)
idict = {node.input[0]: context[node.input[0]]}
idict = {node.input[0]: inp_values}
sess = rt.InferenceSession(model_avgpool.SerializeToString())
result_temp = sess.run(None, idict)
# remove scaling introduced by average
......@@ -77,7 +113,16 @@ class QuantAvgPool2d(CustomOp):
max_bit_width = int(max_value).bit_length()
shift_bits = max_bit_width - self.get_nodeattr("obits")
result = np.right_shift(result_temp.astype(int), shift_bits)
if self.get_nodeattr("data_layout") == "NHWC":
result = result.transpose(0, 2, 3, 1)
context[node.output[0]] = result.astype(np.float32)
def verify_node(self):
pass
info_messages = []
# verify that "domain" is set to "finn"
domain_value = self.onnx_node.domain
if domain_value == "finn":
info_messages.append("Attribute domain is set correctly")
else:
info_messages.append('Attribute domain should be set to "finn"')
return info_messages
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from onnx import helper, TensorProto
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import get_by_name
class ChangeDataLayoutQuantAvgPool2d(Transformation):
"""Replace QuantAvgPool2d with datalayout (N,C,H,W) with Transpose nodes
and QuantAvgPool2dNHWC with datalayout (N,H,W,C)"""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "QuantAvgPool2d" and (
get_by_name(n.attribute, "data_layout") is None
or get_by_name(n.attribute, "data_layout").s.decode("UTF-8") == "NCHW"
):
graph_modified = True
node_input = n.input[0]
node_output = n.output[0]
s = get_by_name(n.attribute, "stride").i
k = get_by_name(n.attribute, "kernel").i
ibits = get_by_name(n.attribute, "ibits").i
obits = get_by_name(n.attribute, "obits").i
signed = get_by_name(n.attribute, "signed").i
batchsize = model.get_tensor_shape(n.input[0])[0] # assume NCHW
channels = model.get_tensor_shape(n.input[0])[1] # assume NCHW
idim = model.get_tensor_shape(n.input[0])[-1] # assume NCHW
odim = model.get_tensor_shape(n.output[0])[-1] # assume NCHW
# create new nodes
# NCHW -> NHWC
# create new intermediate values
inp_trans_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
(batchsize, idim, idim, channels), # NHWC
)
graph.value_info.append(inp_trans_out)
inp_trans_out = inp_trans_out.name
quantavg_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
(batchsize, odim, odim, channels),
)
graph.value_info.append(quantavg_out)
quantavg_out = quantavg_out.name
inp_trans_node = helper.make_node(
"Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]
)
quantavg_node = helper.make_node(
"QuantAvgPool2d",
[inp_trans_out],
[quantavg_out],
domain="finn",
stride=s,
kernel=k,
ibits=ibits,
obits=obits,
signed=signed,
data_layout="NHWC",
)
# NHWC -> NCHW
out_trans_node = helper.make_node(
"Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]
)
# insert nodes
graph.node.insert(node_ind, inp_trans_node)
graph.node.insert(node_ind + 1, quantavg_node)
graph.node.insert(node_ind + 2, out_trans_node)
# remove old nodes
graph.node.remove(n)
# set shapes
model.set_tensor_shape(inp_trans_out, (batchsize, idim, idim, channels))
model.set_tensor_shape(quantavg_out, (batchsize, odim, odim, channels))
model = model.transform(InferShapes())
return (model, graph_modified)
......@@ -38,7 +38,7 @@ def _dims_to_layout(model, node, ndims):
return DataLayout.NC
else:
if node.domain == "finn":
if node.op_type == "MultiThreshold":
if node.op_type == "MultiThreshold" or node.op_type == "QuantAvgPool2d":
mt_inst = registry.getCustomOp(node)
layout = mt_inst.get_nodeattr("data_layout")
if layout == "NHWC" and ndims == 4:
......
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
from onnx import helper, TensorProto
from finn.custom_op.maxpoolnhwc import compute_pool_output_dim
from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
import finn.core.data_layout as DataLayout
from finn.transformation.change_datalayout import ChangeDataLayoutQuantAvgPool2d
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
from finn.util.basic import gen_finn_dt_tensor
from finn.util.basic import get_by_name
import finn.core.onnx_exec as oxe
# stride
@pytest.mark.parametrize("s", [1, 2])
# kernel
@pytest.mark.parametrize("k", [3, 4])
# ibits
@pytest.mark.parametrize("ibits", [4, 8])
# obits
@pytest.mark.parametrize("obits", [2, 4])
# signed
@pytest.mark.parametrize("signed", [False, True])
# channels
@pytest.mark.parametrize("c", [2, 3])
# input dimension
@pytest.mark.parametrize("idim", [6, 7])
def test_change_datalayout_quantavgpool(s, k, ibits, obits, signed, c, idim):
n = 1
odim = compute_pool_output_dim(idim, k, s)
# determine input FINN datatype
if signed is True:
prefix = "INT"
else:
prefix = "UINT"
dt_name = prefix + str(ibits)
dtype = DataType[dt_name]
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [n, c, idim, idim])
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [n, c, odim, odim])
node = helper.make_node(
"QuantAvgPool2d",
["inp"],
["outp"],
domain="finn",
stride=s,
kernel=k,
ibits=ibits,
obits=obits,
signed=signed,
data_layout="NCHW",
)
graph = helper.make_graph(
nodes=[node], name="single-quantavgpool", inputs=[inp], outputs=[outp]
)
model = helper.make_model(graph)
model = ModelWrapper(model)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model_transformed = model.transform(ChangeDataLayoutQuantAvgPool2d())
model_transformed = model_transformed.transform(InferShapes())
model_transformed = model_transformed.transform(InferDataTypes())
model_transformed = model_transformed.transform(InferDataLayouts())
model_transformed = model_transformed.transform(GiveUniqueNodeNames())
model_transformed = model_transformed.transform(GiveReadableTensorNames())
inp_values = gen_finn_dt_tensor(dtype, [n, c, idim, idim])
idict = {"inp": inp_values}
assert oxe.compare_execution(model, model_transformed, idict)
assert len(model.graph.node) + 2 == len(model_transformed.graph.node)
assert model_transformed.graph.node[-1].op_type == "Transpose"
assert model_transformed.graph.node[0].op_type == "Transpose"
# check if QuantAvgPool2d node has datalayout set correctly
node = model_transformed.graph.node[1]
d_layout = get_by_name(node.attribute, "data_layout").s.decode("UTF-8")
assert d_layout == "NHWC"
assert model_transformed.get_tensor_layout(node.input[0]) == DataLayout.NHWC
assert model_transformed.get_tensor_layout(node.output[0]) == DataLayout.NHWC
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment