From 362601eddb9a1f64639810893b565b1580ce5a57 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 14 Jul 2022 16:38:16 +0200
Subject: [PATCH] [HLSCustomOp] add optional target_dir arg for inp2npy

---
 src/finn/custom_op/fpgadataflow/hlscustomop.py | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/hlscustomop.py b/src/finn/custom_op/fpgadataflow/hlscustomop.py
index 9978ab0c7..6692fa4b6 100644
--- a/src/finn/custom_op/fpgadataflow/hlscustomop.py
+++ b/src/finn/custom_op/fpgadataflow/hlscustomop.py
@@ -397,18 +397,19 @@ class HLSCustomOp(CustomOp):
         builder.build(code_gen_dir)
         self.set_nodeattr("executable_path", builder.executable_path)
 
-    def dynamic_input_to_npy(self, context, count):
+    def dynamic_input_to_npy(self, context, count, target_dir=""):
         """Saves input (given context) into .npy files.
 
         Count indicates the number of inputs that have to be saved."""
         node = self.onnx_node
-        code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
-        if code_gen_dir == "":
-            raise Exception(
+        if target_dir == "":
+            code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
+            if code_gen_dir == "":
+                raise Exception(
+                    """
+    Found no codegen dir for this node, did you run the prepare_cppsim transformation?
                 """
-Found no codegen dir for this node, did you run the prepare_cppsim transformation?
-            """
-            )
+                )
         # create a npy file for each input of the node (in_ind is input index)
         # assuming dynamic inputs start from 0
         for in_ind in range(count):
@@ -427,7 +428,7 @@ Found no codegen dir for this node, did you run the prepare_cppsim transformatio
             # make copy before saving the array
             reshaped_input = reshaped_input.copy()
             np.save(
-                os.path.join(code_gen_dir, "input_{}.npy".format(in_ind)),
+                os.path.join(target_dir, "input_{}.npy".format(in_ind)),
                 reshaped_input,
             )
 
-- 
GitLab