From 9052ed2d61dccc261f4b969bb1377c57514bb4c9 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Thu, 24 Oct 2019 11:59:15 +0100 Subject: [PATCH] [Analysis] add is_linear and get_per_tensor_fanouts --- src/finn/analysis/__init__.py | 8 ++++++++ src/finn/analysis/topology.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/finn/analysis/__init__.py create mode 100644 src/finn/analysis/topology.py diff --git a/src/finn/analysis/__init__.py b/src/finn/analysis/__init__.py new file mode 100644 index 000000000..18e1efb37 --- /dev/null +++ b/src/finn/analysis/__init__.py @@ -0,0 +1,8 @@ +""" +How to write an analysis pass for FINN +-------------------------------------- + +An analysis pass traverses the graph structure and produces information about +certain properties. The convention is to take in a ModelWrapper, and return +a dictionary of named properties that the analysis extracts. +""" diff --git a/src/finn/analysis/topology.py b/src/finn/analysis/topology.py new file mode 100644 index 000000000..fdfad4109 --- /dev/null +++ b/src/finn/analysis/topology.py @@ -0,0 +1,18 @@ +def is_linear(model): + """Checks whether the given model graph is linear. This is done by looking + at the fan-out of each tensor. All tensors have a fan-out <= 1 in a linear + graph. Returns {"is_linear", Bool}""" + per_tensor_fanouts = get_per_tensor_fanouts(model) + # check for tensors that have fanout > 1 + multi_fanouts = list(filter(lambda x: x[1] > 1, per_tensor_fanouts.items())) + return {"is_linear": len(multi_fanouts) == 0} + + +def get_per_tensor_fanouts(model): + """Returns a dictionary of (tensor_name, tensor_fanout) for the model.""" + # make execution context to get a list of tensors + per_tensor_fanouts = model.make_empty_exec_context() + # replace every tensor with its fanout + for tensor_name in per_tensor_fanouts.keys(): + per_tensor_fanouts[tensor_name] = model.get_tensor_fanout(tensor_name) + return per_tensor_fanouts -- GitLab