Skip to content
Snippets Groups Projects
Commit 3dc8cdf2 authored by auphelia's avatar auphelia
Browse files

[Test] Update test for MoveFlattenPastAffine to use data layout

parent 88d96953
No related branches found
No related tags found
No related merge requests found
......@@ -25,27 +25,40 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import numpy as np
from onnx import TensorProto, helper
from finn.core.modelwrapper import ModelWrapper
from finn.core.datatype import DataType
import finn.core.data_layout as DataLayout
from finn.util.basic import gen_finn_dt_tensor
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.general import GiveUniqueNodeNames, GiveReadableTensorNames
from finn.transformation.streamline.reorder import MoveFlattenPastAffine
import finn.core.onnx_exec as oxe
# data layout
@pytest.mark.parametrize("data_layout", [DataLayout.NHWC, DataLayout.NCHW])
# batch size
@pytest.mark.parametrize("batch_size", [1, 2])
def test_move_flatten_past_affine(data_layout, batch_size):
if data_layout == DataLayout.NHWC:
ishape = [batch_size, 1, 1, 1024]
oshape = [batch_size, 1000]
else:
ishape = [batch_size, 1024, 1, 1]
oshape = [batch_size, 1000]
def test_move_flatten_past_affine():
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 1, 1, 1024])
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
a0 = helper.make_tensor_value_info("a1", TensorProto.FLOAT, [1024, 1000])
a1 = helper.make_tensor_value_info("a2", TensorProto.FLOAT, [])
a2 = helper.make_tensor_value_info("a3", TensorProto.FLOAT, [1000])
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, oshape)
outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 1000])
flatten_node = helper.make_node("Flatten", ["inp"], ["flatten_out"])
matmul_node = helper.make_node("MatMul", ["flatten_out", "a0"], ["matmul_out"])
mul_node = helper.make_node("Mul", ["matmul_out", "a1"], ["mul_out"])
......@@ -71,18 +84,23 @@ def test_move_flatten_past_affine():
model.set_initializer("a2", a2_values)
model.set_tensor_datatype("inp", DataType.INT2)
model.set_tensor_datatype("flatten_out", DataType.INT2)
model.set_tensor_layout("inp", data_layout)
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model = model.transform(InferDataLayouts())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
# compare execution before and after transformation
inp_values = gen_finn_dt_tensor(DataType.INT2, [1, 1, 1, 1024])
inp_values = gen_finn_dt_tensor(DataType.INT2, ishape)
idict = {"inp": inp_values}
model_transformed = model.transform(MoveFlattenPastAffine())
assert oxe.compare_execution(model, model_transformed, idict)
# check if nodes have new order in transformed graph
assert model.graph != model_transformed.graph
assert model_transformed.graph.node[-1].op_type == "Flatten"
# depending on data layout check if graph is transformed or not
if data_layout == DataLayout.NHWC:
# check if nodes have new order in transformed graph
assert model.graph != model_transformed.graph
assert model_transformed.graph.node[-1].op_type == "Flatten"
else:
assert model.graph == model_transformed.graph
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment