diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py index 213d2cedf92c0276e33fcf2b50e6966aeee8c847..3ef6ef1b13c9b59e8a346c83daddaab4fdf6a859 100644 --- a/src/finn/transformation/insert_topk.py +++ b/src/finn/transformation/insert_topk.py @@ -46,6 +46,15 @@ class InsertTopK(Transformation): self.largest = largest self.sorted = sorted + def is_scalar_linear(self, model, node): + # if is linear + test = (node.op_type == "Mul") or (node.op_type == "Add") + if test: + init = model.get_initializer(node.input[1]) + test = test and (init is not None) and all(x == 1 for x in init.shape) + test = test and init > 0 + return test + def apply(self, model): # get name of output tensor graph_out_name = model.graph.output[0].name @@ -55,6 +64,19 @@ class InsertTopK(Transformation): if final_node.op_type == "TopK": return (model, False) else: + # remove any scalar linear transformations at graph output + # because TopK is invariant to them + while self.is_scalar_linear(model, final_node): + # remove the predecessor + final_node_input = model.get_tensor_valueinfo(final_node.input[0]) + model.graph.output.insert(0, final_node_input) + model.graph.output.pop(1) + model.graph.node.remove(final_node) + graph_out_name = model.graph.output[0].name + final_node = model.find_producer(graph_out_name) + if final_node is None: + break + out_shape = model.get_tensor_shape(graph_out_name) out_dtype = model.get_tensor_datatype(graph_out_name) # adjust shape diff --git a/tests/transformation/test_absorb_mul_into_topk.py b/tests/transformation/test_absorb_mul_into_topk.py index 1394220f7c336ccea8fe9c494734c4175bf2e847..6cc5c9847db6ed163e67ece6b11f854b946187d8 100644 --- a/tests/transformation/test_absorb_mul_into_topk.py +++ b/tests/transformation/test_absorb_mul_into_topk.py @@ -34,7 +34,6 @@ from finn.core.modelwrapper import ModelWrapper from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames -from finn.transformation.insert_topk import InsertTopK from finn.transformation.streamline.absorb import AbsorbScalarMulIntoTopK import finn.core.onnx_exec as oxe @@ -43,25 +42,35 @@ import finn.core.onnx_exec as oxe # parameter to indicate if mul parameter is scalar or not @pytest.mark.parametrize("scalar", [True, False]) def test_absorb_mul_into_topk(mul_positive, scalar): + K = 5 if scalar is True: shape = [1] else: shape = [1, 1, 1, 1000] + + out_shape = [1, 1, 1, K] inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1000]) a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, shape) - outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1, 1, 1000]) - - mul_node = helper.make_node("Mul", ["inp", "a0"], ["outp"]) + outp = helper.make_tensor_value_info("outp", TensorProto.INT64, out_shape) + k_value = helper.make_tensor_value_info("k_value", TensorProto.INT64, [1]) + topk_values = helper.make_tensor_value_info( + "topk_values", TensorProto.FLOAT, out_shape + ) + mul_node = helper.make_node("Mul", ["inp", "a0"], ["a1"]) + top_k = helper.make_node( + "TopK", ["a1", "k_value"], ["topk_values", "outp"], largest=1, axis=-1, sorted=1 + ) mul_graph = helper.make_graph( - nodes=[mul_node], + nodes=[mul_node, top_k], name="mul-graph", inputs=[inp], outputs=[outp], - value_info=[a0], + value_info=[a0, k_value, topk_values], ) model = helper.make_model(mul_graph, producer_name="mul_model") model = ModelWrapper(model) + # initialize values if mul_positive is True: a0_values = np.random.uniform(low=0.1, high=1, size=tuple(shape)).astype( @@ -72,7 +81,9 @@ def test_absorb_mul_into_topk(mul_positive, scalar): np.float32 ) model.set_initializer("a0", a0_values) - model = model.transform(InsertTopK()) + + k_tensor = np.array([K]).astype(np.int64) + model.set_initializer(k_value.name, k_tensor) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) model = model.transform(GiveUniqueNodeNames()) diff --git a/tests/transformation/test_topk_insert.py b/tests/transformation/test_topk_insert.py index a18e63384150f140cb63ec7b438283eb4797266c..b85ed4aa6999faf751e535c1cc687d639c4eb74f 100644 --- a/tests/transformation/test_topk_insert.py +++ b/tests/transformation/test_topk_insert.py @@ -1,4 +1,4 @@ -import os +# import os import onnx from finn.util.test import get_test_model_trained import brevitas.onnx as bo @@ -57,4 +57,4 @@ def test_topk_insert(k): output_pysim_topk = output_pysim_topk.astype(np.int).flatten() assert np.array_equal(output_golden_topk, output_pysim_topk) - os.remove(export_onnx_path) + # os.remove(export_onnx_path)