From 4e84e8c9917a4f25229fff3dd26e1b8f90044436 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Fri, 5 Jun 2020 11:29:33 +0100
Subject: [PATCH] [CustomOp] First draft of QuantAvgPool2d custom op

---
 src/finn/custom_op/quantavgpool2d.py | 60 ++++++++++++++++++++++++++++
 1 file changed, 60 insertions(+)
 create mode 100644 src/finn/custom_op/quantavgpool2d.py

diff --git a/src/finn/custom_op/quantavgpool2d.py b/src/finn/custom_op/quantavgpool2d.py
new file mode 100644
index 000000000..359b11e7c
--- /dev/null
+++ b/src/finn/custom_op/quantavgpool2d.py
@@ -0,0 +1,60 @@
+import numpy as np
+from onnx import TensorProto, helper
+
+from finn.custom_op import CustomOp
+from finn.custom_op.im2col import compute_conv_output_dim
+
+
+class QuantAvgPool2d(CustomOp):
+    """Class that corresponds to the quantized average pooling
+    layer from brevitas"""
+
+    def get_nodeattr_types(self):
+        return {
+            "stride": ("i", True, 1),
+            "kernel": ("i", True, 1),
+            "ibits": ("s", True, ""),
+            "obits": ("i", False, 0),
+            "signed": ("i", False, 0),
+        }
+
+    def make_shape_compatible_op(self, model):
+        node = self.onnx_node
+        inp = node.input[0]
+        ishape = model.get_tensor_shape(inp)
+        # we assume that the shape is (NCHW) and H=W
+        assert len(ishape) == 4, "Unexpected input shape for QuantAvgPool2d"
+        assert (
+            ishape[2] == ishape[3]
+        ), "QuantAvgPool2d for non-square images unsupported"
+        ch = ishape[1]
+        idim = ishape[2]
+        k = self.get_nodeattr("kernel")
+        stride = self.get_nodeattr("stride")
+        odim = compute_conv_output_dim(idim, k, stride)
+
+        # implement tensor with correct shape
+        values = np.random.randn(1, ch, odim, odim).astype(np.float32)
+        return helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=[node.output[0]],
+            value=helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.FLOAT,
+                dims=values.shape,
+                vals=values.flatten().astype(float),
+            ),
+        )
+
+    def infer_node_datatype(self, model):
+        node = self.onnx_node
+        # data type stays the same
+        dtype = model.get_tensor_datatype(node.input[0])
+        model.set_tensor_datatype(node.output[0], dtype)
+
+    def execute_node(self, context, graph):
+        pass
+
+    def verify_node(self):
+        pass
-- 
GitLab