diff --git a/tests/fpgadataflow/test_fpgadataflow_ipstitch.py b/tests/fpgadataflow/test_fpgadataflow_ipstitch.py index b830693c32afe629dd6fc70868d0bddacac4c887..a9f5bf5ffa1f816b82ef701800e92249056b7c74 100644 --- a/tests/fpgadataflow/test_fpgadataflow_ipstitch.py +++ b/tests/fpgadataflow/test_fpgadataflow_ipstitch.py @@ -54,6 +54,10 @@ from finn.util.basic import gen_finn_dt_tensor, pynq_part_map from finn.util.fpgadataflow import pyverilate_stitched_ip from finn.util.test import load_test_checkpoint_or_skip from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext +from finn.transformation.infer_data_layouts import InferDataLayouts +from finn.transformation.fpgadataflow.insert_iodma import InsertIODMA +from finn.transformation.fpgadataflow.floorplan import Floorplan + test_pynq_board = os.getenv("PYNQ_BOARD", default="Pynq-Z1") test_fpga_part = pynq_part_map[test_pynq_board] @@ -390,3 +394,19 @@ def test_fpgadataflow_ipstitch_remote_execution(): assert np.isclose(outp["outp"], x).all() except KeyError: pytest.skip("PYNQ board IP address not specified") + + +def test_fpgadataflow_ipstitch_iodma_floorplan(): + model = create_one_fc_model() + if model.graph.node[0].op_type == "StreamingDataflowPartition": + sdp_node = getCustomOp(model.graph.node[0]) + assert sdp_node.__class__.__name__ == "StreamingDataflowPartition" + assert os.path.isfile(sdp_node.get_nodeattr("model")) + model = load_test_checkpoint_or_skip(sdp_node.get_nodeattr("model")) + model = model.transform(InferDataLayouts()) + model = model.transform(InsertIODMA()) + model = model.transform(Floorplan()) + assert getCustomOp(model.graph.node[0]).get_nodeattr("partition_id") == 0 + assert getCustomOp(model.graph.node[1]).get_nodeattr("partition_id") == 2 + assert getCustomOp(model.graph.node[2]).get_nodeattr("partition_id") == 1 + model.save(ip_stitch_model_dir + "/test_fpgadataflow_ipstitch_iodma_floorplan.onnx")