Skip to content
Snippets Groups Projects
Commit b1a30a7d authored by auphelia's avatar auphelia
Browse files

[Test] Changed verification method to a check if all tensor shapes are set and...

[Test] Changed verification method to a check if all tensor shapes are set and the output of the new node has specific value
parent 39350349
No related branches found
No related tags found
No related merge requests found
from pkgutil import get_data
import numpy as np
import onnx
import onnx.numpy_helper as np_helper
from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
import finn.transformation.infer_shapes as si
from finn.core.modelwrapper import ModelWrapper
......@@ -15,26 +12,16 @@ def test_infer_shapes():
raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
model = ModelWrapper(raw_m)
graph = model.graph
node_ind = 0
node_dict = {}
for n in graph.node:
node_ind += 1
node_dict[node_ind] = n
# multi-thresholding node to be inserted between the first Relu and MaxPool node
# get Relu node to use data to make a new Relu node and delete the old one
Relu_node = node_dict[4]
# get Relu node to use data
Relu_node = graph.node[3]
assert Relu_node.op_type == "Relu", "The wrong model was chosen for the check"
graph.node.remove(Relu_node)
# create new tensors (thresholds as constant) and add them to the graph info
mt_v0 = helper.make_tensor_value_info("mt_v0", TensorProto.FLOAT, [1, 8, 28, 28])
# create thresholds tensor as constant
mt_thresh0 = helper.make_tensor_value_info("mt_thresh0", TensorProto.FLOAT, [8, 7])
graph.value_info.append(mt_v0)
graph.value_info.append(mt_thresh0)
# random numbers for the thresholds
# thresholds for one channel have to be sorted to guarantee the correct behavior
mt_thresh0_values = np.empty([8, 7], dtype=np.float32)
......@@ -43,22 +30,28 @@ def test_infer_shapes():
model.set_initializer(mt_thresh0.name, mt_thresh0_values)
# create and insert new Relu node and one multi-thresholding node
new_Relu_node = helper.make_node("Relu", [Relu_node.input[0]], ["mt_v0"])
# add multi-thresholding node and change Relu node
mt_node = helper.make_node(
"MultiThreshold", ["mt_v0", "mt_thresh0"], [Relu_node.output[0]], domain="finn"
)
Relu_node.output[0] = "mt_v0"
graph.node.insert(4, new_Relu_node)
graph.node.insert(5, mt_node)
graph.node.insert(4, mt_node)
# test shape inference on mixed model
model = model.transform_single(si.infer_shapes)
# first check routine
# check if at least one shape is not specified
assert not (
model.check_all_tensor_shapes_specified()
), "All tensors are already specified before the shape inference execution"
# execution with input values from mnist-conv model
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
# perform shape inference on mixed model
model = model.transform_single(si.infer_shapes)
# run using FINN-based execution
input_dict = {"Input3": np_helper.to_array(input_tensor)}
oxe.execute_onnx(model, input_dict)
# second check routine
# now all shapes should be specified and mt_node output shape is (1,8,28,28)
assert (
model.check_all_tensor_shapes_specified()
), "There are still tensors that are not specified"
assert (model.get_tensor_shape(mt_node.output[0])) == (
[1, 8, 28, 28]
), "output of multi-thresholding node has wrong shape"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment