Skip to content
Snippets Groups Projects
Commit 458788fa authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] make test_const_folding_shapes more robust to example chgs

parent 372e168b
No related branches found
No related tags found
No related merge requests found
......@@ -65,7 +65,8 @@ def test_const_folding_shapes():
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
assert model.graph.node[0].op_type == "Reshape"
assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28]
assert list(model.get_tensor_shape("27")) == [1, 784]
reshape_node = model.graph.node[0]
assert reshape_node.op_type == "Reshape"
assert list(model.get_tensor_shape(reshape_node.input[0])) == [1, 1, 28, 28]
assert list(model.get_tensor_shape(reshape_node.output[0])) == [1, 784]
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment