From 1d54e903b6bc967cb0bc8b0b9c1adaa1aca9297c Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Wed, 8 Jul 2020 17:54:20 +0100
Subject: [PATCH] [Transform] Add first draft of InferVVAU trafo

---
 .../fpgadataflow/convert_to_hls_layers.py     | 145 ++++++++++++++++++
 1 file changed, 145 insertions(+)

diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index ba054c40b..700db1436 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -595,6 +595,151 @@ class InferQuantizedStreamingFCLayer(Transformation):
         return (model, graph_modified)
 
 
+class InferVVAU(Transformation):
+    """Convert MatMul layers with quantized inputs and weights to
+    Vector_Vector_Activate_Batch layers, if the sparsity annotation
+    of the weight matrix indicates that the MatMul layer belongs to
+    a depthwise convolution. Any immediately following MultiThreshold
+    layers will also be absorbed into the VVAU."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if (
+                n.op_type == "MatMul"
+                and model.get_tensor_sparsity(n.input[1]) is not None
+            ):
+                sparsity = model.get_tensor_sparsity(n.input[1])
+                try:
+                    k = sparsity["dw"]["kernel_shape"]
+                except KeyError:
+                    raise Exception(
+                        """Sparsity doesn't indicate that MatMul
+                        belongs to a depthwise convolution."""
+                    )
+
+                mm_input = n.input[0]
+                mm_weight = n.input[1]
+                mm_output = n.output[0]
+                mm_in_shape = model.get_tensor_shape(mm_input)
+                mm_out_shape = model.get_tensor_shape(mm_output)
+                idt = model.get_tensor_datatype(mm_input)
+                wdt = model.get_tensor_datatype(mm_weight)
+                if idt.is_integer() and wdt.is_integer():
+                    mm_output = n.output[0]
+                    W = model.get_initializer(mm_weight)
+                    # infer dense weight tensor from sparse weight matrix
+                    # kernel size k which was extracted above and the value of
+                    # the channels is used.
+                    # the weight matrix has a shape of (k * k * Channels, Channels)
+                    # we need to reverse the creation of the sparse weight matrix
+                    # to achieve a weight tensor of shape (Channels, 1, k, k)
+                    channels = int(W.shape[1])
+                    # transpose to achieve a shape of (k * k * Channels, Channels)
+                    W = W.T
+                    # reshape to (Channels, k, k, Channels) to transpose afterwards
+                    # to (Channels, Channels, k, k)
+                    W = W.reshape(channels, k, k, channels)
+                    W = W.transpose(0, 3, 1, 2)
+                    # now we can extract the values using a for loop over the channels
+                    # and fill a zero numpy array in the correct shape
+                    w_tensor = np.zeros((channels, 1, k, k))
+                    for ch in range(channels):
+                        w_tensor[ch][0] = W[ch][ch]
+                    model.set_initializer(mm_weight, w_tensor)
+                    model.set_tensor_shape(mm_weight, (channels, 1, k, k))
+                    # create node with pe=channels as default
+                    pe = channels
+                    assert (
+                        channels % pe == 0
+                    ), "Requirement Channels divisable by PE is violated."
+                    # see if we have any following thresholds
+                    consumer = model.find_consumer(mm_output)
+                    if consumer is not None and consumer.op_type == "MultiThreshold":
+                        # TODO ensure integer thresholds?
+                        # create MVTU (i.e. including activation)
+                        mt_output = consumer.output[0]
+                        mt_out_shape = model.get_tensor_shape(mt_output)
+                        mt_thres = consumer.input[1]
+                        T = model.get_initializer(mt_thres)
+                        assert (
+                            T.shape[0] == 1 or T.shape[0] == channels
+                        ), """First dimension of
+                        thresholds neither 1 nor Channels."""
+                        odt = model.get_tensor_datatype(mt_output)
+                        scale = getCustomOp(consumer).get_nodeattr("out_scale")
+                        assert (
+                            scale == 1.0
+                        ), "out_scale must be equal to 1.0 for HLS conversion."
+                        actval = getCustomOp(consumer).get_nodeattr("out_bias")
+                        assert (
+                            int(actval) == actval
+                        ), "out_bias must be integer for HLS conversion."
+                        actval = int(actval)
+                        assert (not odt.signed()) or (
+                            actval < 0
+                        ), "Signed output requres actval < 0"
+                        model.set_tensor_shape(mm_input, mm_in_shape)
+                        model.set_tensor_shape(mt_output, mt_out_shape)
+                        # create and insert new Vector_Vector_Activate_Batch node
+                        new_node = helper.make_node(
+                            "Vector_Vector_Activate_Batch",
+                            [mm_input, mm_weight, mt_thres],
+                            [mt_output],
+                            domain="finn",
+                            backend="fpgadataflow",
+                            resType="ap_resource_lut()",
+                            PE=pe,
+                            Dim=mm_in_shape[1],
+                            Channels=channels,
+                            Kernel=k,
+                            inputDataType=idt.name,
+                            weightDataType=wdt.name,
+                            outputDataType=odt.name,
+                            ActVal=actval,
+                            noActivation=0,
+                        )
+                        graph.node.insert(node_ind, new_node)
+                        # remove old nodes
+                        graph.node.remove(n)
+                        graph.node.remove(consumer)
+                        graph_modified = True
+                    else:
+                        # no activation, matmul only
+                        odt = model.get_tensor_datatype(mm_output)
+                        model.set_tensor_shape(mm_input, mm_in_shape)
+                        model.set_tensor_shape(mm_output, mm_out_shape)
+                        # create and insert new StreamingFCLayer node
+                        new_node = helper.make_node(
+                            "Vector_Vector_Activate_Batch",
+                            [mm_input, mm_weight],
+                            [mm_output],
+                            domain="finn",
+                            backend="fpgadataflow",
+                            resType="ap_resource_lut()",
+                            PE=pe,
+                            Dim=mm_in_shape[1],
+                            Channels=channels,
+                            Kernel=k,
+                            inputDataType=idt.name,
+                            weightDataType=wdt.name,
+                            outputDataType=odt.name,
+                            ActVal=0,
+                            noActivation=1,
+                        )
+                        graph.node.insert(node_ind, new_node)
+                        # remove old node
+                        graph.node.remove(n)
+                        graph_modified = True
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
+        return (model, graph_modified)
+
+
 class InferThresholdingLayer(Transformation):
     """Convert any MultiThreshold into a standalone thresholding HLS layer."""
 
-- 
GitLab