From d8ea5e149d546ce2117e754514f0edacfa443e22 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Tue, 26 Nov 2019 18:06:33 +0000
Subject: [PATCH] [Tensor generator] Added option (+test) for int2 data type

---
 src/finn/core/utils.py           |  2 ++
 tests/test_gen_FINN_dt_tensor.py | 13 +++++++++++++
 2 files changed, 15 insertions(+)

diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py
index 6f18fe149..30c9ba031 100644
--- a/src/finn/core/utils.py
+++ b/src/finn/core/utils.py
@@ -184,6 +184,8 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape):
         tensor_values = np.random.randint(2, size=tensor_shape)
     elif FINN_dt == DataType.TERNARY:
         tensor_values = np.random.randint(-1, high=1, size=tensor_shape)
+    elif FINN_dt == DataType.INT2:
+        tensor_values = np.random.randint(-2, high=1, size=tensor_shape)
     else:
         raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt))
     
diff --git a/tests/test_gen_FINN_dt_tensor.py b/tests/test_gen_FINN_dt_tensor.py
index 0cb31a169..ac2978380 100644
--- a/tests/test_gen_FINN_dt_tensor.py
+++ b/tests/test_gen_FINN_dt_tensor.py
@@ -43,6 +43,19 @@ def test_FINN_tensor_generator():
         assert dt_t.allowed(value), """Data type of generated tensor
             does not match the desired Data type"""
 
+    # int2 
+    shape_int2 = [7,4]
+    dt_int2 = DataType.INT2
+    tensor_int2 = util.gen_FINN_dt_tensor(dt_int2, shape_int2)
+    # test shape
+    for i in range(len(shape_int2)):
+        assert shape_int2[i] == tensor_int2.shape[i], """Shape of generated tensor
+            does not match the desired shape"""
+    # test if elements are FINN datatype
+    for value in tensor_int2.flatten():
+        assert dt_int2.allowed(value), """Data type of generated tensor
+            does not match the desired Data type"""
+
 
     #import pdb; pdb.set_trace()
     
-- 
GitLab