diff --git a/src/finn/transformation/fpgadataflow/make_pynq_driver.py b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
index 1e45a65720604144f67245b98dcbe3f6dc8363f5..a7bf9e6e6279923764009a00e2f805be1b1fa9c0 100644
--- a/src/finn/transformation/fpgadataflow/make_pynq_driver.py
+++ b/src/finn/transformation/fpgadataflow/make_pynq_driver.py
@@ -26,10 +26,8 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-import os
-import shutil
-import warnings
 
+import shutil
 from finn.custom_op.registry import getCustomOp
 from finn.transformation import Transformation
 from finn.util.basic import gen_finn_dt_tensor, get_finn_root, make_build_dir
@@ -48,14 +46,11 @@ class MakePYNQDriver(Transformation):
     value.
     """
 
-    def __init__(self):
+    def __init__(self, platform):
         super().__init__()
+        self.platform = platform
 
     def apply(self, model):
-        vivado_pynq_proj = model.get_metadata_prop("vivado_pynq_proj")
-        if vivado_pynq_proj is None or (not os.path.isdir(vivado_pynq_proj)):
-            warnings.warn("No PYNQ project found, apply MakePYNQProject first.")
-
         # create a temporary folder for the generated driver
         pynq_driver_dir = make_build_dir(prefix="pynq_driver_")
         model.set_metadata_prop("pynq_driver_dir", pynq_driver_dir)
@@ -68,11 +63,21 @@ class MakePYNQDriver(Transformation):
         o_tensor_shape_normal = tuple(model.get_tensor_shape(o_tensor_name))
         i_tensor_dt = model.get_tensor_datatype(i_tensor_name)
         o_tensor_dt = model.get_tensor_datatype(o_tensor_name)
-        # extract HLSCustomOp instances to get folded i/o shapes
-        first_node = getCustomOp(model.find_consumer(i_tensor_name))
-        last_node = getCustomOp(model.find_producer(o_tensor_name))
-        i_tensor_shape_folded = tuple(first_node.get_folded_input_shape())
-        o_tensor_shape_folded = tuple(last_node.get_folded_output_shape())
+        # handle folded i/o shapes due to differences in DMA engines
+        if self.platform == "zynq":
+            # extract HLSCustomOp instances to get folded i/o shapes
+            first_node = getCustomOp(model.find_consumer(i_tensor_name))
+            last_node = getCustomOp(model.find_producer(o_tensor_name))
+            i_tensor_shape_folded = tuple(first_node.get_folded_input_shape())
+            o_tensor_shape_folded = tuple(last_node.get_folded_output_shape())
+        else:
+            i_tensor_shape_folded = list(i_tensor_shape_normal)
+            i_tensor_shape_folded.insert(-1, 1)
+            i_tensor_shape_folded = tuple(i_tensor_shape_folded)
+            o_tensor_shape_folded = list(o_tensor_shape_normal)
+            o_tensor_shape_folded.insert(-1, 1)
+            o_tensor_shape_folded = tuple(o_tensor_shape_folded)
+
         # generate dummy folded i/o tensors and their packed versions
         i_tensor_dummy_folded = gen_finn_dt_tensor(i_tensor_dt, i_tensor_shape_folded)
         o_tensor_dummy_folded = gen_finn_dt_tensor(o_tensor_dt, o_tensor_shape_folded)
@@ -99,6 +104,7 @@ class MakePYNQDriver(Transformation):
             ret = ret.replace("[1,", "[%s," % batch_var_name)
             return ret
 
+        driver = driver.replace("$PLATFORM$", self.platform)
         driver = driver.replace("$INPUT_FINN_DATATYPE$", str(i_tensor_dt))
         driver = driver.replace("$INPUT_SHAPE_NORMAL$", mss(i_tensor_shape_normal))
         driver = driver.replace("$INPUT_SHAPE_FOLDED$", mss(i_tensor_shape_folded))
diff --git a/src/finn/transformation/fpgadataflow/templates.py b/src/finn/transformation/fpgadataflow/templates.py
index d22163cbe13514fbfe0fb410ef39c24acd7e8301..036286046ad72248fadc45db4ec695a24e2061e7 100644
--- a/src/finn/transformation/fpgadataflow/templates.py
+++ b/src/finn/transformation/fpgadataflow/templates.py
@@ -104,7 +104,7 @@ from finn.core.datatype import DataType
 from pynq.ps import Clocks
 
 class FINNAccelDriver():
-    def __init__(self, N, bitfile, platform="zynq"):
+    def __init__(self, N, bitfile, platform="$PLATFORM$"):
         \"\"\"Instantiate the FINN accelerator driver.
         Gets batchsize (N) as integer and path to bitfile as string.\"\"\"
         self.platform = platform
diff --git a/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py b/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py
index f3250baa5e3c1fc5ae7cecde01be0e0ca22eac0c..4a5901d320f4b4486a0adc29e1e828464e7bc5b4 100644
--- a/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py
+++ b/tests/end2end/test_zynqbuild_end2end_tfc_w1a1.py
@@ -156,7 +156,7 @@ def test_end2end_zynqbuild_tfc_w1a1_make_driver():
     model = load_test_checkpoint_or_skip(
         build_dir + "/end2end_zynqbuild_tfc_w1a1_folded.onnx"
     )
-    model = model.transform(MakePYNQDriver())
+    model = model.transform(MakePYNQDriver(platform="zynq-iodma"))
     model.save(build_dir + "/end2end_zynqbuild_tfc_w1a1_pynq_driver.onnx")