diff --git a/src/finn/custom_op/fpgadataflow/eltwise.py b/src/finn/custom_op/fpgadataflow/eltwise.py index f17eb6fdf3375711607f7fd45668982358aacb8f..2395d451d171f8d70a5fa2f5980178d95e5fe131 100644 --- a/src/finn/custom_op/fpgadataflow/eltwise.py +++ b/src/finn/custom_op/fpgadataflow/eltwise.py @@ -59,6 +59,15 @@ class StreamingEltwise(HLSCustomOp): my_attrs.update(super().get_nodeattr_types()) return my_attrs + def get_eltwise_op_lambda(self): + eltwise_op = self.get_nodeattr("eltwiseOp") + eltwise_ops = { + "Add": "[](auto a, auto b) { return a + b; }", + "Sub": "[](auto a, auto b) { return a - b; }", + "AbsDiff": "[](auto a, auto b) { return a>b? a-b : b-a; }", + } + return eltwise_ops[eltwise_op] + def get_normal_input_shape(self, ind=0): ich = self.get_nodeattr("NumChannels") vecs = list(self.get_nodeattr("numInputVectors")) @@ -338,7 +347,8 @@ class StreamingEltwise(HLSCustomOp): slice_in0 = "Slice<%s>" % elem_hls_type_0 slice_in1 = "Slice<%s>" % elem_hls_type_1 slice_out = "Slice<%s>" % out_hls_type - eltwise_op_str = "%sEltwiseFunction<%s, %s, %s>()" % ( + eltwise_op_str = self.get_eltwise_op_lambda() + "%sEltwiseFunction<%s, %s, %s>()" % ( op, elem_hls_type_0, elem_hls_type_1,