Skip to content
Snippets Groups Projects
Commit 280c9baa authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

Merge remote-tracking branch 'upstream/dev' into feature/absorb_transform_only_if_linear

parents 736b4e56 6b04574d
No related branches found
No related tags found
No related merge requests found
......@@ -6,3 +6,5 @@ Contributors
* Jakoba Petri-Koenig (@auphelia)
* Andrea Rigoni (@AndreaRigoni)
* Hendrik Borras (@HenniOVP)
* Lucian Petrica (@quetric)
* Tobias Alonso (@Tobi-Alonso)
......@@ -40,6 +40,7 @@ from finn.util.basic import (
from finn.util.fpgadataflow import (
IPGenBuilder,
pyverilate_get_liveness_threshold_cycles,
rtlsim_multi_io,
)
from . import templates
......@@ -318,14 +319,24 @@ Found no codegen dir for this node, did you run the prepare_cppsim transformatio
)
def npy_to_dynamic_output(self, context):
"""Reads the output from a .npy file and saves it at the right place in
the context dictionary."""
# TODO support multi-output nodes as needed
"""Reads the output from an output.npy file generated from cppsim and
places its content into the context dictionary."""
node = self.onnx_node
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
output = np.load("{}/output.npy".format(code_gen_dir))
context[node.output[0]] = output
def npy_to_dynamic_outputs(self, context, npy_list):
"""Reads the output from .npy files generated from cppsim and places
their content into the context dictionary.
npy_list is a list specifying which files to read, and its order must
match the order of node outputs."""
node = self.onnx_node
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
for i in range(len(npy_list)):
output = np.load("{}/{}".format(code_gen_dir, npy_list[i]))
context[node.output[i]] = output
def exec_precompiled_singlenode_model(self):
"""Executes precompiled executable."""
executable_path = self.get_nodeattr("executable_path")
......@@ -421,6 +432,16 @@ compilation transformations?
sim.stop_vcd_trace()
return outputs
def rtlsim_multi_io(self, sim, io_dict):
"Run rtlsim for this node, supports multiple i/o streams."
trace_file = self.get_nodeattr("rtlsim_trace")
if trace_file == "default":
trace_file = self.onnx_node.name + ".vcd"
num_out_values = self.get_number_output_values()
total_cycle_count = rtlsim_multi_io(sim, io_dict, num_out_values, trace_file)
self.set_nodeattr("sim_cycles", total_cycle_count)
def execute_node(self, context, graph):
"""Executes single node using cppsim or rtlsim."""
mode = self.get_nodeattr("exec_mode")
......
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import numpy as np
from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow import HLSCustomOp
from onnx import TensorProto, helper
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
class DuplicateStreams_Batch(HLSCustomOp):
"""Class that corresponds to finn-hlslib function of the same name."""
def __init__(self, onnx_node):
super().__init__(onnx_node)
def get_nodeattr_types(self):
my_attrs = {
"NumChannels": ("i", True, 0),
"PE": ("i", True, 0),
# FINN DataTypes for input
"inputDataType": ("s", True, ""),
# number of input vectors, examples:
# [1] is a single vector (like a FC layer with batch=1)
# [4] is four vectors (like a FC layer with batch=4)
# [1, 4, 4] is four * four vectors (like a conv layer with batch=1)
"numInputVectors": ("ints", False, [1]),
}
my_attrs.update(super().get_nodeattr_types())
return my_attrs
def get_normal_input_shape(self):
ch = self.get_nodeattr("NumChannels")
vecs = list(self.get_nodeattr("numInputVectors"))
ishape = tuple(vecs + [ch])
return ishape
def get_folded_input_shape(self):
ch = self.get_nodeattr("NumChannels")
pe = self.get_nodeattr("PE")
vecs = list(self.get_nodeattr("numInputVectors"))
assert ch % pe == 0, "PE must divide NumChannels"
folds = int(ch / pe)
folded_ishape = tuple(vecs + [folds, pe])
return folded_ishape
def get_normal_output_shape(self):
return self.get_normal_input_shape()
def get_folded_output_shape(self):
return self.get_folded_input_shape()
def make_shape_compatible_op(self, model):
exp_ishape = self.get_normal_input_shape()
oshape = self.get_normal_output_shape()
ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
assert ishape == exp_ishape, "Unexpected input shape."
# implement tensor with correct shape
values = np.random.randn(*oshape).astype(np.float32)
split_input = np.concatenate((values, values), axis=0)
return helper.make_node(
"Split",
inputs=[split_input],
outputs=[self.onnx_node.output[0], self.onnx_node.output[0]],
value=helper.make_tensor(
name="const_tensor", data_type=TensorProto.FLOAT, axis=0
),
)
def infer_node_datatype(self, model):
odt = self.get_output_datatype()
model.set_tensor_datatype(self.onnx_node.output[0], odt)
def verify_node(self):
info_messages = []
# verify that "domain" is set to "finn"
domain_value = self.onnx_node.domain
if domain_value == "finn":
info_messages.append("Attribute domain is set correctly")
else:
info_messages.append('Attribute domain should be set to "finn"')
# verify that "backend" is set to "fpgadataflow"
backend_value = self.get_nodeattr("backend")
if backend_value == "fpgadataflow":
info_messages.append("Attribute backend is set correctly")
else:
info_messages.append('Attribute backend should be set to "fpgadataflow"')
# verify that all necessary attributes exist
try:
self.get_nodeattr("code_gen_dir_cppsim")
self.get_nodeattr("executable_path")
self.get_nodeattr("NumChannels")
self.get_nodeattr("PE")
self.get_nodeattr("inputDataType")
info_messages.append("All necessary attributes exist")
except Exception:
info_messages.append(
"""The required GlobalAccPool_Batch attributes do not exist."""
)
return info_messages
def get_input_datatype(self):
"""Returns FINN DataType of input."""
return DataType[self.get_nodeattr("inputDataType")]
def get_output_datatype(self):
"""Returns FINN DataType of output."""
return DataType[self.get_nodeattr("inputDataType")]
def get_instream_width(self):
"""Returns input stream width."""
ibits = self.get_input_datatype().bitwidth()
pe = self.get_nodeattr("PE")
in_width = pe * ibits
return in_width
def get_outstream_width(self):
"""Returns output stream width."""
obits = self.get_output_datatype().bitwidth()
pe = self.get_nodeattr("PE")
out_width = pe * obits
return out_width
def get_number_output_values(self):
return 2 * np.prod(self.get_folded_output_shape()[1:-1])
def execute_node(self, context, graph):
mode = self.get_nodeattr("exec_mode")
node = self.onnx_node
exp_ishape = self.get_normal_input_shape()
exp_oshape = self.get_normal_output_shape()
folded_ishape = self.get_folded_input_shape()
folded_oshape = self.get_folded_output_shape()
if mode == "cppsim":
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
elif mode == "rtlsim":
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
else:
raise Exception(
"""Invalid value for attribute exec_mode! Is currently set to: {}
has to be set to one of the following value ("cppsim", "rtlsim")""".format(
mode
)
)
inp = context[node.input[0]]
assert str(inp.dtype) == "float32", "Input datatype is not float32"
assert inp.shape == exp_ishape, """Input shape doesn't match expected shape ."""
export_idt = self.get_input_datatype()
# reshape input into folded form
inp = inp.reshape(folded_ishape)
# make copy before saving array
reshaped_input = inp.copy()
np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input)
if mode == "cppsim":
# execute the precompiled model
super().exec_precompiled_singlenode_model()
# load output npy file
super().npy_to_dynamic_outputs(context, ["output0.npy", "output1.npy"])
assert (
context[node.output[0]].shape == folded_oshape
), "cppsim \
did not produce expected ofolded utput shape"
assert (
context[node.output[1]].shape == folded_oshape
), "cppsim \
did not produce expected ofolded utput shape"
context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
context[node.output[1]] = context[node.output[1]].reshape(*exp_oshape)
elif mode == "rtlsim":
sim = self.get_rtlsim()
nbits = self.get_instream_width()
rtlsim_inp = npy_to_rtlsim_input(
"{}/input_0.npy".format(code_gen_dir), export_idt, nbits
)
super().reset_rtlsim(sim)
super().toggle_clk(sim)
rtlsim_dict = {
"inputs": {"in0": rtlsim_inp},
"outputs": {"out0": [], "out1": []},
}
self.rtlsim_multi_io(sim, rtlsim_dict)
odt = self.get_output_datatype()
target_bits = odt.bitwidth()
packed_bits = self.get_outstream_width()
out_shape = self.get_folded_output_shape()
out_npy_path = "{}/output0.npy".format(code_gen_dir)
rtlsim_output_to_npy(
rtlsim_dict["outputs"]["out0"],
out_npy_path,
odt,
out_shape,
packed_bits,
target_bits,
)
# load and reshape output 0
output = np.load(out_npy_path)
output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
context[node.output[0]] = output
out_npy_path = "{}/output1.npy".format(code_gen_dir)
rtlsim_output_to_npy(
rtlsim_dict["outputs"]["out1"],
out_npy_path,
odt,
out_shape,
packed_bits,
target_bits,
)
# load and reshape output 1
output = np.load(out_npy_path)
output = np.asarray([output], dtype=np.float32).reshape(*exp_oshape)
context[node.output[1]] = output
else:
raise Exception(
"""Invalid value for attribute exec_mode! Is currently set to: {}
has to be set to one of the following value ("cppsim", "rtlsim")""".format(
mode
)
)
assert (
context[node.output[0]].shape == exp_oshape
), """Output0 shape doesn't match expected shape."""
assert (
context[node.output[1]].shape == exp_oshape
), """Output1 shape doesn't match expected shape."""
def global_includes(self):
self.code_gen_dict["$GLOBALS$"] = ['#include "streamtools.h"']
def defines(self, var):
self.code_gen_dict["$DEFINES$"] = []
def read_npy_data(self):
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
dtype = self.get_input_datatype()
elem_bits = dtype.bitwidth()
packed_bits = self.get_instream_width()
packed_hls_type = "ap_uint<%d>" % packed_bits
elem_hls_type = dtype.get_hls_datatype_str()
npy_type = "float"
npy_in = "%s/input_0.npy" % code_gen_dir
self.code_gen_dict["$READNPYDATA$"] = []
self.code_gen_dict["$READNPYDATA$"].append(
'npy2apintstream<%s, %s, %d, %s>("%s", in0);'
% (packed_hls_type, elem_hls_type, elem_bits, npy_type, npy_in)
)
def strm_decl(self):
self.code_gen_dict["$STREAMDECLARATIONS$"] = []
self.code_gen_dict["$STREAMDECLARATIONS$"].append(
'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width())
)
self.code_gen_dict["$STREAMDECLARATIONS$"].append(
'hls::stream<ap_uint<{}>> out0 ("out0");'.format(self.get_outstream_width())
)
self.code_gen_dict["$STREAMDECLARATIONS$"].append(
'hls::stream<ap_uint<{}>> out1 ("out1");'.format(self.get_outstream_width())
)
def docompute(self):
self.code_gen_dict["$DOCOMPUTE$"] = [
"""DuplicateStreams_Batch<{}, {}> (in0, out0, out1, 1);""".format(
self.get_outstream_width(), self.get_number_output_values() // 2,
)
]
def dataoutstrm(self):
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
dtype = self.get_output_datatype()
elem_bits = dtype.bitwidth()
packed_bits = self.get_outstream_width()
packed_hls_type = "ap_uint<%d>" % packed_bits
elem_hls_type = dtype.get_hls_datatype_str()
npy_type = "float"
npy_out = "%s/output0.npy" % code_gen_dir
npy_out1 = "%s/output1.npy" % code_gen_dir
oshape = self.get_folded_output_shape()
oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
self.code_gen_dict["$DATAOUTSTREAM$"] = [
'apintstream2npy<%s, %s, %d, %s>(out0, %s, "%s");'
% (
packed_hls_type,
elem_hls_type,
elem_bits,
npy_type,
oshape_cpp_str,
npy_out,
)
]
self.code_gen_dict["$DATAOUTSTREAM$"] += [
'apintstream2npy<%s, %s, %d, %s>(out1, %s, "%s");'
% (
packed_hls_type,
elem_hls_type,
elem_bits,
npy_type,
oshape_cpp_str,
npy_out1,
)
]
def save_as_npy(self):
self.code_gen_dict["$SAVEASCNPY$"] = []
def blackboxfunction(self):
self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
"""void {}(hls::stream<ap_uint<{}>> &in0,
hls::stream<ap_uint<{}>> &out0,
hls::stream<ap_uint<{}>> &out1)""".format(
self.onnx_node.name,
self.get_instream_width(),
self.get_outstream_width(),
self.get_outstream_width(),
)
]
def pragmas(self):
self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out0")
self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out1")
self.code_gen_dict["$PRAGMAS$"].append(
"#pragma HLS INTERFACE ap_ctrl_none port=return"
)
This diff is collapsed.
......@@ -44,8 +44,10 @@ from finn.custom_op.fpgadataflow.streamingdatawidthconverter_batch import (
StreamingDataWidthConverter_Batch,
)
from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
from finn.custom_op.fpgadataflow.duplicatestreams_batch import DuplicateStreams_Batch
# create a mapping of all known CustomOp names and classes
custom_op = {}
......@@ -62,8 +64,10 @@ custom_op["MaxPoolNHWC"] = MaxPoolNHWC
custom_op["StreamingDataWidthConverter_Batch"] = StreamingDataWidthConverter_Batch
custom_op["StreamingFIFO"] = StreamingFIFO
custom_op["GlobalAccPool_Batch"] = GlobalAccPool_Batch
custom_op["Thresholding_Batch"] = Thresholding_Batch
custom_op["AddStreams_Batch"] = AddStreams_Batch
custom_op["LabelSelect_Batch"] = LabelSelect_Batch
custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch
def getCustomOp(node):
......
......@@ -33,6 +33,7 @@ from finn.transformation import Transformation
from finn.custom_op.registry import getCustomOp
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
import finn.core.data_layout as DataLayout
class InferConvInpGen(Transformation):
......@@ -398,3 +399,59 @@ class InferQuantizedStreamingFCLayer(Transformation):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
class InferThresholdingLayer(Transformation):
"""Convert any MultiThreshold into a standalone thresholding HLS layer."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for node in graph.node:
node_ind += 1
if node.op_type == "MultiThreshold":
thl_input = node.input[0]
thl_threshold = node.input[1]
thl_output = node.output[0]
thl_in_shape = model.get_tensor_shape(thl_input)
idt = model.get_tensor_datatype(thl_input)
# skip conversion for layers with float input
if not idt.is_integer():
continue
# skip conversion if input is not NHWC or NC
thl_in_layout = model.get_tensor_layout(thl_input)
if thl_in_layout != DataLayout.NHWC and thl_in_layout != DataLayout.NC:
continue
# now safe to assume number of channels is in last dimension
ifc = int(thl_in_shape[-1])
# create node with no parallelization first
pe = 1
assert ifc % pe == 0, "Requirement IFC divisable by PE is violated."
odt = model.get_tensor_datatype(thl_output)
# create and insert new StreamingFCLayer node
new_node = helper.make_node(
"Thresholding_Batch",
[thl_input, thl_threshold],
[thl_output],
domain="finn",
backend="fpgadataflow",
NumChannels=ifc,
PE=pe,
inputDataType=idt.name,
outputDataType=odt.name,
numInputVectors=list(thl_in_shape[:-1]),
)
graph.node.insert(node_ind, new_node)
# remove old node
graph.node.remove(node)
graph_modified = True
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
......@@ -127,3 +127,91 @@ def is_fpgadataflow_node(node):
is_node = True
return is_node
def rtlsim_multi_io(sim, io_dict, num_out_values, trace_file=""):
"""Runs the pyverilator simulation by passing the input values to the simulation,
toggle the clock and observing the execution time. Function contains also an
observation loop that can abort the simulation if no output value is produced
after a set number of cycles. Can handle multiple i/o streams. See function
implementation for details on how the top-level signals should be named.
sim: the PyVerilator object for simulation
io_dict: a dict of dicts in the following format:
{"inputs" : {"in0" : <input_data>, "in1" : <input_data>},
"outputs" : {"out0" : [], "out1" : []} }
<input_data> is a list of Python arbitrary-precision ints indicating
what data to push into the simulation, and the output lists are
similarly filled when the simulation is complete
num_out_values: number of total values to be read from the simulation to
finish the simulation and return.
returns: number of clock cycles elapsed for completion
"""
if trace_file != "":
sim.start_vcd_trace(trace_file)
for outp in io_dict["outputs"]:
sim.io[outp + "_V_V_TREADY"] = 1
# observe if output is completely calculated
# total_cycle_count will contain the number of cycles the calculation ran
output_done = False
total_cycle_count = 0
output_count = 0
old_output_count = 0
# avoid infinite looping of simulation by aborting when there is no change in
# output values after 100 cycles
no_change_count = 0
liveness_threshold = pyverilate_get_liveness_threshold_cycles()
while not (output_done):
for inp in io_dict["inputs"]:
inputs = io_dict["inputs"][inp]
sim.io[inp + "_V_V_TVALID"] = 1 if len(inputs) > 0 else 0
sim.io[inp + "_V_V_TDATA"] = inputs[0] if len(inputs) > 0 else 0
if sim.io[inp + "_V_V_TREADY"] == 1 and sim.io[inp + "_V_V_TVALID"] == 1:
inputs = inputs[1:]
io_dict["inputs"][inp] = inputs
for outp in io_dict["outputs"]:
outputs = io_dict["outputs"][outp]
if sim.io[outp + "_V_V_TVALID"] == 1 and sim.io[outp + "_V_V_TREADY"] == 1:
outputs = outputs + [sim.io[outp + "_V_V_TDATA"]]
output_count += 1
io_dict["outputs"][outp] = outputs
sim.io.ap_clk = 1
sim.io.ap_clk = 0
total_cycle_count = total_cycle_count + 1
if output_count == old_output_count:
no_change_count = no_change_count + 1
else:
no_change_count = 0
old_output_count = output_count
# check if all expected output words received
if output_count == num_out_values:
output_done = True
# end sim on timeout
if no_change_count == liveness_threshold:
if trace_file != "":
sim.flush_vcd_trace()
sim.stop_vcd_trace()
raise Exception(
"Error in simulation! Takes too long to produce output. "
"Consider setting the LIVENESS_THRESHOLD env.var. to a "
"larger value."
)
if trace_file != "":
sim.flush_vcd_trace()
sim.stop_vcd_trace()
return total_cycle_count
......@@ -39,6 +39,7 @@ from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.streamline import Streamline
from finn.util.test import get_test_model_trained
from finn.transformation.double_to_single_float import DoubleToSingleFloat
......@@ -54,7 +55,9 @@ export_onnx_path_cnv = "test_output_cnv.onnx"
@pytest.mark.vivado
def test_convert_to_hls_layers_cnv_w1a1():
# Standalone or fused thresholding-based activation
@pytest.mark.parametrize("fused_activation", [True, False])
def test_convert_to_hls_layers_cnv_w1a1(fused_activation):
cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv)
model = ModelWrapper(export_onnx_path_cnv)
......@@ -69,6 +72,7 @@ def test_convert_to_hls_layers_cnv_w1a1():
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(Streamline())
model = model.transform(InferDataLayouts())
# model.save("golden.onnx")
# load one of the test vectors
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
......@@ -80,6 +84,10 @@ def test_convert_to_hls_layers_cnv_w1a1():
expected_ctx = oxe.execute_onnx(model, input_dict, True)
expected = expected_ctx[model.graph.output[0].name]
# if we infer thresholding first, all MultiThresholds get converted to HLS
# subsequently, the FC inference will generate passthrough MVAUs
if not fused_activation:
model = model.transform(to_hls.InferThresholdingLayer())
model = model.transform(to_hls.InferBinaryStreamingFCLayer())
model = model.transform(to_hls.InferQuantizedStreamingFCLayer())
for node in model.graph.node:
......@@ -102,7 +110,12 @@ def test_convert_to_hls_layers_cnv_w1a1():
model = model.transform(to_hls.InferStreamingMaxPool())
# check topology status
finn_nodes = model.get_finn_nodes()
assert len(finn_nodes) == 18
if fused_activation:
assert len(finn_nodes) == 18
else:
assert len(finn_nodes) == 26
thr_nodes = model.get_nodes_by_op_type("Thresholding_Batch")
assert len(thr_nodes) == 8
non_finn_nodes = model.get_non_finn_nodes()
assert len(non_finn_nodes) == 4
exp_non_finn_nodes = ["Transpose", "Reshape", "Mul", "Add"]
......
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.general import GiveUniqueNodeNames
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.util.basic import gen_finn_dt_tensor
from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
ReplaceVerilogRelPaths,
)
def make_dupstreams_modelwrapper(ch, pe, idim, idt):
shape = [1, idim, idim, ch]
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, shape)
outp0 = helper.make_tensor_value_info("outp0", TensorProto.FLOAT, shape)
outp1 = helper.make_tensor_value_info("outp1", TensorProto.FLOAT, shape)
dupstrm_node = helper.make_node(
"DuplicateStreams_Batch",
["inp"],
["outp0", "outp1"],
domain="finn",
backend="fpgadataflow",
NumChannels=ch,
PE=pe,
inputDataType=idt.name,
numInputVectors=[1, idim, idim],
)
graph = helper.make_graph(
nodes=[dupstrm_node], name="graph", inputs=[inp], outputs=[outp0, outp1]
)
model = helper.make_model(graph, producer_name="addstreams-model")
model = ModelWrapper(model)
model.set_tensor_datatype("inp", idt)
return model
def prepare_inputs(input_tensor, idt):
return {"inp": input_tensor}
# data type
@pytest.mark.parametrize("idt", [DataType.INT4, DataType.UINT16])
# channels
@pytest.mark.parametrize("ch", [64])
# folding
@pytest.mark.parametrize("fold", [-1, 2, 1])
# image dimension
@pytest.mark.parametrize("imdim", [7])
# execution mode
@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
@pytest.mark.vivado
def test_fpgadataflow_duplicatestreams(idt, ch, fold, imdim, exec_mode):
if fold == -1:
pe = 1
else:
pe = ch // fold
assert ch % pe == 0
# generate input data
x = gen_finn_dt_tensor(idt, (1, imdim, imdim, ch))
model = make_dupstreams_modelwrapper(ch, pe, imdim, idt)
if exec_mode == "cppsim":
model = model.transform(PrepareCppSim())
model = model.transform(CompileCppSim())
model = model.transform(SetExecMode("cppsim"))
elif exec_mode == "rtlsim":
model = model.transform(SetExecMode("rtlsim"))
model = model.transform(GiveUniqueNodeNames())
model = model.transform(PrepareIP("xc7z020clg400-1", 5))
model = model.transform(HLSSynthIP())
model = model.transform(ReplaceVerilogRelPaths())
model = model.transform(PrepareRTLSim())
else:
raise Exception("Unknown exec_mode")
# prepare input data and execute
input_dict = prepare_inputs(x, idt)
output_dict = oxe.execute_onnx(model, input_dict)
y0 = output_dict["outp0"]
y1 = output_dict["outp1"]
expected_y = x
assert (y0 == expected_y).all(), exec_mode + " failed"
assert (y1 == expected_y).all(), exec_mode + " failed"
# Copyright (c) 2020, Xilinx
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of FINN nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import numpy as np
from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
from finn.analysis.fpgadataflow.hls_synth_res_estimation import hls_synth_res_estimation
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.custom_op.multithreshold import multithreshold
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.general import GiveUniqueNodeNames
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.util.basic import gen_finn_dt_tensor
from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
ReplaceVerilogRelPaths,
)
def make_single_thresholding_modelwrapper(T, pe, idt, odt):
NumChannels = T.shape[0]
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, NumChannels])
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, NumChannels])
node_inp_list = ["inp", "thresh"]
Thresholding_node = helper.make_node(
"Thresholding_Batch",
node_inp_list,
["outp"],
domain="finn",
backend="fpgadataflow",
NumChannels=NumChannels,
PE=pe,
inputDataType=idt.name,
outputDataType=odt.name,
)
graph = helper.make_graph(
nodes=[Thresholding_node],
name="thresholding_graph",
inputs=[inp],
outputs=[outp],
)
model = helper.make_model(graph, producer_name="thresholding-model")
model = ModelWrapper(model)
model.set_tensor_datatype("inp", idt)
model.set_tensor_datatype("outp", odt)
model.set_tensor_datatype("thresh", idt)
model.set_initializer("thresh", T)
return model
# activation: None or DataType
@pytest.mark.parametrize("act", [DataType.INT4, DataType.BIPOLAR])
# input datatype
@pytest.mark.parametrize("idt", [DataType.INT16, DataType.UINT16])
# folding, -1 is maximum possible
@pytest.mark.parametrize("nf", [-1, 2, 1])
# number of input features
@pytest.mark.parametrize("ich", [16])
# execution mode
@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
@pytest.mark.vivado
@pytest.mark.slow
def test_fpgadataflow_thresholding(idt, act, nf, ich, exec_mode):
if nf == -1:
nf = ich
pe = ich // nf
assert ich % pe == 0
# generate input data
x = gen_finn_dt_tensor(idt, (1, ich))
odt = act
n_steps = act.get_num_possible_values() - 1
T = np.random.randint(idt.min(), idt.max() + 1, (ich, n_steps)).astype(np.float32)
# provide non-decreasing thresholds
T = np.sort(T, axis=1)
model = make_single_thresholding_modelwrapper(T, pe, idt, odt)
if exec_mode == "cppsim":
model = model.transform(PrepareCppSim())
model = model.transform(CompileCppSim())
model = model.transform(SetExecMode("cppsim"))
elif exec_mode == "rtlsim":
model = model.transform(SetExecMode("rtlsim"))
model = model.transform(GiveUniqueNodeNames())
model = model.transform(PrepareIP("xc7z020clg400-1", 5))
model = model.transform(HLSSynthIP())
model = model.transform(ReplaceVerilogRelPaths())
model = model.transform(PrepareRTLSim())
else:
raise Exception("Unknown exec_mode")
# package input data as dictionary
input_dict = {"inp": x}
y = multithreshold(x, T)
if act == DataType.BIPOLAR:
# binary to bipolar
y = 2 * y - 1
else:
# signed offset
y += act.min()
oshape = model.get_tensor_shape("outp")
y_expected = y.reshape(oshape)
# execute model
y_produced = oxe.execute_onnx(model, input_dict)["outp"]
y_produced = y_produced.reshape(y_expected.shape)
assert (y_produced == y_expected).all(), "cppsim failed"
if exec_mode == "rtlsim":
hls_synt_res_est = model.analysis(hls_synth_res_estimation)
assert "Thresholding_Batch_0" in hls_synt_res_est
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment