diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py index ec5d31f637bdc89782ce3c8f76bec84ac9df4354..44a44ae1239afea2e5d80dbf8479228acc61ff5a 100644 --- a/src/finn/util/basic.py +++ b/src/finn/util/basic.py @@ -139,6 +139,8 @@ def pad_tensor_to_multiple_of(ndarray, pad_to_dims, val=0, distr_pad=False): def gen_finn_dt_tensor(finn_dt, tensor_shape): """Generates random tensor in given shape and with given FINN DataType""" + if type(tensor_shape) == list: + tensor_shape = tuple(tensor_shape) if finn_dt == DataType.BIPOLAR: tensor_values = np.random.randint(2, size=tensor_shape) tensor_values = 2 * tensor_values - 1