Skip to content
Snippets Groups Projects
Commit c5214879 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Eltwise] register op and fix bugs

parent 214612ff
No related branches found
No related tags found
No related merge requests found
...@@ -38,6 +38,7 @@ from finn.custom_op.fpgadataflow.convolutioninputgenerator1d import ( ...@@ -38,6 +38,7 @@ from finn.custom_op.fpgadataflow.convolutioninputgenerator1d import (
) )
from finn.custom_op.fpgadataflow.downsampler import DownSampler from finn.custom_op.fpgadataflow.downsampler import DownSampler
from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
from finn.custom_op.fpgadataflow.eltwise import StreamingEltwise
from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch
from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
from finn.custom_op.fpgadataflow.iodma import IODMA from finn.custom_op.fpgadataflow.iodma import IODMA
...@@ -85,3 +86,4 @@ custom_op["UpsampleNearestNeighbour_Batch"] = UpsampleNearestNeighbour_Batch ...@@ -85,3 +86,4 @@ custom_op["UpsampleNearestNeighbour_Batch"] = UpsampleNearestNeighbour_Batch
custom_op["Lookup"] = Lookup custom_op["Lookup"] = Lookup
custom_op["StreamingConcat"] = StreamingConcat custom_op["StreamingConcat"] = StreamingConcat
custom_op["CheckSum"] = CheckSum custom_op["CheckSum"] = CheckSum
custom_op["StreamingEltwise"] = StreamingEltwise
...@@ -216,7 +216,7 @@ class StreamingEltwise(HLSCustomOp): ...@@ -216,7 +216,7 @@ class StreamingEltwise(HLSCustomOp):
assert ( assert (
inp.shape == exp_ishape inp.shape == exp_ishape
), """Input0 shape doesn't match expected shape .""" ), """Input0 shape doesn't match expected shape ."""
export_idt = self.get_input_datatype() export_idt0 = self.get_input_datatype(0)
# reshape input into folded form # reshape input into folded form
inp = inp.reshape(folded_ishape) inp = inp.reshape(folded_ishape)
# make copy before saving array # make copy before saving array
...@@ -229,7 +229,7 @@ class StreamingEltwise(HLSCustomOp): ...@@ -229,7 +229,7 @@ class StreamingEltwise(HLSCustomOp):
assert ( assert (
inp.shape == exp_ishape inp.shape == exp_ishape
), """Input1 shape doesn't match expected shape .""" ), """Input1 shape doesn't match expected shape ."""
export_idt = self.get_input_datatype() export_idt1 = self.get_input_datatype(1)
# reshape input into folded form # reshape input into folded form
inp = inp.reshape(folded_ishape) inp = inp.reshape(folded_ishape)
# make copy before saving array # make copy before saving array
...@@ -246,12 +246,13 @@ class StreamingEltwise(HLSCustomOp): ...@@ -246,12 +246,13 @@ class StreamingEltwise(HLSCustomOp):
), "cppsim did not produce expected output shape" ), "cppsim did not produce expected output shape"
elif mode == "rtlsim": elif mode == "rtlsim":
sim = self.get_rtlsim() sim = self.get_rtlsim()
nbits = self.get_instream_width() nbits0 = self.get_instream_width(0)
nbits1 = self.get_instream_width(1)
rtlsim_inp0 = npy_to_rtlsim_input( rtlsim_inp0 = npy_to_rtlsim_input(
"{}/input_0.npy".format(code_gen_dir), export_idt, nbits "{}/input_0.npy".format(code_gen_dir), export_idt0, nbits0
) )
rtlsim_inp1 = npy_to_rtlsim_input( rtlsim_inp1 = npy_to_rtlsim_input(
"{}/input_1.npy".format(code_gen_dir), export_idt, nbits "{}/input_1.npy".format(code_gen_dir), export_idt1, nbits1
) )
super().reset_rtlsim(sim) super().reset_rtlsim(sim)
super().toggle_clk(sim) super().toggle_clk(sim)
...@@ -283,7 +284,7 @@ class StreamingEltwise(HLSCustomOp): ...@@ -283,7 +284,7 @@ class StreamingEltwise(HLSCustomOp):
def global_includes(self): def global_includes(self):
self.code_gen_dict["$GLOBALS$"] = [ self.code_gen_dict["$GLOBALS$"] = [
'#include "eltwise.hpp"', '#include "eltwise.hpp"',
'#include "interpret.hpp', '#include "interpret.hpp"',
] ]
def defines(self, var): def defines(self, var):
...@@ -344,7 +345,7 @@ class StreamingEltwise(HLSCustomOp): ...@@ -344,7 +345,7 @@ class StreamingEltwise(HLSCustomOp):
out_hls_type, out_hls_type,
) )
self.code_gen_dict["$DOCOMPUTE$"] = [ self.code_gen_dict["$DOCOMPUTE$"] = [
"""{}<{}, {}, {}, {}, {}, {}, {}>(in0, in1, out);""".format( """{}<{}, {}, {}, {}, {}, {}>(in0, in1, out, {});""".format(
"StreamingEltwise", "StreamingEltwise",
self.get_nodeattr("NumChannels"), self.get_nodeattr("NumChannels"),
self.get_nodeattr("PE"), self.get_nodeattr("PE"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment