diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index e8f6372ab22ba994641d542d144d61c1bb766552..bb7c04566177ce23ff1b5bced346f00afe59ad0a 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -1673,9 +1673,9 @@ class InferConcatLayer(Transformation): return (model, graph_modified) -class InferStreamingEltwiseAbsDiff(Transformation): - """Convert eltwise Sub -> Abs to StreamingEltwise layer - with AbsDiffEltwise op.""" +class InferStreamingEltwise(Transformation): + """Convert eltwise Sub or Sub -> Abs to StreamingEltwise layer + with SubEltwise or AbsDiffEltwise op.""" def apply(self, model): graph = model.graph @@ -1701,14 +1701,14 @@ class InferStreamingEltwiseAbsDiff(Transformation): if not (idt0.is_integer() and idt1.is_integer()): continue + eltwiseOp = "Sub" + nodes_to_remove = [node] # look for a downstream Abs node res_consumer = model.find_consumer(result) - if res_consumer is None: - continue - if res_consumer.op_type != "Abs": - continue - - result = res_consumer.output[0] + if (res_consumer is not None) and (res_consumer.op_type == "Abs"): + eltwiseOp = "AbsDiff" + result = res_consumer.output[0] + nodes_to_remove.append(res_consumer) # check layout and convert if necessary in0_layout = model.get_tensor_layout(in0) @@ -1749,14 +1749,14 @@ class InferStreamingEltwiseAbsDiff(Transformation): PE=pe, inputDataType0=idt0.name, inputDataType1=idt1.name, - eltwiseOp="AbsDiff", + eltwiseOp=eltwiseOp, numInputVectors=in0_shape[:-1], name="StreamingEltwise_" + node.name, ) graph.node.insert(insert_point, new_node) # remove old nodes - graph.node.remove(node) - graph.node.remove(res_consumer) + for nd in nodes_to_remove: + graph.node.remove(nd) graph_modified = True # if graph_modified: diff --git a/tests/fpgadataflow/test_fpgadataflow_eltwise.py b/tests/fpgadataflow/test_fpgadataflow_eltwise.py index 76a68f3779b9273da589f32e90ef10e4b21d6752..bfc007421ec62c5708a27b901247255095e73887 100644 --- a/tests/fpgadataflow/test_fpgadataflow_eltwise.py +++ b/tests/fpgadataflow/test_fpgadataflow_eltwise.py @@ -21,9 +21,17 @@ from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode -def build_model(shp, dt0, dt1): +def build_model(shp, dt0, dt1, do_abs): np.random.seed(0) shp_str = str(shp) + if do_abs: + graph = """ + sub_out = Sub(in0, in1) + out0 = Abs(sub_out) + """ + else: + graph = "out0 = Sub(in0, in1)" + input = f""" < ir_version: 7, @@ -31,8 +39,7 @@ def build_model(shp, dt0, dt1): > agraph (float{shp_str} in0, float{shp_str} in1) => (float{shp_str} out0) {{ - sub_out = Sub(in0, in1) - out0 = Abs(sub_out) + {graph} }} """ model = oprs.parse_model(input) @@ -51,11 +58,13 @@ def build_model(shp, dt0, dt1): @pytest.mark.parametrize("ch", [1, 64]) # folding @pytest.mark.parametrize("fold", [-1, 2, 1]) +# include Abs output node or not +@pytest.mark.parametrize("do_abs", [True, False]) # execution mode @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) @pytest.mark.fpgadataflow @pytest.mark.vivado -def test_fpgadataflow_eltwise(dt0, ch, fold, exec_mode): +def test_fpgadataflow_eltwise(dt0, ch, fold, do_abs, exec_mode): if fold == -1: pe = 1 else: @@ -63,12 +72,12 @@ def test_fpgadataflow_eltwise(dt0, ch, fold, exec_mode): assert ch % pe == 0 dt1 = DataType["UINT8"] shp = [1, 4, 2, ch] - model = build_model(shp, dt0, dt1) + model = build_model(shp, dt0, dt1, do_abs) in0 = gen_finn_dt_tensor(dt0, shp) in1 = gen_finn_dt_tensor(dt1, shp) idict = {"in0": in0, "in1": in1} y_expected = execute_onnx(model, idict)["out0"] - model = model.transform(to_hls.InferStreamingEltwiseAbsDiff()) + model = model.transform(to_hls.InferStreamingEltwise()) assert len(model.graph.node) == 1 assert model.graph.node[0].op_type == "StreamingEltwise" getCustomOp(model.graph.node[0]).set_nodeattr("PE", pe)