diff --git a/tests/transformation/test_absorb_mul_into_topk.py b/tests/transformation/test_absorb_mul_into_topk.py index 1394220f7c336ccea8fe9c494734c4175bf2e847..6cc5c9847db6ed163e67ece6b11f854b946187d8 100644 --- a/tests/transformation/test_absorb_mul_into_topk.py +++ b/tests/transformation/test_absorb_mul_into_topk.py @@ -34,7 +34,6 @@ from finn.core.modelwrapper import ModelWrapper from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames -from finn.transformation.insert_topk import InsertTopK from finn.transformation.streamline.absorb import AbsorbScalarMulIntoTopK import finn.core.onnx_exec as oxe @@ -43,25 +42,35 @@ import finn.core.onnx_exec as oxe # parameter to indicate if mul parameter is scalar or not @pytest.mark.parametrize("scalar", [True, False]) def test_absorb_mul_into_topk(mul_positive, scalar): + K = 5 if scalar is True: shape = [1] else: shape = [1, 1, 1, 1000] + + out_shape = [1, 1, 1, K] inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1000]) a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, shape) - outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1, 1, 1000]) - - mul_node = helper.make_node("Mul", ["inp", "a0"], ["outp"]) + outp = helper.make_tensor_value_info("outp", TensorProto.INT64, out_shape) + k_value = helper.make_tensor_value_info("k_value", TensorProto.INT64, [1]) + topk_values = helper.make_tensor_value_info( + "topk_values", TensorProto.FLOAT, out_shape + ) + mul_node = helper.make_node("Mul", ["inp", "a0"], ["a1"]) + top_k = helper.make_node( + "TopK", ["a1", "k_value"], ["topk_values", "outp"], largest=1, axis=-1, sorted=1 + ) mul_graph = helper.make_graph( - nodes=[mul_node], + nodes=[mul_node, top_k], name="mul-graph", inputs=[inp], outputs=[outp], - value_info=[a0], + value_info=[a0, k_value, topk_values], ) model = helper.make_model(mul_graph, producer_name="mul_model") model = ModelWrapper(model) + # initialize values if mul_positive is True: a0_values = np.random.uniform(low=0.1, high=1, size=tuple(shape)).astype( @@ -72,7 +81,9 @@ def test_absorb_mul_into_topk(mul_positive, scalar): np.float32 ) model.set_initializer("a0", a0_values) - model = model.transform(InsertTopK()) + + k_tensor = np.array([K]).astype(np.int64) + model.set_initializer(k_value.name, k_tensor) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) model = model.transform(GiveUniqueNodeNames())