Skip to content
Snippets Groups Projects
Commit c5214879 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Eltwise] register op and fix bugs

parent 214612ff
No related branches found
No related tags found
No related merge requests found
......@@ -38,6 +38,7 @@ from finn.custom_op.fpgadataflow.convolutioninputgenerator1d import (
)
from finn.custom_op.fpgadataflow.downsampler import DownSampler
from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
from finn.custom_op.fpgadataflow.eltwise import StreamingEltwise
from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch
from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
from finn.custom_op.fpgadataflow.iodma import IODMA
......@@ -85,3 +86,4 @@ custom_op["UpsampleNearestNeighbour_Batch"] = UpsampleNearestNeighbour_Batch
custom_op["Lookup"] = Lookup
custom_op["StreamingConcat"] = StreamingConcat
custom_op["CheckSum"] = CheckSum
custom_op["StreamingEltwise"] = StreamingEltwise
......@@ -216,7 +216,7 @@ class StreamingEltwise(HLSCustomOp):
assert (
inp.shape == exp_ishape
), """Input0 shape doesn't match expected shape ."""
export_idt = self.get_input_datatype()
export_idt0 = self.get_input_datatype(0)
# reshape input into folded form
inp = inp.reshape(folded_ishape)
# make copy before saving array
......@@ -229,7 +229,7 @@ class StreamingEltwise(HLSCustomOp):
assert (
inp.shape == exp_ishape
), """Input1 shape doesn't match expected shape ."""
export_idt = self.get_input_datatype()
export_idt1 = self.get_input_datatype(1)
# reshape input into folded form
inp = inp.reshape(folded_ishape)
# make copy before saving array
......@@ -246,12 +246,13 @@ class StreamingEltwise(HLSCustomOp):
), "cppsim did not produce expected output shape"
elif mode == "rtlsim":
sim = self.get_rtlsim()
nbits = self.get_instream_width()
nbits0 = self.get_instream_width(0)
nbits1 = self.get_instream_width(1)
rtlsim_inp0 = npy_to_rtlsim_input(
"{}/input_0.npy".format(code_gen_dir), export_idt, nbits
"{}/input_0.npy".format(code_gen_dir), export_idt0, nbits0
)
rtlsim_inp1 = npy_to_rtlsim_input(
"{}/input_1.npy".format(code_gen_dir), export_idt, nbits
"{}/input_1.npy".format(code_gen_dir), export_idt1, nbits1
)
super().reset_rtlsim(sim)
super().toggle_clk(sim)
......@@ -283,7 +284,7 @@ class StreamingEltwise(HLSCustomOp):
def global_includes(self):
self.code_gen_dict["$GLOBALS$"] = [
'#include "eltwise.hpp"',
'#include "interpret.hpp',
'#include "interpret.hpp"',
]
def defines(self, var):
......@@ -344,7 +345,7 @@ class StreamingEltwise(HLSCustomOp):
out_hls_type,
)
self.code_gen_dict["$DOCOMPUTE$"] = [
"""{}<{}, {}, {}, {}, {}, {}, {}>(in0, in1, out);""".format(
"""{}<{}, {}, {}, {}, {}, {}>(in0, in1, out, {});""".format(
"StreamingEltwise",
self.get_nodeattr("NumChannels"),
self.get_nodeattr("PE"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment