Skip to content
Snippets Groups Projects
Commit 64e0658d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Util] fixes and doc for numpy_to_hls_code

parent bd57c1de
No related branches found
No related tags found
No related merge requests found
import numpy as np
import sys
from finn.core.utils import pack_innermost_dim_as_hex_string
import numpy as np
from finn.core.datatype import DataType
from finn.core.utils import pack_innermost_dim_as_hex_string
def numpy_to_hls_code(ndarray, dtype, hls_var_name, pack_innermost_dim=True):
"Return C++ code representation of a numpy ndarray."
"""Return C++ code representation of a numpy ndarray with FINN DataType
dtype, using hls_var_name as the resulting C++ variable name. If
pack_innermost_dim is specified, the innermost dimension of the ndarray
will be packed into a hex string using array2hexstring.
"""
hls_dtype = dtype.get_hls_datatype_str()
if pack_innermost_dim:
idimlen = ndarray.shape[-1]
ndarray = pack_innermost_dim_as_hex_string(ndarray, dtype, idimlen)
hls_dtype = "ap_uint<%d>" % (idimlen * dtype.bitwidth())
if type(ndarray) != np.ndarray or ndarray.dtype != np.float32:
# try to convert to a float numpy array (container dtype is float)
ndarray = np.asarray(ndarray, dtype=np.float32)
if pack_innermost_dim:
idimlen = ndarray.shape[-1]
idimbits = idimlen * dtype.bitwidth()
ndarray = pack_innermost_dim_as_hex_string(ndarray, dtype, idimbits)
hls_dtype = "ap_uint<%d>" % idimbits
ndims = ndarray.ndim
# add type string and variable name
# e.g. "const ap_uint<64>" "weightMem0"
......@@ -22,11 +30,12 @@ def numpy_to_hls_code(ndarray, dtype, hls_var_name, pack_innermost_dim=True):
ret += "[%d]" % ndarray.shape[d]
orig_printops = np.get_printoptions()
np.set_printoptions(threshold=sys.maxsize)
# define a function to convert a single element into a C++ init string
# a single element can be a hex string if we are using packing
def elem2str(x):
if type(x) == str:
return "%s(%s, 16)" % (hls_dtype, x)
if type(x) == str or type(x) == np.str_ or type(x) == np.str:
return '%s("%s", 16)' % (hls_dtype, x)
elif type(x) == np.float32:
if dtype == DataType.FLOAT32:
return str(x)
......@@ -34,7 +43,8 @@ def numpy_to_hls_code(ndarray, dtype, hls_var_name, pack_innermost_dim=True):
return str(int(x))
else:
raise Exception("Unsupported type for numpy_to_hls_code")
strarr = np.array2string(ndarray, separator=", ", formatter={'all': elem2str})
strarr = np.array2string(ndarray, separator=", ", formatter={"all": elem2str})
np.set_printoptions(**orig_printops)
strarr = strarr.replace("[", "{").replace("]", "}")
ret = ret + " = \n" + strarr + ";"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment