diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py index 07bdbf1cab6848dc25b7bb556000829617620747..28e46daef136854faec98a726309bc6e9aff108c 100644 --- a/src/finn/custom_op/registry.py +++ b/src/finn/custom_op/registry.py @@ -1,8 +1,10 @@ # make sure new CustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure from finn.custom_op.multithreshold import MultiThreshold +from finn.custom_op.xnorpopcount import XnorPopcountMatMul # create a mapping of all known CustomOp names and classes custom_op = {} custom_op["MultiThreshold"] = MultiThreshold +custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul diff --git a/src/finn/custom_op/xnorpopcount.py b/src/finn/custom_op/xnorpopcount.py new file mode 100644 index 0000000000000000000000000000000000000000..8f78bc267d4a91422b2d44aecf82ae40c5659574 --- /dev/null +++ b/src/finn/custom_op/xnorpopcount.py @@ -0,0 +1,52 @@ +import numpy as np +import onnx.helper as helper + +from finn.core.datatype import DataType +from finn.custom_op import CustomOp + + +class XnorPopcountMatMul(CustomOp): + def make_shape_compatible_op(self, node): + return helper.make_node( + "MatMul", [node.input[0], node.input[1]], [node.output[0]] + ) + + def infer_node_datatype(self, node, model): + # ensure inputs are binary + assert model.get_tensor_datatype(node.input[0]) == DataType["BINARY"] + assert model.get_tensor_datatype(node.input[1]) == DataType["BINARY"] + # XNOR-popcount produces unsigned integers, assume uint32 + model.set_tensor_datatype(node.output[0], DataType["UINT32"]) + + def execute_node(self, node, context, graph): + # save inputs + inp0 = context[node.input[0]] + inp1 = context[node.input[1]] + # calculate output + output = self._execute(inp0, inp1) + # set context according to output name + context[node.output[0]] = output + + def _execute(self, inp0, inp1): + # extract the operand shapes + (M, K0) = inp0.shape + (K1, N) = inp1.shape + # make sure shapes are compatible with matmul + assert K0 == K1 + K = K0 + # we simulate XNOR-popcount matrix multiplication as a regular bipolar + # matrix multiplication followed by some post processing + # first, convert binary inputs to bipolar + inp0_bipolar = 2.0 * inp0 - 1.0 + inp1_bipolar = 2.0 * inp1 - 1.0 + # call regular numpy matrix multiplication + out = np.matmul(inp0_bipolar, inp1_bipolar) + # XNOR-popcount does not produce the regular dot product result -- + # it returns the number of +1s after XNOR. let P be the number of +1s + # and N be the number of -1s. XNOR-popcount returns P, whereas the + # regular dot product result from numpy is P-N, so we need to apply + # some correction. + # out = P-N + # K = P+N + # out + K = 2P, so P = (out + K)/2 + return (out + K) * 0.5