diff --git a/src/finn/transformation/merge_onnx_models.py b/src/finn/transformation/merge_onnx_models.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc6127ed189311c72a119932394aca4745e3608 --- /dev/null +++ b/src/finn/transformation/merge_onnx_models.py @@ -0,0 +1,222 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import copy +import warnings +from onnx import helper + +from finn.transformation import Transformation +from finn.core.modelwrapper import ModelWrapper +import finn.util.basic as util +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_data_layouts import InferDataLayouts +from finn.transformation.general import ( + GiveReadableTensorNames, + GiveUniqueNodeNames, + GiveUniqueParameterTensors, +) + + +class MergeONNXModels(Transformation): + """Merges two models. The model passed in the transformation will be inserted before + the model the transformation is applied on, the resulting model is returned. + This transformation will try to connect graph.output[0] of the pre model and + graph.input[0] of the post model. + If more than one input or output exists, a warning is raised.""" + + def __init__(self, pre_model): + super().__init__() + # use deep copy of model that should be inserted in the beginning of + # the other model to ensure that it stays unchanged + self.pre_model = copy.deepcopy(pre_model) + + def apply(self, model): + graph_modified = False + pre_model = self.pre_model + post_model = copy.deepcopy(model) + + # check for dynamic outputs of pre model + dyn_outp = [] + for outp in pre_model.graph.output: + init_val = pre_model.get_initializer(outp.name) + if init_val is None: + dyn_outp.append(outp) + + if len(dyn_outp) != 1: + warnings.warn( + "The pre model has more than one dynamic output! The transformation " + "tries to connect the first dynamic output to the first dynamic input " + "of the post model." + ) + + # check for dynamic inputs of post model + dyn_inp = [] + for inp in post_model.graph.input: + init_val = post_model.get_initializer(inp.name) + if init_val is None: + dyn_inp.append(inp) + + if len(dyn_inp) != 1: + warnings.warn( + "The post model has more than one dynamic input! The transformation " + "tries to connect the first dynamic input to the first dynamic output " + "of the pre model." + ) + + # erase all node names to avoid conflict + for n in pre_model.graph.node: + n.name = "" + for n in post_model.graph.node: + n.name = "" + + # randomize all tensor names + names1 = pre_model.get_all_tensor_names() + names2 = post_model.get_all_tensor_names() + used_names = names1 + names2 + + # pre_model + for tensor_name in names1: + new_name = util.random_string() + while new_name in used_names: + new_name = util.random_string() + pre_model.rename_tensor(tensor_name, new_name) + used_names.append(new_name) + + # post_model + for tensor in names2: + new_name = util.random_string() + while new_name in used_names: + new_name = util.random_string() + post_model.rename_tensor(tensor_name, new_name) + used_names.append(new_name) + + # check if models can be merged + output_model_a = dyn_outp[0].name + input_model_b = dyn_inp[0].name + output_a_shape = pre_model.get_tensor_shape(output_model_a) + input_b_shape = post_model.get_tensor_shape(input_model_b) + assert ( + output_a_shape == input_b_shape + ), "Models can't be merged! Shapes don't match." + + # connect output of one model to input of the other + for n in pre_model.graph.node: + if output_model_a == n.output[0]: + n.output[0] = input_model_b + + # extract information for new model + + # nodes + node_list_a = pre_model.graph.node + node_list_b = post_model.graph.node + + node_list = node_list_a + for node in node_list_b: + node_list.append(node) + + # in and output + inp = pre_model.graph.input[0] + outp = post_model.graph.output[0] + + # create new graph and model + new_graph = helper.make_graph( + nodes=node_list, + name="fuse-graph", + inputs=[inp], + outputs=[outp], + value_info=[], + ) + + new_model = helper.make_model(new_graph, producer_name="fuse_model") + new_model = ModelWrapper(new_model) + + # add value info from both models to new model + # pre model + vi_pre = [x for x in pre_model.graph.input] + vi_pre += [x for x in pre_model.graph.output] + vi_pre += [x for x in pre_model.graph.value_info] + for vi in vi_pre: + # preserve intializers, quantization/sparsity annotation, etc. + # initializer + init_val = pre_model.get_initializer(vi.name) + if init_val is not None: + new_model.set_initializer(vi.name, init_val) + # FINN datatype + dtype = pre_model.get_tensor_datatype(vi.name) + new_model.set_tensor_datatype(vi.name, dtype) + # data layout + data_layout = pre_model.get_tensor_layout(vi.name) + if data_layout is not None: + new_model.set_tensor_layout(vi.name, data_layout) + # sparsity + sparsity = pre_model.get_tensor_sparsity(vi.name) + if sparsity is not None: + new_model.set_tensor_sparsity(vi.name, sparsity) + # graph input should not be part of graph.value_info, so don't insert + # if current vi == inp, but the quantization annotation is preserved + if vi == inp: + continue + new_model.graph.value_info.append(vi) + + # post model + vi_model = [x for x in post_model.graph.input] + vi_model += [x for x in post_model.graph.output] + vi_model += [x for x in post_model.graph.value_info] + for vi in vi_model: + # preserve intializers, quantization/sparsity annotation, etc. + # initializer + init_val = post_model.get_initializer(vi.name) + if init_val is not None: + new_model.set_initializer(vi.name, init_val) + # FINN datatype + dtype = post_model.get_tensor_datatype(vi.name) + new_model.set_tensor_datatype(vi.name, dtype) + # data layout + data_layout = post_model.get_tensor_layout(vi.name) + if data_layout is not None: + new_model.set_tensor_layout(vi.name, data_layout) + # sparsity + sparsity = post_model.get_tensor_sparsity(vi.name) + if sparsity is not None: + new_model.set_tensor_sparsity(vi.name, sparsity) + # graph output should not be part of graph.value_info, so don't insert + # if current vi == outp, but the quantization annotation is preserved + if vi == outp: + continue + new_model.graph.value_info.append(vi) + + # tidy-up new model + model = new_model + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + model = model.transform(InferDataLayouts()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveUniqueParameterTensors()) + model = model.transform(GiveReadableTensorNames()) + + return (model, graph_modified) diff --git a/tests/transformation/test_merge_onnx_models.py b/tests/transformation/test_merge_onnx_models.py new file mode 100644 index 0000000000000000000000000000000000000000..db7c990baddfb50a39603937a9c5b73f512a0e59 --- /dev/null +++ b/tests/transformation/test_merge_onnx_models.py @@ -0,0 +1,126 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pkgutil import get_data + +import numpy as np +import onnx +import onnx.numpy_helper as np_helper +from onnx import TensorProto, helper + +from finn.core.modelwrapper import ModelWrapper +from finn.core.datatype import DataType +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_data_layouts import InferDataLayouts +from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from finn.transformation.merge_onnx_models import MergeONNXModels +import finn.core.onnx_exec as oxe + + +def test_merge_onnx_models(): + # load pre model + raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") + model1 = ModelWrapper(raw_m) + # the input for model1 comes from a uint8 vector so we set the finn datatype + # of the input tensor to DataType.UINT8 to verify that the datatypes are correctly + # preserved in the transformed model + model1.set_tensor_datatype(model1.graph.input[0].name, DataType.UINT8) + model1 = model1.transform(InferShapes()) + model1 = model1.transform(GiveUniqueNodeNames()) + model1 = model1.transform(GiveReadableTensorNames()) + + # set up post model + shape = [1, 10] + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, shape) + a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, []) + a1 = helper.make_tensor_value_info("a1", TensorProto.FLOAT, []) + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, shape) + + mul_node = helper.make_node("Mul", ["inp", "a0"], ["mul_out"]) + div_node = helper.make_node("Div", ["mul_out", "a1"], ["outp"]) + + graph = helper.make_graph( + nodes=[mul_node, div_node], + name="model2-graph", + inputs=[inp], + outputs=[outp], + value_info=[a0, a1], + ) + + model2 = helper.make_model(graph, producer_name="model2") + model2 = ModelWrapper(model2) + # initialize model2 + a0_value = np.random.uniform(low=0, high=1, size=(1)).astype(np.float32) + model2.set_initializer("a0", a0_value) + a1_value = np.random.uniform(low=0.1, high=1, size=(1)).astype(np.float32) + model2.set_initializer("a1", a1_value) + # set a dummy sparsity annotation to check if it gets correctly transferred + # to the merged model + sparsity = {"dw": {"kernel_shape": 0}} + model2.set_tensor_sparsity("a1", sparsity) + model2 = model2.transform(InferShapes()) + model2 = model2.transform(InferDataTypes()) + model2 = model2.transform(InferDataLayouts()) + model2 = model2.transform(GiveUniqueNodeNames()) + model2 = model2.transform(GiveReadableTensorNames()) + + # simulate the models before the merging and pass the output of model1 to model2 + # load one of the test vectors + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + inp_values = onnx.load_tensor_from_string(raw_i) + inp_values = np_helper.to_array(inp_values) + idict = {model1.graph.input[0].name: inp_values} + odict = oxe.execute_onnx(model1, idict) + temp = odict[model1.graph.output[0].name] + + idict = {model2.graph.input[0].name: temp} + odict = oxe.execute_onnx(model2, idict) + outp = odict[model2.graph.output[0].name] + # merge models + model_transformed = model2.transform(MergeONNXModels(model1)) + + idict = {model_transformed.graph.input[0].name: inp_values} + odict = oxe.execute_onnx(model_transformed, idict) + outp_transformed = odict[model_transformed.graph.output[0].name] + + assert (outp == outp_transformed).all() + assert len(model_transformed.graph.node) == len(model1.graph.node) + len( + model2.graph.node + ) + # to test if the value is preserved we set the sparsity annotation of input[1] + # of the division block to a dummy value, we can now look for the division block + # and check if the sparsity annotation is still the same + for n in model_transformed.graph.node: + if n.op_type == "Div": + tensor_name = n.input[1] + set_sparsity = model_transformed.get_tensor_sparsity(tensor_name) + assert sparsity == set_sparsity + + # check if finn datatype of graph.input[0] is still set to UINT8 + assert model_transformed.get_tensor_datatype("global_in") == DataType.UINT8