From 458788fa2a4cf0730f740afb2963fb9f5b3bea94 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Fri, 20 Mar 2020 21:29:16 +0000
Subject: [PATCH] [Test] make test_const_folding_shapes more robust to example
 chgs

---
 tests/transformation/test_fold_constants.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/tests/transformation/test_fold_constants.py b/tests/transformation/test_fold_constants.py
index cd1c34659..685c14a98 100644
--- a/tests/transformation/test_fold_constants.py
+++ b/tests/transformation/test_fold_constants.py
@@ -65,7 +65,8 @@ def test_const_folding_shapes():
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
-    assert model.graph.node[0].op_type == "Reshape"
-    assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28]
-    assert list(model.get_tensor_shape("27")) == [1, 784]
+    reshape_node = model.graph.node[0]
+    assert reshape_node.op_type == "Reshape"
+    assert list(model.get_tensor_shape(reshape_node.input[0])) == [1, 1, 28, 28]
+    assert list(model.get_tensor_shape(reshape_node.output[0])) == [1, 784]
     os.remove(export_onnx_path)
-- 
GitLab