Skip to content
Snippets Groups Projects
Commit 9052ed2d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Analysis] add is_linear and get_per_tensor_fanouts

parent 9544e260
No related branches found
No related tags found
No related merge requests found
"""
How to write an analysis pass for FINN
--------------------------------------
An analysis pass traverses the graph structure and produces information about
certain properties. The convention is to take in a ModelWrapper, and return
a dictionary of named properties that the analysis extracts.
"""
def is_linear(model):
"""Checks whether the given model graph is linear. This is done by looking
at the fan-out of each tensor. All tensors have a fan-out <= 1 in a linear
graph. Returns {"is_linear", Bool}"""
per_tensor_fanouts = get_per_tensor_fanouts(model)
# check for tensors that have fanout > 1
multi_fanouts = list(filter(lambda x: x[1] > 1, per_tensor_fanouts.items()))
return {"is_linear": len(multi_fanouts) == 0}
def get_per_tensor_fanouts(model):
"""Returns a dictionary of (tensor_name, tensor_fanout) for the model."""
# make execution context to get a list of tensors
per_tensor_fanouts = model.make_empty_exec_context()
# replace every tensor with its fanout
for tensor_name in per_tensor_fanouts.keys():
per_tensor_fanouts[tensor_name] = model.get_tensor_fanout(tensor_name)
return per_tensor_fanouts
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment