From 17051adb86eb3eb5c2e51b13375dc2f6282c47ab Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 20 Oct 2022 23:36:43 +0200
Subject: [PATCH] [Test] use higher PE config for dyn conv tests

---
 ...fpgadataflow_convinputgenerator_rtl_dynamic.py | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py
index 3f8743062..d7085e849 100644
--- a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py
+++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator_rtl_dynamic.py
@@ -57,6 +57,7 @@ from finn.transformation.fpgadataflow.create_dataflow_partition import (
 )
 from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP
 from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.insert_dwc import InsertDWC
 from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
 from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
 from finn.util.basic import pyverilate_get_liveness_threshold_cycles
@@ -152,6 +153,7 @@ def config_hook(configs):
         return None
 
     def write_swg_config(sim):
+        reset_rtlsim(sim)
         for axi_name, config in configs:
             # Write config registers to the SWG/FMPadding dict
             # defines (addr, value) tuples
@@ -193,7 +195,7 @@ def test_fpgadataflow_conv_dynamic(cfg):
     k = cfg["k"]
     stride = cfg["stride"]
     ofm = cfg["ofm"]
-    idt = DataType["UINT8"]
+    idt = DataType["UINT4"]
     wdt = DataType["INT2"]
     exp_cfgs = []
     largest_model = None
@@ -236,14 +238,19 @@ def test_fpgadataflow_conv_dynamic(cfg):
     dyn_nodes = model.get_nodes_by_op_type("ConvolutionInputGenerator_rtl")
     dyn_nodes += model.get_nodes_by_op_type("FMPadding_rtl")
     for swg_node in dyn_nodes:
-        getCustomOp(swg_node).set_nodeattr("SIMD", 1)
+        getCustomOp(swg_node).set_nodeattr("SIMD", 4)
         getCustomOp(swg_node).set_nodeattr("dynamic_mode", 1)
         getCustomOp(swg_node).set_nodeattr("inFIFODepths", [16])
         getCustomOp(swg_node).set_nodeattr("outFIFODepths", [16])
     comp_nodes = model.get_nodes_by_op_type("MatrixVectorActivation")
     comp_nodes += model.get_nodes_by_op_type("VectorVectorActivation")
     for comp_node in comp_nodes:
-        getCustomOp(comp_node).set_nodeattr("PE", 1)
+        if depthwise:
+            getCustomOp(comp_node).set_nodeattr("PE", 4)
+        else:
+            getCustomOp(comp_node).set_nodeattr("SIMD", 4)
+            getCustomOp(comp_node).set_nodeattr("PE", 4)
+    model = model.transform(InsertDWC())
     model = model.transform(InsertFIFO())
     model = model.transform(GiveUniqueNodeNames())
     model = model.transform(GiveReadableTensorNames())
@@ -306,6 +313,8 @@ def test_fpgadataflow_conv_dynamic(cfg):
         ctx = {"global_in": inp.transpose(0, 2, 3, 1)}
         liveness_prev = pyverilate_get_liveness_threshold_cycles()
         os.environ["LIVENESS_THRESHOLD"] = "100000"
+        # model.set_metadata_prop("rtlsim_trace", "trace_%d.vcd" % idim)
+        # import pdb; pdb.set_trace()
         rtlsim_exec(model, ctx, pre_hook=config_hook(configs))
         os.environ["LIVENESS_THRESHOLD"] = str(liveness_prev)
         ret = ctx["global_out"].transpose(0, 3, 1, 2)
-- 
GitLab