diff --git a/src/finn/analysis/topology.py b/src/finn/analysis/topology.py index fdfad410969cf69249e7fd3a8b436c864c09bfba..afb7895c30bf091c8a97ac40ab8cd60025ba6c4c 100644 --- a/src/finn/analysis/topology.py +++ b/src/finn/analysis/topology.py @@ -1,3 +1,6 @@ +import numpy as np + + def is_linear(model): """Checks whether the given model graph is linear. This is done by looking at the fan-out of each tensor. All tensors have a fan-out <= 1 in a linear @@ -16,3 +19,11 @@ def get_per_tensor_fanouts(model): for tensor_name in per_tensor_fanouts.keys(): per_tensor_fanouts[tensor_name] = model.get_tensor_fanout(tensor_name) return per_tensor_fanouts + + +def all_tensors_f32(model): + """Checks whether all tensors have a float32 dtype, extra quantization + annotations notwithstanding.""" + all_tensors = model.make_empty_exec_context().items() + non_f32_tensors = filter(lambda x: x[1].dtype != np.float32, all_tensors) + return {"all_tensors_f32": len(list(non_f32_tensors)) == 0}