diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index 6f18fe14903a0682cda75320f9ce6662f279d835..30c9ba031b5542857a83a28193eb8e582fd790a8 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -184,6 +184,8 @@ def gen_FINN_dt_tensor(FINN_dt, tensor_shape): tensor_values = np.random.randint(2, size=tensor_shape) elif FINN_dt == DataType.TERNARY: tensor_values = np.random.randint(-1, high=1, size=tensor_shape) + elif FINN_dt == DataType.INT2: + tensor_values = np.random.randint(-2, high=1, size=tensor_shape) else: raise ValueError("Datatype {} is not supported, no tensor could be generated".format(FINN_dt)) diff --git a/tests/test_gen_FINN_dt_tensor.py b/tests/test_gen_FINN_dt_tensor.py index 0cb31a1698ba7774f69d0dcd1d9110421c5d252d..ac29783807533adbb9aff7af1cc26833ef5aec25 100644 --- a/tests/test_gen_FINN_dt_tensor.py +++ b/tests/test_gen_FINN_dt_tensor.py @@ -43,6 +43,19 @@ def test_FINN_tensor_generator(): assert dt_t.allowed(value), """Data type of generated tensor does not match the desired Data type""" + # int2 + shape_int2 = [7,4] + dt_int2 = DataType.INT2 + tensor_int2 = util.gen_FINN_dt_tensor(dt_int2, shape_int2) + # test shape + for i in range(len(shape_int2)): + assert shape_int2[i] == tensor_int2.shape[i], """Shape of generated tensor + does not match the desired shape""" + # test if elements are FINN datatype + for value in tensor_int2.flatten(): + assert dt_int2.allowed(value), """Data type of generated tensor + does not match the desired Data type""" + #import pdb; pdb.set_trace()