diff --git a/src/finn/analysis/__init__.py b/src/finn/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18e1efb37e9be82de26d933cccaf64a85bc8ff22 --- /dev/null +++ b/src/finn/analysis/__init__.py @@ -0,0 +1,8 @@ +""" +How to write an analysis pass for FINN +-------------------------------------- + +An analysis pass traverses the graph structure and produces information about +certain properties. The convention is to take in a ModelWrapper, and return +a dictionary of named properties that the analysis extracts. +""" diff --git a/src/finn/analysis/topology.py b/src/finn/analysis/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..fdfad410969cf69249e7fd3a8b436c864c09bfba --- /dev/null +++ b/src/finn/analysis/topology.py @@ -0,0 +1,18 @@ +def is_linear(model): + """Checks whether the given model graph is linear. This is done by looking + at the fan-out of each tensor. All tensors have a fan-out <= 1 in a linear + graph. Returns {"is_linear", Bool}""" + per_tensor_fanouts = get_per_tensor_fanouts(model) + # check for tensors that have fanout > 1 + multi_fanouts = list(filter(lambda x: x[1] > 1, per_tensor_fanouts.items())) + return {"is_linear": len(multi_fanouts) == 0} + + +def get_per_tensor_fanouts(model): + """Returns a dictionary of (tensor_name, tensor_fanout) for the model.""" + # make execution context to get a list of tensors + per_tensor_fanouts = model.make_empty_exec_context() + # replace every tensor with its fanout + for tensor_name in per_tensor_fanouts.keys(): + per_tensor_fanouts[tensor_name] = model.get_tensor_fanout(tensor_name) + return per_tensor_fanouts