diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index 53c73e1dc4fe0bfab53e3f126add992cb338c11d..64f4e3183082baada5fb97e49e4566525eddbc52 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -81,6 +81,54 @@ class GiveReadableTensorNames(Transformation): return (model, False) +class GiveUniqueParameterTensors(Transformation): + """Make every parameter tensor unique. The aim is to avoid affecting + other nodes apart from the one the system is currently operating on.""" + + def apply(self, model): + model_tensor_names = model.get_all_tensor_names() + + graph = model.graph + graph_modified = False + seen_parameters = [] + for n in graph.node: + # copy inputs since they may be modified + node_inputs_list = [x for x in n.input] + for input_idx, node_input in enumerate(node_inputs_list): + # check if it's a parameter + input_init = model.get_initializer(node_input) + if input_init is None: + # dynamic input + continue + + # check if repeated + if node_input not in seen_parameters: + # first occurance + seen_parameters += [node_input] + continue + + # Give new name to tensor + for trials in range(10): + new_param_name = util.random_string(stringLength=6) + if new_param_name not in model_tensor_names: + break + else: + raise Exception( + "Not able to create new tensor name" + + "after 10 trials. Net too big for the random tensor" + + "name lenght chosen? Try larger stringLength?" + ) + + model_tensor_names += [new_param_name] + + model.set_initializer(new_param_name, input_init) + + # point node input to new tensor + n.input[input_idx] = new_param_name + + return (model, graph_modified) + + class ConvertSubToAdd(Transformation): """Convert subtract-a-constant nodes to add-a-constant nodes.""" diff --git a/tests/transformation/test_general_transformation.py b/tests/transformation/test_general_transformation.py index 33b6041a170f3c0de8f741ef3ecb28682f6429ea..153af378eb3e07d5824f114fd194730048fb4953 100644 --- a/tests/transformation/test_general_transformation.py +++ b/tests/transformation/test_general_transformation.py @@ -31,6 +31,12 @@ from pkgutil import get_data from finn.core.modelwrapper import ModelWrapper from finn.transformation.general import GiveUniqueNodeNames +import numpy as np +import onnx +import finn.core.onnx_exec as oxe +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.general import GiveUniqueParameterTensors + def test_give_unique_node_names(): raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") @@ -39,3 +45,76 @@ def test_give_unique_node_names(): assert model.graph.node[0].name == "Reshape_0" assert model.graph.node[1].name == "Conv_0" assert model.graph.node[11].name == "Add_2" + + +def test_give_unique_parameter_tensors(): + + # Create model + input_shape = [4, 4] + in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, input_shape) + out1 = onnx.helper.make_tensor_value_info( + "out1", onnx.TensorProto.FLOAT, input_shape + ) + + graph_nodes = [] + graph_nodes += [ + onnx.helper.make_node("Add", inputs=["in1", "param1"], outputs=["t1"]) + ] + + graph_nodes += [ + onnx.helper.make_node("Sum", inputs=["t1", "param1", "param1"], outputs=["t2"]) + ] + + graph_nodes += [ + onnx.helper.make_node("Sum", inputs=["t2", "param2", "param1"], outputs=["t3"]) + ] + + graph_nodes += [ + onnx.helper.make_node("Add", inputs=["t3", "param1"], outputs=["out1"]) + ] + + onnx_graph = onnx.helper.make_graph( + nodes=graph_nodes, name="simple_graph", inputs=[in1], outputs=[out1], + ) + + onnx_model = onnx.helper.make_model(onnx_graph, producer_name="simple-model") + model = ModelWrapper(onnx_model) + + # Set param values + np.random.seed(0) + param1 = np.random.rand(*input_shape).astype(np.float32) + param2 = np.random.rand(*input_shape).astype(np.float32) + model.set_initializer("param1", param1) + model.set_initializer("param2", param2) + model = model.transform(InferShapes()) + + # Apply transformation + new_model = model.transform(GiveUniqueParameterTensors()) + new_model = new_model.transform(InferShapes()) + + # Test + # Breaks the model? + input_tensor = np.random.rand(*input_shape).astype(np.float32) + input_dict = {"in1": input_tensor} + + # run original + expected_context = oxe.execute_onnx(model, input_dict) + expected_output = expected_context[model.graph.output[0].name] + + # run modified + produced_context = oxe.execute_onnx(new_model, input_dict) + produced_output = produced_context[new_model.graph.output[0].name] + + assert np.isclose( + expected_output, produced_output, atol=1e-8 + ).all(), " GiveUniqueParameterTensors() transform breaks the model" + + # Does the job? + param_set = set() + param_cnt = 0 + for n in new_model.graph.node: + for i in range(1, len(n.input)): + param_set |= {n.input[i]} + param_cnt += 1 + + assert len(param_set) == param_cnt, " There are still parameters reused"