Skip to content
Snippets Groups Projects
Commit 20b11dfe authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[ShapeInf] correct ReLU op_type typo in MultiThreshold shape inf

parent 7aadf843
No related branches found
No related tags found
No related merge requests found
import onnx.helper as helper
import onnx.shape_inference as si
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
def _make_shape_compatible_op(node):
"""Return a shape-compatible non-FINN op for a given FINN op. Used for
shape inference with custom ops."""
assert(node.domain == "finn")
assert node.domain == "finn"
if node.op_type == "MultiThreshold":
return helper.make_node("ReLU", [node.input[0]], [node.output[0]])
return helper.make_node("Relu", [node.input[0]], [node.output[0]])
else:
raise Exception("No known shape-compatible op for %s" % node.op_type)
def _hide_finn_ops(model):
"""Replace any FINN ops by shape-compatible ones, and return a dict that
can be used to map the string representations of the new (shape-compatible)
......@@ -28,6 +29,7 @@ def _hide_finn_ops(model):
model.graph.node.remove(node)
return hidden_ops
def _restore_finn_ops(model, hidden_ops):
"""Replace any shape-compatible ops with the FINN ops that originally
generated them."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment