diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index dbd98623c4cdf5baca9fa9c137debf8be0f70981..d98f927bfeb7b23d945cd612d5125e6f2751891f 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -398,3 +398,60 @@ class InferQuantizedStreamingFCLayer(Transformation): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) return (model, graph_modified) + + +class InferThresholdingLayer(Transformation): + """Convert any MultiThreshold into a standalone thresholding HLS layer.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type == "MultiThreshold": + thl_input = node.input[0] + thl_threshold = node.input[1] + thl_output = node.output[0] + thl_in_shape = model.get_tensor_shape(thl_input) + thl_out_shape = model.get_tensor_shape(thl_output) + idt = model.get_tensor_datatype(thl_input) + + # skip conversion for layers with float input + if not idt.is_integer(): + continue + + # extract weight shape, note that ONNX and finn-hlslib + # make different assumptions about dim order here + # ONNX assumes W has (in, out) shape + # finn-hlslib assumes W has (out, in) shape + ifc = int(thl_in_shape[-1]) + # create node with no parallelization first + pe = 1 + assert ifc % pe == 0, "Requirement IFC divisable by PE is violated." + + odt = model.get_tensor_datatype(thl_output) + model.set_tensor_shape(thl_input, thl_in_shape) + model.set_tensor_shape(thl_output, thl_out_shape) + # create and insert new StreamingFCLayer node + new_node = helper.make_node( + "Thresholding_Batch", + [thl_input, thl_threshold], + [thl_output], + domain="finn", + backend="fpgadataflow", + NumChannels=ifc, + PE=pe, + inputDataType=idt.name, + outputDataType=odt.name, + numInputVectors=list(thl_in_shape[:-1]), + ) + graph.node.insert(node_ind, new_node) + # remove old node + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) diff --git a/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py b/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py index e03090f0581eebf68cac7baffb6888a6992df68d..be32a365e6c05d92075877b95b7eff50ad1beed4 100644 --- a/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py +++ b/tests/fpgadataflow/test_convert_to_hls_layers_cnv.py @@ -54,7 +54,9 @@ export_onnx_path_cnv = "test_output_cnv.onnx" @pytest.mark.vivado -def test_convert_to_hls_layers_cnv_w1a1(): +# Standalone or fused thresholding-based activation +@pytest.mark.parametrize("fused_activation", [True, False]) +def test_convert_to_hls_layers_cnv_w1a1(fused_activation): cnv = get_test_model_trained("CNV", 1, 1) bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv) model = ModelWrapper(export_onnx_path_cnv) @@ -80,6 +82,10 @@ def test_convert_to_hls_layers_cnv_w1a1(): expected_ctx = oxe.execute_onnx(model, input_dict, True) expected = expected_ctx[model.graph.output[0].name] + # if we infer thresholding first, all MultiThresholds get converted to HLS + # subsequently, the FC inference will generate passthrough MVAUs + if not fused_activation: + model = model.transform(to_hls.InferThresholdingLayer()) model = model.transform(to_hls.InferBinaryStreamingFCLayer()) model = model.transform(to_hls.InferQuantizedStreamingFCLayer()) for node in model.graph.node: @@ -102,7 +108,12 @@ def test_convert_to_hls_layers_cnv_w1a1(): model = model.transform(to_hls.InferStreamingMaxPool()) # check topology status finn_nodes = model.get_finn_nodes() - assert len(finn_nodes) == 18 + if fused_activation: + assert len(finn_nodes) == 18 + else: + assert len(finn_nodes) == 26 + thr_nodes = model.get_nodes_by_op_type("Thresholding_Batch") + assert len(thr_nodes) == 8 non_finn_nodes = model.get_non_finn_nodes() assert len(non_finn_nodes) == 4 exp_non_finn_nodes = ["Transpose", "Reshape", "Mul", "Add"]