From e9653644ee0a4eec909d85017637b5c10410e711 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Tue, 24 Mar 2020 10:51:58 +0000
Subject: [PATCH] [ShapeInf] pass model to CustomOp make_shape_compatible_op
 fxns

---
 src/finn/custom_op/__init__.py                              | 2 +-
 .../custom_op/fpgadataflow/convolutioninputgenerator.py     | 2 +-
 src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py   | 2 +-
 src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py   | 2 +-
 src/finn/custom_op/fpgadataflow/tlastmarker.py              | 2 +-
 src/finn/custom_op/im2col.py                                | 2 +-
 src/finn/custom_op/maxpoolnhwc.py                           | 2 +-
 src/finn/custom_op/multithreshold.py                        | 2 +-
 src/finn/custom_op/streamingdataflowpartition.py            | 6 ++----
 src/finn/custom_op/xnorpopcount.py                          | 2 +-
 src/finn/transformation/infer_shapes.py                     | 6 +++---
 11 files changed, 14 insertions(+), 16 deletions(-)

diff --git a/src/finn/custom_op/__init__.py b/src/finn/custom_op/__init__.py
index 39de40f1e..ab6e03bee 100644
--- a/src/finn/custom_op/__init__.py
+++ b/src/finn/custom_op/__init__.py
@@ -94,7 +94,7 @@ class CustomOp(ABC):
         pass
 
     @abstractmethod
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         """Returns a standard ONNX op which is compatible with this CustomOp
         for performing shape inference."""
         pass
diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
index 14016ce9c..55daff5f7 100644
--- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
+++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
@@ -58,7 +58,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
         my_attrs.update(super().get_nodeattr_types())
         return my_attrs
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         pass
 
     def infer_node_datatype(self, model):
diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 9b4dfe69f..96465b9de 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -98,7 +98,7 @@ class StreamingFCLayer_Batch(HLSCustomOp):
             pe = self.get_nodeattr("PE")
             return mh // pe
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         pass
 
     def infer_node_datatype(self, model):
diff --git a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
index 43951332d..c3b3f7dce 100644
--- a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
@@ -41,7 +41,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
         my_attrs.update(super().get_nodeattr_types())
         return my_attrs
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         pass
 
     def infer_node_datatype(self, model):
diff --git a/src/finn/custom_op/fpgadataflow/tlastmarker.py b/src/finn/custom_op/fpgadataflow/tlastmarker.py
index c0f599958..4d4dee650 100644
--- a/src/finn/custom_op/fpgadataflow/tlastmarker.py
+++ b/src/finn/custom_op/fpgadataflow/tlastmarker.py
@@ -59,7 +59,7 @@ class TLastMarker(HLSCustomOp):
         i_tensor = context[i_name]
         context[o_name] = i_tensor
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         # not supported for shape inference
         pass
 
diff --git a/src/finn/custom_op/im2col.py b/src/finn/custom_op/im2col.py
index e2fe918ab..6f425565f 100644
--- a/src/finn/custom_op/im2col.py
+++ b/src/finn/custom_op/im2col.py
@@ -78,7 +78,7 @@ class Im2Col(CustomOp):
             "pad_value": ("i", False, 0),
         }
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         k = self.get_nodeattr("kernel_size")
         stride = self.get_nodeattr("stride")
         ishape = self.get_nodeattr("input_shape")
diff --git a/src/finn/custom_op/maxpoolnhwc.py b/src/finn/custom_op/maxpoolnhwc.py
index 824b37159..7586a859c 100644
--- a/src/finn/custom_op/maxpoolnhwc.py
+++ b/src/finn/custom_op/maxpoolnhwc.py
@@ -39,7 +39,7 @@ class MaxPoolNHWC(CustomOp):
         # no specific attributes for MaxPoolNHWC
         return {}
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         raise Exception("MaxPoolNHWC does not yet support shape inference")
 
     def infer_node_datatype(self, model):
diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py
index 56c49e66f..37f8e0950 100644
--- a/src/finn/custom_op/multithreshold.py
+++ b/src/finn/custom_op/multithreshold.py
@@ -109,7 +109,7 @@ class MultiThreshold(CustomOp):
             "data_layout": ("s", False, "NCHW"),
         }
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         node = self.onnx_node
         return helper.make_node("Relu", [node.input[0]], [node.output[0]])
 
diff --git a/src/finn/custom_op/streamingdataflowpartition.py b/src/finn/custom_op/streamingdataflowpartition.py
index 586537460..b63326d67 100644
--- a/src/finn/custom_op/streamingdataflowpartition.py
+++ b/src/finn/custom_op/streamingdataflowpartition.py
@@ -36,11 +36,9 @@ class StreamingDataflowPartition(CustomOp):
     bitfile by itself."""
 
     def get_nodeattr_types(self):
-        return {
-            "model": ("s", True, ""),
-        }
+        return {"model": ("s", True, "")}
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         pass
 
     def infer_node_datatype(self, model):
diff --git a/src/finn/custom_op/xnorpopcount.py b/src/finn/custom_op/xnorpopcount.py
index 511a120b1..d15a7315f 100644
--- a/src/finn/custom_op/xnorpopcount.py
+++ b/src/finn/custom_op/xnorpopcount.py
@@ -65,7 +65,7 @@ class XnorPopcountMatMul(CustomOp):
     def get_nodeattr_types(self):
         return {}
 
-    def make_shape_compatible_op(self):
+    def make_shape_compatible_op(self, model):
         node = self.onnx_node
         return helper.make_node(
             "MatMul", [node.input[0], node.input[1]], [node.output[0]]
diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py
index 74a3e62e3..361ef7f6a 100644
--- a/src/finn/transformation/infer_shapes.py
+++ b/src/finn/transformation/infer_shapes.py
@@ -33,7 +33,7 @@ from finn.core.modelwrapper import ModelWrapper
 from finn.transformation import Transformation
 
 
-def _make_shape_compatible_op(node):
+def _make_shape_compatible_op(node, model):
     """Return a shape-compatible non-FINN op for a given FINN op. Used for
     shape inference with custom ops."""
     assert node.domain == "finn", 'Node domain is not set to "finn".'
@@ -41,7 +41,7 @@ def _make_shape_compatible_op(node):
     try:
         # lookup op_type in registry of CustomOps
         inst = registry.custom_op[op_type](node)
-        return inst.make_shape_compatible_op()
+        return inst.make_shape_compatible_op(model)
     except KeyError:
         # exception if op_type is not supported
         raise Exception("Custom op_type %s is currently not supported." % op_type)
@@ -56,7 +56,7 @@ def _hide_finn_ops(model):
     for node in model.graph.node:
         node_ind += 1
         if node.domain == "finn":
-            new_node = _make_shape_compatible_op(node)
+            new_node = _make_shape_compatible_op(node, model)
             hidden_ops[str(new_node)] = node
             model.graph.node.insert(node_ind, new_node)
             model.graph.node.remove(node)
-- 
GitLab