From a9d7d896820713778caf99757a6b7fbf1d4442ce Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Tue, 3 Sep 2019 15:33:36 +0100
Subject: [PATCH] use Brevitas ONNX export functions, add some assertions

---
 tests/test_brevitas_export.py | 119 +++++++++-------------------------
 1 file changed, 31 insertions(+), 88 deletions(-)

diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index c464fc21f..d7e1c5e73 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -1,12 +1,12 @@
+import os
 from functools import reduce
 from operator import mul
 
 import onnx
+import onnx.numpy_helper as nph
 import torch
 import torch.onnx
 from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
-from torch import nn
-from torch.autograd import Function
 from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
 
 
@@ -70,90 +70,33 @@ def test_brevitas_to_onnx_export():
             out = self.fc(x)
             return out
 
-    class objdict(dict):
-        def __getattr__(self, name):
-            if name in self:
-                return self[name]
-            else:
-                raise AttributeError("No such attribute: " + name)
-
-        def __setattr__(self, name, value):
-            self[name] = value
-
-        def __delattr__(self, name):
-            if name in self:
-                del self[name]
-            else:
-                raise AttributeError("No such attribute: " + name)
-
-    # TODO: <all this needs to mvoe into Brevitas>
-    quantization_annotation = dict()
-
-    class QuantizedLinearPlaceholderFunction(Function):
-        @staticmethod
-        def symbolic(g, W, x, bw, out_features):
-            # import pdb; pdb.set_trace()
-            quantization_annotation[W.uniqueName()] = str(bw)
-            return g.op("MatMul", W, x, domain_s="finn")
-
-        @staticmethod
-        def forward(ctx, W, x, bw, out_features):
-            return torch.empty(1, out_features, dtype=torch.float)
-
-    class QuantizedLinearPlaceholder(nn.Module):
-        def __init__(self, quantized_linear):
-            super(QuantizedLinearPlaceholder, self).__init__()
-            self.in_features = quantized_linear.in_features
-            self.out_features = quantized_linear.out_features
-            # compute the quantized weights
-            W, s, bitwidth = quantized_linear.weight_quant(quantized_linear.weight)
-            W = W.detach().numpy().reshape(self.out_features, self.in_features)
-            s = s.detach().numpy()
-            s = s.reshape(s.size, 1)
-            W = W / s
-            self.W = torch.from_numpy(W)
-            self.bitwidth = bitwidth.item()
-
-        def forward(self, x):
-            # return linear(self.W, x)
-            return QuantizedLinearPlaceholderFunction.apply(
-                self.W, x, self.bitwidth, self.out_features
-            )
-
-    class QuantizedHardTanhPlaceholderFunction(Function):
-        @staticmethod
-        def symbolic(g, input):
-            ret = g.op("QuantizedHardTanh", input, domain_s="finn")
-            # insert quantization annotation for the resulting tensor, TODO fix bitwidth
-            quantization_annotation[ret.uniqueName()] = "1"
-            return ret
-
-        @staticmethod
-        def forward(ctx, input):
-            return input.clamp(0)
-
-    class QuantizedHardTanhPlaceholder(nn.Module):
-        def __init__(self):
-            super(QuantizedHardTanhPlaceholder, self).__init__()
-
-        def forward(self, x):
-            return QuantizedHardTanhPlaceholderFunction.apply(x)
-
-    # TODO: </all this needs to mvoe into Brevitas>
     export_onnx_path = "test_output_lfc.onnx"
-    lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
-    for i in range(len(lfc.features)):
-        L = lfc.features[i]
-        if type(L).__name__ == "QuantLinear":
-            lfc.features[i] = QuantizedLinearPlaceholder(L)
-        elif type(L).__name__ == "QuantHardTanh":
-            lfc.features[i] = QuantizedHardTanhPlaceholder()
-    lfc.fc = QuantizedLinearPlaceholder(lfc.fc)
-    torch.onnx.export(lfc, torch.empty(784, dtype=torch.float), export_onnx_path)
-    model = onnx.load(export_onnx_path)
-    assert len(model.graph.input) == 16
-    assert len(model.graph.node) == 33
-    assert len(model.graph.output) == 1
-    assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
-    assert model.graph.node[13].op_type == "Constant"
-    assert model.graph.node[12].op_type == "QuantizedHardTanh"
+    with torch.no_grad():
+        lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
+        import brevitas.onnx as bo
+
+        bo.prepare_for_onnx_export(lfc, True)
+        torch.onnx.export(
+            lfc, torch.empty(784, dtype=torch.float), export_onnx_path, verbose=True
+        )
+        model = onnx.load(export_onnx_path)
+        # TODO the following way of testing is highly sensitive to small changes
+        # in PyTorch ONNX export: the order, names, count... of nodes could
+        # easily change between different versions, and break this test.
+        assert len(model.graph.input) == 32
+        assert len(model.graph.node) == 33
+        assert len(model.graph.output) == 1
+        assert model.graph.output[0].type.tensor_type.shape.dim[1].dim_value == 10
+        assert model.graph.node[12].op_type == "QuantizedHardTanh"
+        assert model.graph.node[13].op_type == "Constant"
+        assert model.graph.node[14].op_type == "MatMul"
+        assert model.graph.node[12].output[0] == model.graph.node[14].input[1]
+        assert model.graph.node[13].output[0] == model.graph.node[14].input[0]
+        int_weights_pytorch = lfc.features[2].int_weight.detach().numpy()
+        int_weights_onnx = nph.to_array(model.graph.node[13].attribute[0].t)
+        assert (int_weights_onnx == int_weights_pytorch).all()
+        assert model.graph.node[12].attribute[0].name == "activation_qnt"
+        assert model.graph.node[12].attribute[0].s.decode("utf-8") == "1"
+        assert model.graph.node[14].attribute[1].name == "weight_qnt"
+        assert model.graph.node[14].attribute[1].s.decode("utf-8") == "1"
+        os.remove(export_onnx_path)
-- 
GitLab