diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index 9333e681fe6e35546e3e513bdca732cd433b1601..f1c9c459f9f4b2a0c600c78191eed6290485ec67 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -46,8 +46,8 @@ def array2hexstring(array, dtype, pad_to_nbits): string. Any BIPOLAR values will be converted to a single bit with a 0 representing -1. - pad_to_bits is used to prepend leading zeros to ensure packed strings of - fixed width. The minimum value for pad_to_bits is 4, since a single hex + pad_to_nbits is used to prepend leading zeros to ensure packed strings of + fixed width. The minimum value for pad_to_nbits is 4, since a single hex digit is four bits. Examples: @@ -84,3 +84,25 @@ def array2hexstring(array, dtype, pad_to_nbits): raise Exception("Number of bits is greater than pad_to_nbits") # represent as hex return lineval.hex + + +def pack_innermost_dim_as_hex_string(ndarray, dtype, pad_to_nbits): + """Pack the innermost dimension of the given numpy ndarray into hex + strings using array2hexstring. Examples: + + A = [[1, 1, 1, 0], [0, 1, 1, 0]] + eA = ["0e", "06"] + pack_innermost_dim_as_hex_string(A, DataType.BINARY, 8) == eA + B = [[[3, 3], [3, 3]], [[1, 3], [3, 1]]] + eB = [[ "0f", "0f"], ["07", "0d"]] + pack_innermost_dim_as_hex_string(B, DataType.UINT2, 8) == eB + """ + + if type(ndarray) != np.ndarray or ndarray.dtype != np.float32: + # try to convert to a float numpy array (container dtype is float) + ndarray = np.asarray(ndarray, dtype=np.float32) + + def fun(x): + return array2hexstring(x, dtype, pad_to_nbits) + + return np.apply_along_axis(fun, ndarray.ndim - 1, ndarray) diff --git a/tests/test_npy2hls.py b/tests/test_npy2hls.py index 39b594c2f08327358727b8912d9e1f8b32806123..e205c06f31e9773a16c3decbd21a314fcbd4b8dc 100644 --- a/tests/test_npy2hls.py +++ b/tests/test_npy2hls.py @@ -1,5 +1,7 @@ +import numpy as np + from finn.core.datatype import DataType -from finn.core.utils import array2hexstring +from finn.core.utils import array2hexstring, pack_innermost_dim_as_hex_string def test_array2hexstring(): @@ -12,3 +14,12 @@ def test_array2hexstring(): assert array2hexstring([1, 1, 1, -1], DataType.INT4, 16) == "111f" assert array2hexstring([-1], DataType.FLOAT32, 32) == "bf800000" assert array2hexstring([17.125], DataType.FLOAT32, 32) == "41890000" + + +def test_pack_innermost_dim_as_hex_string(): + A = [[1, 1, 1, 0], [0, 1, 1, 0]] + eA = np.asarray(["0e", "06"]) + assert (pack_innermost_dim_as_hex_string(A, DataType.BINARY, 8) == eA).all() + B = [[[3, 3], [3, 3]], [[1, 3], [3, 1]]] + eB = np.asarray([["0f", "0f"], ["07", "0d"]]) + assert (pack_innermost_dim_as_hex_string(B, DataType.UINT2, 8) == eB).all()