From dc8d8b7f09883066f1d684eccfc4f129666f1519 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sat, 16 Nov 2019 21:28:47 +0000
Subject: [PATCH] [CustomOp] new infrastructure for custom ops, impl MultiThres

---
 src/finn/custom_op/__init__.py       | 18 ++++++
 src/finn/custom_op/multithreshold.py | 91 ++++++++++++++++++++++++++++
 src/finn/custom_op/registry.py       |  8 +++
 3 files changed, 117 insertions(+)
 create mode 100644 src/finn/custom_op/__init__.py
 create mode 100644 src/finn/custom_op/multithreshold.py
 create mode 100644 src/finn/custom_op/registry.py

diff --git a/src/finn/custom_op/__init__.py b/src/finn/custom_op/__init__.py
new file mode 100644
index 000000000..fd7024aff
--- /dev/null
+++ b/src/finn/custom_op/__init__.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+
+
+class CustomOp(ABC):
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def make_shape_compatible_op(self, node):
+        pass
+
+    @abstractmethod
+    def infer_node_datatype(self, node, model):
+        pass
+
+    @abstractmethod
+    def execute_node(self, node, context, graph):
+        pass
diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py
new file mode 100644
index 000000000..ec8f1ad4c
--- /dev/null
+++ b/src/finn/custom_op/multithreshold.py
@@ -0,0 +1,91 @@
+import numpy as np
+import onnx.helper as helper
+
+from finn.core.datatype import DataType
+from finn.core.utils import get_by_name
+from finn.custom_op import CustomOp
+
+
+class MultiThreshold(CustomOp):
+    def make_shape_compatible_op(self, node):
+        return helper.make_node("Relu", [node.input[0]], [node.output[0]])
+
+    def infer_node_datatype(self, node, model):
+        try:
+            odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
+            model.set_tensor_datatype(node.output[0], DataType[odt])
+        except AttributeError:
+            # number of thresholds decides # output bits
+            # use get_smallest_possible, assuming unsigned
+            n_thres = model.get_tensor_shape(node.input[1])[1]
+            odtype = DataType.get_smallest_possible(n_thres)
+            model.set_tensor_datatype(node.output[0], odtype)
+
+    def execute_node(self, node, context, graph):
+        # save inputs
+        v = context[node.input[0]]
+        thresholds = context[node.input[1]]
+        # retrieve attributes if output scaling is used
+        try:
+            out_scale = get_by_name(node.attribute, "out_scale").f
+        except AttributeError:
+            out_scale = None
+        try:
+            out_bias = get_by_name(node.attribute, "out_bias").f
+        except AttributeError:
+            out_bias = None
+        # calculate output
+        output = self._execute(v, thresholds, out_scale, out_bias)
+        # setting context according to output
+        context[node.output[0]] = output
+
+    def _compare(self, x, y):
+        if x >= y:
+            return 1.0
+        else:
+            return 0.0
+
+    def _execute(self, v, thresholds, out_scale=None, out_bias=None):
+        # the inputs are expected to be in the shape (N,C,H,W)
+        # N : Batch size
+        # C : Number of channels
+        # H : Heigth of the input images
+        # W : Width of the input images
+        #
+        # the thresholds are expected to be in the shape (C, B)
+        # C : Number of channels (must be the same value as C in input tensor
+        #     or 1 if all channels use the same threshold value)
+        # B : Desired activation steps => i.e. for 4-bit activation,
+        #     B=7 (2^(n)-1 and n=4)
+        # the output tensor will be scaled by out_scale and biased by out_bias
+        # assert threshold shape
+        is_global_threshold = thresholds.shape[0] == 1
+        assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
+        # save the required shape sizes for the loops (N, C and B)
+        num_batch = v.shape[0]
+        num_channel = v.shape[1]
+        num_act = thresholds.shape[1]
+        # reshape inputs to enable channel-wise reading
+        vr = v.reshape((v.shape[0], v.shape[1], -1))
+        # save the new shape size of the images
+        num_img_elem = vr.shape[2]
+        # initiate output tensor
+        ret = np.zeros_like(vr)
+        # iterate over thresholds channel-wise
+        for t in range(num_channel):
+            channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
+            # iterate over batches
+            for b in range(num_batch):
+                # iterate over image elements on which the thresholds will be applied
+                for elem in range(num_img_elem):
+                    # iterate over the different thresholds for one channel
+                    for a in range(num_act):
+                        # apply successive thresholding to every element
+                        ret[b][t][elem] += self._compare(
+                            vr[b][t][elem], channel_thresh[a]
+                        )
+        if out_scale is None:
+            out_scale = 1.0
+        if out_bias is None:
+            out_bias = 0.0
+        return out_scale * ret.reshape(v.shape) + out_bias
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
new file mode 100644
index 000000000..07bdbf1ca
--- /dev/null
+++ b/src/finn/custom_op/registry.py
@@ -0,0 +1,8 @@
+# make sure new CustomOp subclasses are imported here so that they get
+# registered and plug in correctly into the infrastructure
+from finn.custom_op.multithreshold import MultiThreshold
+
+# create a mapping of all known CustomOp names and classes
+custom_op = {}
+
+custom_op["MultiThreshold"] = MultiThreshold
-- 
GitLab