Skip to content
Snippets Groups Projects
Commit 9b8c3032 authored by mmrahorovic's avatar mmrahorovic
Browse files

[custom_op]: Matrix_Vector_Activate_Batch instantiated instead of wrapper from fclayer.h

parent e570cb65
No related branches found
No related tags found
No related merge requests found
...@@ -1010,16 +1010,12 @@ class StreamingFCLayer_Batch(HLSCustomOp): ...@@ -1010,16 +1010,12 @@ class StreamingFCLayer_Batch(HLSCustomOp):
self.code_gen_dict["$GLOBALS$"] += ['#include "activations.hpp"'] self.code_gen_dict["$GLOBALS$"] += ['#include "activations.hpp"']
mem_mode = self.get_nodeattr("mem_mode") mem_mode = self.get_nodeattr("mem_mode")
if mem_mode == "const": if mem_mode not in ["const", "decoupled", "external"]:
# self.code_gen_dict["$GLOBALS$"] += ['#include "params.h"']
pass
elif mem_mode == "decoupled" or mem_mode == "external":
self.code_gen_dict["$GLOBALS$"] += ['#include "mvau.hpp"']
else:
raise Exception( raise Exception(
"""Please set mem_mode to "const", "decoupled", or "external", """Please set mem_mode to "const", "decoupled", or "external",
currently no other parameter value is supported!""" currently no other parameter value is supported!"""
) )
self.code_gen_dict["$GLOBALS$"] += ['#include "mvau.hpp"']
if self.calc_tmem() != 0: if self.calc_tmem() != 0:
# TODO find a better way of checking for no pregenerated thresholds # TODO find a better way of checking for no pregenerated thresholds
self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"'] self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"']
...@@ -1123,11 +1119,9 @@ class StreamingFCLayer_Batch(HLSCustomOp): ...@@ -1123,11 +1119,9 @@ class StreamingFCLayer_Batch(HLSCustomOp):
else: else:
threshs = "threshs" threshs = "threshs"
if mem_mode == "const": if mem_mode == "const":
node = self.onnx_node
self.code_gen_dict["$DOCOMPUTE$"] = [ self.code_gen_dict["$DOCOMPUTE$"] = [
"""{}<MW1, MH1, SIMD1, PE1, {}, {}, {}> """Matrix_Vector_Activate_Batch<MW1, MH1, SIMD1, PE1, 1, {}, {}, {}>
(in0, out, weights, {}, numReps, {});""".format( (in0, out, weights, {}, numReps, {});""".format(
node.op_type,
tmpl_args["TSrcI"], tmpl_args["TSrcI"],
tmpl_args["TDstI"], tmpl_args["TDstI"],
tmpl_args["TWeightI"], tmpl_args["TWeightI"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment