From 801233efe6c9ac232c64569dc618befc2f9f0f38 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sat, 9 Nov 2019 22:19:02 +0000
Subject: [PATCH] [Test] improve test_infer_shapes by manually removing shapes
 first

---
 tests/test_infer_shapes.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/tests/test_infer_shapes.py b/tests/test_infer_shapes.py
index f4c27572c..9ca0c9303 100644
--- a/tests/test_infer_shapes.py
+++ b/tests/test_infer_shapes.py
@@ -3,6 +3,7 @@ from pkgutil import get_data
 import numpy as np
 from onnx import TensorProto, helper
 
+import finn.core.utils as util
 import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
 
@@ -26,7 +27,7 @@ def test_infer_shapes():
     # thresholds for one channel have to be sorted to guarantee the correct behavior
     mt_thresh0_values = np.empty([8, 7], dtype=np.float32)
     for i in range(len(mt_thresh0_values)):
-        mt_thresh0_values[i] = np.sort(np.random.random_sample(7,) * 10)
+        mt_thresh0_values[i] = np.sort(np.random.random_sample(7) * 10)
 
     model.set_initializer(mt_thresh0.name, mt_thresh0_values)
 
@@ -36,6 +37,9 @@ def test_infer_shapes():
     )
     Relu_node.output[0] = "mt_v0"
 
+    # explicitly remove any present shape from ReLU and MultiThreshold outputs
+    util.remove_by_name(model.graph.value_info, Relu_node.output[0])
+    util.remove_by_name(model.graph.value_info, mt_node.output[0])
     graph.node.insert(4, mt_node)
 
     # first check routine
-- 
GitLab