Skip to content
Snippets Groups Projects
Commit 8e8d19e9 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Util] add interleave_matrix_outer_dim_from_partitions

parent 0a894588
No related branches found
No related tags found
No related merge requests found
......@@ -107,3 +107,19 @@ def pack_innermost_dim_as_hex_string(ndarray, dtype, pad_to_nbits):
return array2hexstring(x, dtype, pad_to_nbits)
return np.apply_along_axis(fun, ndarray.ndim - 1, ndarray)
def interleave_matrix_outer_dim_from_partitions(matrix, n_partitions):
if type(matrix) != np.ndarray or matrix.dtype != np.float32:
# try to convert to a float numpy array (container dtype is float)
matrix = np.asarray(matrix, dtype=np.float32)
shp = matrix.shape
ndim = matrix.ndim
# ensure # partitions evenly divide the outermost dimension
assert shp[0] % n_partitions == 0
# only tested for matrices
assert ndim == 2
# interleave rows between PEs using reshape + transpose
matrix_r = matrix.reshape(-1, n_partitions, shp[1]).transpose((1, 0, 2))
matrix_r = matrix_r.reshape(n_partitions, -1, shp[1])
return matrix_r
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment