diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index 2c5e69b1791836c526ca7c95e91e242c01177c42..fc525c3770f3ad420ecd93f35bfdf8ccfbb2908a 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -18,20 +18,26 @@ class StreamingFCLayer_Batch(HLSCustomOp): self.resDataType = get_by_name(node.attribute, "resDataType").s.decode("utf-8") def global_includes(self, node): - pass + self.code_gen_dict["$GLOBALS$"] = [""] def defines(self, node): - pass + numReps = 2 + self.code_gen_dict["$DEFINES$"] = [ + """#define MW {}\n #define MH {}\n + #define SIMD {}\n #define PE {}\n #define numReps {}""".format( + self.MW, self.MH, self.SIMD, self.PE, numReps + ) + ] def read_npy_data(self, node): pass - + def strm_decl(self, node): pass def docompute(self, node): pass - + def dataoutstrm(self, node): pass