diff --git a/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py b/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py index 9a897d9fa16064017dfc02f500d2360ae8431b4a..fead30650c60d38f9cd70de8f1515f847e15276f 100644 --- a/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py +++ b/src/finn/custom_op/fpgadataflow/vector_vector_activate_batch.py @@ -26,9 +26,9 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): def get_nodeattr_types(self): my_attrs = { "PE": ("i", True, 0), - "Dim": ("i", True, 0), + "Dim": ("ints", True, []), # [H, W] "Channels": ("i", True, 0), - "Kernel": ("i", True, 0), + "Kernel": ("ints", True, []), # [H, W] "resType": ("s", False, "auto", {"auto", "lut", "dsp"}), "ActVal": ("i", False, 0), # FINN DataTypes for inputs, weights, outputs @@ -45,10 +45,10 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): def minimize_accumulator_width(self, model): weights = model.get_initializer(self.onnx_node.input[1]) - k = self.get_nodeattr("Kernel") + k_h, k_w = self.get_nodeattr("Kernel") fm = self.get_nodeattr("Channels") # put weights into the shape expected by calculate_matvec_accumulator_range - weights = weights.reshape(fm, k * k).transpose() + weights = weights.reshape(fm, k_h * k_w).transpose() if len(self.onnx_node.input) > 2: thresholds = model.get_initializer(self.onnx_node.input[2]) else: @@ -85,9 +85,11 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): tdt = DataType.get_smallest_possible(0 - tdt_max) else: tdt = DataType.get_smallest_possible(tdt_max) - assert np.vectorize(tdt.allowed)(threshold_tensor).all(), ( - "Thresholds in %s can't be expressed with type %s" - % (self.onnx_node.name, str(tdt)) + assert np.vectorize(tdt.allowed)( + threshold_tensor + ).all(), "Thresholds in %s can't be expressed with type %s" % ( + self.onnx_node.name, + str(tdt), ) self.set_nodeattr("accDataType", tdt.name) else: @@ -110,9 +112,9 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): def calc_wmem(self): """Calculates and returns WMEM.""" ch = self.get_nodeattr("Channels") - k = self.get_nodeattr("Kernel") + k_h, k_w = self.get_nodeattr("Kernel") pe = self.get_nodeattr("PE") - wmem = k * k * ch // pe + wmem = k_h * k_w * ch // pe return wmem def calc_tmem(self): @@ -181,34 +183,34 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): return out_width def get_folded_input_shape(self): - k = self.get_nodeattr("Kernel") - sf = k * k - dim = self.get_nodeattr("Dim") + k_h, k_w = self.get_nodeattr("Kernel") + sf = k_h * k_w + dim_h, dim_w = self.get_nodeattr("Dim") ch = self.get_nodeattr("Channels") pe = self.get_nodeattr("PE") nf = ch // pe - folded_input_shape = tuple([1, dim, dim, sf * nf, pe]) + folded_input_shape = tuple([1, dim_h, dim_w, sf * nf, pe]) return folded_input_shape def get_folded_output_shape(self): ch = self.get_nodeattr("Channels") pe = self.get_nodeattr("PE") nf = ch // pe - dim = self.get_nodeattr("Dim") - folded_output_shape = tuple([1, dim, dim, nf, pe]) + dim_h, dim_w = self.get_nodeattr("Dim") + folded_output_shape = tuple([1, dim_h, dim_w, nf, pe]) return folded_output_shape def get_normal_input_shape(self): - dim = self.get_nodeattr("Dim") + dim_h, dim_w = self.get_nodeattr("Dim") ch = self.get_nodeattr("Channels") - k = self.get_nodeattr("Kernel") - normal_input_shape = tuple([1, dim, dim, k * k * ch]) + k_h, k_w = self.get_nodeattr("Kernel") + normal_input_shape = tuple([1, dim_h, dim_w, k_h * k_w * ch]) return normal_input_shape def get_normal_output_shape(self): ch = self.get_nodeattr("Channels") - dim = self.get_nodeattr("Dim") - normal_output_shape = tuple([1, dim, dim, ch]) + dim_h, dim_w = self.get_nodeattr("Dim") + normal_output_shape = tuple([1, dim_h, dim_w, ch]) return normal_output_shape def get_number_output_values(self): @@ -218,13 +220,13 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): def get_exp_cycles(self): pe = self.get_nodeattr("PE") ch = self.get_nodeattr("Channels") - dim = self.get_nodeattr("Dim") - k = self.get_nodeattr("Kernel") + dim_h, dim_w = self.get_nodeattr("Dim") + k_h, k_w = self.get_nodeattr("Kernel") # currently FINN supports for vvau a batch size of 1 batch_size = 1 # since mmv != 1 is not supported yet, we set mmv for now to 1 mmv = 1 - exp_cycles = ((ch * k * k) / pe) * batch_size * (dim * dim) / mmv + exp_cycles = ((ch * k_h * k_w) / pe) * batch_size * (dim_h * dim_w) / mmv return int(exp_cycles) def get_template_param_values(self): @@ -251,17 +253,17 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): def get_hls_compatible_weight_tensor(self, orig_weight_matrix): pe = self.get_nodeattr("PE") ch = self.get_nodeattr("Channels") - k = self.get_nodeattr("Kernel") + k_h, k_w = self.get_nodeattr("Kernel") wmem = self.calc_wmem() assert orig_weight_matrix.shape == ( ch, 1, - k, - k, + k_h, + k_w, ), """Weights matrix doesn't have expected shape (channels, 1, kernel_size, kernel_size)""" ret = orig_weight_matrix - ret = ret.reshape(ch, k * k) + ret = ret.reshape(ch, k_h * k_w) # distribute rows between PEs ret = interleave_matrix_outer_dim_from_partitions(ret, pe) ret = ret.reshape(1, pe, wmem, 1) @@ -338,9 +340,11 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds) # get computed threshold datatype from attribute tdt = DataType[self.get_nodeattr("accDataType")] - assert np.vectorize(tdt.allowed)(threshold_tensor).all(), ( - "Thresholds in %s can't be expressed with type %s" - % (self.onnx_node.name, str(tdt)) + assert np.vectorize(tdt.allowed)( + threshold_tensor + ).all(), "Thresholds in %s can't be expressed with type %s" % ( + self.onnx_node.name, + str(tdt), ) thresholds_hls_code = numpy_to_hls_code( threshold_tensor, tdt, "thresholds", False, True @@ -455,10 +459,10 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): self.code_gen_dict["$GLOBALS$"] += ['#include "thresh.h"'] def defines(self, var): - dim = self.get_nodeattr("Dim") - numReps = 1 * dim * dim - kernel = self.get_nodeattr("Kernel") - innerProdDim = kernel * kernel + dim_h, dim_w = self.get_nodeattr("Dim") + numReps = 1 * dim_h * dim_w + k_h, k_w = self.get_nodeattr("Kernel") + innerProdDim = k_h * k_w self.code_gen_dict["$DEFINES$"] = [ """#define Channels1 {}\n #define InnerProdDim {}\n #define SIMD1 1\n #define PE1 {}\n #define numReps {}""".format( @@ -664,8 +668,8 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): else: mult_luts = (2 * math.ceil((W + A) / 6) - 1) * (W + A) # accumulator - k = self.get_nodeattr("Kernel") - acc_bits = W + A + math.ceil(math.log(k * k, 2)) + k_h, k_w = self.get_nodeattr("Kernel") + acc_bits = W + A + math.ceil(math.log(k_h * k_w, 2)) acc_luts = acc_bits # thresholds and threshold comparators thr_luts = 0 @@ -694,20 +698,20 @@ class Vector_Vector_Activate_Batch(HLSCustomOp): return int(mult_dsp) def get_op_and_param_counts(self): - k = self.get_nodeattr("Kernel") + k_h, k_w = self.get_nodeattr("Kernel") fm = self.get_nodeattr("Channels") - dim = self.get_nodeattr("Dim") + dim_h, dim_w = self.get_nodeattr("Dim") weight_bits = self.get_weight_datatype().bitwidth() inp_bits = self.get_input_datatype().bitwidth() - num_repetitions = int(dim * dim) - mac_count = k * k * fm * num_repetitions + num_repetitions = int(dim_h * dim_w) + mac_count = k_h * k_w * fm * num_repetitions # cannonicalize op type: highest bitwidth operand first s.t. # e.g. mac_8bx4b and mac_4bx8b don't appear as two different op types bw1 = min(inp_bits, weight_bits) bw2 = max(inp_bits, weight_bits) mac_op_type = "op_mac_%dbx%db" % (bw1, bw2) weight_param_type = "param_weight_%db" % (weight_bits) - weight_count = k * k * fm + weight_count = k_h * k_w * fm ret_dict = {mac_op_type: mac_count, weight_param_type: weight_count} if self.get_nodeattr("noActivation") == 0: tdt = DataType[self.get_nodeattr("accDataType")] diff --git a/tests/fpgadataflow/test_fpgadataflow_vvau.py b/tests/fpgadataflow/test_fpgadataflow_vvau.py new file mode 100644 index 0000000000000000000000000000000000000000..4756d4fe18ccd4934b4041c70bf2f3a1bb577ec7 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_vvau.py @@ -0,0 +1,242 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from onnx import TensorProto, helper + +import finn.core.onnx_exec as oxe +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper +from finn.util.basic import gen_finn_dt_tensor +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.general import GiveUniqueNodeNames +from finn.custom_op.general.multithreshold import multithreshold + +from finn.custom_op.registry import getCustomOp +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer + + +def _infer_sparse_weight_tensor(W_conv, k_h, k_w, channels): + W_sparse = np.zeros((channels, channels, k_h, k_w)) + for ch in range(channels): + W_sparse[ch][ch] = W_conv[ch][0] + W_conv = W_sparse.astype(np.float32) + W_matmul = W_conv.transpose(0, 2, 3, 1) + W_matmul = W_matmul.reshape(channels, channels * k_h * k_w) + W_matmul = W_matmul.T + + return W_matmul + + +def _calculate_dot_prod_range(dt_a, dt_b, len): + """Returns the (min,max) values a dot product between two (un)signed vectors of + types dt_a and dt_b of len elements can take.""" + min_prod = 2 ** 30 + max_prod = -(2 ** 30) + for a_val in [dt_a.min(), dt_a.max()]: + for b_val in [dt_b.min(), dt_b.max()]: + prod = a_val * b_val * len + if prod < min_prod: + min_prod = prod + if prod > max_prod: + max_prod = prod + return (min_prod, max_prod) + + +def _make_single_vvau_modelwrapper( + W, pe, k_h, k_w, channels, dim_h, dim_w, wdt, idt, odt, T=None, tdt=None +): + in_shape = [1, dim_h, dim_w, k_h * k_w * channels] # [N, H, W, K*K*CH] + out_shape = [ + 1, + dim_h, + dim_w, + channels, + ] # [N, H, W, OFM_CH] (OFM_CH=IFM_CH because depthwise convolution) + + inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, in_shape) + outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, out_shape) + + if T is not None: + no_act = 0 + node_inp_list = ["inp", "weights", "thresh"] + actval = odt.min() + else: + no_act = 1 + node_inp_list = ["inp", "weights"] + actval = 0 + + VVAU_node = helper.make_node( + "Vector_Vector_Activate_Batch", + node_inp_list, + ["outp"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + PE=pe, + Dim=[dim_h, dim_w], + Channels=channels, + Kernel=[k_h, k_w], + resType="lut", + ActVal=actval, + inputDataType=idt.name, + weightDataType=wdt.name, + outputDataType=odt.name, + noActivation=no_act, + ) + + graph = helper.make_graph( + nodes=[VVAU_node], name="vvau_graph", inputs=[inp], outputs=[outp] + ) + + model = helper.make_model(graph, producer_name="vvau-model") + model = ModelWrapper(model) + + model.set_tensor_datatype("inp", idt) + model.set_tensor_datatype("outp", odt) + model.set_tensor_datatype("weights", wdt) + + model.set_initializer("weights", W) + model.set_tensor_shape("weights", (channels, 1, k_h, k_w)) + + if T is not None: + model.set_tensor_datatype("thresh", tdt) + model.set_initializer("thresh", T) + + return model + + +def prepare_inputs(input_tensor): + return {"inp": input_tensor} + + +# mem_mode: const or decoupled +@pytest.mark.parametrize("idt", [DataType.UINT4, DataType.UINT8]) +# weight datatype +@pytest.mark.parametrize("wdt", [DataType.INT4]) +# activation: None or DataType +@pytest.mark.parametrize("act", [DataType.UINT4, None]) +# PE +@pytest.mark.parametrize("pe", [1, "channels"]) +# Input image shape +@pytest.mark.parametrize("dim_h", [10]) +@pytest.mark.parametrize("dim_w", [10, 1]) +# Kernel shape +@pytest.mark.parametrize("k_h", [3]) +@pytest.mark.parametrize("k_w", [3, 1]) +# Number of input and output channels +@pytest.mark.parametrize("channels", [3, 4]) +# execution mode +@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) +@pytest.mark.slow +@pytest.mark.vivado +def test_fpgadataflow_vvau( + idt, wdt, act, pe, dim_h, dim_w, k_h, k_w, channels, exec_mode +): + if pe == "channels": + pe = channels + + if dim_w == 1 and k_w != 1: + pytest.skip("1D image requires 1D kernel, skipping.") + + if channels % pe != 0: + pytest.skip("Requirement Channels divisable by PE is violated.") + + # Generate weights in expected shape for ONNX and HLS node + W = gen_finn_dt_tensor(wdt, (channels, 1, k_h, k_w)) # shape: [channels, 1, k, k] + W_onnx = _infer_sparse_weight_tensor( + W, k_h, k_w, channels + ) # shape: [k*k*channels, channels] + + # Generate inputs in expected format for ONNX and HLS node + x = gen_finn_dt_tensor(idt, (1, dim_h, dim_w, k_h * k_w * channels)) + x_vvau = x.reshape(1, dim_h, dim_w, k_h * k_w, channels // pe, pe) + x_vvau = x_vvau.transpose(0, 1, 2, 4, 3, 5) + x_vvau = x_vvau.reshape(1, dim_h, dim_w, channels * k_h * k_w) + + if act is None: + T = None + tdt = None + odt = DataType.INT32 + else: + odt = act + (min_v, max_v) = _calculate_dot_prod_range(idt, wdt, k_h * k_w * channels) + n_steps = act.get_num_possible_values() - 1 + T = np.random.randint(min_v, max_v - 1, (channels, n_steps)).astype(np.float32) + T = np.sort(T, axis=1) + tdt = DataType.INT32 + + model = _make_single_vvau_modelwrapper( + W, pe, k_h, k_w, channels, dim_h, dim_w, wdt, idt, odt, T, tdt + ) + + if exec_mode == "cppsim": + model = model.transform(SetExecMode("cppsim")) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + elif exec_mode == "rtlsim": + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xc7z020clg400-1", 5)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + else: + raise Exception("Unknown exec_mode in test_fpgadataflow_vvau") + + input_dict = prepare_inputs(x_vvau) + + # Calculate output + y_expected = np.matmul(x, W_onnx) # Y is in [N, H, W, C] format + if T is not None: + # Reshape Y, as multithreshold expects Y to be in [N, C, H, W] format + y_expected = np.transpose(y_expected, (0, 3, 1, 2)) + y_expected = multithreshold(y_expected, T) + y_expected = np.transpose(y_expected, (0, 2, 3, 1)) + # signed offset + y_expected += act.min() + + y_produced = oxe.execute_onnx(model, input_dict, return_full_exec_context=False)[ + "outp" + ] + + assert (y_produced == y_expected).all(), "cppsim failed" + + if exec_mode == "rtlsim": + node = model.get_nodes_by_op_type("Vector_Vector_Activate_Batch")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0