From a6c863cd4bfc0b2d5a726ba85147039ce41c4473 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Mon, 6 Jul 2020 14:22:29 +0100
Subject: [PATCH] [Test] test RemoveStaticGraphInputs and
 RemoveUnusedInitAndValueInfo

---
 tests/brevitas/test_brevitas_cnv.py                   |  5 ++++-
 tests/brevitas/test_brevitas_fc.py                    |  4 ++++
 tests/end2end/test_end2end_cnv_w1a1.py                |  9 ++++++++-
 tests/end2end/test_end2end_cnv_w2a2.py                |  9 ++++++++-
 tests/end2end/test_end2end_tfc_w1a1.py                |  9 ++++++++-
 tests/end2end/test_end2end_tfc_w1a2.py                |  9 ++++++++-
 tests/end2end/test_end2end_tfc_w2a2.py                |  9 ++++++++-
 .../transformation/streamline/test_streamline_cnv.py  | 11 ++++++++++-
 tests/transformation/streamline/test_streamline_fc.py | 11 ++++++++++-
 9 files changed, 68 insertions(+), 8 deletions(-)

diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py
index f91ca600d..764671bee 100644
--- a/tests/brevitas/test_brevitas_cnv.py
+++ b/tests/brevitas/test_brevitas_cnv.py
@@ -38,7 +38,7 @@ import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
 from finn.transformation.infer_shapes import InferShapes
-from finn.transformation.general import GiveUniqueNodeNames
+from finn.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs
 from finn.transformation.double_to_single_float import DoubleToSingleFloat
 from finn.util.test import get_test_model_trained
 
@@ -57,6 +57,9 @@ def test_brevitas_cnv_export_exec(wbits, abits):
     model = model.transform(DoubleToSingleFloat())
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
+    model = model.transform(RemoveStaticGraphInputs())
+    assert len(model.graph.input) == 1
+    assert len(model.graph.output) == 1
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
     input_tensor = np.load(fn)["arr_0"].astype(np.float32)
     input_tensor = input_tensor / 255
diff --git a/tests/brevitas/test_brevitas_fc.py b/tests/brevitas/test_brevitas_fc.py
index db18d91e3..9369b2538 100644
--- a/tests/brevitas/test_brevitas_fc.py
+++ b/tests/brevitas/test_brevitas_fc.py
@@ -39,6 +39,7 @@ import torch
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import RemoveStaticGraphInputs
 from finn.transformation.infer_shapes import InferShapes
 from finn.util.basic import make_build_dir
 from finn.util.test import get_test_model_trained
@@ -63,6 +64,9 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits):
     model = ModelWrapper(finn_onnx)
     model = model.transform(InferShapes())
     model = model.transform(FoldConstants())
+    model = model.transform(RemoveStaticGraphInputs())
+    assert len(model.graph.input) == 1
+    assert len(model.graph.output) == 1
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
diff --git a/tests/end2end/test_end2end_cnv_w1a1.py b/tests/end2end/test_end2end_cnv_w1a1.py
index e3f281904..c54079785 100644
--- a/tests/end2end/test_end2end_cnv_w1a1.py
+++ b/tests/end2end/test_end2end_cnv_w1a1.py
@@ -42,7 +42,12 @@ from finn.transformation.double_to_single_float import DoubleToSingleFloat
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedInitAndValueInfo,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.streamline import Streamline
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
 from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
@@ -97,6 +102,7 @@ def test_end2end_cnv_w1a1_import_and_tidy():
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_cnv_w1a1_tidy.onnx")
 
 
@@ -108,6 +114,7 @@ def test_end2end_cnv_w1a1_streamline():
     model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
     model = model.transform(ConvertBipolarMatMulToXnorPopcount())
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedInitAndValueInfo())
     model.save(build_dir + "/end2end_cnv_w1a1_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_cnv_w2a2.py b/tests/end2end/test_end2end_cnv_w2a2.py
index 31ccebd4c..ae1771fab 100644
--- a/tests/end2end/test_end2end_cnv_w2a2.py
+++ b/tests/end2end/test_end2end_cnv_w2a2.py
@@ -42,7 +42,12 @@ from finn.transformation.double_to_single_float import DoubleToSingleFloat
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedInitAndValueInfo,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.streamline import Streamline
 from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
 import finn.transformation.streamline.absorb as absorb
@@ -96,6 +101,7 @@ def test_end2end_cnv_w2a2_import_and_tidy():
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_cnv_w2a2_tidy.onnx")
 
 
@@ -106,6 +112,7 @@ def test_end2end_cnv_w2a2_streamline():
     model = model.transform(MakeMaxPoolNHWC())
     model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedInitAndValueInfo())
     model.save(build_dir + "/end2end_cnv_w2a2_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w1a1.py b/tests/end2end/test_end2end_tfc_w1a1.py
index ebfed5e57..8dafc7fec 100644
--- a/tests/end2end/test_end2end_tfc_w1a1.py
+++ b/tests/end2end/test_end2end_tfc_w1a1.py
@@ -63,7 +63,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 )
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedInitAndValueInfo,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
@@ -98,12 +103,14 @@ def test_end2end_tfc_w1a1_import_and_tidy():
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
     model = model.transform(InferDataTypes())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
 
 
 def test_end2end_tfc_w1a1_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedInitAndValueInfo())
     model.save(build_dir + "/end2end_tfc_w1a1_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w1a2.py b/tests/end2end/test_end2end_tfc_w1a2.py
index d4c005a86..cba558d83 100644
--- a/tests/end2end/test_end2end_tfc_w1a2.py
+++ b/tests/end2end/test_end2end_tfc_w1a2.py
@@ -61,7 +61,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 )
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedInitAndValueInfo,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
@@ -93,12 +98,14 @@ def test_end2end_tfc_w1a2_import_and_tidy():
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
     model = model.transform(InferDataTypes())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
 
 
 def test_end2end_tfc_w1a2_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedInitAndValueInfo())
     model.save(build_dir + "/end2end_tfc_w1a2_streamlined.onnx")
 
 
