diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py index dc1a05a018a2767767cf2f2b811d763539df0841..bd1605c96d8bb7a0687686b48912bd3b61a8f6cc 100644 --- a/src/finn/custom_op/registry.py +++ b/src/finn/custom_op/registry.py @@ -1,7 +1,7 @@ # make sure new CustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure from finn.custom_op.fpgadataflow.convolutioninputgenerator import ( - ConvolutionInputGenerator + ConvolutionInputGenerator, ) from finn.custom_op.fpgadataflow.streamingfclayer_batch import StreamingFCLayer_Batch from finn.custom_op.fpgadataflow.streamingmaxpool_batch import StreamingMaxPool_Batch @@ -18,3 +18,15 @@ custom_op["StreamingMaxPool_Batch"] = StreamingMaxPool_Batch custom_op["StreamingFCLayer_Batch"] = StreamingFCLayer_Batch custom_op["ConvolutionInputGenerator"] = ConvolutionInputGenerator custom_op["TLastMarker"] = TLastMarker + + +def getCustomOp(node): + "Return a FINN CustomOp instance for the given ONNX node, if it exists." + op_type = node.op_type + try: + # lookup op_type in registry of CustomOps + inst = custom_op[op_type](node) + return inst + except KeyError: + # exception if op_type is not supported + raise Exception("Custom op_type %s is currently not supported." % op_type)