Skip to content
Snippets Groups Projects
Commit e3d8a5f4 authored by auphelia's avatar auphelia
Browse files

[Test] Extend change_datalayout test to check layout annotation after InferDataLayout transform

parent 5905f974
No related branches found
No related tags found
No related merge requests found
......@@ -31,9 +31,11 @@ from onnx import helper, TensorProto
from finn.custom_op.maxpoolnhwc import compute_pool_output_dim
from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
import finn.core.data_layout as DataLayout
from finn.transformation.change_datalayout import ChangeDataLayoutQuantAvgPool2d
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
from finn.util.basic import gen_finn_dt_tensor
from finn.util.basic import get_by_name
......@@ -87,11 +89,13 @@ def test_change_datalayout_quantavgpool(s, k, ibits, obits, signed, c, idim):
model = ModelWrapper(model)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model_transformed = model.transform(ChangeDataLayoutQuantAvgPool2d())
model_transformed = model_transformed.transform(InferShapes())
model_transformed = model_transformed.transform(InferDataTypes())
model_transformed = model_transformed.transform(InferDataLayouts())
model_transformed = model_transformed.transform(GiveUniqueNodeNames())
model_transformed = model_transformed.transform(GiveReadableTensorNames())
inp_values = gen_finn_dt_tensor(dtype, [n, c, idim, idim])
......@@ -100,6 +104,9 @@ def test_change_datalayout_quantavgpool(s, k, ibits, obits, signed, c, idim):
assert len(model.graph.node) + 2 == len(model_transformed.graph.node)
assert model_transformed.graph.node[-1].op_type == "Transpose"
assert model_transformed.graph.node[0].op_type == "Transpose"
quant_node = model_transformed.graph.node[1]
d_layout = get_by_name(quant_node.attribute, "data_layout").s.decode("UTF-8")
# check if QuantAvgPool2d node has datalayout set correctly
node = model_transformed.graph.node[1]
d_layout = get_by_name(node.attribute, "data_layout").s.decode("UTF-8")
assert d_layout == "NHWC"
assert model_transformed.get_tensor_layout(node.input[0]) == DataLayout.NHWC
assert model_transformed.get_tensor_layout(node.output[0]) == DataLayout.NHWC
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment