From ba9f5fc97d76fc6ae0c1d275cce58a691d553ad2 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 22 Jul 2020 14:32:24 +0100
Subject: [PATCH] [CustomOp&Test] Update get_exp_cycles for pool batch and add
 test

---
 src/finn/custom_op/fpgadataflow/pool_batch.py          | 10 ++++++++--
 tests/fpgadataflow/test_convert_to_hls_pool_batch.py   |  9 +++++++++
 tests/fpgadataflow/test_fpgadataflow_addstreams.py     |  6 ++++--
 .../fpgadataflow/test_fpgadataflow_channelwise_ops.py  |  6 ++++--
 4 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/pool_batch.py b/src/finn/custom_op/fpgadataflow/pool_batch.py
index d5efe3e81..3b0f3fbf0 100644
--- a/src/finn/custom_op/fpgadataflow/pool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/pool_batch.py
@@ -137,8 +137,14 @@ class Pool_Batch(HLSCustomOp):
         return np.prod(folded_oshape[1:-1])
 
     def get_exp_cycles(self):
-        # Channels/PE * batch size * odim * odim
-        return np.prod(self.get_folded_output_shape()[:-1])
+        # (Channels * kernel * kernel) / PE * odim * odim * batch_size
+        ifm_ch = self.get_nodeattr("Channels")
+        pe = self.get_nodeattr("PE")
+        k = self.get_nodeattr("KernelSize")
+        odim = self.get_nodeattr("OutImgDim")
+        batch_size = self.get_nodeattr("BatchSize")
+        exp_cycles = (ifm_ch * k * k) / pe * odim * odim * batch_size
+        return int(exp_cycles)
 
     def get_instream_width(self):
         dt_bits = self.get_input_datatype().bitwidth()
diff --git a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
index aba973051..2915ba48c 100644
--- a/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
+++ b/tests/fpgadataflow/test_convert_to_hls_pool_batch.py
@@ -44,6 +44,7 @@ from finn.transformation.general import GiveUniqueNodeNames
 from finn.custom_op.registry import getCustomOp
 from finn.util.basic import gen_finn_dt_tensor
 from finn.transformation.infer_shapes import InferShapes
+from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer
 
 
 def make_single_maxpool_modelwrapper(k, stride, pad, ifm_ch, ifm_dim, ofm_dim, idt):
@@ -210,3 +211,11 @@ def test_convert_to_hls_pool_batch(
             assert len(new_model.graph.node) == 5
     else:
         assert len(new_model.graph.node) == 1
+
+    if exec_mode == "rtlsim":
+        node = new_model.get_nodes_by_op_type("Pool_Batch")[0]
+        inst = getCustomOp(node)
+        sim_cycles = inst.get_nodeattr("sim_cycles")
+        exp_cycles_dict = new_model.analysis(exp_cycles_per_layer)
+        exp_cycles = exp_cycles_dict[str(node)]
+        assert np.isclose(exp_cycles, sim_cycles, atol=10)
diff --git a/tests/fpgadataflow/test_fpgadataflow_addstreams.py b/tests/fpgadataflow/test_fpgadataflow_addstreams.py
index 38f3584e9..70cf1f109 100644
--- a/tests/fpgadataflow/test_fpgadataflow_addstreams.py
+++ b/tests/fpgadataflow/test_fpgadataflow_addstreams.py
@@ -130,8 +130,10 @@ def test_fpgadataflow_addstreams(idt, ch, fold, exec_mode):
     assert (y_produced == y_expected).all(), exec_mode + " failed"
 
     if exec_mode == "rtlsim":
-        inst = getCustomOp(model.graph.node[0])
+        node = model.get_nodes_by_op_type("AddStreams_Batch")[0]
+        inst = getCustomOp(node)
         sim_cycles = inst.get_nodeattr("sim_cycles")
         exp_cycles_dict = model.analysis(exp_cycles_per_layer)
-        exp_cycles = exp_cycles_dict[str(model.graph.node[0])]
+        exp_cycles = exp_cycles_dict[str(node)]
         assert np.isclose(exp_cycles, sim_cycles, atol=10)
+        assert exp_cycles != 0
diff --git a/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py b/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py
index d93636942..1a5aa82e0 100644
--- a/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py
+++ b/tests/fpgadataflow/test_fpgadataflow_channelwise_ops.py
@@ -157,8 +157,10 @@ def test_fpgadataflow_channelwise_ops(idt, act, pdt, nf, ich, func, vecs, exec_m
         hls_synt_res_est = model.analysis(hls_synth_res_estimation)
         assert "ChannelwiseOp_Batch_0" in hls_synt_res_est
 
-        inst = getCustomOp(model.graph.node[0])
+        node = model.get_nodes_by_op_type("ChannelwiseOp_Batch")[0]
+        inst = getCustomOp(node)
         sim_cycles = inst.get_nodeattr("sim_cycles")
         exp_cycles_dict = model.analysis(exp_cycles_per_layer)
-        exp_cycles = exp_cycles_dict[str(model.graph.node[0])]
+        exp_cycles = exp_cycles_dict[str(node)]
         assert np.isclose(exp_cycles, sim_cycles, atol=10)
+        assert exp_cycles != 0
-- 
GitLab