diff --git a/src/finn/transformation/fold_constants.py b/src/finn/transformation/fold_constants.py index a951f057ee6b123497c19b3ef2ab020cfd7e7b03..a0d21f1acdeb4810be7bf00cb18045bbc0a59a3b 100644 --- a/src/finn/transformation/fold_constants.py +++ b/src/finn/transformation/fold_constants.py @@ -13,7 +13,10 @@ def fold_constants(model): node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input)) node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits)) node_out = n.output[0] - if len(node_inp_dyn) == 0: + is_all_constant_inputs = len(node_inp_dyn) == 0 + ishape = model.get_tensor_shape(n.input[0]) + is_const_shape = (n.op_type == "Shape") and (ishape is not None) + if is_all_constant_inputs or is_const_shape: # this node has no dynamic inputs, only constant ones -- so we can # do constant folding. oxe.execute_node(n, execution_context, graph)