diff --git a/run-docker.sh b/run-docker.sh index d63c84beac2aceb2820c1ad42c50cba9a5ed5009..e8a3d105ea558bec93d8c75de6924def166cbe5a 100755 --- a/run-docker.sh +++ b/run-docker.sh @@ -84,7 +84,7 @@ VIVADO_IP_CACHE=$BUILD_LOCAL/vivado_ip_cache # clone dependency repos git clone --branch feature/finn_onnx_export $BREVITAS_REPO $BREVITAS_LOCAL || git -C "$BREVITAS_LOCAL" pull -git clone $EXAMPLES_REPO $EXAMPLES_LOCAL || git -C "$EXAMPLES_LOCAL" pull +git clone $EXAMPLES_REPO $EXAMPLES_LOCAL || git -C "$EXAMPLES_LOCAL" checkout feature/rework_scaling_clipping; git -C "$EXAMPLES_LOCAL" pull git clone $CNPY_REPO $CNPY_LOCAL || git -C "$CNPY_LOCAL" pull git clone $FINN_HLS_REPO $FINN_HLS_LOCAL || git -C "$FINN_HLS_LOCAL" checkout master; git -C "$FINN_HLS_LOCAL" pull git clone $PYVERILATOR_REPO $PYVERILATOR_LOCAL || git -C "$PYVERILATOR_LOCAL" pull diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py index 73bdbc4177867350eecf75cef0943b01522e8508..56c49e66fad50b72703447a091876121ad80e300 100644 --- a/src/finn/custom_op/multithreshold.py +++ b/src/finn/custom_op/multithreshold.py @@ -50,7 +50,9 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): or equal to. The output tensor will be scaled by out_scale and biased by out_bias.""" - # the inputs are expected to be in the shape (N,C,H,W) + # the inputs are expected to be in the shape (N,C,H,W) or (N, C) + # the MultiThreshold node supports a data_layout attribute that can be set + # to 'NHWC' to support (N,H,W,C) data layout mode for in-out as well # N : Batch size # C : Number of channels # H : Heigth of the input images @@ -104,6 +106,7 @@ class MultiThreshold(CustomOp): "out_dtype": ("s", True, ""), "out_scale": ("f", False, 1.0), "out_bias": ("f", False, 0.0), + "data_layout": ("s", False, "NCHW"), } def make_shape_compatible_op(self): @@ -123,26 +126,40 @@ class MultiThreshold(CustomOp): # retrieve attributes if output scaling is used out_scale = self.get_nodeattr("out_scale") out_bias = self.get_nodeattr("out_bias") + # transpose input if NHWC data layout is chosen + data_layout = self.get_nodeattr("data_layout") + if data_layout == "NHWC": + if v.ndim == 4: + # NHWC -> NCHW + v = np.transpose(v, (0, 3, 1, 2)) + elif v.ndim == 2: + # no HW dimension means NHWC and NCHW layouts are equivalent + pass + else: + raise Exception( + "Unknown data_layout and input ndim" + " combination for MultiThreshold." + ) # calculate output output = multithreshold(v, thresholds, out_scale, out_bias) # setting context according to output + if data_layout == "NHWC": + if output.ndim == 4: + # NCHW -> NHWC + output = np.transpose(output, (0, 2, 3, 1)) + elif output.ndim == 2: + # no HW dimension means NHWC and NCHW layouts are equivalent + pass + else: + raise Exception( + "Unknown data_layout and output ndim" + " combination for MultiThreshold." + ) context[node.output[0]] = output def verify_node(self): info_messages = [] - # verify number of attributes - num_of_attr = 3 - if len(self.onnx_node.attribute) == num_of_attr: - info_messages.append("The number of attributes is correct") - else: - info_messages.append( - """The number of attributes is incorrect, - {} should have {} attributes""".format( - self.onnx_node.op_type, num_of_attr - ) - ) - # verify that "domain" is set to "finn" domain_value = self.onnx_node.domain if domain_value == "finn": @@ -152,8 +169,6 @@ class MultiThreshold(CustomOp): # verify that all necessary attributes exist try: - self.get_nodeattr("out_scale") - self.get_nodeattr("out_bias") self.get_nodeattr("out_dtype") info_messages.append("All necessary attributes exist") except Exception: diff --git a/src/finn/transformation/double_to_single_float.py b/src/finn/transformation/double_to_single_float.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7eb1cc82b346321eb6ee711b87c2c98ccfdaee --- /dev/null +++ b/src/finn/transformation/double_to_single_float.py @@ -0,0 +1,45 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from finn.transformation import Transformation +import numpy as np + + +class DoubleToSingleFloat(Transformation): + """Convert any float64 initializers to float32.""" + + def apply(self, model): + graph_modified = False + init_names = [x.name for x in model.graph.initializer] + for nm in init_names: + init = model.get_initializer(nm) + if init.dtype == np.float64: + init_f32 = init.astype(np.float32) + model.set_initializer(nm, init_f32) + graph_modified = True + return (model, graph_modified) diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index 176772be902474b1ca0ff96ec6b8f88304749550..53c73e1dc4fe0bfab53e3f126add992cb338c11d 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -82,7 +82,7 @@ class GiveReadableTensorNames(Transformation): class ConvertSubToAdd(Transformation): - """Convert sub nodes to add nodes of appropriate sign.""" + """Convert subtract-a-constant nodes to add-a-constant nodes.""" def apply(self, model): graph = model.graph @@ -94,3 +94,18 @@ class ConvertSubToAdd(Transformation): model.set_initializer(n.input[1], -A) # return model_was_changed = False as single iteration is always enough return (model, False) + + +class ConvertDivToMul(Transformation): + """Convert divide by constant nodes to multiply by constant nodes.""" + + def apply(self, model): + graph = model.graph + for n in graph.node: + if n.op_type == "Div": + A = model.get_initializer(n.input[1]) + if A is not None: + n.op_type = "Mul" + model.set_initializer(n.input[1], 1.0 / A) + # return model_was_changed = False as single iteration is always enough + return (model, False) diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index b5a5bd2f65b41fdb1d0e1c048949c206adfa357b..c9c73fa4c8303ee28bc1cc6aee879d633740e01e 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -30,6 +30,7 @@ from finn.transformation import Transformation from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.general import ( ConvertSubToAdd, + ConvertDivToMul, GiveReadableTensorNames, GiveUniqueNodeNames, ) @@ -39,6 +40,7 @@ from finn.transformation.streamline.absorb import ( AbsorbMulIntoMultiThreshold, FactorOutMulSignMagnitude, Absorb1BitMulIntoMatMul, + Absorb1BitMulIntoConv, ) from finn.transformation.streamline.collapse_repeated import ( @@ -50,6 +52,8 @@ from finn.transformation.streamline.reorder import ( MoveAddPastMul, MoveScalarMulPastMatMul, MoveScalarAddPastMatMul, + MoveScalarAddPastConv, + MoveScalarMulPastConv, ) from finn.transformation.streamline.round_thresholds import RoundAndClipThresholds @@ -63,11 +67,14 @@ class Streamline(Transformation): def apply(self, model): streamline_transformations = [ ConvertSubToAdd(), + ConvertDivToMul(), BatchNormToAffine(), ConvertSignToThres(), MoveAddPastMul(), MoveScalarAddPastMatMul(), + MoveScalarAddPastConv(), MoveScalarMulPastMatMul(), + MoveScalarMulPastConv(), MoveAddPastMul(), CollapseRepeatedAdd(), CollapseRepeatedMul(), @@ -75,6 +82,7 @@ class Streamline(Transformation): FactorOutMulSignMagnitude(), AbsorbMulIntoMultiThreshold(), Absorb1BitMulIntoMatMul(), + Absorb1BitMulIntoConv(), RoundAndClipThresholds(), ] for trn in streamline_transformations: diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 5a9355bcfefc57f9608490a156906e65f7672271..92945a9eff1cc45ce295ccd76b40a39b429f45f8 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -55,7 +55,8 @@ class AbsorbAddIntoMultiThreshold(Transformation): start_name = n.input[0] # we can only absorb 0d or 1d adds is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) - is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1] + actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) + is_1d = actual_ndims == 1 if is_scalar or is_1d: Tnew = T - A.reshape(-1, 1) # Tnew = T - A.reshape(-1, T.shape[1]) @@ -85,7 +86,8 @@ class AbsorbMulIntoMultiThreshold(Transformation): assert A is not None, "Initializer for mul weights is not set." is_signed = (A < 0).any() is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) - is_1d = A.ndim > 0 and np.prod(A.shape) == A.shape[-1] + actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) + is_1d = actual_ndims == 1 consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": if not is_signed and (is_1d or is_scalar): @@ -122,7 +124,8 @@ class FactorOutMulSignMagnitude(Transformation): A = model.get_initializer(mul_weight_name) assert A is not None, "Initializer for mul weights is not set." is_scalar = np.prod(A.shape) == 1 - is_1d = len(A.shape) == 2 and A.shape[0] == 1 + actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) + is_1d = actual_ndims == 1 is_not_bipolar = ( model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR ) @@ -183,3 +186,44 @@ class Absorb1BitMulIntoMatMul(Transformation): graph.node.remove(consumer) graph_modified = True return (model, graph_modified) + + +class Absorb1BitMulIntoConv(Transformation): + """Absorb bipolar or binary multiplications into the preciding convolution.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Conv": + conv_weight_name = n.input[1] + W = model.get_initializer(conv_weight_name) + Wdt = model.get_tensor_datatype(conv_weight_name) + assert W is not None, "Initializer for conv weights is not set." + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Mul": + mul_weight_name = consumer.input[1] + A = model.get_initializer(mul_weight_name) + assert A is not None, "Initializer for mul weights is not set." + is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 + is_scalar = np.prod(A.shape) == 1 + actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) + is_1d = actual_ndims == 1 + if is_1bit and (is_1d or is_scalar): + # move the mul to the OFM position, since the mul is + # applied on the outputs channelwise or as scalar + Wnew = A.reshape(-1, 1, 1, 1) * W + assert ( + Wnew.shape == W.shape + ), """Shape of new weights is not + the same as the shape of the conv weights before.""" + check_fxn = np.vectorize(lambda x: Wdt.allowed(x)) + # only absorb if permitted by W datatype + if check_fxn(Wnew).all(): + model.set_initializer(conv_weight_name, Wnew) + n.output[0] = consumer.output[0] + graph.node.remove(consumer) + graph_modified = True + return (model, graph_modified) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index db55dc2021a0c9ef8330270aa9eb950dcf66c575..062fa68cdd75dc49a0a8cb197f6a950eda5996df 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -31,6 +31,7 @@ from onnx import helper as oh from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +from finn.core.onnx_exec import execute_node class MoveAddPastMul(Transformation): @@ -167,3 +168,103 @@ class MoveScalarAddPastMatMul(Transformation): graph_modified = True model = model.transform(InferShapes()) return (model, graph_modified) + + +class MoveScalarAddPastConv(Transformation): + """Move scalar add operations past conv operations. We want to have adds + next to each other such that they can be collapsed into a single add.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Conv": + conv_node = consumer + add_node = n + add_weight_name = n.input[1] + conv_in_name = consumer.input[0] + conv_in_shape = model.get_tensor_shape(conv_in_name) + A = model.get_initializer(add_weight_name) + assert A is not None, "Initializer for add weights is not set." + start_name = n.input[0] + end_name = consumer.output[0] + conv_out_shape = model.get_tensor_shape(end_name) + if all(x == 1 for x in A.shape): + # create a tensor filled with the add constant, in + # the shape expected by the convolution + conv_in_const = np.zeros(conv_in_shape, dtype=np.float32) + conv_in_const.fill(A.item()) + # create an execution context and put in const input + exec_ctx = model.make_empty_exec_context() + exec_ctx[conv_in_name] = conv_in_const + # execute the conv node only + execute_node(conv_node, exec_ctx, model.graph) + # retrieve the conv output + Anew = exec_ctx[end_name] + # strip out repetition + Anew = Anew[0, :, 0, 0].reshape(1, -1, 1, 1) + # update the add weight + model.set_initializer(add_weight_name, Anew) + # rewire add input to be conv input + conv_node.input[0] = start_name + model.set_tensor_shape(start_name, conv_in_shape) + # use old conv input tensor as conv output + conv_node.output[0] = conv_in_name + model.set_tensor_shape(conv_in_name, conv_out_shape) + # use new conv output as new add node input + add_node.input[0] = conv_in_name + # use old conv output as new add node output + add_node.output[0] = end_name + # move add node past conv node + graph.node.remove(add_node) + graph.node.insert(node_ind, add_node) + graph_modified = True + model = model.transform(InferShapes()) + return (model, graph_modified) + + +class MoveScalarMulPastConv(Transformation): + """Move scalar mul operations past conv operations. We want to have muls + next to each other such that they can be collapsed into a single mul.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Mul": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Conv": + mul_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + assert A is not None, "Initializer for mul weights is not set." + conv_node = consumer + mul_node = n + start_name = mul_node.input[0] + conv_in_name = conv_node.input[0] + conv_in_shape = model.get_tensor_shape(conv_in_name) + conv_out_name = conv_node.output[0] + conv_out_shape = model.get_tensor_shape(conv_out_name) + if all(x == 1 for x in A.shape): + # if the mul is scalar, we can simply swap the order of ops + # rewire mul input to be conv input + conv_node.input[0] = start_name + model.set_tensor_shape(start_name, conv_in_shape) + # use old conv input tensor as conv output + conv_node.output[0] = conv_in_name + model.set_tensor_shape(conv_in_name, conv_out_shape) + # use new conv output as new mul node input + mul_node.input[0] = conv_in_name + # use old conv output as new mul node output + mul_node.output[0] = conv_out_name + # move add node past conv node + graph.node.remove(mul_node) + graph.node.insert(node_ind, mul_node) + graph_modified = True + model = model.transform(InferShapes()) + return (model, graph_modified) diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index ecc7cb192177cb2bb57a8d2efdfea91f22a488a1..f99a453d05d7cb3c824784e80103b6021f072a79 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -40,6 +40,9 @@ from finn.core.datatype import DataType pynq_part_map = dict() pynq_part_map["Ultra96"] = "xczu3eg-sbva484-1-e" pynq_part_map["Pynq-Z1"] = "xc7z020clg400-1" +pynq_part_map["Pynq-Z2"] = "xc7z020clg400-1" +pynq_part_map["ZCU104"] = "xczu7ev-ffvc1156-2-e" + def get_finn_root(): diff --git a/src/finn/util/test.py b/src/finn/util/test.py index 428ac3ea63e6913ed12364785b3ebfae527d1fdb..34edc3cacdecc461d1254c35c026c56ff8813549 100644 --- a/src/finn/util/test.py +++ b/src/finn/util/test.py @@ -53,7 +53,11 @@ def get_test_model_trained(netname, wbits, abits): and activations from the FINN Brevitas test networks.""" model_def_fxn = get_test_model_def_fxn(netname) checkpoint_loc = get_trained_checkpoint(netname, wbits, abits) - fc = model_def_fxn(weight_bit_width=wbits, act_bit_width=abits, in_bit_width=abits) + if netname == "CNV": + ibits = 8 + else: + ibits = abits + fc = model_def_fxn(weight_bit_width=wbits, act_bit_width=abits, in_bit_width=ibits) checkpoint = torch.load(checkpoint_loc, map_location="cpu") fc.load_state_dict(checkpoint["state_dict"]) return fc.eval() @@ -62,5 +66,9 @@ def get_test_model_trained(netname, wbits, abits): def get_test_model_untrained(netname, wbits, abits): """Returns untrained model specified by input arguments.""" model_def_fxn = get_test_model_def_fxn(netname) - fc = model_def_fxn(weight_bit_width=wbits, act_bit_width=abits, in_bit_width=abits) + if netname == "CNV": + ibits = 8 + else: + ibits = abits + fc = model_def_fxn(weight_bit_width=wbits, act_bit_width=abits, in_bit_width=ibits) return fc.eval() diff --git a/tests/brevitas/test_brevitas_act_export.py b/tests/brevitas/test_brevitas_act_export.py new file mode 100644 index 0000000000000000000000000000000000000000..08c4a99151d1105ad4258a8d7d6c19cc72da7a99 --- /dev/null +++ b/tests/brevitas/test_brevitas_act_export.py @@ -0,0 +1,43 @@ +import numpy as np +import torch +import brevitas.onnx as bo +from brevitas.nn import QuantHardTanh +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.scaling import ScalingImplType +from models.common import get_quant_type +import pytest +from finn.core.modelwrapper import ModelWrapper +import finn.core.onnx_exec as oxe +from finn.transformation.infer_shapes import InferShapes + +export_onnx_path = "test_act.onnx" + + +@pytest.mark.parametrize("abits", [1, 2, 4, 8]) +@pytest.mark.parametrize("narrow_range", [False, True]) +@pytest.mark.parametrize("max_val", [1.0, 1 - 2 ** (-7)]) +def test_brevitas_act_export(abits, narrow_range, max_val): + act_quant_type = get_quant_type(abits) + min_val = -1.0 + ishape = (1, 10) + b_act = QuantHardTanh( + bit_width=abits, + quant_type=act_quant_type, + max_val=max_val, + min_val=min_val, + restrict_scaling_type=RestrictValueType.LOG_FP, + scaling_impl_type=ScalingImplType.CONST, + narrow_range=narrow_range, + ) + bo.export_finn_onnx(b_act, ishape, export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype( + np.float32 + ) + idict = {model.graph.input[0].name: inp_tensor} + odict = oxe.execute_onnx(model, idict, True) + produced = odict[model.graph.output[0].name] + inp_tensor = torch.from_numpy(inp_tensor).float() + expected = b_act.forward(inp_tensor).detach().numpy() + assert np.isclose(produced, expected, atol=1e-3).all() diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py index 3c9f8d08223ed57d5b7409093f39d1af9613be83..8d21a33f78ca6f1229bdc11e753fc17cdf170242 100644 --- a/tests/brevitas/test_brevitas_cnv.py +++ b/tests/brevitas/test_brevitas_cnv.py @@ -28,6 +28,7 @@ import os import pkg_resources as pk +import pytest import brevitas.onnx as bo import numpy as np @@ -37,68 +38,34 @@ import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper from finn.transformation.fold_constants import FoldConstants from finn.transformation.infer_shapes import InferShapes -from finn.util.test import get_test_model_trained, get_test_model_untrained +from finn.transformation.general import GiveUniqueNodeNames +from finn.transformation.double_to_single_float import DoubleToSingleFloat +from finn.util.test import get_test_model_trained export_onnx_path = "test_output_cnv.onnx" -def test_brevitas_cnv_w1a1_export(): - cnv = get_test_model_untrained("CNV", 1, 1) - bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) - model = ModelWrapper(export_onnx_path) - assert model.graph.node[2].op_type == "Sign" - assert model.graph.node[3].op_type == "Conv" - conv0_wname = model.graph.node[3].input[1] - assert list(model.get_initializer(conv0_wname).shape) == [64, 3, 3, 3] - assert model.graph.node[4].op_type == "Mul" - os.remove(export_onnx_path) - - -def test_brevitas_cnv_w1a1_export_exec(): - cnv = get_test_model_trained("CNV", 1, 1) +@pytest.mark.parametrize("abits", [1, 2]) +@pytest.mark.parametrize("wbits", [1, 2]) +def test_brevitas_cnv_export_exec(wbits, abits): + if wbits > abits: + pytest.skip("No wbits > abits cases at the moment") + cnv = get_test_model_trained("CNV", wbits, abits) bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) model = ModelWrapper(export_onnx_path) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(DoubleToSingleFloat()) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) - model.save(export_onnx_path) fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") input_tensor = np.load(fn)["arr_0"].astype(np.float32) assert input_tensor.shape == (1, 3, 32, 32) # run using FINN-based execution - input_dict = {"0": input_tensor} - output_dict = oxe.execute_onnx(model, input_dict) - produced = output_dict[list(output_dict.keys())[0]] + input_dict = {model.graph.input[0].name: input_tensor} + output_dict = oxe.execute_onnx(model, input_dict, True) + produced = output_dict[model.graph.output[0].name] # do forward pass in PyTorch/Brevitas input_tensor = torch.from_numpy(input_tensor).float() expected = cnv.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() os.remove(export_onnx_path) - - -def test_brevitas_cnv_w1a1_pytorch(): - # load pretrained weights into CNV-w1a1 - cnv = get_test_model_trained("CNV", 1, 1) - fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") - input_tensor = np.load(fn)["arr_0"] - input_tensor = torch.from_numpy(input_tensor).float() - assert input_tensor.shape == (1, 3, 32, 32) - # do forward pass in PyTorch/Brevitas - produced = cnv.forward(input_tensor).detach().numpy() - expected = np.asarray( - [ - [ - 3.7939777, - -2.3108773, - 0.06898145, - 0.55185133, - 0.37939775, - -1.9659703, - -0.3104164, - -2.828238, - 2.6902752, - 0.48286998, - ] - ], - dtype=np.float32, - ) - assert np.isclose(produced, expected, atol=1e-3).all() diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py index 29ef2ee560d498eba04845fc0a6051fd0cae14ab..086681dde0ff029ceaa7d3274bad4d3f15bd32fc 100644 --- a/tests/core/test_custom_onnx_exec.py +++ b/tests/core/test_custom_onnx_exec.py @@ -254,3 +254,24 @@ def test_execute_custom_node_multithreshold(): ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) outputs_scaled = 2.0 * outputs - 1.0 assert (execution_context["out"] == outputs_scaled).all() + + # test the optional data layout option for MultiThreshold + node_def = helper.make_node( + "MultiThreshold", + ["v", "thresholds"], + ["out"], + domain="finn", + data_layout="NHWC", + ) + + v_nhwc = helper.make_tensor_value_info("v", TensorProto.FLOAT, [6, 2, 2, 3]) + out_nhwc = helper.make_tensor_value_info("out", TensorProto.FLOAT, [6, 2, 2, 3]) + inputs_nhwc = np.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC + outputs_nhwc = np.transpose(outputs, (0, 2, 3, 1)) # NCHW -> NHWC + execution_context["v"] = inputs_nhwc + + graph_def = helper.make_graph( + [node_def], "test_model", [v_nhwc, thresholds], [out_nhwc] + ) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + assert (execution_context["out"] == outputs_nhwc).all() diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 5d18de2d18157383a3c7882febfa752d72774572..942eda19ca4c2cdbded9f906a5e7772f50acbd6e 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -43,17 +43,26 @@ def test_modelwrapper(): bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) assert model.check_all_tensor_shapes_specified() is False - inp_shape = model.get_tensor_shape("0") + inp_name = model.graph.input[0].name + inp_shape = model.get_tensor_shape(inp_name) assert inp_shape == [1, 1, 28, 28] - l0_mat_tensor_name = "33" + # find first matmul node + l0_mat_tensor_name = "" + l0_inp_tensor_name = "" + for node in model.graph.node: + if node.op_type == "MatMul": + l0_inp_tensor_name = node.input[0] + l0_mat_tensor_name = node.input[1] + break + assert l0_mat_tensor_name != "" l0_weights = model.get_initializer(l0_mat_tensor_name) assert l0_weights.shape == (784, 1024) l0_weights_hist = Counter(l0_weights.flatten()) - assert l0_weights_hist[1.0] == 401311 and l0_weights_hist[-1.0] == 401505 + assert (l0_weights_hist[1.0] + l0_weights_hist[-1.0]) == 784 * 1024 l0_weights_rand = np.random.randn(784, 1024) model.set_initializer(l0_mat_tensor_name, l0_weights_rand) assert (model.get_initializer(l0_mat_tensor_name) == l0_weights_rand).all() - l0_inp_tensor_name = "32" + assert l0_inp_tensor_name != "" inp_cons = model.find_consumer(l0_inp_tensor_name) assert inp_cons.op_type == "MatMul" out_prod = model.find_producer(l0_inp_tensor_name) diff --git a/tests/transformation/streamline/test_streamline_cnv.py b/tests/transformation/streamline/test_streamline_cnv.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5bf441b736b5faed0024749b5b77f213949029 --- /dev/null +++ b/tests/transformation/streamline/test_streamline_cnv.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import brevitas.onnx as bo +import numpy as np +import pytest +import pkg_resources as pk + +import finn.core.onnx_exec as oxe +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.streamline import Streamline +from finn.util.test import get_test_model_trained +from finn.util.basic import make_build_dir +from finn.transformation.double_to_single_float import DoubleToSingleFloat + +export_onnx_path = make_build_dir("test_streamline_cnv_") + +# act bits +@pytest.mark.parametrize("abits", [1]) +# weight bits +@pytest.mark.parametrize("wbits", [1]) +# network topology / size +@pytest.mark.parametrize("size", ["CNV"]) +def test_streamline_cnv(size, wbits, abits): + if wbits > abits: + pytest.skip("No wbits > abits cases at the moment") + nname = "%s_%dW%dA" % (size, wbits, abits) + finn_onnx = export_onnx_path + "/%s.onnx" % nname + fc = get_test_model_trained(size, wbits, abits) + bo.export_finn_onnx(fc, (1, 3, 32, 32), finn_onnx) + model = ModelWrapper(finn_onnx) + model = model.transform(DoubleToSingleFloat()) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + # load one of the test vectors + fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") + input_tensor = np.load(fn)["arr_0"].astype(np.float32) + assert input_tensor.shape == (1, 3, 32, 32) + # run using FINN-based execution + input_dict = {"global_in": input_tensor} + expected_ctx = oxe.execute_onnx(model, input_dict, True) + expected = expected_ctx[model.graph.output[0].name] + model.save("orig_cnv.onnx") + model = model.transform(Streamline()) + model.save("streamlined_cnv.onnx") + produced_ctx = oxe.execute_onnx(model, input_dict, True) + produced = produced_ctx[model.graph.output[0].name] + assert np.isclose(expected, produced, atol=1e-3).all() + assert model.graph.node[0].op_type == "MultiThreshold" diff --git a/tests/transformation/test_fold_constants.py b/tests/transformation/test_fold_constants.py index cd1c346593e3666ce8a89bd4248fa8436423de6d..685c14a98b9031096aaf5b244c4f484d4f308bca 100644 --- a/tests/transformation/test_fold_constants.py +++ b/tests/transformation/test_fold_constants.py @@ -65,7 +65,8 @@ def test_const_folding_shapes(): model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) - assert model.graph.node[0].op_type == "Reshape" - assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28] - assert list(model.get_tensor_shape("27")) == [1, 784] + reshape_node = model.graph.node[0] + assert reshape_node.op_type == "Reshape" + assert list(model.get_tensor_shape(reshape_node.input[0])) == [1, 1, 28, 28] + assert list(model.get_tensor_shape(reshape_node.output[0])) == [1, 784] os.remove(export_onnx_path)