Skip to content
Snippets Groups Projects
Unverified Commit 0e3f663b authored by Lucian Petrica's avatar Lucian Petrica Committed by GitHub
Browse files

Automatically set folding attributes (#251)


* Added transform and test for auto-folding to throughput target

* [SetFolding] improve comments

* [SetFolding] allow customizing mvau_wwidth_max

* [SetFolding] put ops into groups to reduce code redundancy

* [SetFolding] add a two-pass relaxation option

* [LabelSelect] implement get_exp_cycles

* [SetFolding] handle LabelSelect too

* [SetFolding] extra comment

* [SetFolding] fix swg setting, add warning

* [SetFolding] call GiveUniqueNodeNames before AnnotateCycles

* [Build] add build config option to target FPS

* [Test] use target_fps for example dataflow build config file

Co-authored-by: default avatarYaman Umuroglu <yamanu@xilinx.com>
parent 620b96f3
No related branches found
No related tags found
No related merge requests found
......@@ -352,3 +352,9 @@ class LabelSelect_Batch(HLSCustomOp):
self.code_gen_dict["$PRAGMAS$"].append(
"#pragma HLS INTERFACE ap_ctrl_none port=return"
)
def get_exp_cycles(self):
nlabels = self.get_nodeattr("Labels")
pe = self.get_nodeattr("PE")
exp_cycles = nlabels / pe
return int(exp_cycles)
{
"output_dir": "output_tfc_w1a1_Pynq-Z1",
"folding_config_file": "folding_config.json",
"target_fps": 100000,
"mvau_width_max": 10000,
"synth_clk_period_ns": 10.0,
"board": "Pynq-Z1",
"shell_flow_type": "vivado_zynq",
......@@ -8,13 +9,5 @@
"pynq_driver",
"stitched_ip",
"bitfile"
],
"fpga_part": null,
"auto_fifo_depths": true,
"hls_clk_period_ns": null,
"default_mem_mode": "decoupled",
"vitis_platform": null,
"vitis_floorplan_file": null,
"save_intermediate_models": true,
"enable_debug": false
]
}
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation
from finn.util.fpgadataflow import is_fpgadataflow_node
from finn.analysis.fpgadataflow.dataflow_performance import dataflow_performance
from finn.transformation.fpgadataflow.annotate_cycles import AnnotateCycles
from finn.transformation.general import GiveUniqueNodeNames
import warnings
def divisors(num):
for x in range(1, num + 1):
if (num % x) == 0:
yield x
class SetFolding(Transformation):
"""Attempt to set parallelism attributes in all nodes to meet a specific
target expressed as cycles per frame target_cycles_per_frame. For each
HLSCustomOp node type, the attribute may vary but is typically one of {PE, SIMD},
and has a certain allowed-maximum value and divisibility constraints,
which SetFolding will take into account. Note that the algorithm implemented
by SetFolding is very simple and it is often possible to hand-tune the returned
parallelism configuration for better results.
In the returned model, each node's
cycles_estimate attribute will be set to its estimated number of cycles.
If two_pass_relaxation is enabled,
SetFolding will internally run a second time if the target cycles from the
first pass could not be achieved, instead using the achievable target (which
may be constrained by a single node) to obtain a balanced pipeline.
Notable exceptions and special behavior:
* When folding dense convolution/FC compute engines (StreamingFCLayer_Batch),
which have two attributes (PE and SIMD):
* first increases SIMD while weight stream width per PE is <= mvau_wwidth_max
(configurable in the SetFolding initializer, defaults to 36)
* then increases PE until the target is met or max PE reached
* When folding depthwise convolutions ("VVAU"/Vector_Vector_Activate_Batch)
or spatial reduction ops (Pool_Batch):
* the producer of the node is expected to be a ConvolutionInputGenerator
with depthwise=1, whose SIMD value will be set equal to the PE value of
its consumer node
"""
def __init__(
self, target_cycles_per_frame=1000, mvau_wwidth_max=36, two_pass_relaxation=True
):
super().__init__()
self.target_cycles_per_frame = target_cycles_per_frame
self.mvau_wwidth_max = mvau_wwidth_max
self.two_pass_relaxation = two_pass_relaxation
def optimize_attribute_val(self, node_inst, max_val, attr_name):
node_inst.set_nodeattr(attr_name, 1)
for val in divisors(max_val):
node_inst.set_nodeattr(attr_name, val)
cyc = node_inst.get_exp_cycles()
if cyc < self.target_cycles_per_frame:
# finish if target met
break
def apply(self, model):
graph = model.graph
# these ops use PE parallelism, up to a max value of NumChannels
pe_ops = [
"AddStreams_Batch",
"ChannelwiseOp_Batch",
"DuplicateStreams_Batch",
"GlobalAccPool_Batch",
"Thresholding_Batch",
]
# these ops use SIMD parallelism, up to a max value of NumChannels
# ConvolutionInputGenerator has a special case when depthwise=1
simd_ops = ["DownSampler", "FMPadding_Batch", "ConvolutionInputGenerator"]
# these ops are preceded by depthwise SWG and have special behavior,
# as explained in the SetFolding docstring
depthwise_op_exceptions = ["Vector_Vector_Activate_Batch", "Pool_Batch"]
for node in graph.node:
if not is_fpgadataflow_node(node):
continue
op_type = node.op_type
node_inst = getCustomOp(node)
if op_type == "StreamingFCLayer_Batch":
max_simd = node_inst.get_nodeattr("MW")
max_pe = node_inst.get_nodeattr("MH")
node_inst.set_nodeattr("PE", 1)
node_inst.set_nodeattr("SIMD", 1)
# increase SIMD until either we meet
# the target or weight stream becomes
# too wide
for simd_val in divisors(max_simd):
prev_simd_val = node_inst.get_nodeattr("SIMD")
node_inst.set_nodeattr("SIMD", simd_val)
cyc = node_inst.get_exp_cycles()
if cyc < self.target_cycles_per_frame:
# finish if target met
break
if (
node_inst.get_weight_datatype().bitwidth()
* node_inst.get_nodeattr("SIMD")
> self.mvau_wwidth_max
):
# revert if we've gone above width threshold
node_inst.set_nodeattr("SIMD", prev_simd_val)
break
# increase PE until target met or reached max_pe
self.optimize_attribute_val(node_inst, max_pe, "PE")
elif op_type in pe_ops:
max_pe = node_inst.get_nodeattr("NumChannels")
self.optimize_attribute_val(node_inst, max_pe, "PE")
elif op_type == "LabelSelect_Batch":
max_pe = node_inst.get_nodeattr("Labels")
self.optimize_attribute_val(node_inst, max_pe, "PE")
elif op_type in depthwise_op_exceptions:
max_pe = node_inst.get_nodeattr("Channels")
self.optimize_attribute_val(node_inst, max_pe, "PE")
# also set the folding of the upsteam DW SWU
# which must be identical to this node
swu_node = model.find_producer(node.input[0])
if swu_node.op_type == "ConvolutionInputGenerator":
swu_node_inst = getCustomOp(swu_node)
pe = node_inst.get_nodeattr("PE")
swu_node_inst.set_nodeattr("SIMD", pe)
else:
raise Exception(
"Expected SWU on DW op input, found " + swu_node.op_type
)
elif op_type in simd_ops:
if op_type == "ConvolutionInputGenerator":
depthwise = node_inst.get_nodeattr("depthwise")
if depthwise == 0:
max_simd = node_inst.get_nodeattr("IFMChannels")
self.optimize_attribute_val(node_inst, max_simd, "SIMD")
else:
# depthwise SWGs are handled separately
continue
else:
max_simd = node_inst.get_nodeattr("NumChannels")
self.optimize_attribute_val(node_inst, max_simd, "SIMD")
else:
warnings.warn(
"SetFolding doesn't know how to handle op_type " + op_type
)
model = model.transform(GiveUniqueNodeNames())
model = model.transform(AnnotateCycles())
if self.two_pass_relaxation:
perf_dict = model.analysis(dataflow_performance)
if perf_dict["max_cycles"] > self.target_cycles_per_frame:
# run again, but with lower target (that we managed) -- this
# may be coming from a single node's constraints, but we want
# to balance the entire dataflow pipeline instead
# no two_pass_relaxation this time -- no guarantee we'll
# converge otherwise
warnings.warn(
"Node %s is bottleneck with %d cycles, running second pass"
% (perf_dict["max_cycles_node_name"], perf_dict["max_cycles"])
)
model = model.transform(
SetFolding(
target_cycles_per_frame=perf_dict["max_cycles"],
mvau_wwidth_max=self.mvau_wwidth_max,
two_pass_relaxation=False,
)
)
return (model, False)
......@@ -69,6 +69,7 @@ from finn.transformation.fpgadataflow.make_zynq_proj import ZynqBuild
from finn.transformation.fpgadataflow.vitis_build import VitisBuild, VitisOptStrategy
from finn.transformation.fpgadataflow.make_pynq_driver import MakePYNQDriver
from finn.util.basic import pynq_part_map, alveo_part_map
from finn.transformation.fpgadataflow.set_folding import SetFolding
from finn.transformation.fpgadataflow.create_dataflow_partition import (
CreateDataflowPartition,
)
......@@ -156,26 +157,45 @@ class DataflowBuildConfig:
#: Directory where the final build outputs will be written into
output_dir: str
#: Path to folding configuration JSON file. May also contain FIFO sizes
#: and any other HLS node config if desired.
#: Will be applied with :py:mod:`finn.transformation.general.ApplyConfig`
folding_config_file: str
#: Target clock frequency (in nanoseconds) for Vivado synthesis.
#: e.g. synth_clk_period_ns=5.0 will target a 200 MHz clock.
#: If hls_clk_period_ns is not specified it will default to this value.
synth_clk_period_ns: float
#: Which output(s) to generate from the build flow.
#: Which output(s) to generate from the build flow. See documentation of
#: DataflowOutputType for available options.
generate_outputs: List[DataflowOutputType]
#: (Optional) Path to configuration JSON file. May include parallelization,
#: FIFO sizes, RAM and implementation style attributes and so on.
#: If the parallelization attributes (PE, SIMD) are part of the config,
#: this will override the automatically generated parallelization
#: attributes inferred from target_fps (if any)
#: Will be applied with :py:mod:`finn.transformation.general.ApplyConfig`
folding_config_file: Optional[str] = None
#: (Optional) Target inference performance in frames per second.
#: Note that target may not be achievable due to specific layer constraints,
#: or due to resource limitations of the FPGA.
#: If parallelization attributes are specified as part of folding_config_file
#: that will override the target_fps setting here.
target_fps: Optional[int] = None
#: (Optional) Control the maximum width of the per-PE MVAU stream while
#: exploring the parallelization attributes to reach target_fps
#: Only relevant if target_fps is specified.
#: Set this to a large value (e.g. 10000) if targeting full unfolding or
#: very high performance.
mvau_wwidth_max: Optional[int] = 36
#: Target board, only needed for generating full bitfiles where the FINN
#: design is integrated into a shell.
#: e.g. "Pynq-Z1" or "U250"
board: Optional[str] = None
#: Target shell flow, only needed for generating full bitfiles where the FINN
#: design is integrated into a shell.
#: design is integrated into a shell. See documentation of ShellFlowType
#: for options.
shell_flow_type: Optional[ShellFlowType] = None
#: Target Xilinx FPGA part. Only needed when board is not specified.
......@@ -277,6 +297,14 @@ class DataflowBuildConfig:
raise Exception("Could not resolve build step: " + str(transform_step))
return steps_as_fxns
def _resolve_cycles_per_frame(self):
if self.target_fps is None:
return None
else:
n_clock_cycles_per_sec = 10 ** 9 / self.synth_clk_period_ns
n_cycles_per_frame = n_clock_cycles_per_sec / self.target_fps
return int(n_cycles_per_frame)
def _resolve_vitis_opt_strategy(self):
# convert human-readable enum to value expected by v++
name_to_strategy = {
......@@ -373,12 +401,25 @@ def step_create_dataflow_partition(model: ModelWrapper, cfg: DataflowBuildConfig
return model
def step_target_fps_parallelization(model: ModelWrapper, cfg: DataflowBuildConfig):
"""If target_fps was specified, use the SetFolding transformation to determine
parallelization attributes."""
target_cycles_per_frame = cfg._resolve_cycles_per_frame()
if target_cycles_per_frame is not None:
model = model.transform(
SetFolding(target_cycles_per_frame, mvau_wwidth_max=cfg.mvau_wwidth_max)
)
return model
def step_apply_folding_config(model: ModelWrapper, cfg: DataflowBuildConfig):
"""Apply the folding configuration file onto the model to set folding (parallelization)
and other attributes."""
and other attributes, if config file is specified."""
model = model.transform(GiveUniqueNodeNames())
model = model.transform(ApplyConfig(cfg.folding_config_file))
if cfg.folding_config_file is not None:
model = model.transform(GiveUniqueNodeNames())
model = model.transform(ApplyConfig(cfg.folding_config_file))
return model
......@@ -517,6 +558,7 @@ default_build_dataflow_steps = [
"step_streamline",
"step_convert_to_hls",
"step_create_dataflow_partition",
"step_target_fps_parallelization",
"step_apply_folding_config",
"step_hls_ipgen",
"step_set_fifo_depths",
......@@ -531,6 +573,7 @@ _internal_step_lookup = {
"step_streamline": step_streamline,
"step_convert_to_hls": step_convert_to_hls,
"step_create_dataflow_partition": step_create_dataflow_partition,
"step_target_fps_parallelization": step_target_fps_parallelization,
"step_apply_folding_config": step_apply_folding_config,
"step_hls_ipgen": step_hls_ipgen,
"step_set_fifo_depths": step_set_fifo_depths,
......
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import numpy as np
import math
import random
from onnx import TensorProto, helper
from finn.custom_op.registry import getCustomOp
from finn.core.datatype import DataType
from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fpgadataflow.set_folding import SetFolding
from finn.transformation.general import GiveUniqueNodeNames
from finn.transformation.fpgadataflow.create_dataflow_partition import (
CreateDataflowPartition,
)
from finn.util.test import load_test_checkpoint_or_skip
def make_multi_fclayer_model(ch, wdt, adt, tdt, nnodes):
W = np.random.randint(wdt.min(), wdt.max()+1, size=(ch, ch))
W = W.astype(np.float32)
T = np.random.randint(tdt.min(), tdt.max()+1, size=(ch, 2**adt.bitwidth()-1))
T = T.astype(np.float32)
tensors = []
tensors.append(helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, ch]))
for i in range(1, nnodes):
inter = helper.make_tensor_value_info("inter_"+str(i), TensorProto.FLOAT, [1, ch])
tensors.append(inter)
tensors.append(helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, ch]))
FCLayer_nodes = []
for i in range(nnodes):
pe = 1
simd = 1
FCLayer_nodes += [helper.make_node(
"StreamingFCLayer_Batch",
[tensors[i].name, "weights_"+str(i), "thresh_"+str(i)],
[tensors[i+1].name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
MW=ch,
MH=ch,
SIMD=simd,
PE=pe,
inputDataType=adt.name,
weightDataType=wdt.name,
outputDataType=adt.name,
ActVal=0,
binaryXnorMode=0,
noActivation=0,
)]
graph = helper.make_graph(
nodes=FCLayer_nodes, name="fclayer_graph", inputs=[tensors[0]], outputs=[tensors[-1]]
)
model = helper.make_model(graph, producer_name="fclayer-model")
model = ModelWrapper(model)
model.set_tensor_datatype("inp", adt)
model.set_tensor_datatype("outp", adt)
for i in range(1, nnodes+1):
model.graph.value_info.append(tensors[i])
model.set_initializer("weights_"+str(i-1), W)
model.set_initializer("thresh_"+str(i-1), T)
model.set_tensor_datatype("weights_"+str(i-1), wdt)
model.set_tensor_datatype("thresh_"+str(i-1), tdt)
return model
# desired frames per second
@pytest.mark.parametrize("target_fps", [30, 10 ** 5, 10 ** 7])
# target chip or board
@pytest.mark.parametrize("platform", ["Pynq-Z1", "Ultra96", "U200"])
def test_set_folding(target_fps, platform):
model = make_multi_fclayer_model(128, DataType.INT4, DataType.INT2, DataType.INT16, 5)
model = model.transform(GiveUniqueNodeNames())
parent_model = model.transform(CreateDataflowPartition())
sdp_node = parent_model.get_nodes_by_op_type("StreamingDataflowPartition")[0]
sdp_node = getCustomOp(sdp_node)
dataflow_model_filename = sdp_node.get_nodeattr("model")
dataflow_model = load_test_checkpoint_or_skip(dataflow_model_filename)
clk_ns = 5
target_cycles_per_frame = int((10 ** 9 / clk_ns) / target_fps)
dataflow_model = dataflow_model.transform(SetFolding(target_cycles_per_frame))
exp_cycles_dict = dataflow_model.analysis(exp_cycles_per_layer)
achieved_cycles_per_frame = max(exp_cycles_dict.values())
min_cycles = dict()
min_cycles["Pynq-Z1"] = 128
min_cycles["Ultra96"] = 64
min_cycles["U200"] = 1
assert achieved_cycles_per_frame <= max(
min_cycles[platform], target_cycles_per_frame
), "Folding target not met"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment