From 801233efe6c9ac232c64569dc618befc2f9f0f38 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sat, 9 Nov 2019 22:19:02 +0000 Subject: [PATCH] [Test] improve test_infer_shapes by manually removing shapes first --- tests/test_infer_shapes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_infer_shapes.py b/tests/test_infer_shapes.py index f4c27572c..9ca0c9303 100644 --- a/tests/test_infer_shapes.py +++ b/tests/test_infer_shapes.py @@ -3,6 +3,7 @@ from pkgutil import get_data import numpy as np from onnx import TensorProto, helper +import finn.core.utils as util import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper @@ -26,7 +27,7 @@ def test_infer_shapes(): # thresholds for one channel have to be sorted to guarantee the correct behavior mt_thresh0_values = np.empty([8, 7], dtype=np.float32) for i in range(len(mt_thresh0_values)): - mt_thresh0_values[i] = np.sort(np.random.random_sample(7,) * 10) + mt_thresh0_values[i] = np.sort(np.random.random_sample(7) * 10) model.set_initializer(mt_thresh0.name, mt_thresh0_values) @@ -36,6 +37,9 @@ def test_infer_shapes(): ) Relu_node.output[0] = "mt_v0" + # explicitly remove any present shape from ReLU and MultiThreshold outputs + util.remove_by_name(model.graph.value_info, Relu_node.output[0]) + util.remove_by_name(model.graph.value_info, mt_node.output[0]) graph.node.insert(4, mt_node) # first check routine -- GitLab