Skip to content
Snippets Groups Projects
Commit 070bd059 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[ShapeInf] use a different approach for custom op shape inference

unfortunately some bugs in onnxruntime made the previous approach
crash with certain test topologies, so we're opting for this
instead.
parent 5c828386
No related branches found
No related tags found
No related merge requests found
......@@ -4,53 +4,50 @@ import onnx.shape_inference as si
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
def infer_shapes(model):
"""Ensure every tensor in the model has a specified shape (ValueInfo)."""
def _make_shape_compatible_op(node):
"""Return a shape-compatible non-FINN op for a given FINN op. Used for
shape inference with custom ops."""
assert(node.domain == "finn")
if node.op_type == "MultiThreshold":
return helper.make_node("ReLU", [node.input[0]], [node.output[0]])
else:
raise Exception("No known shape-compatible op for %s" % node.op_type)
def _hide_finn_ops(model):
"""Replace any FINN ops by shape-compatible ones, and return a dict that
can be used to map the string representations of the new (shape-compatible)
ops back to the old ops."""
hidden_ops = {}
node_ind = 0
for node in model.graph.node:
node_ind += 1
if node.domain == "finn":
new_node = _make_shape_compatible_op(node)
hidden_ops[str(new_node)] = node
model.graph.node.insert(node_ind, new_node)
model.graph.node.remove(node)
return hidden_ops
def _restore_finn_ops(model, hidden_ops):
"""Replace any shape-compatible ops with the FINN ops that originally
generated them."""
node_ind = 0
for node in model.graph.node:
node_ind += 1
try:
old_node = hidden_ops[str(node)]
model.graph.node.insert(node_ind, old_node)
model.graph.node.remove(node)
except KeyError:
pass
# create an empty execution context
execution_context = model.make_empty_exec_context()
# execute node with empty context
oxe.execute_node(node, execution_context, model.graph)
# set the tensor shape for all outputs of the node
for output in node.output:
model.set_tensor_shape(output, execution_context[output].shape)
else:
# onnx shape inference unfortunately does not take single node,
# it can only analyze entire models -- so we create a model which solely
# consists of our current node.
node_inputs = list(
filter(lambda x: x.name in node.input, model.graph.input)
)
node_inputs += list(
filter(lambda x: x.name in node.input, model.graph.value_info)
)
node_outputs = list(
filter(lambda x: x.name in node.output, model.graph.output)
)
node_outputs += list(
filter(lambda x: x.name in node.output, model.graph.value_info)
)
node_graph = helper.make_graph(
nodes=[node],
name="single-node-exec",
inputs=node_inputs,
outputs=node_outputs,
)
node_model = helper.make_model(node_graph)
node_model = si.infer_shapes(node_model)
node_model = ModelWrapper(node_model)
# set the corresponding tensors in the whole model
for output in node.output:
model.set_tensor_shape(output, node_model.get_tensor_shape(output))
# single-step operation, no need to call multiple times so return
# model_was_changed = false
def infer_shapes(model):
"""Ensure every tensor in the model has a specified shape (ValueInfo)."""
# hide your riches!
hidden_ops = _hide_finn_ops(model)
# call regular ONNX shape inference
model = ModelWrapper(si.infer_shapes(model.model))
# bring back hidden ops
_restore_finn_ops(model, hidden_ops)
return (model, False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment