diff --git a/src/finn/core/onnx_exec.py b/src/finn/core/onnx_exec.py
index cee9b703a8c85a405551bf8ea219f168a087dd5a..6c2bf6d053d3a24a5e457cd7152964bc5fa3e8d1 100644
--- a/src/finn/core/onnx_exec.py
+++ b/src/finn/core/onnx_exec.py
@@ -52,6 +52,13 @@ def execute_node(node, context, graph):
         sdp_node = getCustomOp(node)
         model = ModelWrapper(sdp_node.get_nodeattr("model"))
         inp_ctx = dict(filter(lambda x: x[0] in node.input, context.items()))
+        # input may have been renamed in partition
+        assert len(inp_ctx) == 1
+        old_iname = node.input[0]
+        new_iname = model.graph.input[0].name
+        if old_iname != new_iname:
+            inp_ctx[new_iname] = inp_ctx[old_iname]
+            del inp_ctx[old_iname]
         ret = execute_onnx(model, inp_ctx, False)
         context.update(ret)
     else: