diff --git a/tests/test_infer_shapes.py b/tests/test_infer_shapes.py index f4c27572cf30decce49f93fb32f12868440b00e2..9ca0c9303b4cda55a954dee0292255bba21eed83 100644 --- a/tests/test_infer_shapes.py +++ b/tests/test_infer_shapes.py @@ -3,6 +3,7 @@ from pkgutil import get_data import numpy as np from onnx import TensorProto, helper +import finn.core.utils as util import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper @@ -26,7 +27,7 @@ def test_infer_shapes(): # thresholds for one channel have to be sorted to guarantee the correct behavior mt_thresh0_values = np.empty([8, 7], dtype=np.float32) for i in range(len(mt_thresh0_values)): - mt_thresh0_values[i] = np.sort(np.random.random_sample(7,) * 10) + mt_thresh0_values[i] = np.sort(np.random.random_sample(7) * 10) model.set_initializer(mt_thresh0.name, mt_thresh0_values) @@ -36,6 +37,9 @@ def test_infer_shapes(): ) Relu_node.output[0] = "mt_v0" + # explicitly remove any present shape from ReLU and MultiThreshold outputs + util.remove_by_name(model.graph.value_info, Relu_node.output[0]) + util.remove_by_name(model.graph.value_info, mt_node.output[0]) graph.node.insert(4, mt_node) # first check routine