diff --git a/src/finn/core/execute_custom_node.py b/src/finn/core/execute_custom_node.py index 6583f7cc9695f4a481883b9fb5bbba9233c59bcf..3ad448fc44c76b1b23e3702a002057163b151e0e 100644 --- a/src/finn/core/execute_custom_node.py +++ b/src/finn/core/execute_custom_node.py @@ -1,54 +1,14 @@ -# import onnx.helper as helper -import sys -import os -import tempfile -import numpy as np -import finn.core.multithreshold as multiThresh -import finn.core.utils as util -import finn.backend.fpgadataflow.code_gen_for_single_node_execution as code_gen +import finn.custom_op.registry as registry def execute_custom_node(node, context, graph): """Call custom implementation to execute a single custom node. Input/output provided via context.""" - - if (util.get_by_name(node.attribute, 'backend')) is not None: - if node.op_type == "StreamingMaxPool": - in_ind = 0 - temp_files = [] - for inputs in node.input: - np.save("input_{}.npy".format(in_ind), context[inputs]) - temp_files.append("input_{}.npy".format(in_ind)) - in_ind += 1 - - code_gen.execute(node, context, graph) - - output = np.load("output.npy") - for i in range(output.shape[0]): - print(np.transpose(output[i])) - - - - - ## deleting temporary files - #for temp_file in temp_files: - # os.remove(temp_file) - sys.exit(1) - else: - # exception if op_type is not supported - raise Exception("This hls lib custom node is currently not supported.") - - - else: - if node.op_type == "MultiThreshold": - # save inputs - v = context[node.input[0]] - thresholds = context[node.input[1]] - # calculate output - output = multiThresh.execute(v, thresholds) - # setting context according to output - context[node.output[0]] = output - - else: - # exception if op_type is not supported - raise Exception("This custom node is currently not supported.") + op_type = node.op_type + try: + # lookup op_type in registry of CustomOps + inst = registry.custom_op[op_type]() + inst.execute_node(node, context, graph) + except KeyError: + # exception if op_type is not supported + raise Exception("Custom op_type %s is currently not supported." % op_type) diff --git a/src/finn/core/multithreshold.py b/src/finn/core/multithreshold.py deleted file mode 100755 index 009259c577879a8aa09ac44ace704af55ca2593d..0000000000000000000000000000000000000000 --- a/src/finn/core/multithreshold.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np - - -def compare(x, y): - if x >= y: - return 1.0 - else: - return 0.0 - - -def execute(v, thresholds): - - # the inputs are expected to be in the shape (N,C,H,W) - # N : Batch size - # C : Number of channels - # H : Heigth of the input images - # W : Width of the input images - # - # the thresholds are expected to be in the shape (C, B) - # C : Number of channels (must be the same value as C in input tensor or 1 - # if all channels use the same threshold value) - # B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4) - - # assert threshold shape - is_global_threshold = thresholds.shape[0] == 1 - assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold - - # save the required shape sizes for the loops (N, C and B) - num_batch = v.shape[0] - num_channel = v.shape[1] - - num_act = thresholds.shape[1] - - # reshape inputs to enable channel-wise reading - vr = v.reshape((v.shape[0], v.shape[1], -1)) - - # save the new shape size of the images - num_img_elem = vr.shape[2] - - # initiate output tensor - ret = np.zeros_like(vr) - - # iterate over thresholds channel-wise - for t in range(num_channel): - channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] - - # iterate over batches - for b in range(num_batch): - - # iterate over image elements on which the thresholds should be applied - for elem in range(num_img_elem): - - # iterate over the different thresholds that correspond to one channel - for a in range(num_act): - # apply successive thresholding to every element of one channel - ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a]) - - return ret.reshape(v.shape) diff --git a/src/finn/custom_op/__init__.py b/src/finn/custom_op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd7024affedd15eb564b4e588a2b79343e95d371 --- /dev/null +++ b/src/finn/custom_op/__init__.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + + +class CustomOp(ABC): + def __init__(self): + super().__init__() + + @abstractmethod + def make_shape_compatible_op(self, node): + pass + + @abstractmethod + def infer_node_datatype(self, node, model): + pass + + @abstractmethod + def execute_node(self, node, context, graph): + pass diff --git a/src/finn/custom_op/multithreshold.py b/src/finn/custom_op/multithreshold.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8f1ad4ceded8a40e182af5d79ccd803e0b8c33 --- /dev/null +++ b/src/finn/custom_op/multithreshold.py @@ -0,0 +1,91 @@ +import numpy as np +import onnx.helper as helper + +from finn.core.datatype import DataType +from finn.core.utils import get_by_name +from finn.custom_op import CustomOp + + +class MultiThreshold(CustomOp): + def make_shape_compatible_op(self, node): + return helper.make_node("Relu", [node.input[0]], [node.output[0]]) + + def infer_node_datatype(self, node, model): + try: + odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8") + model.set_tensor_datatype(node.output[0], DataType[odt]) + except AttributeError: + # number of thresholds decides # output bits + # use get_smallest_possible, assuming unsigned + n_thres = model.get_tensor_shape(node.input[1])[1] + odtype = DataType.get_smallest_possible(n_thres) + model.set_tensor_datatype(node.output[0], odtype) + + def execute_node(self, node, context, graph): + # save inputs + v = context[node.input[0]] + thresholds = context[node.input[1]] + # retrieve attributes if output scaling is used + try: + out_scale = get_by_name(node.attribute, "out_scale").f + except AttributeError: + out_scale = None + try: + out_bias = get_by_name(node.attribute, "out_bias").f + except AttributeError: + out_bias = None + # calculate output + output = self._execute(v, thresholds, out_scale, out_bias) + # setting context according to output + context[node.output[0]] = output + + def _compare(self, x, y): + if x >= y: + return 1.0 + else: + return 0.0 + + def _execute(self, v, thresholds, out_scale=None, out_bias=None): + # the inputs are expected to be in the shape (N,C,H,W) + # N : Batch size + # C : Number of channels + # H : Heigth of the input images + # W : Width of the input images + # + # the thresholds are expected to be in the shape (C, B) + # C : Number of channels (must be the same value as C in input tensor + # or 1 if all channels use the same threshold value) + # B : Desired activation steps => i.e. for 4-bit activation, + # B=7 (2^(n)-1 and n=4) + # the output tensor will be scaled by out_scale and biased by out_bias + # assert threshold shape + is_global_threshold = thresholds.shape[0] == 1 + assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold + # save the required shape sizes for the loops (N, C and B) + num_batch = v.shape[0] + num_channel = v.shape[1] + num_act = thresholds.shape[1] + # reshape inputs to enable channel-wise reading + vr = v.reshape((v.shape[0], v.shape[1], -1)) + # save the new shape size of the images + num_img_elem = vr.shape[2] + # initiate output tensor + ret = np.zeros_like(vr) + # iterate over thresholds channel-wise + for t in range(num_channel): + channel_thresh = thresholds[0] if is_global_threshold else thresholds[t] + # iterate over batches + for b in range(num_batch): + # iterate over image elements on which the thresholds will be applied + for elem in range(num_img_elem): + # iterate over the different thresholds for one channel + for a in range(num_act): + # apply successive thresholding to every element + ret[b][t][elem] += self._compare( + vr[b][t][elem], channel_thresh[a] + ) + if out_scale is None: + out_scale = 1.0 + if out_bias is None: + out_bias = 0.0 + return out_scale * ret.reshape(v.shape) + out_bias diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..61d7c8704dca247224502d479d1bfec2edfeddbd --- /dev/null +++ b/src/finn/custom_op/registry.py @@ -0,0 +1,12 @@ +# make sure new CustomOp subclasses are imported here so that they get +# registered and plug in correctly into the infrastructure +from finn.custom_op.multithreshold import MultiThreshold +from finn.custom_op.xnorpopcount import XnorPopcountMatMul +from finn.custom_op.streamingmaxpool import StreamingMaxPool + +# create a mapping of all known CustomOp names and classes +custom_op = {} + +custom_op["MultiThreshold"] = MultiThreshold +custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul +custom_op["StreamingMaxPool"] = StreamingMaxPool diff --git a/src/finn/custom_op/streamingmaxpool.py b/src/finn/custom_op/streamingmaxpool.py new file mode 100644 index 0000000000000000000000000000000000000000..e963aab67012bac6388e17413ecd7a3feb047111 --- /dev/null +++ b/src/finn/custom_op/streamingmaxpool.py @@ -0,0 +1,31 @@ +import numpy as np +import onnx.helper as helper + +from finn.core.datatype import DataType +from finn.core.utils import get_by_name +from finn.custom_op import CustomOp +import finn.backend.fpgadataflow.code_gen_for_single_node_execution as code_gen + + +class StreamingMaxPool(CustomOp): + def make_shape_compatible_op(self, node): + pass + + def infer_node_datatype(self, node, model): + pass + + def execute_node(self, node, context, graph): + in_ind = 0 + temp_files = [] + for inputs in node.input: + np.save("input_{}.npy".format(in_ind), context[inputs]) + temp_files.append("input_{}.npy".format(in_ind)) + in_ind += 1 + code_gen.execute(node, context, graph) + output = np.load("output.npy") + for i in range(output.shape[0]): + print(np.transpose(output[i])) + ## deleting temporary files + #for temp_file in temp_files: + # os.remove(temp_file) + sys.exit(1) diff --git a/src/finn/custom_op/xnorpopcount.py b/src/finn/custom_op/xnorpopcount.py new file mode 100644 index 0000000000000000000000000000000000000000..8f78bc267d4a91422b2d44aecf82ae40c5659574 --- /dev/null +++ b/src/finn/custom_op/xnorpopcount.py @@ -0,0 +1,52 @@ +import numpy as np +import onnx.helper as helper + +from finn.core.datatype import DataType +from finn.custom_op import CustomOp + + +class XnorPopcountMatMul(CustomOp): + def make_shape_compatible_op(self, node): + return helper.make_node( + "MatMul", [node.input[0], node.input[1]], [node.output[0]] + ) + + def infer_node_datatype(self, node, model): + # ensure inputs are binary + assert model.get_tensor_datatype(node.input[0]) == DataType["BINARY"] + assert model.get_tensor_datatype(node.input[1]) == DataType["BINARY"] + # XNOR-popcount produces unsigned integers, assume uint32 + model.set_tensor_datatype(node.output[0], DataType["UINT32"]) + + def execute_node(self, node, context, graph): + # save inputs + inp0 = context[node.input[0]] + inp1 = context[node.input[1]] + # calculate output + output = self._execute(inp0, inp1) + # set context according to output name + context[node.output[0]] = output + + def _execute(self, inp0, inp1): + # extract the operand shapes + (M, K0) = inp0.shape + (K1, N) = inp1.shape + # make sure shapes are compatible with matmul + assert K0 == K1 + K = K0 + # we simulate XNOR-popcount matrix multiplication as a regular bipolar + # matrix multiplication followed by some post processing + # first, convert binary inputs to bipolar + inp0_bipolar = 2.0 * inp0 - 1.0 + inp1_bipolar = 2.0 * inp1 - 1.0 + # call regular numpy matrix multiplication + out = np.matmul(inp0_bipolar, inp1_bipolar) + # XNOR-popcount does not produce the regular dot product result -- + # it returns the number of +1s after XNOR. let P be the number of +1s + # and N be the number of -1s. XNOR-popcount returns P, whereas the + # regular dot product result from numpy is P-N, so we need to apply + # some correction. + # out = P-N + # K = P+N + # out + K = 2P, so P = (out + K)/2 + return (out + K) * 0.5 diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index a311012fc6631e76e75a37b8dc4d1b99d21ce7c7..60ea43c97d0298969ab4b3a280ed9bd3f62cbab8 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -1,3 +1,4 @@ +import finn.custom_op.registry as registry from finn.core.datatype import DataType from finn.transformation import Transformation @@ -7,31 +8,37 @@ def _infer_node_datatype(model, node): changes were made.""" idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) - if node.op_type == "MultiThreshold": - # number of thresholds decides # output buts, use get_smallest_possible - n_thres = model.get_tensor_shape(node.input[1])[1] - odtype = DataType.get_smallest_possible(n_thres) - model.set_tensor_datatype(node.output[0], odtype) - elif node.op_type == "Sign": - # always produces bipolar outputs - model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) - elif node.op_type == "MatMul": - if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: - # node has at least one float input, output is also float - model.set_tensor_datatype(node.output[0], DataType.FLOAT32) - else: - # TODO compute minimum / maximum result to minimize bitwidth - # use (u)int32 accumulators for now - has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 - if has_signed_inp: - odtype = DataType.INT32 - else: - odtype = DataType.UINT32 - model.set_tensor_datatype(node.output[0], odtype) + op_type = node.op_type + if node.domain == "finn": + # handle DataType inference for CustomOp + try: + # lookup op_type in registry of CustomOps + inst = registry.custom_op[op_type]() + inst.infer_node_datatype(node, model) + except KeyError: + # exception if op_type is not supported + raise Exception("Custom op_type %s is currently not supported." % op_type) else: - # unknown, assume node produces float32 outputs - for o in node.output: - model.set_tensor_datatype(o, DataType.FLOAT32) + if node.op_type == "Sign": + # always produces bipolar outputs + model.set_tensor_datatype(node.output[0], DataType.BIPOLAR) + elif node.op_type == "MatMul": + if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0: + # node has at least one float input, output is also float + model.set_tensor_datatype(node.output[0], DataType.FLOAT32) + else: + # TODO compute minimum / maximum result to minimize bitwidth + # use (u)int32 accumulators for now + has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0 + if has_signed_inp: + odtype = DataType.INT32 + else: + odtype = DataType.UINT32 + model.set_tensor_datatype(node.output[0], odtype) + else: + # unknown, assume node produces float32 outputs + for o in node.output: + model.set_tensor_datatype(o, DataType.FLOAT32) # compare old and new output dtypes to see if anything changed new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) graph_modified = new_odtypes != odtypes diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py index e92c6f81625a9d328bed3225a7730a943d6c9830..4938606704f2497f77f9c120e652f37eb25ad8a7 100644 --- a/src/finn/transformation/infer_shapes.py +++ b/src/finn/transformation/infer_shapes.py @@ -1,6 +1,6 @@ -import onnx.helper as helper import onnx.shape_inference as si +import finn.custom_op.registry as registry from finn.core.modelwrapper import ModelWrapper from finn.transformation import Transformation @@ -10,10 +10,14 @@ def _make_shape_compatible_op(node): """Return a shape-compatible non-FINN op for a given FINN op. Used for shape inference with custom ops.""" assert node.domain == "finn" - if node.op_type == "MultiThreshold": - return helper.make_node("Relu", [node.input[0]], [node.output[0]]) - else: - raise Exception("No known shape-compatible op for %s" % node.op_type) + op_type = node.op_type + try: + # lookup op_type in registry of CustomOps + inst = registry.custom_op[op_type]() + return inst.make_shape_compatible_op(node) + except KeyError: + # exception if op_type is not supported + raise Exception("Custom op_type %s is currently not supported." % op_type) def _hide_finn_ops(model): diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index fb9e530d641063187dd75de30c8ca49936565bae..03d4cbc0802a8c1af9b960a98a11fe27ae1a8b7e 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -1,11 +1,94 @@ import numpy as np +from onnx import TensorProto from onnx import helper as oh from finn.core.datatype import DataType +from finn.core.utils import get_by_name from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +class ConvertBipolarMatMulToXnorPopcount(Transformation): + """Convert MatMul nodes with all-bipolar inputs to XnorPopcountMatMul + and associated result correction.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "MatMul": + mm_input = n.input[0] + mm_weight = n.input[1] + mm_output = n.output[0] + i_bp = model.get_tensor_datatype(mm_input) == DataType.BIPOLAR + w_bp = model.get_tensor_datatype(mm_weight) == DataType.BIPOLAR + if i_bp and w_bp: + graph_modified = True + # change node type and domain + n.op_type = "XnorPopcountMatMul" + n.domain = "finn" + # convert weights into binary (-1,+1) -> (0,1) + Wbin = (model.get_initializer(mm_weight) + 1) / 2 + # extract vector length (common matrix dim) + K = Wbin.shape[0] + model.set_initializer(mm_weight, Wbin) + model.set_tensor_datatype(mm_weight, DataType.BINARY) + # find producing threshold node and adjust output to binary + mt = model.find_producer(mm_input) + if mt is not None and mt.op_type == "MultiThreshold": + bin_dt_attr = "BINARY".encode("utf-8") + get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr + get_by_name(mt.attribute, "out_scale").f = 1.0 + get_by_name(mt.attribute, "out_bias").f = 0 + model.set_tensor_datatype(mm_input, DataType.BINARY) + else: + raise Exception( + """Requires Bipolar2Binary, not yet + implemented.""" + ) + # make new output node with correct shape + mm_out_shape = model.get_tensor_shape(mm_output) + xnorpcout = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, mm_out_shape + ) + n.output[0] = xnorpcout.name + model.set_tensor_datatype(xnorpcout.name, DataType.UINT32) + # add mul-add nodes to produce correct dot product result + # need to derive P-N from P and K = P+N + # so we need 2*P-K + A = np.asarray([2.0], dtype=np.float32) + B = np.asarray([-K], dtype=np.float32) + # create value_info and initializers for Mul and Add constants + mul_const = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape + ) + graph.value_info.append(mul_const) + model.set_initializer(mul_const.name, A) + mul_output = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, mm_out_shape + ) + graph.value_info.append(mul_output) + add_const = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape + ) + graph.value_info.append(add_const) + model.set_initializer(add_const.name, B) + # create Mul and Add nodes to replace the batchnorm + mul_node = oh.make_node( + "Mul", [xnorpcout.name, mul_const.name], [mul_output.name] + ) + add_node = oh.make_node( + "Add", [mul_output.name, add_const.name], [mm_output] + ) + # insert where the batchnorm is to preserve topological ordering + graph.node.insert(node_ind, mul_node) + graph.node.insert(node_ind + 1, add_node) + model = model.transform(InferShapes()) + return (model, graph_modified) + + class ConvertSignToThres(Transformation): """Convert Sign node instances to MultiThreshold with threshold at 0.""" @@ -16,50 +99,30 @@ class ConvertSignToThres(Transformation): for n in graph.node: node_ind += 1 if n.op_type == "Sign": + sign_in_name = n.input[0] sign_out_name = n.output[0] # find consumer consumer = model.find_consumer(sign_out_name) assert consumer is not None - # change op type and create threshold - n.op_type = "MultiThreshold" + # create thresholds thres_param_name = model.make_new_valueinfo_name() thres_param = np.asarray([[0]], dtype=np.float32) - n.input.append(thres_param_name) - n.domain = "finn" model.set_initializer(thres_param_name, thres_param) - # convert 0,1 -> -1,+1 with 2*x-1 - out_shape = model.get_tensor_shape(sign_out_name) - # make a mul node - # note how set_initializer or set_tensor_shape is called before - # calling make_new_valueinfo_name again - mul_param_name = model.make_new_valueinfo_name() - model.set_initializer( - mul_param_name, np.asarray([[2]], dtype=np.float32) - ) - mul_out_name = model.make_new_valueinfo_name() - model.set_tensor_shape(mul_out_name, out_shape) - mul_node = oh.make_node( - "Mul", [sign_out_name, mul_param_name], [mul_out_name] - ) - # make an add node - add_param_name = model.make_new_valueinfo_name() - model.set_initializer( - add_param_name, np.asarray([[-1]], dtype=np.float32) - ) - add_out_name = model.make_new_valueinfo_name() - model.set_tensor_shape(add_out_name, out_shape) - add_node = oh.make_node( - "Add", [mul_out_name, add_param_name], [add_out_name] + # create a new node + mt_node = oh.make_node( + "MultiThreshold", + [sign_in_name, thres_param_name], + [sign_out_name], + domain="finn", + out_scale=2.0, + out_bias=-1.0, + out_dtype="BIPOLAR", ) - # add new nodes to graph at correct position - graph.node.insert(node_ind, mul_node) - graph.node.insert(node_ind + 1, add_node) - # rewrite consumer's input - consumer.input[0] = add_out_name + # remove old node, add new node to graph at correct position + graph.node.insert(node_ind, mt_node) + graph.node.remove(n) # add quantization annotations - model.set_tensor_datatype(sign_out_name, DataType.BINARY) - model.set_tensor_datatype(mul_out_name, DataType.UINT2) - model.set_tensor_datatype(add_out_name, DataType.BIPOLAR) + model.set_tensor_datatype(sign_out_name, DataType.BIPOLAR) graph_modified = True return (model, graph_modified) diff --git a/tests/test_custom_onnx_exec.py b/tests/test_custom_onnx_exec.py index 0d07b9888d1330753446d589619149c1f8f316cf..e1ff552572e8a6d4d55e204cd21f17e4984ce30d 100644 --- a/tests/test_custom_onnx_exec.py +++ b/tests/test_custom_onnx_exec.py @@ -211,3 +211,18 @@ def test_execute_custom_node_multithreshold(): ) assert (execution_context["out"] == outputs).all() + + # test the optional output scaling features on MultiThreshold + node_def = helper.make_node( + "MultiThreshold", + ["v", "thresholds"], + ["out"], + domain="finn", + out_scale=2.0, + out_bias=-1.0, + ) + + graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + outputs_scaled = 2.0 * outputs - 1.0 + assert (execution_context["out"] == outputs_scaled).all() diff --git a/tests/test_multi_thresholding.py b/tests/test_multi_thresholding.py index 49305e5572f35eb4a2e5f7678c73038777eb8b92..b16d1602a65284ba366e9c331370c15096981ece 100644 --- a/tests/test_multi_thresholding.py +++ b/tests/test_multi_thresholding.py @@ -1,6 +1,6 @@ import numpy as np -import finn.core.multithreshold as multi_thresh +from finn.custom_op.multithreshold import MultiThreshold def test_execute_multi_thresholding(): @@ -194,6 +194,11 @@ def test_execute_multi_thresholding(): ), ) - results = multi_thresh.execute(inputs, thresholds) + multi_thresh = MultiThreshold() + results = multi_thresh._execute(inputs, thresholds) assert (results == outputs).all() + + results_scaled = multi_thresh._execute(inputs, thresholds, 2.0, -1.0) + outputs_scaled = 2.0 * outputs - 1.0 + assert (results_scaled == outputs_scaled).all() diff --git a/tests/test_streamline.py b/tests/test_streamline.py index 75b0301072bd603e0c25357fa022b66d1df4e467..d0e83e8b5b75cf33b2b1cdedda46feee38ef339a 100644 --- a/tests/test_streamline.py +++ b/tests/test_streamline.py @@ -68,7 +68,7 @@ def test_streamline_lfc_w1a1(): model = model.transform(InferDataTypes()) produced_ctx = oxe.execute_onnx(model, input_dict, True) produced = produced_ctx[model.graph.output[0].name] - # model.save("%d-%s.onnx" % (trn_ind, trn.__name__)) + # model.save("%d-%s.onnx" % (trn_ind, trn.__class__.__name__)) assert np.isclose(expected, produced, atol=1e-3).all() trn_ind += 1 os.remove(export_onnx_path) diff --git a/tests/test_xnorpopcountmatmul.py b/tests/test_xnorpopcountmatmul.py new file mode 100644 index 0000000000000000000000000000000000000000..a416c2bfd3640eeb092007fdc6970f53647632dd --- /dev/null +++ b/tests/test_xnorpopcountmatmul.py @@ -0,0 +1,86 @@ +import os +from pkgutil import get_data + +import brevitas.onnx as bo +import numpy as np +import onnx +import onnx.helper as helper +import onnx.numpy_helper as nph +import torch +from models.LFC import LFC +from onnx import TensorProto + +import finn.core.onnx_exec as oxe +import finn.transformation.streamline as sl +from finn.core.datatype import DataType +from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_shapes import InferShapes + +export_onnx_path = "test_output_lfc.onnx" +# TODO get from config instead, hardcoded to Docker path for now +trained_lfc_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" +) + + +def test_xnorpopcountmatmul(): + M = 1 + K = 3 + N = 3 + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [M, K]) + W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [K, N]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, ["x", "y"]) + node_def = helper.make_node( + "XnorPopcountMatMul", ["x", "W"], ["out"], domain="finn" + ) + modelproto = helper.make_model( + helper.make_graph([node_def], "test_model", [x], [out], value_info=[W]) + ) + model = ModelWrapper(modelproto) + model.set_tensor_datatype("x", DataType.BINARY) + model.set_tensor_datatype("W", DataType.BINARY) + W_data = np.asarray([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + model.set_initializer("W", W_data) + # test shape inference + model = model.transform(InferShapes()) + assert model.get_tensor_shape("out") == [M, N] + # test datatype inference + assert model.get_tensor_datatype("out") is DataType.FLOAT32 + model = model.transform(InferDataTypes()) + assert model.get_tensor_datatype("out") is DataType.UINT32 + # test execution + x_data = np.asarray([[1, 0, 0]], dtype=np.float32) + inp_dict = {"x": x_data} + out_dict = oxe.execute_onnx(model, inp_dict) + Wb = 2 * W_data - 1 + xb = 2 * x_data - 1 + rb = np.matmul(xb, Wb) + assert (2 * out_dict["out"] - K == rb).all() + + +def test_convert_bipolar_matmul_to_xnorpopcountmatmul(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + model = model.transform(sl.ConvertSignToThres()) + # load one of the test vectors + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) + # run using FINN-based execution + input_dict = {"global_in": nph.to_array(input_tensor)} + expected_ctx = oxe.execute_onnx(model, input_dict, True) + expected = expected_ctx[model.graph.output[0].name] + model = model.transform(sl.ConvertBipolarMatMulToXnorPopcount()) + produced_ctx = oxe.execute_onnx(model, input_dict, True) + produced = produced_ctx[model.graph.output[0].name] + assert np.isclose(expected, produced, atol=1e-3).all() + os.remove(export_onnx_path)