From 63c467c898f8a8ef5763dbf04400a20ec0be4ce9 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sun, 3 Nov 2019 22:47:59 +0000
Subject: [PATCH] [Test] use LFC as a better test_sign_to_thres testcase

---
 tests/test_sign_to_thres.py | 48 +++++++++++++++++++++----------------
 1 file changed, 28 insertions(+), 20 deletions(-)

diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py
index 665638283..4e18822cf 100644
--- a/tests/test_sign_to_thres.py
+++ b/tests/test_sign_to_thres.py
@@ -1,31 +1,39 @@
-import numpy as np
-from onnx import TensorProto, helper
+import os
+from pkgutil import get_data
+
+import brevitas.onnx as bo
+import onnx
+import onnx.numpy_helper as nph
+import torch
+from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
+import finn.transformation.fold_constants as fc
 import finn.transformation.infer_shapes as si
 import finn.transformation.streamline as sl
 from finn.core.modelwrapper import ModelWrapper
 
+export_onnx_path = "test_output_lfc.onnx"
+transformed_onnx_path = "test_output_lfc_transformed.onnx"
+# TODO get from config instead, hardcoded to Docker path for now
+trained_lfc_checkpoint = (
+    "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
+)
+
 
 def test_sign_to_thres():
-    out0 = helper.make_tensor_value_info("out0", TensorProto.FLOAT, [6, 3, 2, 2])
-    graph_def = helper.make_graph(
-        nodes=[
-            helper.make_node("Sign", ["v"], ["out0"]),
-            helper.make_node("Relu", ["out0"], ["out1"]),
-        ],
-        name="test-model",
-        inputs=[helper.make_tensor_value_info("v", TensorProto.FLOAT, [6, 3, 2, 2])],
-        value_info=[out0],
-        outputs=[
-            helper.make_tensor_value_info("out1", TensorProto.FLOAT, [6, 3, 2, 2])
-        ],
-    )
-    model_def = helper.make_model(graph_def, producer_name="finn-test")
-    model = ModelWrapper(model_def)
+    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
+    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
+    lfc.load_state_dict(checkpoint["state_dict"])
+    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
     model = model.transform_single(si.infer_shapes)
-    input_dict = {}
-    input_dict["v"] = np.random.randn(*[6, 3, 2, 2]).astype(np.float32)
+    model = model.transform_repeated(fc.fold_constants)
     new_model = model.transform_single(sl.convert_sign_to_thres)
-    assert new_model.graph.node[0].op_type == "MultiThreshold"
+    assert new_model.graph.node[3].op_type == "MultiThreshold"
+    # load one of the test vectors
+    raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
+    input_tensor = onnx.load_tensor_from_string(raw_i)
+    input_dict = {"0": nph.to_array(input_tensor)}
     assert oxe.compare_execution(model, new_model, input_dict)
+    os.remove(export_onnx_path)
-- 
GitLab