From e9653644ee0a4eec909d85017637b5c10410e711 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Tue, 24 Mar 2020 10:51:58 +0000 Subject: [PATCH] [ShapeInf] pass model to CustomOp make_shape_compatible_op fxns --- src/finn/custom_op/__init__.py | 2 +- .../custom_op/fpgadataflow/convolutioninputgenerator.py | 2 +- src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py | 2 +- src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py | 2 +- src/finn/custom_op/fpgadataflow/tlastmarker.py | 2 +- src/finn/custom_op/im2col.py | 2 +- src/finn/custom_op/maxpoolnhwc.py | 2 +- src/finn/custom_op/multithreshold.py | 2 +- src/finn/custom_op/streamingdataflowpartition.py | 6 ++---- src/finn/custom_op/xnorpopcount.py | 2 +- src/finn/transformation/infer_shapes.py | 6 +++--- 11 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/finn/custom_op/__init__.py b/src/finn/custom_op/__init__.py index 39de40f1e..ab6e03bee 100644 --- a/src/finn/custom_op/__init__.py +++ b/src/finn/custom_op/__init__.py @@ -94,7 +94,7 @@ class CustomOp(ABC): pass @abstractmethod - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): """Returns a standard ONNX op which is compatible with this CustomOp for performing shape inference.""" pass diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py index 14016ce9c..55daff5f7 100644 --- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py +++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py @@ -58,7 +58,7 @@ class ConvolutionInputGenerator(HLSCustomOp): my_attrs.update(super().get_nodeattr_types()) return my_attrs - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): pass def infer_node_datatype(self, model): diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index 9b4dfe69f..96465b9de 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -98,7 +98,7 @@ class StreamingFCLayer_Batch(HLSCustomOp): pe = self.get_nodeattr("PE") return mh // pe - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): pass def infer_node_datatype(self, model): diff --git a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py index 43951332d..c3b3f7dce 100644 --- a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py @@ -41,7 +41,7 @@ class StreamingMaxPool_Batch(HLSCustomOp): my_attrs.update(super().get_nodeattr_types()) return my_attrs - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): pass def infer_node_datatype(self, model): diff --git a/src/finn/custom_op/fpgadataflow/tlastmarker.py b/src/finn/custom_op/fpgadataflow/tlastmarker.py index c0f599958..4d4dee650 100644 --- a/src/finn/custom_op/fpgadataflow/tlastmarker.py +++ b/src/finn/custom_op/fpgadataflow/tlastmarker.py @@ -59,7 +59,7 @@ class TLastMarker(HLSCustomOp): i_tensor = context[i_name] context[o_name] = i_tensor - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): # not supported for shape inference pass diff --git a/src/finn/custom_op/im2col.py b/src/finn/custom_op/im2col.py index e2fe918ab..6f425565f 100644 --- a/src/finn/custom_op/im2col.py +++ b/src/finn/custom_op/im2col.py @@ -78,7 +78,7 @@ class Im2Col(CustomOp): "pad_value": ("i", False, 0), } - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): k = self.get_nodeattr("kernel_size") stride = self.get_nodeattr("stride") ishape = self.get_nodeattr("input_shape") diff --git a/src/finn/custom_op/maxpoolnhwc.py b/src/finn/custom_op/maxpoolnhwc.py index 824b37159..7586a859c 100644 --- a/src/finn/custom_op/maxpoolnhwc.py +++ b/src/finn/custom_op/maxpoolnhwc.py @@ -39,7 +39,7 @@ class MaxPoolNHWC(CustomOp): # no specific attributes for MaxPoolNHWC return {} - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): raise Exception("MaxPoolNHWC does not yet support shape inference") def infer_node_datatype(self, model): diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py index 56c49e66f..37f8e0950 100644 --- a/src/finn/custom_op/multithreshold.py +++ b/src/finn/custom_op/multithreshold.py @@ -109,7 +109,7 @@ class MultiThreshold(CustomOp): "data_layout": ("s", False, "NCHW"), } - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): node = self.onnx_node return helper.make_node("Relu", [node.input[0]], [node.output[0]]) diff --git a/src/finn/custom_op/streamingdataflowpartition.py b/src/finn/custom_op/streamingdataflowpartition.py index 586537460..b63326d67 100644 --- a/src/finn/custom_op/streamingdataflowpartition.py +++ b/src/finn/custom_op/streamingdataflowpartition.py @@ -36,11 +36,9 @@ class StreamingDataflowPartition(CustomOp): bitfile by itself.""" def get_nodeattr_types(self): - return { - "model": ("s", True, ""), - } + return {"model": ("s", True, "")} - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): pass def infer_node_datatype(self, model): diff --git a/src/finn/custom_op/xnorpopcount.py b/src/finn/custom_op/xnorpopcount.py index 511a120b1..d15a7315f 100644 --- a/src/finn/custom_op/xnorpopcount.py +++ b/src/finn/custom_op/xnorpopcount.py @@ -65,7 +65,7 @@ class XnorPopcountMatMul(CustomOp): def get_nodeattr_types(self): return {} - def make_shape_compatible_op(self): + def make_shape_compatible_op(self, model): node = self.onnx_node return helper.make_node( "MatMul", [node.input[0], node.input[1]], [node.output[0]] diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py index 74a3e62e3..361ef7f6a 100644 --- a/src/finn/transformation/infer_shapes.py +++ b/src/finn/transformation/infer_shapes.py @@ -33,7 +33,7 @@ from finn.core.modelwrapper import ModelWrapper from finn.transformation import Transformation -def _make_shape_compatible_op(node): +def _make_shape_compatible_op(node, model): """Return a shape-compatible non-FINN op for a given FINN op. Used for shape inference with custom ops.""" assert node.domain == "finn", 'Node domain is not set to "finn".' @@ -41,7 +41,7 @@ def _make_shape_compatible_op(node): try: # lookup op_type in registry of CustomOps inst = registry.custom_op[op_type](node) - return inst.make_shape_compatible_op() + return inst.make_shape_compatible_op(model) except KeyError: # exception if op_type is not supported raise Exception("Custom op_type %s is currently not supported." % op_type) @@ -56,7 +56,7 @@ def _hide_finn_ops(model): for node in model.graph.node: node_ind += 1 if node.domain == "finn": - new_node = _make_shape_compatible_op(node) + new_node = _make_shape_compatible_op(node, model) hidden_ops[str(new_node)] = node model.graph.node.insert(node_ind, new_node) model.graph.node.remove(node) -- GitLab