Skip to content
Snippets Groups Projects
Commit fb0ba564 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Build] allow specifying optional start from/stop at steps

parent 178a09d3
No related branches found
No related tags found
No related merge requests found
......@@ -62,7 +62,7 @@ class StreamToLogger(object):
pass
def resolve_build_steps(cfg: DataflowBuildConfig):
def resolve_build_steps(cfg: DataflowBuildConfig, partial: bool = True):
steps = cfg.steps
if steps is None:
steps = default_build_dataflow_steps
......@@ -76,19 +76,56 @@ def resolve_build_steps(cfg: DataflowBuildConfig):
steps_as_fxns.append(transform_step)
else:
raise Exception("Could not resolve build step: " + str(transform_step))
if partial:
step_names = list(map(lambda x: x.__name__, steps_as_fxns))
if cfg.start_step is None:
start_ind = 0
else:
start_ind = step_names.index(cfg.start_step)
if cfg.stop_step is None:
stop_ind = len(step_names) - 1
else:
stop_ind = step_names.index(cfg.stop_step)
steps_as_fxns = steps_as_fxns[start_ind : (stop_ind + 1)]
return steps_as_fxns
def resolve_step_filename(
step_name: str, cfg: DataflowBuildConfig, step_delta: int = 0
):
step_names = list(
map(lambda x: x.__name__, resolve_build_steps(cfg, partial=False))
)
assert step_name in step_names, "start_step %s not found" + step_name
step_no = step_names.index(step_name) + step_delta
assert step_no >= 0, "Invalid step+delta combination"
assert step_no < len(step_names), "Invalid step+delta combination"
filename = cfg.output_dir + "/intermediate_models/"
filename += "%s.onnx" % (step_names[step_no])
return filename
def build_dataflow_cfg(model_filename, cfg: DataflowBuildConfig):
"""Best-effort build a dataflow accelerator using the given configuration.
:param model_filename: ONNX model filename to build
:param cfg: Build configuration
"""
model = ModelWrapper(model_filename)
# if start_step is specified, override the input model
if cfg.start_step is None:
print("Building dataflow accelerator from " + model_filename)
model = ModelWrapper(model_filename)
else:
intermediate_model_filename = resolve_step_filename(cfg.start_step, cfg, -1)
print(
"Building dataflow accelerator from intermediate checkpoint"
+ intermediate_model_filename
)
model = ModelWrapper(intermediate_model_filename)
assert type(model) is ModelWrapper
finn_build_dir = os.environ["FINN_BUILD_DIR"]
print("Building dataflow accelerator from " + model_filename)
print("Intermediate outputs will be generated in " + finn_build_dir)
print("Final outputs will be generated in " + cfg.output_dir)
print("Build log is at " + cfg.output_dir + "/build_dataflow.log")
......@@ -132,7 +169,7 @@ def build_dataflow_cfg(model_filename, cfg: DataflowBuildConfig):
sys.stdout = stdout_orig
sys.stderr = stderr_orig
time_per_step[step_name] = step_end - step_start
chkpt_name = "%d_%s.onnx" % (step_num, step_name)
chkpt_name = "%s.onnx" % (step_name)
if cfg.save_intermediate_models:
intermediate_model_dir = cfg.output_dir + "/intermediate_models"
if not os.path.exists(intermediate_model_dir):
......
......@@ -271,6 +271,13 @@ class DataflowBuildConfig:
#: - functions are called with (model, DataflowBuildConfig) as args
steps: Optional[List[Any]] = None
#: If given, start from this step, loading the intermediate model generated
#: from the previous step (save_intermediate_models must be enabled)
start_step: Optional[str] = None
#: If given, stop at this step.
stop_step: Optional[str] = None
def _resolve_hls_clk_period(self):
if self.hls_clk_period_ns is None:
# use same clk for synth and hls if not explicitly specified
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment