Skip to content
Snippets Groups Projects
Commit 372e168b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] make test_modelwrapper more robust to example model changes

parent d79c6317
No related branches found
No related tags found
No related merge requests found
......@@ -43,17 +43,26 @@ def test_modelwrapper():
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
assert model.check_all_tensor_shapes_specified() is False
inp_shape = model.get_tensor_shape("0")
inp_name = model.graph.input[0].name
inp_shape = model.get_tensor_shape(inp_name)
assert inp_shape == [1, 1, 28, 28]
l0_mat_tensor_name = "33"
# find first matmul node
l0_mat_tensor_name = ""
l0_inp_tensor_name = ""
for node in model.graph.node:
if node.op_type == "MatMul":
l0_inp_tensor_name = node.input[0]
l0_mat_tensor_name = node.input[1]
break
assert l0_mat_tensor_name != ""
l0_weights = model.get_initializer(l0_mat_tensor_name)
assert l0_weights.shape == (784, 1024)
l0_weights_hist = Counter(l0_weights.flatten())
assert l0_weights_hist[1.0] == 401311 and l0_weights_hist[-1.0] == 401505
assert (l0_weights_hist[1.0] + l0_weights_hist[-1.0]) == 784 * 1024
l0_weights_rand = np.random.randn(784, 1024)
model.set_initializer(l0_mat_tensor_name, l0_weights_rand)
assert (model.get_initializer(l0_mat_tensor_name) == l0_weights_rand).all()
l0_inp_tensor_name = "32"
assert l0_inp_tensor_name != ""
inp_cons = model.find_consumer(l0_inp_tensor_name)
assert inp_cons.op_type == "MatMul"
out_prod = model.find_producer(l0_inp_tensor_name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment