diff --git a/tests/transformation/test_fold_constants.py b/tests/transformation/test_fold_constants.py index cd1c346593e3666ce8a89bd4248fa8436423de6d..685c14a98b9031096aaf5b244c4f484d4f308bca 100644 --- a/tests/transformation/test_fold_constants.py +++ b/tests/transformation/test_fold_constants.py @@ -65,7 +65,8 @@ def test_const_folding_shapes(): model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) - assert model.graph.node[0].op_type == "Reshape" - assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28] - assert list(model.get_tensor_shape("27")) == [1, 784] + reshape_node = model.graph.node[0] + assert reshape_node.op_type == "Reshape" + assert list(model.get_tensor_shape(reshape_node.input[0])) == [1, 1, 28, 28] + assert list(model.get_tensor_shape(reshape_node.output[0])) == [1, 784] os.remove(export_onnx_path)