diff --git a/tests/end2end/test_end2end_tfc_w2a2.py b/tests/end2end/test_end2end_tfc_w2a2.py
index 19d3f86e0..44a8a8b17 100644
--- a/tests/end2end/test_end2end_tfc_w2a2.py
+++ b/tests/end2end/test_end2end_tfc_w2a2.py
@@ -61,7 +61,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
 )
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
 from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedInitAndValueInfo,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_datatypes import InferDataTypes
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
@@ -93,12 +98,14 @@ def test_end2end_tfc_w2a2_import_and_tidy():
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
     model = model.transform(InferDataTypes())
+    model = model.transform(RemoveStaticGraphInputs())
     model.save(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
 
 
 def test_end2end_tfc_w2a2_streamline():
     model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedInitAndValueInfo())
     model.save(build_dir + "/end2end_tfc_w2a2_streamlined.onnx")
 
 
diff --git a/tests/transformation/streamline/test_streamline_cnv.py b/tests/transformation/streamline/test_streamline_cnv.py
index 103967dfb..050e12a97 100644
--- a/tests/transformation/streamline/test_streamline_cnv.py
+++ b/tests/transformation/streamline/test_streamline_cnv.py
@@ -34,7 +34,12 @@ import pkg_resources as pk
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedInitAndValueInfo,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
 from finn.util.test import get_test_model_trained
@@ -62,6 +67,7 @@ def test_streamline_cnv(size, wbits, abits):
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     # load one of the test vectors
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
     input_tensor = np.load(fn)["arr_0"].astype(np.float32)
@@ -73,6 +79,9 @@ def test_streamline_cnv(size, wbits, abits):
     expected = expected_ctx[model.graph.output[0].name]
     # model.save("orig_cnv.onnx")
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedInitAndValueInfo())
+    assert len(model.graph.initializer) == 21
+    assert len(model.graph.value_info) == 43
     # model.save("streamlined_cnv.onnx")
     assert len(model.graph.node) == 23
     produced_ctx = oxe.execute_onnx(model, input_dict, True)
diff --git a/tests/transformation/streamline/test_streamline_fc.py b/tests/transformation/streamline/test_streamline_fc.py
index c68561239..cb356dc38 100644
--- a/tests/transformation/streamline/test_streamline_fc.py
+++ b/tests/transformation/streamline/test_streamline_fc.py
@@ -37,7 +37,12 @@ import pytest
 import finn.core.onnx_exec as oxe
 from finn.core.modelwrapper import ModelWrapper
 from finn.transformation.fold_constants import FoldConstants
-from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.general import (
+    RemoveUnusedInitAndValueInfo,
+    RemoveStaticGraphInputs,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames,
+)
 from finn.transformation.infer_shapes import InferShapes
 from finn.transformation.streamline import Streamline
 from finn.util.test import get_test_model_trained
@@ -65,6 +70,7 @@ def test_streamline_fc(size, wbits, abits):
     model = model.transform(FoldConstants())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
+    model = model.transform(RemoveStaticGraphInputs())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
@@ -73,6 +79,9 @@ def test_streamline_fc(size, wbits, abits):
     expected_ctx = oxe.execute_onnx(model, input_dict, True)
     expected = expected_ctx[model.graph.output[0].name]
     model = model.transform(Streamline())
+    model = model.transform(RemoveUnusedInitAndValueInfo())
+    assert len(model.graph.initializer) == 11
+    assert len(model.graph.value_info) == 21
     produced_ctx = oxe.execute_onnx(model, input_dict, True)
     produced = produced_ctx[model.graph.output[0].name]
     assert np.isclose(expected, produced, atol=1e-3).all()
-- 
GitLab