From fb0ba564c28e286d3b12506968a57dfd99e12f72 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 10 Sep 2021 20:59:25 +0200
Subject: [PATCH] [Build] allow specifying optional start from/stop at steps

---
 src/finn/builder/build_dataflow.py        | 45 +++++++++++++++++++++--
 src/finn/builder/build_dataflow_config.py |  7 ++++
 2 files changed, 48 insertions(+), 4 deletions(-)

diff --git a/src/finn/builder/build_dataflow.py b/src/finn/builder/build_dataflow.py
index 4aa1ad31e..c4664a547 100644
--- a/src/finn/builder/build_dataflow.py
+++ b/src/finn/builder/build_dataflow.py
@@ -62,7 +62,7 @@ class StreamToLogger(object):
         pass
 
 
-def resolve_build_steps(cfg: DataflowBuildConfig):
+def resolve_build_steps(cfg: DataflowBuildConfig, partial: bool = True):
     steps = cfg.steps
     if steps is None:
         steps = default_build_dataflow_steps
@@ -76,19 +76,56 @@ def resolve_build_steps(cfg: DataflowBuildConfig):
             steps_as_fxns.append(transform_step)
         else:
             raise Exception("Could not resolve build step: " + str(transform_step))
+    if partial:
+        step_names = list(map(lambda x: x.__name__, steps_as_fxns))
+        if cfg.start_step is None:
+            start_ind = 0
+        else:
+            start_ind = step_names.index(cfg.start_step)
+        if cfg.stop_step is None:
+            stop_ind = len(step_names) - 1
+        else:
+            stop_ind = step_names.index(cfg.stop_step)
+        steps_as_fxns = steps_as_fxns[start_ind : (stop_ind + 1)]
+
     return steps_as_fxns
 
 
+def resolve_step_filename(
+    step_name: str, cfg: DataflowBuildConfig, step_delta: int = 0
+):
+    step_names = list(
+        map(lambda x: x.__name__, resolve_build_steps(cfg, partial=False))
+    )
+    assert step_name in step_names, "start_step %s not found" + step_name
+    step_no = step_names.index(step_name) + step_delta
+    assert step_no >= 0, "Invalid step+delta combination"
+    assert step_no < len(step_names), "Invalid step+delta combination"
+    filename = cfg.output_dir + "/intermediate_models/"
+    filename += "%s.onnx" % (step_names[step_no])
+    return filename
+
+
 def build_dataflow_cfg(model_filename, cfg: DataflowBuildConfig):
     """Best-effort build a dataflow accelerator using the given configuration.
 
     :param model_filename: ONNX model filename to build
     :param cfg: Build configuration
     """
-    model = ModelWrapper(model_filename)
+    # if start_step is specified, override the input model
+    if cfg.start_step is None:
+        print("Building dataflow accelerator from " + model_filename)
+        model = ModelWrapper(model_filename)
+    else:
+        intermediate_model_filename = resolve_step_filename(cfg.start_step, cfg, -1)
+        print(
+            "Building dataflow accelerator from intermediate checkpoint"
+            + intermediate_model_filename
+        )
+        model = ModelWrapper(intermediate_model_filename)
     assert type(model) is ModelWrapper
     finn_build_dir = os.environ["FINN_BUILD_DIR"]
-    print("Building dataflow accelerator from " + model_filename)
+
     print("Intermediate outputs will be generated in " + finn_build_dir)
     print("Final outputs will be generated in " + cfg.output_dir)
     print("Build log is at " + cfg.output_dir + "/build_dataflow.log")
@@ -132,7 +169,7 @@ def build_dataflow_cfg(model_filename, cfg: DataflowBuildConfig):
             sys.stdout = stdout_orig
             sys.stderr = stderr_orig
             time_per_step[step_name] = step_end - step_start
-            chkpt_name = "%d_%s.onnx" % (step_num, step_name)
+            chkpt_name = "%s.onnx" % (step_name)
             if cfg.save_intermediate_models:
                 intermediate_model_dir = cfg.output_dir + "/intermediate_models"
                 if not os.path.exists(intermediate_model_dir):
diff --git a/src/finn/builder/build_dataflow_config.py b/src/finn/builder/build_dataflow_config.py
index 4f29b0b11..1655f2bc6 100644
--- a/src/finn/builder/build_dataflow_config.py
+++ b/src/finn/builder/build_dataflow_config.py
@@ -271,6 +271,13 @@ class DataflowBuildConfig:
     #: - functions are called with (model, DataflowBuildConfig) as args
     steps: Optional[List[Any]] = None
 
+    #: If given, start from this step, loading the intermediate model generated
+    #: from the previous step (save_intermediate_models must be enabled)
+    start_step: Optional[str] = None
+
+    #: If given, stop at this step.
+    stop_step: Optional[str] = None
+
     def _resolve_hls_clk_period(self):
         if self.hls_clk_period_ns is None:
             # use same clk for synth and hls if not explicitly specified
-- 
GitLab