diff --git a/tests/test_infer_shapes.py b/tests/test_infer_shapes.py
index f4c27572cf30decce49f93fb32f12868440b00e2..9ca0c9303b4cda55a954dee0292255bba21eed83 100644
--- a/tests/test_infer_shapes.py
+++ b/tests/test_infer_shapes.py
@@ -3,6 +3,7 @@ from pkgutil import get_data
 import numpy as np
 from onnx import TensorProto, helper
 
+import finn.core.utils as util
 import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
 
@@ -26,7 +27,7 @@ def test_infer_shapes():
     # thresholds for one channel have to be sorted to guarantee the correct behavior
     mt_thresh0_values = np.empty([8, 7], dtype=np.float32)
     for i in range(len(mt_thresh0_values)):
-        mt_thresh0_values[i] = np.sort(np.random.random_sample(7,) * 10)
+        mt_thresh0_values[i] = np.sort(np.random.random_sample(7) * 10)
 
     model.set_initializer(mt_thresh0.name, mt_thresh0_values)
 
@@ -36,6 +37,9 @@ def test_infer_shapes():
     )
     Relu_node.output[0] = "mt_v0"
 
+    # explicitly remove any present shape from ReLU and MultiThreshold outputs
+    util.remove_by_name(model.graph.value_info, Relu_node.output[0])
+    util.remove_by_name(model.graph.value_info, mt_node.output[0])
     graph.node.insert(4, mt_node)
 
     # first check routine