Skip to content
Snippets Groups Projects
Commit 7f9c0dc6 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[StreamingFC] add i/w/o dtype attributes, infer HLS dtypes

parent bdc91582
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,10 @@ class StreamingFCLayer_Batch(HLSCustomOp):
"MH": ("i", True, 0),
"resType": ("s", True, ""),
"resDataType": ("s", True, ""),
# FINN DataTypes for inputs, weights, outputs
"inputDataType": ("s", True, ""),
"weightDataType": ("s", True, ""),
"outputDataType": ("s", True, ""),
}
def make_shape_compatible_op(self):
......@@ -31,6 +35,48 @@ class StreamingFCLayer_Batch(HLSCustomOp):
def infer_node_datatype(self, model):
pass
def get_input_datatype(self):
return DataType[self.get_nodeattr("inputDataType")]
def get_weight_datatype(self):
return DataType[self.get_nodeattr("weightDataType")]
def get_output_datatype(self):
return DataType[self.get_nodeattr("outputDataType")]
def get_template_param_values(self):
ret = dict()
inp_hls_str = self.get_input_datatype().get_hls_datatype_str()
wt_hls_str = self.get_weight_datatype().get_hls_datatype_str()
out_hls_str = self.get_output_datatype().get_hls_datatype_str()
inp_is_binary = self.get_input_datatype() == DataType.BINARY
out_is_binary = self.get_output_datatype() == DataType.BINARY
wt_is_binary = self.get_weight_datatype() == DataType.BINARY
if inp_is_binary or wt_is_binary or out_is_binary:
raise Exception("True binary (non-bipolar) inputs not yet supported")
inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
out_is_bipolar = self.get_output_datatype() == DataType.BIPOLAR
wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
# fill in TSrcI and TWeightI
if inp_is_bipolar and wt_is_bipolar:
ret["TSrcI"] = "Recast<XnorMul>"
ret["TWeightI"] = "Identity"
elif (not inp_is_bipolar) and wt_is_bipolar:
ret["TSrcI"] = "Slice<%s>" % inp_hls_str
ret["TWeightI"] = "Recast<Binary>"
elif inp_is_bipolar and (not wt_is_bipolar):
ret["TSrcI"] = "Recast<Binary>"
ret["TWeightI"] = "Slice<%s>" % wt_hls_str
elif (not inp_is_bipolar) and (not wt_is_bipolar):
ret["TSrcI"] = "Slice<%s>" % inp_hls_str
ret["TWeightI"] = "Slice<%s>" % wt_hls_str
# fill in TDstI
if out_is_bipolar:
ret["TDstI"] = "Identity"
else:
ret["TDstI"] = "Slice<%s>" % out_hls_str
return ret
def execute_node(self, context, graph):
node = self.onnx_node
# make temporary directory for generated files
......@@ -56,7 +102,6 @@ class StreamingFCLayer_Batch(HLSCustomOp):
weights = context[inputs]
# transpose and expand the weights to get the right shape
# for the code generation
self.set_nodeattr("WMEM", weights.shape[1])
weights = np.expand_dims(weights, 0)
weights = numpy_to_hls_code(
weights, DataType.BINARY, "weights", True, True
......@@ -77,7 +122,6 @@ class StreamingFCLayer_Batch(HLSCustomOp):
else:
thresholds = context[inputs]
self.set_nodeattr("TMEM", thresholds.shape[1])
thresholds = np.expand_dims(thresholds, 0)
thresholds = numpy_to_hls_code(
thresholds, DataType.BINARY, "thresholds", True, True
......
......@@ -58,7 +58,12 @@ def make_single_fclayer_modelwrapper(W, pe, simd, wdt, idt, odt, T=None, tdt=Non
MH=mh,
SIMD=simd,
PE=pe,
WMEM=wmem,
TMEM=tmem,
resDataType=rdt,
inputDataType=idt.name,
weightDataType=wdt.name,
outputDataType=odt.name,
)
graph = helper.make_graph(
nodes=[FCLayer_node], name="fclayer_graph", inputs=[inp], outputs=[outp],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment