From 5a6c8a48e5b04e675a19b843c06f97c7fb032ad1 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sun, 3 Nov 2019 17:49:14 +0000
Subject: [PATCH] [Test] add test_const_folding_shapes

---
 tests/test_fold_constants.py | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/tests/test_fold_constants.py b/tests/test_fold_constants.py
index ed6bb3815..894887b0c 100644
--- a/tests/test_fold_constants.py
+++ b/tests/test_fold_constants.py
@@ -1,14 +1,19 @@
+import os
 from pkgutil import get_data
 
+import brevitas.onnx as bo
 import numpy as np
 import onnx
 import onnx.numpy_helper as np_helper
+from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
 import finn.transformation.fold_constants as fc
 import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
 
+export_onnx_path = "test_output_lfc.onnx"
+
 
 def test_const_folding():
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
@@ -24,3 +29,15 @@ def test_const_folding():
     assert np.isclose(
         np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3
     ).all()
+
+
+def test_const_folding_shapes():
+    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
+    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform_single(si.infer_shapes)
+    model = model.transform_repeated(fc.fold_constants)
+    assert model.graph.node[0].op_type == "Reshape"
+    assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28]
+    assert list(model.get_tensor_shape("27")) == [1, 784]
+    os.remove(export_onnx_path)
-- 
GitLab