From fb0ba564c28e286d3b12506968a57dfd99e12f72 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Fri, 10 Sep 2021 20:59:25 +0200 Subject: [PATCH] [Build] allow specifying optional start from/stop at steps --- src/finn/builder/build_dataflow.py | 45 +++++++++++++++++++++-- src/finn/builder/build_dataflow_config.py | 7 ++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/finn/builder/build_dataflow.py b/src/finn/builder/build_dataflow.py index 4aa1ad31e..c4664a547 100644 --- a/src/finn/builder/build_dataflow.py +++ b/src/finn/builder/build_dataflow.py @@ -62,7 +62,7 @@ class StreamToLogger(object): pass -def resolve_build_steps(cfg: DataflowBuildConfig): +def resolve_build_steps(cfg: DataflowBuildConfig, partial: bool = True): steps = cfg.steps if steps is None: steps = default_build_dataflow_steps @@ -76,19 +76,56 @@ def resolve_build_steps(cfg: DataflowBuildConfig): steps_as_fxns.append(transform_step) else: raise Exception("Could not resolve build step: " + str(transform_step)) + if partial: + step_names = list(map(lambda x: x.__name__, steps_as_fxns)) + if cfg.start_step is None: + start_ind = 0 + else: + start_ind = step_names.index(cfg.start_step) + if cfg.stop_step is None: + stop_ind = len(step_names) - 1 + else: + stop_ind = step_names.index(cfg.stop_step) + steps_as_fxns = steps_as_fxns[start_ind : (stop_ind + 1)] + return steps_as_fxns +def resolve_step_filename( + step_name: str, cfg: DataflowBuildConfig, step_delta: int = 0 +): + step_names = list( + map(lambda x: x.__name__, resolve_build_steps(cfg, partial=False)) + ) + assert step_name in step_names, "start_step %s not found" + step_name + step_no = step_names.index(step_name) + step_delta + assert step_no >= 0, "Invalid step+delta combination" + assert step_no < len(step_names), "Invalid step+delta combination" + filename = cfg.output_dir + "/intermediate_models/" + filename += "%s.onnx" % (step_names[step_no]) + return filename + + def build_dataflow_cfg(model_filename, cfg: DataflowBuildConfig): """Best-effort build a dataflow accelerator using the given configuration. :param model_filename: ONNX model filename to build :param cfg: Build configuration """ - model = ModelWrapper(model_filename) + # if start_step is specified, override the input model + if cfg.start_step is None: + print("Building dataflow accelerator from " + model_filename) + model = ModelWrapper(model_filename) + else: + intermediate_model_filename = resolve_step_filename(cfg.start_step, cfg, -1) + print( + "Building dataflow accelerator from intermediate checkpoint" + + intermediate_model_filename + ) + model = ModelWrapper(intermediate_model_filename) assert type(model) is ModelWrapper finn_build_dir = os.environ["FINN_BUILD_DIR"] - print("Building dataflow accelerator from " + model_filename) + print("Intermediate outputs will be generated in " + finn_build_dir) print("Final outputs will be generated in " + cfg.output_dir) print("Build log is at " + cfg.output_dir + "/build_dataflow.log") @@ -132,7 +169,7 @@ def build_dataflow_cfg(model_filename, cfg: DataflowBuildConfig): sys.stdout = stdout_orig sys.stderr = stderr_orig time_per_step[step_name] = step_end - step_start - chkpt_name = "%d_%s.onnx" % (step_num, step_name) + chkpt_name = "%s.onnx" % (step_name) if cfg.save_intermediate_models: intermediate_model_dir = cfg.output_dir + "/intermediate_models" if not os.path.exists(intermediate_model_dir): diff --git a/src/finn/builder/build_dataflow_config.py b/src/finn/builder/build_dataflow_config.py index 4f29b0b11..1655f2bc6 100644 --- a/src/finn/builder/build_dataflow_config.py +++ b/src/finn/builder/build_dataflow_config.py @@ -271,6 +271,13 @@ class DataflowBuildConfig: #: - functions are called with (model, DataflowBuildConfig) as args steps: Optional[List[Any]] = None + #: If given, start from this step, loading the intermediate model generated + #: from the previous step (save_intermediate_models must be enabled) + start_step: Optional[str] = None + + #: If given, stop at this step. + stop_step: Optional[str] = None + def _resolve_hls_clk_period(self): if self.hls_clk_period_ns is None: # use same clk for synth and hls if not explicitly specified -- GitLab