From 4dad79922b3ca6d1e0bf8a1f5c1affeb44dbd360 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 19 Feb 2020 23:31:41 +0000
Subject: [PATCH] [Test] parametrize npy2aptintstream, test for INT32

---
 tests/util/test_data_packing.py | 25 +++++++------------------
 1 file changed, 7 insertions(+), 18 deletions(-)

diff --git a/tests/util/test_data_packing.py b/tests/util/test_data_packing.py
index 2c175953e..d96490497 100644
--- a/tests/util/test_data_packing.py
+++ b/tests/util/test_data_packing.py
@@ -2,6 +2,8 @@ import os
 import shutil
 import subprocess
 
+import pytest
+
 import numpy as np
 
 import finn.util.basic as cutil
@@ -15,7 +17,10 @@ from finn.util.data_packing import (
 )
 
 
-def make_npy2apintstream_testcase(ndarray, dtype):
+@pytest.mark.parametrize("dtype", [DataType.BINARY, DataType.INT2, DataType.INT32])
+@pytest.mark.parametrize("test_shape", [(1, 2, 4), (1, 1, 64), (2, 64)])
+def test_npy2apintstream(test_shape, dtype):
+    ndarray = cutil.gen_finn_dt_tensor(dtype, test_shape)
     test_dir = cutil.make_build_dir(prefix="test_npy2apintstream_")
     shape = ndarray.shape
     elem_bits = dtype.bitwidth()
@@ -39,6 +44,7 @@ def make_npy2apintstream_testcase(ndarray, dtype):
     shape_cpp_str = str(shape).replace("(", "{").replace(")", "}")
     test_app_string = []
     test_app_string += ["#include <cstddef>"]
+    test_app_string += ["#define AP_INT_MAX_W 4096"]
     test_app_string += ['#include "ap_int.h"']
     test_app_string += ['#include "stdint.h"']
     test_app_string += ['#include "hls_stream.h"']
@@ -84,23 +90,6 @@ g++ -o test_npy2apintstream test.cpp /workspace/cnpy/cnpy.cpp \
     assert success
 
 
-test_shapes = [(1, 2, 4), (1, 1, 64), (2, 64)]
-
-
-def test_npy2apintstream_binary():
-    for test_shape in test_shapes:
-        dt = DataType.BINARY
-        W = cutil.gen_finn_dt_tensor(dt, test_shape)
-        make_npy2apintstream_testcase(W, dt)
-
-
-def test_npy2apintstream_int2():
-    for test_shape in test_shapes:
-        dt = DataType.INT2
-        W = cutil.gen_finn_dt_tensor(dt, test_shape)
-        make_npy2apintstream_testcase(W, dt)
-
-
 def test_array2hexstring():
     assert array2hexstring([1, 1, 1, 0], DataType.BINARY, 4) == "0xe"
     assert array2hexstring([1, 1, 1, 0], DataType.BINARY, 8) == "0x0e"
-- 
GitLab