Skip to content
Snippets Groups Projects
Commit dadae12d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/xnorpopcountmatmul' into dev

parents 15d76acb 669943de
No related branches found
No related tags found
No related merge requests found
# make sure new CustomOp subclasses are imported here so that they get
# registered and plug in correctly into the infrastructure
from finn.custom_op.multithreshold import MultiThreshold
from finn.custom_op.xnorpopcount import XnorPopcountMatMul
# create a mapping of all known CustomOp names and classes
custom_op = {}
custom_op["MultiThreshold"] = MultiThreshold
custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
import numpy as np
import onnx.helper as helper
from finn.core.datatype import DataType
from finn.custom_op import CustomOp
class XnorPopcountMatMul(CustomOp):
def make_shape_compatible_op(self, node):
return helper.make_node(
"MatMul", [node.input[0], node.input[1]], [node.output[0]]
)
def infer_node_datatype(self, node, model):
# ensure inputs are binary
assert model.get_tensor_datatype(node.input[0]) == DataType["BINARY"]
assert model.get_tensor_datatype(node.input[1]) == DataType["BINARY"]
# XNOR-popcount produces unsigned integers, assume uint32
model.set_tensor_datatype(node.output[0], DataType["UINT32"])
def execute_node(self, node, context, graph):
# save inputs
inp0 = context[node.input[0]]
inp1 = context[node.input[1]]
# calculate output
output = self._execute(inp0, inp1)
# set context according to output name
context[node.output[0]] = output
def _execute(self, inp0, inp1):
# extract the operand shapes
(M, K0) = inp0.shape
(K1, N) = inp1.shape
# make sure shapes are compatible with matmul
assert K0 == K1
K = K0
# we simulate XNOR-popcount matrix multiplication as a regular bipolar
# matrix multiplication followed by some post processing
# first, convert binary inputs to bipolar
inp0_bipolar = 2.0 * inp0 - 1.0
inp1_bipolar = 2.0 * inp1 - 1.0
# call regular numpy matrix multiplication
out = np.matmul(inp0_bipolar, inp1_bipolar)
# XNOR-popcount does not produce the regular dot product result --
# it returns the number of +1s after XNOR. let P be the number of +1s
# and N be the number of -1s. XNOR-popcount returns P, whereas the
# regular dot product result from numpy is P-N, so we need to apply
# some correction.
# out = P-N
# K = P+N
# out + K = 2P, so P = (out + K)/2
return (out + K) * 0.5
import numpy as np
from onnx import TensorProto
from onnx import helper as oh
from finn.core.datatype import DataType
from finn.core.utils import get_by_name
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
class ConvertBipolarMatMulToXnorPopcount(Transformation):
"""Convert MatMul nodes with all-bipolar inputs to XnorPopcountMatMul
and associated result correction."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "MatMul":
mm_input = n.input[0]
mm_weight = n.input[1]
mm_output = n.output[0]
i_bp = model.get_tensor_datatype(mm_input) == DataType.BIPOLAR
w_bp = model.get_tensor_datatype(mm_weight) == DataType.BIPOLAR
if i_bp and w_bp:
graph_modified = True
# change node type and domain
n.op_type = "XnorPopcountMatMul"
n.domain = "finn"
# convert weights into binary (-1,+1) -> (0,1)
Wbin = (model.get_initializer(mm_weight) + 1) / 2
# extract vector length (common matrix dim)
K = Wbin.shape[0]
model.set_initializer(mm_weight, Wbin)
model.set_tensor_datatype(mm_weight, DataType.BINARY)
# find producing threshold node and adjust output to binary
mt = model.find_producer(mm_input)
if mt is not None and mt.op_type == "MultiThreshold":
bin_dt_attr = "BINARY".encode("utf-8")
get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr
get_by_name(mt.attribute, "out_scale").f = 1.0
get_by_name(mt.attribute, "out_bias").f = 0
model.set_tensor_datatype(mm_input, DataType.BINARY)
else:
raise Exception(
"""Requires Bipolar2Binary, not yet
implemented."""
)
# make new output node with correct shape
mm_out_shape = model.get_tensor_shape(mm_output)
xnorpcout = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, mm_out_shape
)
n.output[0] = xnorpcout.name
model.set_tensor_datatype(xnorpcout.name, DataType.UINT32)
# add mul-add nodes to produce correct dot product result
# need to derive P-N from P and K = P+N
# so we need 2*P-K
A = np.asarray([2.0], dtype=np.float32)
B = np.asarray([-K], dtype=np.float32)
# create value_info and initializers for Mul and Add constants
mul_const = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
)
graph.value_info.append(mul_const)
model.set_initializer(mul_const.name, A)
mul_output = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, mm_out_shape
)
graph.value_info.append(mul_output)
add_const = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape
)
graph.value_info.append(add_const)
model.set_initializer(add_const.name, B)
# create Mul and Add nodes to replace the batchnorm
mul_node = oh.make_node(
"Mul", [xnorpcout.name, mul_const.name], [mul_output.name]
)
add_node = oh.make_node(
"Add", [mul_output.name, add_const.name], [mm_output]
)
# insert where the batchnorm is to preserve topological ordering
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
model = model.transform(InferShapes())
return (model, graph_modified)
class ConvertSignToThres(Transformation):
"""Convert Sign node instances to MultiThreshold with threshold at 0."""
......
import os
from pkgutil import get_data
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.helper as helper
import onnx.numpy_helper as nph
import torch
from models.LFC import LFC
from onnx import TensorProto
import finn.core.onnx_exec as oxe
import finn.transformation.streamline as sl
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
export_onnx_path = "test_output_lfc.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)
def test_xnorpopcountmatmul():
M = 1
K = 3
N = 3
x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [M, K])
W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [K, N])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, ["x", "y"])
node_def = helper.make_node(
"XnorPopcountMatMul", ["x", "W"], ["out"], domain="finn"
)
modelproto = helper.make_model(
helper.make_graph([node_def], "test_model", [x], [out], value_info=[W])
)
model = ModelWrapper(modelproto)
model.set_tensor_datatype("x", DataType.BINARY)
model.set_tensor_datatype("W", DataType.BINARY)
W_data = np.asarray([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
model.set_initializer("W", W_data)
# test shape inference
model = model.transform(InferShapes())
assert model.get_tensor_shape("out") == [M, N]
# test datatype inference
assert model.get_tensor_datatype("out") is DataType.FLOAT32
model = model.transform(InferDataTypes())
assert model.get_tensor_datatype("out") is DataType.UINT32
# test execution
x_data = np.asarray([[1, 0, 0]], dtype=np.float32)
inp_dict = {"x": x_data}
out_dict = oxe.execute_onnx(model, inp_dict)
Wb = 2 * W_data - 1
xb = 2 * x_data - 1
rb = np.matmul(xb, Wb)
assert (2 * out_dict["out"] - K == rb).all()
def test_convert_bipolar_matmul_to_xnorpopcountmatmul():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(sl.ConvertSignToThres())
# load one of the test vectors
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
# run using FINN-based execution
input_dict = {"global_in": nph.to_array(input_tensor)}
expected_ctx = oxe.execute_onnx(model, input_dict, True)
expected = expected_ctx[model.graph.output[0].name]
model = model.transform(sl.ConvertBipolarMatMulToXnorPopcount())
produced_ctx = oxe.execute_onnx(model, input_dict, True)
produced = produced_ctx[model.graph.output[0].name]
assert np.isclose(expected, produced, atol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment