diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py index 8fc1a221d200ebcfe4cca2fdb03645349874c4ef..29773f297364a1fa299fcf4e618c1c9875793ecd 100644 --- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py +++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py @@ -26,6 +26,9 @@ class StreamingFCLayer_Batch(HLSCustomOp): "inputDataType": ("s", True, ""), "weightDataType": ("s", True, ""), "outputDataType": ("s", True, ""), + # use xnor-popcount for binary weights/inputs, thus treating them + # as bipolar + "binaryXnorMode": ("i", False, 0), } my_attrs.update(super().get_nodeattr_types()) return my_attrs @@ -60,11 +63,15 @@ class StreamingFCLayer_Batch(HLSCustomOp): inp_is_binary = self.get_input_datatype() == DataType.BINARY out_is_binary = self.get_output_datatype() == DataType.BINARY wt_is_binary = self.get_weight_datatype() == DataType.BINARY - if inp_is_binary or wt_is_binary or out_is_binary: + bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1 + if (inp_is_binary or wt_is_binary) and (not bin_xnor_mode): raise Exception("True binary (non-bipolar) inputs not yet supported") inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR out_is_bipolar = self.get_output_datatype() == DataType.BIPOLAR wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR + # reinterpret inp/wt as bipolar if bin_xnor_mode is iset + inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode) + wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode) # fill in TSrcI and TWeightI # TODO check these with Giulio # TODO handle non-bipolar binary inputs @@ -81,7 +88,7 @@ class StreamingFCLayer_Batch(HLSCustomOp): ret["TSrcI"] = "Slice<%s>" % inp_hls_str ret["TWeightI"] = "Identity" # fill in TDstI - if out_is_bipolar: + if out_is_bipolar or out_is_binary: ret["TDstI"] = "Identity" else: ret["TDstI"] = "Slice<%s>" % out_hls_str @@ -134,6 +141,12 @@ class StreamingFCLayer_Batch(HLSCustomOp): n_thres_steps = orig_thres_matrix.shape[1] inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR + # reinterpret inp/wt as bipolar if bin_xnor_mode is iset + inp_is_binary = self.get_input_datatype() == DataType.BINARY + wt_is_binary = self.get_weight_datatype() == DataType.BINARY + bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1 + inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode) + wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode) if inp_is_bipolar and wt_is_bipolar: assert (orig_thres_matrix >= 0).all() ret = orig_thres_matrix @@ -193,6 +206,12 @@ class StreamingFCLayer_Batch(HLSCustomOp): # use UINT32 threshold export for bipolar times bipolar inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR + # reinterpret inp/wt as bipolar if bin_xnor_mode is iset + inp_is_binary = self.get_input_datatype() == DataType.BINARY + wt_is_binary = self.get_weight_datatype() == DataType.BINARY + bin_xnor_mode = self.get_nodeattr("binaryXnorMode") == 1 + inp_is_bipolar = inp_is_bipolar or (inp_is_binary and bin_xnor_mode) + wt_is_bipolar = wt_is_bipolar or (wt_is_binary and bin_xnor_mode) if inp_is_bipolar and wt_is_bipolar: tdt = DataType.UINT32 thresholds_hls_code = numpy_to_hls_code( diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py index a24886c6cca216594523294e21fca76bed09e3c7..afc2a2010fa41ba9431acbd03c18a6080cb296a0 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py @@ -76,6 +76,7 @@ class InferBinaryStreamingFCLayer(Transformation): weightDataType=wdt.name, outputDataType=odt.name, ActVal=actval, + binaryXnorMode=1, ) graph.node.insert(node_ind, new_node) # remove old nodes @@ -107,6 +108,7 @@ class InferBinaryStreamingFCLayer(Transformation): weightDataType=wdt.name, outputDataType=odt.name, ActVal=0, + binaryXnorMode=1, ) graph.node.insert(node_ind, new_node) # remove old node