Bug: Dataset Assumes 2D map structure
Short description of expected behaviour:
The Implementation of BaseDataset assumes 2D map-shap implicitly. A generalization to 3D maps is advisable.
Code Snippet:
original_shape = tf.shape(tensor)
flat_tensor = tf.reshape(tensor, tf.concat(((-1,), original_shape[2:]), axis=0))
return_tuple += (flat_tensor,)
Since the number of axis is not defined per se for parameter maps, it wil be necessary to expand 2D maps to 3D maps as general format and index like original_shape[3:]
Error Stack-trace: