diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py index 1e45a65720604144f67245b98dcbe3f6dc8363f5..a7bf9e6e6279923764009a00e2f805be1b1fa9c0 100644 --- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py +++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py @@ -26,10 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os -import shutil -import warnings +import shutil from finn.custom_op.registry import getCustomOp from finn.transformation import Transformation from finn.util.basic import gen_finn_dt_tensor, get_finn_root, make_build_dir @@ -48,14 +46,11 @@ class MakePYNQDriver(Transformation): value. """ - def __init__(self): + def __init__(self, platform): super().__init__() + self.platform = platform def apply(self, model): - vivado_pynq_proj = model.get_metadata_prop("vivado_pynq_proj") - if vivado_pynq_proj is None or (not os.path.isdir(vivado_pynq_proj)): - warnings.warn("No PYNQ project found, apply MakePYNQProject first.") - # create a temporary folder for the generated driver pynq_driver_dir = make_build_dir(prefix="pynq_driver_") model.set_metadata_prop("pynq_driver_dir", pynq_driver_dir) @@ -68,11 +63,21 @@ class MakePYNQDriver(Transformation): o_tensor_shape_normal = tuple(model.get_tensor_shape(o_tensor_name)) i_tensor_dt = model.get_tensor_datatype(i_tensor_name) o_tensor_dt = model.get_tensor_datatype(o_tensor_name) - # extract HLSCustomOp instances to get folded i/o shapes - first_node = getCustomOp(model.find_consumer(i_tensor_name)) - last_node = getCustomOp(model.find_producer(o_tensor_name)) - i_tensor_shape_folded = tuple(first_node.get_folded_input_shape()) - o_tensor_shape_folded = tuple(last_node.get_folded_output_shape()) + # handle folded i/o shapes due to differences in DMA engines + if self.platform == "zynq": + # extract HLSCustomOp instances to get folded i/o shapes + first_node = getCustomOp(model.find_consumer(i_tensor_name)) + last_node = getCustomOp(model.find_producer(o_tensor_name)) + i_tensor_shape_folded = tuple(first_node.get_folded_input_shape()) + o_tensor_shape_folded = tuple(last_node.get_folded_output_shape()) + else: + i_tensor_shape_folded = list(i_tensor_shape_normal) + i_tensor_shape_folded.insert(-1, 1) + i_tensor_shape_folded = tuple(i_tensor_shape_folded) + o_tensor_shape_folded = list(o_tensor_shape_normal) + o_tensor_shape_folded.insert(-1, 1) + o_tensor_shape_folded = tuple(o_tensor_shape_folded) + # generate dummy folded i/o tensors and their packed versions i_tensor_dummy_folded = gen_finn_dt_tensor(i_tensor_dt, i_tensor_shape_folded) o_tensor_dummy_folded = gen_finn_dt_tensor(o_tensor_dt, o_tensor_shape_folded) @@ -99,6 +104,7 @@ class MakePYNQDriver(Transformation): ret = ret.replace("[1,", "[%s," % batch_var_name) return ret + driver = driver.replace("$PLATFORM$", self.platform) driver = driver.replace("$INPUT_FINN_DATATYPE$", str(i_tensor_dt)) driver = driver.replace("$INPUT_SHAPE_NORMAL$", mss(i_tensor_shape_normal)) driver = driver.replace("$INPUT_SHAPE_FOLDED$", mss(i_tensor_shape_folded)) diff --git a/src/finn/transformation/fpgadataflow/templates.py b/src/finn/transformation/fpgadataflow/templates.py index d22163cbe13514fbfe0fb410ef39c24acd7e8301..036286046ad72248fadc45db4ec695a24e2061e7 100644 --- a/src/finn/transformation/fpgadataflow/templates.py +++ b/src/finn/transformation/fpgadataflow/templates.py @@ -104,7 +104,7 @@ from finn.core.datatype import DataType from pynq.ps import Clocks class FINNAccelDriver(): - def __init__(self, N, bitfile, platform="zynq"): + def __init__(self, N, bitfile, platform="$PLATFORM$"): \"\"\"Instantiate the FINN accelerator driver. Gets batchsize (N) as integer and path to bitfile as string.\"\"\" self.platform = platform diff --git a/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py b/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py index f3250baa5e3c1fc5ae7cecde01be0e0ca22eac0c..4a5901d320f4b4486a0adc29e1e828464e7bc5b4 100644 --- a/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py +++ b/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py @@ -156,7 +156,7 @@ def test_end2end_zynqbuild_tfc_w1a1_make_driver(): model = load_test_checkpoint_or_skip( build_dir + "/end2end_zynqbuild_tfc_w1a1_folded.onnx" ) - model = model.transform(MakePYNQDriver()) + model = model.transform(MakePYNQDriver(platform="zynq-iodma")) model.save(build_dir + "/end2end_zynqbuild_tfc_w1a1_pynq_driver.onnx")