diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py index 0043e063e0c2a986e3e8f491b8ea3876348fa384..44ec42049f606cefd4f4b0ef6f585d5591f2bd9f 100644 --- a/src/finn/transformation/infer_shapes.py +++ b/src/finn/transformation/infer_shapes.py @@ -1,18 +1,19 @@ import onnx.helper as helper import onnx.shape_inference as si -import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper + def _make_shape_compatible_op(node): """Return a shape-compatible non-FINN op for a given FINN op. Used for shape inference with custom ops.""" - assert(node.domain == "finn") + assert node.domain == "finn" if node.op_type == "MultiThreshold": - return helper.make_node("ReLU", [node.input[0]], [node.output[0]]) + return helper.make_node("Relu", [node.input[0]], [node.output[0]]) else: raise Exception("No known shape-compatible op for %s" % node.op_type) + def _hide_finn_ops(model): """Replace any FINN ops by shape-compatible ones, and return a dict that can be used to map the string representations of the new (shape-compatible) @@ -28,6 +29,7 @@ def _hide_finn_ops(model): model.graph.node.remove(node) return hidden_ops + def _restore_finn_ops(model, hidden_ops): """Replace any shape-compatible ops with the FINN ops that originally generated them."""