diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index f1b14d7cd0c797cbd01f21cc69b867b711e9184b..88754b974790de38b4eddff3d9a2d0d80491ca23 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -1180,8 +1180,9 @@ class InferDuplicateStreamsLayer(Transformation): for node in graph.node: node_ind += 1 successors = model.find_consumers(node.output[0]) - if successors is not None and len(successors) == 2: + if successors is not None and len(successors) >= 2: output_tensor = node.output[0] + n_outputs = len(successors) dt = model.get_tensor_datatype(output_tensor) @@ -1192,7 +1193,7 @@ class InferDuplicateStreamsLayer(Transformation): # create clone tensors out_shape = model.get_tensor_shape(output_tensor) out_tensor_clones = [] - for i in range(2): + for i in range(n_outputs): clone = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, out_shape ) @@ -1215,6 +1216,7 @@ class InferDuplicateStreamsLayer(Transformation): PE=pe, inputDataType=dt.name, numInputVectors=vecs, + NumOutputStreams=n_outputs, name="DuplicateStreams_Batch_" + node.name, ) diff --git a/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py index 4592fc2f8a2e908a91dd74fef58fb03467f9e7c0..1faf647df225853cf026a49adbfc6bb9d8f1b670 100644 --- a/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py +++ b/tests/fpgadataflow/test_fpgadataflow_duplicatestreams.py @@ -48,26 +48,32 @@ from finn.transformation.infer_shapes import InferShapes from finn.util.basic import gen_finn_dt_tensor -def make_dupstreams_modelwrapper(ch, pe, idim, idt): +def make_dupstreams_modelwrapper(ch, pe, idim, idt, n_dupl): shape = [1, idim, idim, ch] inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, shape) - outp0 = helper.make_tensor_value_info("outp0", TensorProto.FLOAT, shape) - outp1 = helper.make_tensor_value_info("outp1", TensorProto.FLOAT, shape) + out_names = [] + out_vi = [] + for i in range(n_dupl): + outp_name = "outp%d" % i + out_names.append(outp_name) + out_vi.append( + helper.make_tensor_value_info(outp_name, TensorProto.FLOAT, shape) + ) dupstrm_node = helper.make_node( "DuplicateStreams_Batch", ["inp"], - ["outp0", "outp1"], + out_names, domain="finn.custom_op.fpgadataflow", backend="fpgadataflow", NumChannels=ch, - NumOutputStreams=2, + NumOutputStreams=n_dupl, PE=pe, inputDataType=idt.name, numInputVectors=[1, idim, idim], ) graph = helper.make_graph( - nodes=[dupstrm_node], name="graph", inputs=[inp], outputs=[outp0, outp1] + nodes=[dupstrm_node], name="graph", inputs=[inp], outputs=out_vi ) model = helper.make_model(graph, producer_name="addstreams-model") @@ -93,10 +99,12 @@ def prepare_inputs(input_tensor, idt): @pytest.mark.parametrize("fold", [-1, 2, 1]) # image dimension @pytest.mark.parametrize("imdim", [7]) +# amount of duplication +@pytest.mark.parametrize("n_dupl", [2, 3]) # execution mode @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) @pytest.mark.vivado -def test_fpgadataflow_duplicatestreams(idt, ch, fold, imdim, exec_mode): +def test_fpgadataflow_duplicatestreams(idt, ch, fold, imdim, n_dupl, exec_mode): if fold == -1: pe = 1 else: @@ -106,7 +114,7 @@ def test_fpgadataflow_duplicatestreams(idt, ch, fold, imdim, exec_mode): # generate input data x = gen_finn_dt_tensor(idt, (1, imdim, imdim, ch)) - model = make_dupstreams_modelwrapper(ch, pe, imdim, idt) + model = make_dupstreams_modelwrapper(ch, pe, imdim, idt, n_dupl) if exec_mode == "cppsim": model = model.transform(PrepareCppSim()) @@ -124,12 +132,11 @@ def test_fpgadataflow_duplicatestreams(idt, ch, fold, imdim, exec_mode): # prepare input data and execute input_dict = prepare_inputs(x, idt) output_dict = oxe.execute_onnx(model, input_dict) - y0 = output_dict["outp0"] - y1 = output_dict["outp1"] - expected_y = x - assert (y0 == expected_y).all(), exec_mode + " failed" - assert (y1 == expected_y).all(), exec_mode + " failed" + expected_y = x + for i in range(n_dupl): + y = output_dict["outp%d" % i] + assert (y == expected_y).all(), exec_mode + " failed" if exec_mode == "rtlsim": node = model.get_nodes_by_op_type("DuplicateStreams_Batch")[0]