From 01a68e4b9aecdd1be298d33dca18c2caf67fde50 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Sun, 10 Oct 2021 22:50:01 +0200
Subject: [PATCH] [Pack, Test] fix npy<>stream packing for fixed pt, add test

---
 src/finn/qnn-data/cpp/npy2apintstream.hpp | 9 +++++++--
 tests/util/test_data_packing_hls.py       | 3 ++-
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/src/finn/qnn-data/cpp/npy2apintstream.hpp b/src/finn/qnn-data/cpp/npy2apintstream.hpp
index f3afbc5bf..bb17f11c4 100644
--- a/src/finn/qnn-data/cpp/npy2apintstream.hpp
+++ b/src/finn/qnn-data/cpp/npy2apintstream.hpp
@@ -3,6 +3,7 @@
 #include "hls_stream.h"
 #include "ap_int.h"
 #include <vector>
+#include <stdio.h>
 
 #ifdef DEBUG
 #define DEBUG_NPY2APINTSTREAM(x) std::cout << "[npy2apintstream] " << x << std::endl;
@@ -34,7 +35,7 @@ void npy2apintstream(const char * npy_path, hls::stream<PackedT> & out_stream, b
         NpyT loaded_elem_npyt = *loaded_data;
         ElemT loaded_elem = (ElemT) loaded_elem_npyt;
         DEBUG_NPY2APINTSTREAM("NpyT " << loaded_elem_npyt << " elem " << loaded_elem)
-        packed_elem((i+1)*ElemBits-1, i*ElemBits) = loaded_elem;
+        packed_elem((i+1)*ElemBits-1, i*ElemBits) = *reinterpret_cast<ap_uint<ElemBits>*>(&loaded_elem);
         loaded_data++;
       }
       DEBUG_NPY2APINTSTREAM("packed hls elem " << std::hex << packed_elem << std::dec)
@@ -59,7 +60,11 @@ void apintstream2npy(hls::stream<PackedT> & in_stream, const std::vector<size_t>
       DEBUG_APINTSTREAM2NPY("packed hls elem " << std::hex << packed_elem << std::dec)
       for(size_t ii = 0; ii < inner_dim_elems; ii++) {
         size_t i = reverse_inner ? inner_dim_elems-ii-1 : ii;
-        ElemT elem = packed_elem((i+1)*ElemBits-1, i*ElemBits);
+        ap_uint<ElemBits> tmp_elem = packed_elem((i+1)*ElemBits-1, i*ElemBits);
+        // important: don't init elem = reinterpret_cast.. directly here
+        // this causes weird behavior for conversion to NpyT afterwards
+        ElemT elem;
+        elem = reinterpret_cast<ElemT&>(tmp_elem);
         NpyT npyt = (NpyT) elem;
         DEBUG_APINTSTREAM2NPY("elem " << elem << " NpyT " << npyt)
         data_to_save.push_back(npyt);
diff --git a/tests/util/test_data_packing_hls.py b/tests/util/test_data_packing_hls.py
index 9c47bb293..897d9df96 100644
--- a/tests/util/test_data_packing_hls.py
+++ b/tests/util/test_data_packing_hls.py
@@ -39,7 +39,8 @@ from finn.util.data_packing import numpy_to_hls_code
 
 
 @pytest.mark.parametrize(
-    "dtype", [DataType["BINARY"], DataType["INT2"], DataType["INT32"]]
+    "dtype",
+    [DataType["BINARY"], DataType["INT2"], DataType["INT32"], DataType["FIXED<9,6>"]],
 )
 @pytest.mark.parametrize("test_shape", [(1, 2, 4), (1, 1, 64), (2, 64)])
 @pytest.mark.vivado
-- 
GitLab