Skip to content
Snippets Groups Projects
Commit 801233ef authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] improve test_infer_shapes by manually removing shapes first

parent 20b11dfe
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ from pkgutil import get_data
import numpy as np
from onnx import TensorProto, helper
import finn.core.utils as util
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
......@@ -26,7 +27,7 @@ def test_infer_shapes():
# thresholds for one channel have to be sorted to guarantee the correct behavior
mt_thresh0_values = np.empty([8, 7], dtype=np.float32)
for i in range(len(mt_thresh0_values)):
mt_thresh0_values[i] = np.sort(np.random.random_sample(7,) * 10)
mt_thresh0_values[i] = np.sort(np.random.random_sample(7) * 10)
model.set_initializer(mt_thresh0.name, mt_thresh0_values)
......@@ -36,6 +37,9 @@ def test_infer_shapes():
)
Relu_node.output[0] = "mt_v0"
# explicitly remove any present shape from ReLU and MultiThreshold outputs
util.remove_by_name(model.graph.value_info, Relu_node.output[0])
util.remove_by_name(model.graph.value_info, mt_node.output[0])
graph.node.insert(4, mt_node)
# first check routine
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment