diff --git a/src/finn/custom_op/fpgadataflow/eltwise.py b/src/finn/custom_op/fpgadataflow/eltwise.py index 2395d451d171f8d70a5fa2f5980178d95e5fe131..d8c55b228307a90af1f7fd6aeb152266a5230331 100644 --- a/src/finn/custom_op/fpgadataflow/eltwise.py +++ b/src/finn/custom_op/fpgadataflow/eltwise.py @@ -61,10 +61,19 @@ class StreamingEltwise(HLSCustomOp): def get_eltwise_op_lambda(self): eltwise_op = self.get_nodeattr("eltwiseOp") + idt0 = self.get_input_datatype(0) + idt1 = self.get_input_datatype(1) + odt = self.get_output_datatype() + tin0 = idt0.get_hls_datatype_str() + tin1 = idt1.get_hls_datatype_str() + tout = odt.get_hls_datatype_str() eltwise_ops = { - "Add": "[](auto a, auto b) { return a + b; }", - "Sub": "[](auto a, auto b) { return a - b; }", - "AbsDiff": "[](auto a, auto b) { return a>b? a-b : b-a; }", + # "Add": "[](auto a, auto b) { return a + b; }", + # "Sub": "[](auto a, auto b) { return a - b; }", + # "AbsDiff": "[](auto a, auto b) { return a>b? a-b : b-a; }", + "Add": f"add<{tin0}, {tin1}, {tout}>()", + "Sub": f"sub<{tin0}, {tin1}, {tout}>()", + "AbsDiff": f"absdiff<{tin0}, {tin1}, {tout}>()", } return eltwise_ops[eltwise_op] @@ -296,6 +305,32 @@ class StreamingEltwise(HLSCustomOp): '#include "interpret.hpp"', ] + self.code_gen_dict["$GLOBALS$"].extend( + [ + "template<typename TI1, typename TI2, typename TO>", + "struct absdiff {", + "TO operator()(TI1 const &a, TI2 const &b) const {", + "#pragma HLS inline", + "return a>b? a-b : b-a;", + "}", + "};", + "template<typename TI1, typename TI2, typename TO>", + "struct sub {", + "TO operator()(TI1 const &a, TI2 const &b) const {", + "#pragma HLS inline", + "return a-b;", + "}", + "};", + "template<typename TI1, typename TI2, typename TO>", + "struct add {", + "TO operator()(TI1 const &a, TI2 const &b) const {", + "#pragma HLS inline", + "return a+b;", + "}", + "};", + ] + ) + def defines(self, var): self.code_gen_dict["$DEFINES$"] = []