diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index 05e3adf1f0c2643c21a1185c7a3c681ba28fd845..6c7a1c82a6b9ae16c520b74f072ee2bdf69b81e1 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -183,11 +183,9 @@ def gen_finn_dt_tensor(finn_dt, tensor_shape): tensor_values = 2 * tensor_values - 1 elif finn_dt == DataType.BINARY: tensor_values = np.random.randint(2, size=tensor_shape) - elif finn_dt == DataType.TERNARY: - tensor_values = np.random.randint(-1, high=1, size=tensor_shape) - elif "INT" in finn_dt.name: + elif "INT" in finn_dt.name or finn_dt == DataType.TERNARY: tensor_values = np.random.randint( - finn_dt.min(), high=finn_dt.max(), size=tensor_shape + finn_dt.min(), high=finn_dt.max() + 1, size=tensor_shape ) else: raise ValueError(