From 63c467c898f8a8ef5763dbf04400a20ec0be4ce9 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sun, 3 Nov 2019 22:47:59 +0000 Subject: [PATCH] [Test] use LFC as a better test_sign_to_thres testcase --- tests/test_sign_to_thres.py | 48 +++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py index 665638283..4e18822cf 100644 --- a/tests/test_sign_to_thres.py +++ b/tests/test_sign_to_thres.py @@ -1,31 +1,39 @@ -import numpy as np -from onnx import TensorProto, helper +import os +from pkgutil import get_data + +import brevitas.onnx as bo +import onnx +import onnx.numpy_helper as nph +import torch +from models.LFC import LFC import finn.core.onnx_exec as oxe +import finn.transformation.fold_constants as fc import finn.transformation.infer_shapes as si import finn.transformation.streamline as sl from finn.core.modelwrapper import ModelWrapper +export_onnx_path = "test_output_lfc.onnx" +transformed_onnx_path = "test_output_lfc_transformed.onnx" +# TODO get from config instead, hardcoded to Docker path for now +trained_lfc_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" +) + def test_sign_to_thres(): - out0 = helper.make_tensor_value_info("out0", TensorProto.FLOAT, [6, 3, 2, 2]) - graph_def = helper.make_graph( - nodes=[ - helper.make_node("Sign", ["v"], ["out0"]), - helper.make_node("Relu", ["out0"], ["out1"]), - ], - name="test-model", - inputs=[helper.make_tensor_value_info("v", TensorProto.FLOAT, [6, 3, 2, 2])], - value_info=[out0], - outputs=[ - helper.make_tensor_value_info("out1", TensorProto.FLOAT, [6, 3, 2, 2]) - ], - ) - model_def = helper.make_model(graph_def, producer_name="finn-test") - model = ModelWrapper(model_def) + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) model = model.transform_single(si.infer_shapes) - input_dict = {} - input_dict["v"] = np.random.randn(*[6, 3, 2, 2]).astype(np.float32) + model = model.transform_repeated(fc.fold_constants) new_model = model.transform_single(sl.convert_sign_to_thres) - assert new_model.graph.node[0].op_type == "MultiThreshold" + assert new_model.graph.node[3].op_type == "MultiThreshold" + # load one of the test vectors + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) + input_dict = {"0": nph.to_array(input_tensor)} assert oxe.compare_execution(model, new_model, input_dict) + os.remove(export_onnx_path) -- GitLab