From 0d544550917b355cc51044dc8c45e2641909f979 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 20 Feb 2020 00:16:53 +0000
Subject: [PATCH] [Util] add inner dim reversal option to packing fxns and
 tests

---
 src/finn/util/data_packing.py   | 18 ++++++++++++------
 tests/util/test_data_packing.py |  7 +++++++
 2 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/src/finn/util/data_packing.py b/src/finn/util/data_packing.py
index a6664045d..9f4367c78 100644
--- a/src/finn/util/data_packing.py
+++ b/src/finn/util/data_packing.py
@@ -9,7 +9,7 @@ from finn.core.datatype import DataType
 from finn.util.basic import roundup_to_integer_multiple
 
 
-def array2hexstring(array, dtype, pad_to_nbits, prefix="0x"):
+def array2hexstring(array, dtype, pad_to_nbits, prefix="0x", reverse=False):
     """
     Pack given one-dimensional NumPy array with FINN DataType dtype into a hex
     string.
@@ -17,11 +17,14 @@ def array2hexstring(array, dtype, pad_to_nbits, prefix="0x"):
     -1.
     pad_to_nbits is used to prepend leading zeros to ensure packed strings of
     fixed width. The minimum value for pad_to_nbits is 4, since a single hex
-    digit is four bits.
+    digit is four bits. reverse can be used to reverse the array prior to
+    packing.
 
     Examples:
-    array2hexstring([1, 1, 1, 0], DataType.BINARY, 4) = "e"
-    array2hexstring([1, 1, 1, 0], DataType.BINARY, 8) = "0e"
+    array2hexstring([1, 1, 1, 0], DataType.BINARY, 4) = "0xe"
+    array2hexstring([1, 1, 1, 0], DataType.BINARY, 8) = "0x0e"
+    array2hexstring([1, 1, 0, 1], DataType.BINARY, 4, reverse=True) = "0xb"
+    array2hexstring([1, 1, 1, 0], DataType.BINARY, 8, reverse=True) = "0x07"
     """
     if pad_to_nbits < 4:
         pad_to_nbits = 4
@@ -35,6 +38,9 @@ def array2hexstring(array, dtype, pad_to_nbits, prefix="0x"):
         # convert bipolar values to binary
         array = (array + 1) / 2
         dtype = DataType.BINARY
+    # reverse prior to packing, if desired
+    if reverse:
+        array = np.flip(array, -1)
     lineval = BitArray(length=0)
     bw = dtype.bitwidth()
     for val in array:
@@ -77,7 +83,7 @@ def npbytearray2hexstring(npbytearray, prefix="0x"):
     return prefix + binascii.hexlify(bytearray(npbytearray)).decode("utf-8")
 
 
-def pack_innermost_dim_as_hex_string(ndarray, dtype, pad_to_nbits):
+def pack_innermost_dim_as_hex_string(ndarray, dtype, pad_to_nbits, reverse_inner=False):
     """Pack the innermost dimension of the given numpy ndarray into hex
     strings using array2hexstring. Examples:
 
@@ -94,7 +100,7 @@ def pack_innermost_dim_as_hex_string(ndarray, dtype, pad_to_nbits):
         ndarray = np.asarray(ndarray, dtype=np.float32)
 
     def fun(x):
-        return array2hexstring(x, dtype, pad_to_nbits)
+        return array2hexstring(x, dtype, pad_to_nbits, reverse=reverse_inner)
 
     return np.apply_along_axis(fun, ndarray.ndim - 1, ndarray)
 
diff --git a/tests/util/test_data_packing.py b/tests/util/test_data_packing.py
index d96490497..20db7af49 100644
--- a/tests/util/test_data_packing.py
+++ b/tests/util/test_data_packing.py
@@ -100,6 +100,8 @@ def test_array2hexstring():
     assert array2hexstring([1, 1, 1, -1], DataType.INT4, 16) == "0x111f"
     assert array2hexstring([-1], DataType.FLOAT32, 32) == "0xbf800000"
     assert array2hexstring([17.125], DataType.FLOAT32, 32) == "0x41890000"
+    assert array2hexstring([1, 1, 0, 1], DataType.BINARY, 4, reverse=True) == "0xb"
+    assert array2hexstring([1, 1, 1, 0], DataType.BINARY, 8, reverse=True) == "0x07"
 
 
 def test_pack_innermost_dim_as_hex_string():
@@ -109,6 +111,11 @@ def test_pack_innermost_dim_as_hex_string():
     B = [[[3, 3], [3, 3]], [[1, 3], [3, 1]]]
     eB = np.asarray([["0x0f", "0x0f"], ["0x07", "0x0d"]])
     assert (pack_innermost_dim_as_hex_string(B, DataType.UINT2, 8) == eB).all()
+    C = [[[3, 3], [3, 3]], [[1, 3], [3, 1]]]
+    eC = np.asarray([["0x0f", "0x0f"], ["0x0d", "0x07"]])
+    assert (
+        pack_innermost_dim_as_hex_string(C, DataType.UINT2, 8, reverse_inner=True) == eC
+    ).all()
 
 
 def test_numpy_to_hls_code():
-- 
GitLab