From 5a6c8a48e5b04e675a19b843c06f97c7fb032ad1 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sun, 3 Nov 2019 17:49:14 +0000 Subject: [PATCH] [Test] add test_const_folding_shapes --- tests/test_fold_constants.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_fold_constants.py b/tests/test_fold_constants.py index ed6bb3815..894887b0c 100644 --- a/tests/test_fold_constants.py +++ b/tests/test_fold_constants.py @@ -1,14 +1,19 @@ +import os from pkgutil import get_data +import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as np_helper +from models.LFC import LFC import finn.core.onnx_exec as oxe import finn.transformation.fold_constants as fc import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +export_onnx_path = "test_output_lfc.onnx" + def test_const_folding(): raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") @@ -24,3 +29,15 @@ def test_const_folding(): assert np.isclose( np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3 ).all() + + +def test_const_folding_shapes(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform_single(si.infer_shapes) + model = model.transform_repeated(fc.fold_constants) + assert model.graph.node[0].op_type == "Reshape" + assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28] + assert list(model.get_tensor_shape("27")) == [1, 784] + os.remove(export_onnx_path) -- GitLab