From d2341c2f0a04f2da9bc3b2ac46cd6e7c5fd7fdcf Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 26 May 2021 14:36:32 +0100
Subject: [PATCH] [Test] shape inf + fold consts to get rid of Cast in ToTensor

---
 tests/end2end/test_end2end_bnn_pynq.py     | 2 ++
 tests/end2end/test_end2end_mobilenet_v1.py | 1 +
 2 files changed, 3 insertions(+)

diff --git a/tests/end2end/test_end2end_bnn_pynq.py b/tests/end2end/test_end2end_bnn_pynq.py
index d0a9aa1c0..a6e7ad642 100644
--- a/tests/end2end/test_end2end_bnn_pynq.py
+++ b/tests/end2end/test_end2end_bnn_pynq.py
@@ -347,6 +347,8 @@ class TestEnd2End:
         assert os.path.isfile(chkpt_preproc_name)
         # join preprocessing and core model
         pre_model = ModelWrapper(chkpt_preproc_name)
+        pre_model = pre_model.transform(InferShapes())
+        pre_model = pre_model.transform(FoldConstants())
         model = model.transform(MergeONNXModels(pre_model))
         # add input quantization annotation: UINT8 for all BNN-PYNQ models
         global_inp_name = model.graph.input[0].name
diff --git a/tests/end2end/test_end2end_mobilenet_v1.py b/tests/end2end/test_end2end_mobilenet_v1.py
index 6e11a30a4..79263a709 100644
--- a/tests/end2end/test_end2end_mobilenet_v1.py
+++ b/tests/end2end/test_end2end_mobilenet_v1.py
@@ -101,6 +101,7 @@ def test_end2end_mobilenet_export():
     # set input finn datatype to UINT8
     preproc_model.set_tensor_datatype(preproc_model.graph.input[0].name, DataType.UINT8)
     preproc_model = preproc_model.transform(InferShapes())
+    preproc_model = preproc_model.transform(FoldConstants())
     preproc_model = preproc_model.transform(GiveUniqueNodeNames())
     preproc_model = preproc_model.transform(GiveUniqueParameterTensors())
     preproc_model = preproc_model.transform(GiveReadableTensorNames())
-- 
GitLab