From 4537a31904491e5a0a63430b8b406c2f3e76476b Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Wed, 16 Oct 2019 16:42:09 +0100
Subject: [PATCH] [Test] add skeleton for exported Brevitas ONNX exec test

---
 tests/test_brevitas_export.py | 58 ++++++++++++++++++++++++++++++++++-
 1 file changed, 57 insertions(+), 1 deletion(-)

diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index dbdaba600..ea5287025 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -1,19 +1,34 @@
 import os
+import shutil
 from functools import reduce
 from operator import mul
 
 import brevitas.onnx as bo
+import numpy as np
 import onnx
 import onnx.numpy_helper as nph
+import torch
+import wget
 from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
 from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
 
+import finn.core.onnx_exec as oxe
+
 FC_OUT_FEATURES = [1024, 1024, 1024]
 INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
 LAST_FC_PER_OUT_CH_SCALING = False
 IN_DROPOUT = 0.2
 HIDDEN_DROPOUT = 0.2
 
+mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist"
+mnist_onnx_filename = "mnist.tar.gz"
+mnist_onnx_local_dir = "/tmp/mnist_onnx"
+export_onnx_path = "test_output_lfc.onnx"
+# TODO get from config instead, hardcoded to Docker path for now
+trained_lfc_checkpoint = (
+    "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
+)
+
 
 class LFC(Module):
     def __init__(
@@ -70,7 +85,6 @@ class LFC(Module):
 
 
 def test_brevitas_to_onnx_export():
-    export_onnx_path = "test_output_lfc.onnx"
     lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = onnx.load(export_onnx_path)
@@ -99,3 +113,45 @@ def test_brevitas_to_onnx_export():
     int_weights_onnx = nph.to_array(model.graph.initializer[init_ind])
     assert (int_weights_onnx == int_weights_pytorch).all()
     os.remove(export_onnx_path)
+
+
+def test_brevitas_to_onnx_export_and_exec():
+    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
+    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
+    lfc.load_state_dict(checkpoint["state_dict"])
+    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
+    model = onnx.load(export_onnx_path)
+    dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp")
+    shutil.unpack_archive(dl_ret, mnist_onnx_local_dir)
+    # load one of the test vectors
+    input_tensor = onnx.TensorProto()
+    output_tensor = onnx.TensorProto()
+    with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
+        input_tensor.ParseFromString(f.read())
+    with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/output_0.pb", "rb") as f:
+        output_tensor.ParseFromString(f.read())
+    # run using FINN-based execution
+    input_dict = {"0": nph.to_array(input_tensor)}
+    output_dict = oxe.execute_onnx(model, input_dict)
+    assert np.isclose(nph.to_array(output_tensor), output_dict["53"], atol=1e-3).all()
+    # remove the downloaded model and extracted files
+    os.remove(dl_ret)
+    shutil.rmtree(mnist_onnx_local_dir)
+    os.remove(export_onnx_path)
+
+
+class objdict(dict):
+    def __getattr__(self, name):
+        if name in self:
+            return self[name]
+        else:
+            raise AttributeError("No such attribute: " + name)
+
+    def __setattr__(self, name, value):
+        self[name] = value
+
+    def __delattr__(self, name):
+        if name in self:
+            del self[name]
+        else:
+            raise AttributeError("No such attribute: " + name)
-- 
GitLab