diff --git a/src/finn/transformation/fold_constants.py b/src/finn/transformation/fold_constants.py
index a951f057ee6b123497c19b3ef2ab020cfd7e7b03..a0d21f1acdeb4810be7bf00cb18045bbc0a59a3b 100644
--- a/src/finn/transformation/fold_constants.py
+++ b/src/finn/transformation/fold_constants.py
@@ -13,7 +13,10 @@ def fold_constants(model):
         node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
         node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
         node_out = n.output[0]
-        if len(node_inp_dyn) == 0:
+        is_all_constant_inputs = len(node_inp_dyn) == 0
+        ishape = model.get_tensor_shape(n.input[0])
+        is_const_shape = (n.op_type == "Shape") and (ishape is not None)
+        if is_all_constant_inputs or is_const_shape:
             # this node has no dynamic inputs, only constant ones -- so we can
             # do constant folding.
             oxe.execute_node(n, execution_context, graph)