Skip to content
Snippets Groups Projects
Commit 024003b4 authored by auphelia's avatar auphelia
Browse files

[Test] Extend test lowering of dw conv

parent f9443cb1
No related branches found
No related tags found
No related merge requests found
......@@ -26,6 +26,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import os
from onnx import helper, TensorProto
import pkg_resources as pk
......@@ -72,14 +73,22 @@ def test_conv_lowering_cnv_w1a1():
os.remove(export_onnx_path)
def test_depthwise_conv_lowering():
idt = odt = wdt = DataType.INT4
k = 3
stride = 1
ifm_dim = 4
ifm_ch = 2
# input datatype
@pytest.mark.parametrize("idt", [DataType.INT2, DataType.INT4])
# kernel size
@pytest.mark.parametrize("k", [2, 4])
# input dimension
@pytest.mark.parametrize("ifm_dim", [4, 6])
# input channels
@pytest.mark.parametrize("ifm_ch", [2, 3])
# stride
@pytest.mark.parametrize("stride", [1, 2])
# padding
@pytest.mark.parametrize("padding", [[0, 0, 0, 0], [1, 1, 1, 1]])
def test_depthwise_conv_lowering(idt, k, ifm_dim, ifm_ch, stride, padding):
odt = wdt = idt
ofm_ch = ifm_ch
ofm_dim = compute_conv_output_dim(ifm_dim, k, stride)
ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad=padding[0])
# set up onnx model
inp = helper.make_tensor_value_info(
......@@ -96,7 +105,7 @@ def test_depthwise_conv_lowering():
inputs=["inp", "W"],
outputs=["outp"],
kernel_shape=[k, k],
pads=[0, 0, 0, 0],
pads=padding,
strides=[stride, stride],
group=ifm_ch,
)
......@@ -115,6 +124,7 @@ def test_depthwise_conv_lowering():
model.set_tensor_datatype("W", wdt)
w_tensor = gen_finn_dt_tensor(wdt, [ofm_ch, 1, k, k])
model.set_initializer("W", w_tensor)
model = model.transform(InferShapes())
input_tensor = gen_finn_dt_tensor(idt, [1, ifm_ch, ifm_dim, ifm_dim])
input_dict = {"inp": input_tensor}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment