diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index 5d18de2d18157383a3c7882febfa752d72774572..942eda19ca4c2cdbded9f906a5e7772f50acbd6e 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -43,17 +43,26 @@ def test_modelwrapper(): bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) assert model.check_all_tensor_shapes_specified() is False - inp_shape = model.get_tensor_shape("0") + inp_name = model.graph.input[0].name + inp_shape = model.get_tensor_shape(inp_name) assert inp_shape == [1, 1, 28, 28] - l0_mat_tensor_name = "33" + # find first matmul node + l0_mat_tensor_name = "" + l0_inp_tensor_name = "" + for node in model.graph.node: + if node.op_type == "MatMul": + l0_inp_tensor_name = node.input[0] + l0_mat_tensor_name = node.input[1] + break + assert l0_mat_tensor_name != "" l0_weights = model.get_initializer(l0_mat_tensor_name) assert l0_weights.shape == (784, 1024) l0_weights_hist = Counter(l0_weights.flatten()) - assert l0_weights_hist[1.0] == 401311 and l0_weights_hist[-1.0] == 401505 + assert (l0_weights_hist[1.0] + l0_weights_hist[-1.0]) == 784 * 1024 l0_weights_rand = np.random.randn(784, 1024) model.set_initializer(l0_mat_tensor_name, l0_weights_rand) assert (model.get_initializer(l0_mat_tensor_name) == l0_weights_rand).all() - l0_inp_tensor_name = "32" + assert l0_inp_tensor_name != "" inp_cons = model.find_consumer(l0_inp_tensor_name) assert inp_cons.op_type == "MatMul" out_prod = model.find_producer(l0_inp_tensor_name)