From 205528b689894396fb0f709b1850ce74abf5c48a Mon Sep 17 00:00:00 2001
From: icolbert <Ian.Colbert@amd.com>
Date: Tue, 21 Feb 2023 15:59:45 -0800
Subject: [PATCH] Adding unit test

---
 tests/end2end/test_end2end_bnn_pynq.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/tests/end2end/test_end2end_bnn_pynq.py b/tests/end2end/test_end2end_bnn_pynq.py
index 858363d6d..a627606f4 100644
--- a/tests/end2end/test_end2end_bnn_pynq.py
+++ b/tests/end2end/test_end2end_bnn_pynq.py
@@ -89,6 +89,8 @@ from finn.transformation.streamline.reorder import (
     MakeMaxPoolNHWC,
     MoveScalarLinearPastInvariants,
 )
+from finn.transformation.fpgadataflow.minimize_accumulator_width import MinimizeAccumulatorWidth
+from finn.transformation.fpgadataflow.minimize_weight_bit_width import MinimizeWeightBitWidth
 from finn.util.basic import get_finn_root
 from finn.util.gdrive import upload_to_end2end_dashboard
 from finn.util.pytorch import ToTensor
@@ -511,11 +513,23 @@ class TestEnd2End:
         model = folding_fxn(model)
         model.save(get_checkpoint_name(topology, wbits, abits, QONNX_export, "fold"))
 
+    def test_minimize_bit_width(self, topology, wbits, abits, QONNX_export):
+        prev_chkpt_name = get_checkpoint_name(
+            topology, wbits, abits, QONNX_export, "fold"
+        )
+        model = load_test_checkpoint_or_skip(prev_chkpt_name)
+        model = model.transform(MinimizeAccumulatorWidth())
+        model = model.transform(MinimizeWeightBitWidth())
+        curr_chkpt_name = get_checkpoint_name(
+            topology, wbits, abits, QONNX_export, "minimize_bit_width"
+        )
+        model.save(curr_chkpt_name)
+
     @pytest.mark.slow
     @pytest.mark.vivado
     def test_cppsim(self, topology, wbits, abits, QONNX_export):
         prev_chkpt_name = get_checkpoint_name(
-            topology, wbits, abits, QONNX_export, "fold"
+            topology, wbits, abits, QONNX_export, "minimize_bit_width"
         )
         model = load_test_checkpoint_or_skip(prev_chkpt_name)
         model = model.transform(PrepareCppSim())
-- 
GitLab