From 4dad79922b3ca6d1e0bf8a1f5c1affeb44dbd360 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Wed, 19 Feb 2020 23:31:41 +0000 Subject: [PATCH] [Test] parametrize npy2aptintstream, test for INT32 --- tests/util/test_data_packing.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/tests/util/test_data_packing.py b/tests/util/test_data_packing.py index 2c175953e..d96490497 100644 --- a/tests/util/test_data_packing.py +++ b/tests/util/test_data_packing.py @@ -2,6 +2,8 @@ import os import shutil import subprocess +import pytest + import numpy as np import finn.util.basic as cutil @@ -15,7 +17,10 @@ from finn.util.data_packing import ( ) -def make_npy2apintstream_testcase(ndarray, dtype): +@pytest.mark.parametrize("dtype", [DataType.BINARY, DataType.INT2, DataType.INT32]) +@pytest.mark.parametrize("test_shape", [(1, 2, 4), (1, 1, 64), (2, 64)]) +def test_npy2apintstream(test_shape, dtype): + ndarray = cutil.gen_finn_dt_tensor(dtype, test_shape) test_dir = cutil.make_build_dir(prefix="test_npy2apintstream_") shape = ndarray.shape elem_bits = dtype.bitwidth() @@ -39,6 +44,7 @@ def make_npy2apintstream_testcase(ndarray, dtype): shape_cpp_str = str(shape).replace("(", "{").replace(")", "}") test_app_string = [] test_app_string += ["#include <cstddef>"] + test_app_string += ["#define AP_INT_MAX_W 4096"] test_app_string += ['#include "ap_int.h"'] test_app_string += ['#include "stdint.h"'] test_app_string += ['#include "hls_stream.h"'] @@ -84,23 +90,6 @@ g++ -o test_npy2apintstream test.cpp /workspace/cnpy/cnpy.cpp \ assert success -test_shapes = [(1, 2, 4), (1, 1, 64), (2, 64)] - - -def test_npy2apintstream_binary(): - for test_shape in test_shapes: - dt = DataType.BINARY - W = cutil.gen_finn_dt_tensor(dt, test_shape) - make_npy2apintstream_testcase(W, dt) - - -def test_npy2apintstream_int2(): - for test_shape in test_shapes: - dt = DataType.INT2 - W = cutil.gen_finn_dt_tensor(dt, test_shape) - make_npy2apintstream_testcase(W, dt) - - def test_array2hexstring(): assert array2hexstring([1, 1, 1, 0], DataType.BINARY, 4) == "0xe" assert array2hexstring([1, 1, 1, 0], DataType.BINARY, 8) == "0x0e" -- GitLab