diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py
index 3ef6ef1b13c9b59e8a346c83daddaab4fdf6a859..213d2cedf92c0276e33fcf2b50e6966aeee8c847 100644
--- a/src/finn/transformation/insert_topk.py
+++ b/src/finn/transformation/insert_topk.py
@@ -46,15 +46,6 @@ class InsertTopK(Transformation):
         self.largest = largest
         self.sorted = sorted
 
-    def is_scalar_linear(self, model, node):
-        # if is linear
-        test = (node.op_type == "Mul") or (node.op_type == "Add")
-        if test:
-            init = model.get_initializer(node.input[1])
-            test = test and (init is not None) and all(x == 1 for x in init.shape)
-            test = test and init > 0
-        return test
-
     def apply(self, model):
         # get name of output tensor
         graph_out_name = model.graph.output[0].name
@@ -64,19 +55,6 @@ class InsertTopK(Transformation):
         if final_node.op_type == "TopK":
             return (model, False)
         else:
-            # remove any scalar linear transformations at graph output
-            # because TopK is invariant to them
-            while self.is_scalar_linear(model, final_node):
-                # remove the predecessor
-                final_node_input = model.get_tensor_valueinfo(final_node.input[0])
-                model.graph.output.insert(0, final_node_input)
-                model.graph.output.pop(1)
-                model.graph.node.remove(final_node)
-                graph_out_name = model.graph.output[0].name
-                final_node = model.find_producer(graph_out_name)
-                if final_node is None:
-                    break
-
             out_shape = model.get_tensor_shape(graph_out_name)
             out_dtype = model.get_tensor_datatype(graph_out_name)
             # adjust shape
diff --git a/tests/fpgadataflow/test_convert_to_hls_layers_synthetic.py b/tests/fpgadataflow/test_convert_to_hls_layers_synthetic.py
index 81e86eece777e3d1f88d026ee940aabeb6be76e2..9d861929f3d421c431a27ccac5d513938aa7d726 100644
--- a/tests/fpgadataflow/test_convert_to_hls_layers_synthetic.py
+++ b/tests/fpgadataflow/test_convert_to_hls_layers_synthetic.py
@@ -52,7 +52,10 @@ import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
 from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
 from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
-from finn.transformation.streamline.absorb import AbsorbConsecutiveTransposes
+from finn.transformation.streamline.absorb import (
+    AbsorbScalarMulIntoTopK,
+    AbsorbConsecutiveTransposes,
+)
 from finn.transformation.streamline.collapse_repeated import (
     CollapseRepeatedMul,
     CollapseRepeatedAdd,
@@ -191,6 +194,7 @@ def test_convert_to_hls_layers_synthetic(ch, ifmdim, idt):
     model = model.transform(to_hls.InferGlobalAccPoolLayer())
     model = model.transform(MoveScalarLinearPastInvariants())
     model = model.transform(InsertTopK())
+    model = model.transform(AbsorbScalarMulIntoTopK())
     model = model.transform(InferDataTypes())
     model = model.transform(to_hls.InferLabelSelectLayer())
     model = model.transform(AbsorbConsecutiveTransposes())
diff --git a/tests/transformation/test_absorb_mul_into_topk.py b/tests/transformation/test_absorb_mul_into_topk.py
index 6cc5c9847db6ed163e67ece6b11f854b946187d8..1394220f7c336ccea8fe9c494734c4175bf2e847 100644
--- a/tests/transformation/test_absorb_mul_into_topk.py
+++ b/tests/transformation/test_absorb_mul_into_topk.py
@@ -34,6 +34,7 @@ from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
+from finn.transformation.insert_topk import InsertTopK
 from finn.transformation.streamline.absorb import AbsorbScalarMulIntoTopK
 import finn.core.onnx_exec as oxe
 
@@ -42,35 +43,25 @@ import finn.core.onnx_exec as oxe
 # parameter to indicate if mul parameter is scalar or not
 @pytest.mark.parametrize("scalar", [True, False])
 def test_absorb_mul_into_topk(mul_positive, scalar):
-    K = 5
     if scalar is True:
         shape = [1]
     else:
         shape = [1, 1, 1, 1000]
-
-    out_shape = [1, 1, 1, K]
     inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1000])
     a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, shape)
-    outp = helper.make_tensor_value_info("outp", TensorProto.INT64, out_shape)
-    k_value = helper.make_tensor_value_info("k_value", TensorProto.INT64, [1])
-    topk_values = helper.make_tensor_value_info(
-        "topk_values", TensorProto.FLOAT, out_shape
-    )
-    mul_node = helper.make_node("Mul", ["inp", "a0"], ["a1"])
-    top_k = helper.make_node(
-        "TopK", ["a1", "k_value"], ["topk_values", "outp"], largest=1, axis=-1, sorted=1
-    )
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1, 1, 1000])
+
+    mul_node = helper.make_node("Mul", ["inp", "a0"], ["outp"])
     mul_graph = helper.make_graph(
-        nodes=[mul_node, top_k],
+        nodes=[mul_node],
         name="mul-graph",
         inputs=[inp],
         outputs=[outp],
-        value_info=[a0, k_value, topk_values],
+        value_info=[a0],
     )
 
     model = helper.make_model(mul_graph, producer_name="mul_model")
     model = ModelWrapper(model)
-
     # initialize values
     if mul_positive is True:
         a0_values = np.random.uniform(low=0.1, high=1, size=tuple(shape)).astype(
@@ -81,9 +72,7 @@ def test_absorb_mul_into_topk(mul_positive, scalar):
             np.float32
         )
     model.set_initializer("a0", a0_values)
-
-    k_tensor = np.array([K]).astype(np.int64)
-    model.set_initializer(k_value.name, k_tensor)
+    model = model.transform(InsertTopK())
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
     model = model.transform(GiveUniqueNodeNames())