Skip to content
Snippets Groups Projects
Commit f5fbd0bc authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] all CustomOps take an onnx_node in constructor

parent 894039f0
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ def execute_custom_node(node, context, graph):
op_type = node.op_type
try:
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type]()
inst = registry.custom_op[op_type](node)
inst.execute_node(node, context, graph)
except KeyError:
# exception if op_type is not supported
......
......@@ -2,8 +2,10 @@ from abc import ABC, abstractmethod
class CustomOp(ABC):
def __init__(self):
def __init__(self, onnx_node):
super().__init__()
self.onnx_node = onnx_node
# TODO consider specifying a list of allowed attributes
@abstractmethod
def make_shape_compatible_op(self, node):
......
......@@ -4,8 +4,8 @@ from finn.custom_op import CustomOp
class HLSCustomOp(CustomOp):
def __init__(self):
super().__init__()
def __init__(self, onnx_node):
super().__init__(onnx_node)
# template for single node execution
self.docompute_template = """
#include "cnpy.h"
......@@ -35,7 +35,7 @@ class HLSCustomOp(CustomOp):
"""
self.code_gen_dict = {}
self.tmp_dir = " "
self.tmp_dir = " "
@abstractmethod
def get_attributes(self, node):
......
......@@ -11,8 +11,8 @@ from finn.custom_op.fpgadataflow import HLSCustomOp
class StreamingFCLayer_Batch(HLSCustomOp):
def __init__(self):
super().__init__()
def __init__(self, onnx_node):
super().__init__(onnx_node)
self.WMEM = 0
self.TMEM = 0
......
......@@ -13,7 +13,7 @@ def _infer_node_datatype(model, node):
# handle DataType inference for CustomOp
try:
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type]()
inst = registry.custom_op[op_type](node)
inst.infer_node_datatype(node, model)
except KeyError:
# exception if op_type is not supported
......
......@@ -5,7 +5,6 @@ from finn.core.modelwrapper import ModelWrapper
from finn.transformation import Transformation
def _make_shape_compatible_op(node):
"""Return a shape-compatible non-FINN op for a given FINN op. Used for
shape inference with custom ops."""
......@@ -13,7 +12,7 @@ def _make_shape_compatible_op(node):
op_type = node.op_type
try:
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type]()
inst = registry.custom_op[op_type](node)
return inst.make_shape_compatible_op(node)
except KeyError:
# exception if op_type is not supported
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment