From 9b8c303292e742de235f8dba8fd90680ef9fa473 Mon Sep 17 00:00:00 2001
From: mmrahorovic <mmrahorovic@hotmail.com>
Date: Tue, 7 Jun 2022 16:37:02 +0100
Subject: [PATCH] [custom_op]: Matrix_Vector_Activate_Batch instantiated
 instead of wrapper from fclayer.h

---
 .../custom_op/fpgadataflow/streamingfclayer_batch.py | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 2a47c8d80..9cf758a15 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -1010,16 +1010,12 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         self.code_gen_dict["$GLOBALS$"] += ['#include "activations.hpp"']
 
         mem_mode = self.get_nodeattr("mem_mode")
-        if mem_mode == "const":
-            # self.code_gen_dict["$GLOBALS$"] += ['#include "params.h"']
-            pass
-        elif mem_mode == "decoupled" or mem_mode == "external":
-            self.code_gen_dict["$GLOBALS$"] += ['#include "mvau.hpp"']
-        else:
+        if mem_mode not in ["const", "decoupled", "external"]:
             raise Exception(
                 """Please set mem_mode to "const", "decoupled", or "external",
                 currently no other parameter value is supported!"""
             )
+        self.code_gen_dict["$GLOBALS$"] += ['#include "mvau.hpp"']
         if self.calc_tmem() != 0:
             # TODO find a better way of checking for no pregenerated thresholds
             self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
@@ -1123,11 +1119,9 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         else:
             threshs = "threshs"
         if mem_mode == "const":
-            node = self.onnx_node
             self.code_gen_dict["$DOCOMPUTE$"] = [
-                """{}<MW1, MH1, SIMD1, PE1, {}, {}, {}>
+                """Matrix_Vector_Activate_Batch<MW1, MH1, SIMD1, PE1, 1, {}, {}, {}>
                 (in0, out, weights, {}, numReps, {});""".format(
-                    node.op_type,
                     tmpl_args["TSrcI"],
                     tmpl_args["TDstI"],
                     tmpl_args["TWeightI"],
-- 
GitLab