diff --git a/src/finn/transformation/insert_topk.py b/src/finn/transformation/insert_topk.py
index 213d2cedf92c0276e33fcf2b50e6966aeee8c847..3ef6ef1b13c9b59e8a346c83daddaab4fdf6a859 100644
--- a/src/finn/transformation/insert_topk.py
+++ b/src/finn/transformation/insert_topk.py
@@ -46,6 +46,15 @@ class InsertTopK(Transformation):
         self.largest = largest
         self.sorted = sorted
 
+    def is_scalar_linear(self, model, node):
+        # if is linear
+        test = (node.op_type == "Mul") or (node.op_type == "Add")
+        if test:
+            init = model.get_initializer(node.input[1])
+            test = test and (init is not None) and all(x == 1 for x in init.shape)
+            test = test and init > 0
+        return test
+
     def apply(self, model):
         # get name of output tensor
         graph_out_name = model.graph.output[0].name
@@ -55,6 +64,19 @@ class InsertTopK(Transformation):
         if final_node.op_type == "TopK":
             return (model, False)
         else:
+            # remove any scalar linear transformations at graph output
+            # because TopK is invariant to them
+            while self.is_scalar_linear(model, final_node):
+                # remove the predecessor
+                final_node_input = model.get_tensor_valueinfo(final_node.input[0])
+                model.graph.output.insert(0, final_node_input)
+                model.graph.output.pop(1)
+                model.graph.node.remove(final_node)
+                graph_out_name = model.graph.output[0].name
+                final_node = model.find_producer(graph_out_name)
+                if final_node is None:
+                    break
+
             out_shape = model.get_tensor_shape(graph_out_name)
             out_dtype = model.get_tensor_datatype(graph_out_name)
             # adjust shape
diff --git a/tests/transformation/test_absorb_mul_into_topk.py b/tests/transformation/test_absorb_mul_into_topk.py
index 1394220f7c336ccea8fe9c494734c4175bf2e847..6cc5c9847db6ed163e67ece6b11f854b946187d8 100644
--- a/tests/transformation/test_absorb_mul_into_topk.py
+++ b/tests/transformation/test_absorb_mul_into_topk.py
@@ -34,7 +34,6 @@ from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
-from finn.transformation.insert_topk import InsertTopK
 from finn.transformation.streamline.absorb import AbsorbScalarMulIntoTopK
 import finn.core.onnx_exec as oxe
 
@@ -43,25 +42,35 @@ import finn.core.onnx_exec as oxe
 # parameter to indicate if mul parameter is scalar or not
 @pytest.mark.parametrize("scalar", [True, False])
 def test_absorb_mul_into_topk(mul_positive, scalar):
+    K = 5
     if scalar is True:
         shape = [1]
     else:
         shape = [1, 1, 1, 1000]
+
+    out_shape = [1, 1, 1, K]
     inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1000])
     a0 = helper.make_tensor_value_info("a0", TensorProto.FLOAT, shape)
-    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1, 1, 1000])
-
-    mul_node = helper.make_node("Mul", ["inp", "a0"], ["outp"])
+    outp = helper.make_tensor_value_info("outp", TensorProto.INT64, out_shape)
+    k_value = helper.make_tensor_value_info("k_value", TensorProto.INT64, [1])
+    topk_values = helper.make_tensor_value_info(
+        "topk_values", TensorProto.FLOAT, out_shape
+    )
+    mul_node = helper.make_node("Mul", ["inp", "a0"], ["a1"])
+    top_k = helper.make_node(
+        "TopK", ["a1", "k_value"], ["topk_values", "outp"], largest=1, axis=-1, sorted=1
+    )
     mul_graph = helper.make_graph(
-        nodes=[mul_node],
+        nodes=[mul_node, top_k],
         name="mul-graph",
         inputs=[inp],
         outputs=[outp],
-        value_info=[a0],
+        value_info=[a0, k_value, topk_values],
     )
 
     model = helper.make_model(mul_graph, producer_name="mul_model")
     model = ModelWrapper(model)
+
     # initialize values
     if mul_positive is True:
         a0_values = np.random.uniform(low=0.1, high=1, size=tuple(shape)).astype(
@@ -72,7 +81,9 @@ def test_absorb_mul_into_topk(mul_positive, scalar):
             np.float32
         )
     model.set_initializer("a0", a0_values)
-    model = model.transform(InsertTopK())
+
+    k_tensor = np.array([K]).astype(np.int64)
+    model.set_initializer(k_value.name, k_tensor)
     model = model.transform(InferShapes())
     model = model.transform(InferDataTypes())
     model = model.transform(GiveUniqueNodeNames())
diff --git a/tests/transformation/test_topk_insert.py b/tests/transformation/test_topk_insert.py
index a18e63384150f140cb63ec7b438283eb4797266c..b85ed4aa6999faf751e535c1cc687d639c4eb74f 100644
--- a/tests/transformation/test_topk_insert.py
+++ b/tests/transformation/test_topk_insert.py
@@ -1,4 +1,4 @@
-import os
+# import os
 import onnx
 from finn.util.test import get_test_model_trained
 import brevitas.onnx as bo
@@ -57,4 +57,4 @@ def test_topk_insert(k):
     output_pysim_topk = output_pysim_topk.astype(np.int).flatten()
 
     assert np.array_equal(output_golden_topk, output_pysim_topk)
-    os.remove(export_onnx_path)
+    # os.remove(export_onnx_path)