Skip to content
Snippets Groups Projects
Commit 8ed8d8b2 authored by auphelia's avatar auphelia
Browse files

Merge branch 'dev' into feature/adjusting_clk_frequency

parents 8088d406 76a4d609
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ BREVITAS_COMMIT=215cf44c76d562339fca368c8c3afee3110033e8
BREVITAS_EXAMPLES_COMMIT=2059f96bd576bf71f32c757e7f92617a70190c90
CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af
PYVERILATOR_COMMIT=fb1afefa5b207acf6fec28f8abb72a862f2ca1d2
PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6
PYNQSHELL_COMMIT=0c82a61b0ec1a07fa275a14146233824ded7a13d
......
......@@ -79,3 +79,26 @@ def node_inputs_in_expected_order(model):
if n.op_type != "Add":
all_OK = all_OK and (model.get_initializer(n.input[1]) is not None)
return {"node_inputs_in_expected_order": all_OK}
def nodes_topologically_sorted(model):
"""Verifies that graph.node is topologically sorted. This is required by the
ONNX specification.
Returns {"nodes_topologically_sorted": Bool}."""
# get successors of every node and check that
# successor index > current node index
all_OK = True
for n in model.graph.node:
successors = model.find_direct_successors(n)
if successors is not None:
for successor in successors:
# check the condition by checking the antithesis
index_n = model.get_node_index(n)
index_suc = model.get_node_index(successor)
if index_n > index_suc:
all_OK = False
return {"nodes_topologically_sorted": all_OK}
......@@ -288,6 +288,46 @@ class ModelWrapper:
except ValueError:
return None
def find_consumers(self, tensor_name):
"""Finds and returns a list of the nodes that consume tensor with
given name."""
consumers = []
for n in self._model_proto.graph.node:
for inp_tensor in n.input:
if inp_tensor == tensor_name:
consumers.append(n)
if consumers != []:
return consumers
else:
return None
def find_direct_successors(self, node):
"""Finds and returns a list of the nodes that are successors of
given node."""
successors = []
for outp_tensor in node.output:
tensor_consumer_list = self.find_consumers(outp_tensor)
if tensor_consumer_list is not None:
for consumer in tensor_consumer_list:
successors.append(consumer)
if successors != []:
return successors
else:
return None
def find_direct_predecessors(self, node):
"""Finds and returns a list of the nodes that are predecessors of
given node."""
predecessors = []
for inp_tensor in node.input:
producer = self.find_producer(inp_tensor)
if producer is not None:
predecessors.append(producer)
if predecessors != []:
return predecessors
else:
return None
def get_all_tensor_names(self):
"""Returns a list of all (input, output and value_info) tensor names
in the graph."""
......@@ -383,3 +423,14 @@ class ModelWrapper:
def get_non_finn_nodes(self):
"""Returns a list of nodes where domain != 'finn'."""
return list(filter(lambda x: x.domain != "finn", self.graph.node))
def get_node_index(self, node):
"""Returns current index of given node."""
n_ind = 0
try:
for n in self.graph.node:
if n == node:
return n_ind
n_ind += 1
except ValueError:
return None
......@@ -38,6 +38,7 @@ from finn.core.modelwrapper import ModelWrapper
from finn.core.remote_exec import remote_exec
from finn.core.rtlsim_exec import rtlsim_exec
from finn.custom_op.registry import getCustomOp
import finn.analysis.topology as ta
def execute_node(node, context, graph):
......@@ -121,6 +122,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False):
if not model.check_all_tensor_shapes_specified():
raise Exception("Found unspecified tensor shapes, try infer_shapes")
ret = model.analysis(ta.nodes_topologically_sorted)
assert (
ret["nodes_topologically_sorted"] is True
), """Nodes must be
topologically sorted."""
graph = model.graph
# first, we need to make sure that every variable required by the graph has
......
......@@ -26,11 +26,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
from pkgutil import get_data
import onnx.helper as oh
from onnx import TensorProto
import brevitas.onnx as bo
from finn.util.test import get_test_model_trained
import finn.analysis.topology as ta
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.infer_shapes import InferShapes
......@@ -88,3 +90,116 @@ def test_node_inputs_in_expected_order():
# this model has an (unnecessary) dynamic reshape for its weight tensor
# and so it fails the check
assert ret["node_inputs_in_expected_order"] is False
def test_nodes_topologically_sorted():
# test analysis pass (nodes_topologically_sorted) with different models
# test with data/onnx/finn-hls-model/tfc_w1_a1_after_conv_to_hls.onnx
raw_m = get_data(
"finn", "data/onnx/finn-hls-model/tfc_w1_a1_after_conv_to_hls.onnx"
)
model = ModelWrapper(raw_m)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is True
# remove first node and add it at the end
graph = model.graph
first_node = graph.node[0]
graph.node.remove(first_node)
graph.node.append(first_node)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is False
# test with data/onnx/mnist-conv/model.onnx
raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
model = ModelWrapper(raw_m)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is True
# remove first node and add it at the end
graph = model.graph
first_node = graph.node[0]
graph.node.remove(first_node)
graph.node.append(first_node)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is False
# test with manually created small network
Neg_node = oh.make_node("Neg", inputs=["in1"], outputs=["neg1"])
Round_node = oh.make_node("Round", inputs=["neg1"], outputs=["round1"])
Ceil_node = oh.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"])
Add_node = oh.make_node("Add", inputs=["round1", "ceil1"], outputs=["out1"])
in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, [4, 4])
out1 = oh.make_tensor_value_info("out1", TensorProto.FLOAT, [4, 4])
graph = oh.make_graph(
nodes=[Neg_node, Round_node, Ceil_node, Add_node],
name="simple_graph",
inputs=[in1],
outputs=[out1],
value_info=[
oh.make_tensor_value_info("neg1", TensorProto.FLOAT, [4, 4]),
oh.make_tensor_value_info("round1", TensorProto.FLOAT, [4, 4]),
oh.make_tensor_value_info("ceil1", TensorProto.FLOAT, [4, 4]),
],
)
onnx_model = oh.make_model(graph, producer_name="simple-model")
model = ModelWrapper(onnx_model)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is True
# create same graph but with "wrong" node order
graph = oh.make_graph(
nodes=[Round_node, Ceil_node, Neg_node, Add_node],
name="simple_graph",
inputs=[in1],
outputs=[out1],
value_info=[
oh.make_tensor_value_info("neg1", TensorProto.FLOAT, [4, 4]),
oh.make_tensor_value_info("round1", TensorProto.FLOAT, [4, 4]),
oh.make_tensor_value_info("ceil1", TensorProto.FLOAT, [4, 4]),
],
)
onnx_model = oh.make_model(graph, producer_name="simple-model")
model = ModelWrapper(onnx_model)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is False
# test with data/onnx/finn-hls-model/finn-hls-onnx-model.onnx
raw_m = get_data("finn", "data/onnx/finn-hls-model/finn-hls-onnx-model.onnx")
model = ModelWrapper(raw_m)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is True
# remove first node and add it at the end
graph = model.graph
first_node = graph.node[0]
graph.node.remove(first_node)
graph.node.append(first_node)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is False
# test with cnv_w1a1
build_dir = "/tmp/" + os.environ["FINN_INST_NAME"]
cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(
cnv, (1, 3, 32, 32), build_dir + "/end2end_cnv_w1a1_export.onnx"
)
model = ModelWrapper(build_dir + "/end2end_cnv_w1a1_export.onnx")
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is True
# remove first node and add it at the end
graph = model.graph
first_node = graph.node[0]
graph.node.remove(first_node)
graph.node.append(first_node)
ret = model.analysis(ta.nodes_topologically_sorted)
assert ret["nodes_topologically_sorted"] is False
import os
import numpy as np
import torch
import brevitas.onnx as bo
......@@ -41,3 +42,4 @@ def test_brevitas_act_export(abits, narrow_range, max_val):
inp_tensor = torch.from_numpy(inp_tensor).float()
expected = b_act.forward(inp_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
......@@ -27,8 +27,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import onnx
from collections import Counter
import brevitas.onnx as bo
import numpy as np
......@@ -68,3 +68,56 @@ def test_modelwrapper():
out_prod = model.find_producer(l0_inp_tensor_name)
assert out_prod.op_type == "Sign"
os.remove(export_onnx_path)
def test_modelwrapper_graph_order():
# create small network with properties to be tested
Neg_node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["neg1"],)
Round_node = onnx.helper.make_node("Round", inputs=["neg1"], outputs=["round1"],)
Ceil_node = onnx.helper.make_node("Ceil", inputs=["neg1"], outputs=["ceil1"],)
Add_node = onnx.helper.make_node(
"Add", inputs=["round1", "ceil1"], outputs=["out1"],
)
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
graph = onnx.helper.make_graph(
nodes=[Neg_node, Round_node, Ceil_node, Add_node],
name="simple_graph",
inputs=[in1],
outputs=[out1],
value_info=[
onnx.helper.make_tensor_value_info("neg1", onnx.TensorProto.FLOAT, [4, 4]),
onnx.helper.make_tensor_value_info(
"round1", onnx.TensorProto.FLOAT, [4, 4]
),
onnx.helper.make_tensor_value_info("ceil1", onnx.TensorProto.FLOAT, [4, 4]),
],
)
onnx_model = onnx.helper.make_model(graph, producer_name="simple-model")
model = ModelWrapper(onnx_model)
# test graph order functions
assert model.find_consumers("in1") == [Neg_node]
assert model.find_consumers("neg1") == [Round_node, Ceil_node]
assert model.find_consumers("round1") == [Add_node]
assert model.find_consumers("ceil1") == [Add_node]
assert model.find_consumers("out1") is None
assert model.find_direct_successors(Neg_node) == [Round_node, Ceil_node]
assert model.find_direct_successors(Round_node) == [Add_node]
assert model.find_direct_successors(Ceil_node) == [Add_node]
assert model.find_direct_successors(Add_node) is None
assert model.find_direct_predecessors(Neg_node) is None
assert model.find_direct_predecessors(Round_node) == [Neg_node]
assert model.find_direct_predecessors(Ceil_node) == [Neg_node]
assert model.find_direct_predecessors(Add_node) == [Round_node, Ceil_node]
assert model.get_node_index(Neg_node) == 0
assert model.get_node_index(Round_node) == 1
assert model.get_node_index(Ceil_node) == 2
assert model.get_node_index(Add_node) == 3
import os
import onnx
from finn.util.test import get_test_model_trained
import brevitas.onnx as bo
......@@ -56,3 +57,4 @@ def test_topk_insert(k):
output_pysim_topk = output_pysim_topk.astype(np.int).flatten()
assert np.array_equal(output_golden_topk, output_pysim_topk)
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment