From ec09c481262e864cc19e9b1c582245774137a17a Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Tue, 26 Nov 2019 17:42:33 +0000
Subject: [PATCH] [Test] Added checking for shape and right FINN data type and
 added binary case

---
 src/finn/core/utils.py           |  4 +++-
 tests/test_gen_FINN_dt_tensor.py | 27 ++++++++++++++++++++++++---
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py
index c69bd1b16..2f416c4af 100644
--- a/src/finn/core/utils.py
+++ b/src/finn/core/utils.py
@@ -180,7 +180,9 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape):
     if FINN_dt == DataType.BIPOLAR:
         tensor_values = np.random.randint(2, size=tensor_shape)
         tensor_values = 2 * tensor_values - 1
+    elif FINN_dt == DataType.BINARY:
+        tensor_values = np.random.randint(2, size=tensor_shape)
     else:
         raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt))
     
-    return[tensor_values]
+    return tensor_values
diff --git a/tests/test_gen_FINN_dt_tensor.py b/tests/test_gen_FINN_dt_tensor.py
index 5f4c89ea7..4321498b5 100644
--- a/tests/test_gen_FINN_dt_tensor.py
+++ b/tests/test_gen_FINN_dt_tensor.py
@@ -1,3 +1,4 @@
+import numpy as np
 import finn.core.utils as util
 from finn.core.datatype import DataType
 
@@ -6,8 +7,28 @@ def test_FINN_tensor_generator():
     shape_bp = [2,2]
     dt_bp = DataType.BIPOLAR
     tensor_bp = util.gen_FINN_dt_tensor(dt_bp, shape_bp)
-    import pdb; pdb.set_trace()
-    print(tensor_bp)
-
     # test shape
+    for i in range(len(shape_bp)):
+        assert shape_bp[i] == tensor_bp.shape[i], """Shape of generated tensor 
+            does not match the desired shape"""
+    # test if elements are FINN datatype
+    for value in tensor_bp.flatten():
+        assert dt_bp.allowed(value), """Data type of generated tensor
+            does not match the desired Data type"""
+    
+    # binary
+    shape_b = [4,2,3]
+    dt_b = DataType.BINARY
+    tensor_b = util.gen_FINN_dt_tensor(dt_b, shape_b)
+    # test shape
+    for i in range(len(shape_b)):
+        assert shape_b[i] == tensor_b.shape[i], """Shape of generated tensor
+            does not match the desired shape"""
     # test if elements are FINN datatype
+    for value in tensor_b.flatten():
+        assert dt_b.allowed(value), """Data type of generated tensor
+            does not match the desired Data type"""
+
+
+    #import pdb; pdb.set_trace()
+    
-- 
GitLab