Skip to content
Snippets Groups Projects
Commit 669943de authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add test_convert_bipolar_matmul_to_xnorpopcountmatmul

parent 39531898
No related branches found
No related tags found
No related merge requests found
import os
from pkgutil import get_data
import brevitas.onnx as bo
import numpy as np
import onnx
import onnx.helper as helper
import onnx.numpy_helper as nph
import torch
from models.LFC import LFC
from onnx import TensorProto
import finn.core.onnx_exec as oxe
import finn.transformation.streamline as sl
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
export_onnx_path = "test_output_lfc.onnx"
# TODO get from config instead, hardcoded to Docker path for now
trained_lfc_checkpoint = (
"/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
)
def test_xnorpopcountmatmul():
M = 1
......@@ -42,3 +59,28 @@ def test_xnorpopcountmatmul():
xb = 2 * x_data - 1
rb = np.matmul(xb, Wb)
assert (2 * out_dict["out"] - K == rb).all()
def test_convert_bipolar_matmul_to_xnorpopcountmatmul():
lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
lfc.load_state_dict(checkpoint["state_dict"])
bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(sl.ConvertSignToThres())
# load one of the test vectors
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
# run using FINN-based execution
input_dict = {"global_in": nph.to_array(input_tensor)}
expected_ctx = oxe.execute_onnx(model, input_dict, True)
expected = expected_ctx[model.graph.output[0].name]
model = model.transform(sl.ConvertBipolarMatMulToXnorPopcount())
produced_ctx = oxe.execute_onnx(model, input_dict, True)
produced = produced_ctx[model.graph.output[0].name]
assert np.isclose(expected, produced, atol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment