Skip to content
Snippets Groups Projects
Commit 1fc1d1d0 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/nhwc_flatten' of https://github.com/fpjentzsch/finn into...

Merge branch 'feature/nhwc_flatten' of https://github.com/fpjentzsch/finn into fpjentzsch-feature/nhwc_flatten
parents 4e0a1051 4e50e7ab
No related branches found
No related tags found
No related merge requests found
......@@ -29,11 +29,9 @@ def _suitable_node(node):
def _suitable_folded_shapes(ishape, oshape):
i_dummy = np.random.rand(*ishape)
o_dummy = np.random.rand(*oshape)
ishape_canonical = np.squeeze(i_dummy).shape
oshape_canonical = np.squeeze(o_dummy).shape
return ishape_canonical == oshape_canonical
matching_stream_width = ishape[-1] == oshape[-1]
matching_size = np.prod(ishape) == np.prod(oshape)
return matching_stream_width and matching_size
class InsertFIFO(Transformation):
......
from finn.transformation.base import Transformation
from finn.util.basic import get_by_name, is_finn_op
from finn.custom_op.registry import getCustomOp
import warnings
def _is_fpgadataflow_node(node):
......@@ -18,33 +20,66 @@ def _is_fpgadataflow_node(node):
class RemoveCNVtoFCFlatten(Transformation):
"""Removes a node that implements a (1, -1) reshape if it is
between two fpgadataflow nodes"""
"""Removes a flatten node if it is between two fpgadataflow nodes.
For an NHWC-Conv to FC transition, the preceding transpose is absorbed.
The flatten operation can also be implemented by a reshape node."""
def apply(self, model):
graph = model.graph
graph_modified = False
for n in graph.node:
if n.op_type == "Reshape":
shape = model.get_initializer(n.input[1])
if (shape == [1, -1]).all():
# also support implicit flatten via reshape, e.g. reshape(1,-1)
if n.op_type == "Flatten" or n.op_type == "Reshape":
ishape = model.get_tensor_shape(n.input[0])
oshape = model.get_tensor_shape(n.output[0])
if len(oshape) == 2 and ishape[0] == oshape[0]:
producer = model.find_producer(n.input[0])
if _is_fpgadataflow_node(producer) is True:
# standalone flatten, remove
consumer = model.find_consumer(n.output[0])
if _is_fpgadataflow_node(consumer) is True:
graph_modified = True
consumer.input[0] = n.input[0]
graph.node.remove(n)
elif producer.op_type == "Transpose":
# transpose + flatten, absorb into following node
transp_node = producer
producer = model.find_producer(transp_node.input[0])
if _is_fpgadataflow_node(producer) is True:
consumer = model.find_consumer(n.output[0])
if _is_fpgadataflow_node(consumer) is True:
graph_modified = True
consumer.input[0] = transp_node.input[0]
graph.node.remove(n)
graph.node.remove(transp_node)
# check if transpose converts NHWC to NCHW
perms = list(get_by_name(transp_node.attribute, "perm").ints)
if perms == [0, 3, 1, 2]:
producer = model.find_producer(transp_node.input[0])
if _is_fpgadataflow_node(producer) is True:
consumer = model.find_consumer(n.output[0])
if consumer.op_type == "StreamingFCLayer_Batch":
fc_inst = getCustomOp(consumer)
mw = fc_inst.get_nodeattr("MW")
mh = fc_inst.get_nodeattr("MH")
(b, h, w, c) = model.get_tensor_shape(
transp_node.input[0]
)
# absorb transpose into weight matrix,
# allowing FC layer to operate on the NHWC input
W = model.get_initializer(consumer.input[1])
assert (
W is not None
), "Initializer for matmul weights is not set."
W_new = W.reshape(c, h, w, mh)
W_new = W_new.transpose((1, 2, 0, 3))
W_new = W_new.reshape(mw, mh)
model.set_initializer(consumer.input[1], W_new)
# remove transpose & flatten nodes
consumer.input[0] = transp_node.input[0]
graph.node.remove(n)
graph.node.remove(transp_node)
graph_modified = True
else:
warnings.warn(
"Could not absorb transpose->flatten \
into subsequent node"
)
else:
warnings.warn(
"Unsupported transpose node before flatten layer"
)
return (model, graph_modified)
......@@ -309,7 +309,8 @@ class Absorb1BitMulIntoConv(Transformation):
class AbsorbTransposeIntoMultiThreshold(Transformation):
"""Change (NHWCTranpose -> MultiThreshold -> NCHWTranspose) to (MultiThreshold)
with NHWC mode."""
with NHWC mode. For (NHWCTranpose -> MultiThreshold -> Flatten), move Transpose
past MultiThreshold to prepare for the RemoveCNVtoFCFlatten() transformation."""
def apply(self, model):
graph = model.graph
......@@ -338,23 +339,34 @@ class AbsorbTransposeIntoMultiThreshold(Transformation):
graph.node.remove(n)
graph.node.remove(final_t_cand)
graph_modified = True
elif final_t_cand.op_type == "Reshape":
# also support implicit flatten via reshape, e.g. reshape(1,-1)
elif (
final_t_cand.op_type == "Flatten"
or final_t_cand.op_type == "Reshape"
):
ishape = model.get_tensor_shape(final_t_cand.input[0])
oshape = model.get_tensor_shape(final_t_cand.output[0])
if len(oshape) == 2:
if len(oshape) == 2 and ishape[0] == oshape[0]:
# transition to FC part, can still use NHWC
mt = getCustomOp(mt_cand)
mt.set_nodeattr("data_layout", "NHWC")
# get rid of first tranpose node
mt_cand.input[0] = n.input[0]
graph.node.remove(n)
# fix output shape for MultiThreshold
mt_ishape = model.get_tensor_shape(mt_cand.input[0])
(b, h, w, c) = mt_ishape
assert (
h == 1 and w == 1
), """Untested spatial dim
in conv->fc transition, proceed with caution!"""
model.set_tensor_shape(mt_cand.output[0], mt_ishape)
graph.node.remove(n)
# re-insert Transpose behind MultiThreshold
transpose_output = model.make_new_valueinfo_name()
new_transpose = oh.make_node(
"Transpose",
[mt_cand.output[0]],
[transpose_output],
perm=[0, 3, 1, 2],
)
graph.node.insert(node_ind + 1, new_transpose)
final_t_cand.input[0] = transpose_output
graph_modified = True
if graph_modified:
model = model.transform(InferDataTypes())
......
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from onnx import TensorProto, helper
import numpy as np
import pytest
from finn.core.datatype import DataType
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.general import GiveUniqueNodeNames
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.util.basic import gen_finn_dt_tensor
import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.custom_op.general.im2col import compute_conv_output_dim
import finn.transformation.streamline.absorb as absorb
from finn.transformation.general import RemoveUnusedTensors
from finn.transformation.streamline import Streamline
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.streamline.reorder import MoveScalarLinearPastInvariants
import finn.core.data_layout as DataLayout
def get_multithreshold_rand_params(channels, num_of_thres, seed=None):
if seed is not None:
np.random.seed(seed)
steps = np.random.rand(channels, 1) * 30
bias = np.random.rand(channels, 1) * -10
thres = [np.arange(num_of_thres) for chn in range(channels)]
thres = ((thres + bias) * steps).astype(np.float32)
thres = np.round(thres)
return thres
# conv_config: input_shape, kernel_shape, stride, pad
@pytest.mark.parametrize(
"conv_config",
[
((6, 6), (3, 3), (1, 1), (1, 1)),
# TODO: enable 1d conv test cases
# ((12, 1), (3, 1), (1, 1), (1, 0)),
# ((1, 15), (1, 5), (1, 1), (0, 2)),
],
)
@pytest.mark.parametrize("depthwise", [False, True])
@pytest.mark.parametrize("use_reshape", [False, True])
def test_convert_to_hls_conv_fc_transition(conv_config, depthwise, use_reshape):
np.random.seed(0)
idt = DataType.UINT4
odt = DataType.UINT4
conv_weight_dt = DataType.INT4
fc_weight_dt = DataType.INT4
input_shape, kernel_shape, stride, pad = conv_config
kernel_size_h, kernel_size_w = kernel_shape
input_size_h, input_size_w = input_shape
stride_h, stride_w = stride
pad_h, pad_w = pad
in_chn = 4
fc_filters = 16
if depthwise is True:
group = out_chn = in_chn
conv_param_shape = [out_chn, 1, kernel_size_h, kernel_size_w]
else:
group = 1
out_chn = 8
conv_param_shape = [out_chn, in_chn, kernel_size_h, kernel_size_w]
output_size_h = compute_conv_output_dim(
input_size_h, kernel_size_h, stride_h, 2 * pad_h
)
output_size_w = compute_conv_output_dim(
input_size_w, kernel_size_w, stride_w, 2 * pad_w
)
input_shape = [1, in_chn, input_size_h, input_size_w]
fc_param_shape = [out_chn * output_size_h * output_size_w, fc_filters]
output_shape = [1, fc_filters]
conv_config = {}
conv_config["dilations"] = [1, 1]
conv_config["group"] = group
conv_config["kernel_shape"] = [kernel_size_h, kernel_size_w]
conv_config["pads"] = [pad_h, pad_w, pad_h, pad_w]
conv_config["strides"] = [stride_h, stride_w]
global_in = helper.make_tensor_value_info(
"global_in", TensorProto.FLOAT, input_shape
)
global_out = helper.make_tensor_value_info(
"global_out", TensorProto.FLOAT, output_shape
)
value_info = [
helper.make_tensor_value_info(
"conv_param", TensorProto.FLOAT, conv_param_shape
),
helper.make_tensor_value_info("thres1_param", TensorProto.FLOAT, (out_chn, 15)),
helper.make_tensor_value_info(
"matmul_param", TensorProto.FLOAT, fc_param_shape
),
helper.make_tensor_value_info(
"thres2_param", TensorProto.FLOAT, (fc_filters, 15)
),
helper.make_tensor_value_info("reshape_shape", TensorProto.INT64, []),
]
if use_reshape:
flatten_node = helper.make_node(
"Reshape", ["thres1_out", "reshape_shape"], ["flatten_out"]
)
else:
flatten_node = helper.make_node(
"Flatten", ["thres1_out"], ["flatten_out"], axis=1
)
modelproto = helper.make_model(
helper.make_graph(
name="test",
inputs=[global_in],
outputs=[global_out],
value_info=value_info,
nodes=[
helper.make_node(
"Conv", ["global_in", "conv_param"], ["conv_out"], **conv_config
),
helper.make_node(
"MultiThreshold",
["conv_out", "thres1_param"],
["thres1_out"],
domain="finn.custom_op.general",
out_dtype="UINT4",
),
flatten_node,
helper.make_node(
"MatMul", ["flatten_out", "matmul_param"], ["matmul_out"]
),
helper.make_node(
"MultiThreshold",
["matmul_out", "thres2_param"],
["global_out"],
domain="finn.custom_op.general",
out_dtype="UINT4",
),
],
)
)
model = ModelWrapper(modelproto)
model.set_tensor_datatype("global_in", idt)
model.set_tensor_layout("global_in", DataLayout.NCHW)
model.set_tensor_datatype("global_out", odt)
model.set_tensor_datatype("conv_param", conv_weight_dt)
model.set_tensor_datatype("matmul_param", fc_weight_dt)
model.set_tensor_datatype("thres1_param", DataType.INT32)
model.set_tensor_datatype("thres2_param", DataType.INT32)
model.set_initializer(
"conv_param", gen_finn_dt_tensor(conv_weight_dt, conv_param_shape)
)
model.set_initializer(
"thres1_param", get_multithreshold_rand_params(out_chn, 15, seed=0)
)
model.set_initializer(
"thres2_param", get_multithreshold_rand_params(fc_filters, 15, seed=0)
)
model.set_initializer(
"matmul_param", gen_finn_dt_tensor(fc_weight_dt, fc_param_shape)
)
model.set_initializer("reshape_shape", np.array([1, -1]))
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
# streamlining
new_model = model.transform(MoveScalarLinearPastInvariants())
new_model = new_model.transform(Streamline())
new_model = new_model.transform(LowerConvsToMatMul())
new_model = new_model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
new_model = new_model.transform(Streamline())
new_model = new_model.transform(InferDataLayouts())
new_model = new_model.transform(RemoveUnusedTensors())
# convert_to_hls
if depthwise is True:
new_model = new_model.transform(to_hls.InferVVAU())
new_model = new_model.transform(to_hls.InferQuantizedStreamingFCLayer())
new_model = new_model.transform(to_hls.InferThresholdingLayer())
new_model = new_model.transform(to_hls.InferConvInpGen())
new_model = new_model.transform(to_hls.InferStreamingMaxPool())
new_model = new_model.transform(RemoveCNVtoFCFlatten())
new_model = new_model.transform(absorb.AbsorbConsecutiveTransposes())
new_model = new_model.transform(GiveUniqueNodeNames())
new_model = new_model.transform(InferDataLayouts())
# prepare cppsim
new_model = new_model.transform(PrepareCppSim())
new_model = new_model.transform(CompileCppSim())
new_model = new_model.transform(SetExecMode("cppsim"))
# check for correct execution
x = gen_finn_dt_tensor(idt, input_shape)
inp_dict = {model.graph.input[0].name: x}
assert oxe.compare_execution(model, new_model, inp_dict)
num_transpose = len(new_model.get_nodes_by_op_type("Transpose"))
num_flatten = len(new_model.get_nodes_by_op_type("Flatten"))
num_reshape = len(new_model.get_nodes_by_op_type("Reshape"))
# check if transpose->flatten was removed
assert num_transpose == 1 and num_flatten == 0 and num_reshape == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment