From 05f5598b8ad0224762ecfa0ce64f533378fe1308 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 6 Feb 2020 23:56:04 +0100
Subject: [PATCH] [Util] fix padding, handle 1D arrays to
 finnpy_to_packed_bytearray

---
 src/finn/util/data_packing.py | 24 +++++++++++++++++++-----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/src/finn/util/data_packing.py b/src/finn/util/data_packing.py
index db60e4dba..b0b61c4fd 100644
--- a/src/finn/util/data_packing.py
+++ b/src/finn/util/data_packing.py
@@ -5,6 +5,7 @@ import numpy as np
 from bitstring import BitArray
 
 from finn.core.datatype import DataType
+from finn.util.basic import roundup_to_integer_multiple
 
 
 def array2hexstring(array, dtype, pad_to_nbits, prefix="0x"):
@@ -231,22 +232,35 @@ def finnpy_to_packed_bytearray(ndarray, dtype):
     """Given a numpy ndarray with FINN DataType dtype, pack the innermost
     dimension and return the packed representation as an ndarray of uint8.
     The packed innermost dimension will be padded to the nearest multiple
-    of 8 bits.
+    of 8 bits. The returned ndarray has the same number of dimensions as the
+    input.
     """
-
+    if type(ndarray) != np.ndarray or ndarray.dtype != np.float32:
+        # try to convert to a float numpy array (container dtype is float)
+        ndarray = np.asarray(ndarray, dtype=np.float32)
     # pack innermost dim to hex strings padded to 8 bits
-    packed_hexstring = pack_innermost_dim_as_hex_string(ndarray, dtype, 8)
+    bits = dtype.bitwidth() * ndarray.shape[-1]
+    bits_padded = roundup_to_integer_multiple(bits, 8)
+    packed_hexstring = pack_innermost_dim_as_hex_string(ndarray, dtype, bits_padded)
 
-    # convert hex strings to byte array
     def fn(x):
         return np.asarray(list(map(hexstring2npbytearray, x)))
 
-    return np.apply_along_axis(fn, packed_hexstring.ndim - 1, packed_hexstring)
+    if packed_hexstring.ndim == 0:
+        # scalar, call hexstring2npbytearray directly
+        return hexstring2npbytearray(np.asscalar(packed_hexstring))
+    else:
+        # convert ndarray of hex strings to byte array
+        return np.apply_along_axis(fn, packed_hexstring.ndim - 1, packed_hexstring)
 
 
 def packed_bytearray_to_finnpy(packed_bytearray, dtype):
     """Given a packed numpy uint8 ndarray, unpack it into a FINN array of
     given DataType."""
+    if type(packed_bytearray) != np.ndarray or packed_bytearray.dtype != np.uint8:
+        raise Exception("packed_bytearray_to_finnpy needs NumPy uint8 arrays")
+    if packed_bytearray.ndim == 0:
+        raise Exception("packed_bytearray_to_finnpy expects at least 1D ndarray")
     # TODO how to handle un-padding here?
     packed_dim = packed_bytearray.ndim - 1
     packed_bits = packed_bytearray.shape[packed_dim] * 8
-- 
GitLab