From 82f871462f0d718c681a52b51b576495cb80057d Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Fri, 5 Jun 2020 15:43:42 +0100
Subject: [PATCH] [Docker & Test] Update brevitas version and extend avg pool
 export testing

---
 docker/finn_entrypoint.sh                     |  2 +-
 .../brevitas/test_brevitas_avg_pool_export.py | 35 +++++++++++++++++++
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/docker/finn_entrypoint.sh b/docker/finn_entrypoint.sh
index 4baa12aa5..30dae7d2b 100644
--- a/docker/finn_entrypoint.sh
+++ b/docker/finn_entrypoint.sh
@@ -13,7 +13,7 @@ gecho () {
 
 # checkout the correct dependency repo commits
 # the repos themselves are cloned in the Dockerfile
-BREVITAS_COMMIT=d45ac15325c7f33de6a9d2d2f654ef48cb20c49d
+BREVITAS_COMMIT=093de7d138c6715dbcaf82a9e1d530069327ad98
 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
 HLSLIB_COMMIT=6b88db826bb023937506913a23d964775a7606af
 PYVERILATOR_COMMIT=1d89cb0d4e0c97469cc6352c611f876ec13edfa6
diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 3bf6a8ed6..17b2bc5aa 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -1,3 +1,5 @@
+import os
+
 import onnx  # noqa
 import torch
 import numpy as np
@@ -5,6 +7,12 @@ import brevitas.onnx as bo
 from brevitas.nn import QuantAvgPool2d
 from brevitas.quant_tensor import pack_quant_tensor
 from brevitas.core.quant import QuantType
+from finn.core.modelwrapper import ModelWrapper
+from finn.core.datatype import DataType
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
+from finn.util.basic import gen_finn_dt_tensor
+import finn.core.onnx_exec as oxe
 
 import pytest
 
@@ -36,3 +44,30 @@ def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width):
         tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
     )
     bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
+    model = ModelWrapper(export_onnx_path)
+    # set FINN datatype
+    if signed is True:
+        prefix = "INT"
+    else:
+        prefix = "UINT"
+    dt_name = prefix + str(bit_width)
+    dtype = DataType[dt_name]
+    model.set_tensor_datatype(model.graph.input[0].name, dtype)
+    model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
+
+    # calculate golden output
+    inp = gen_finn_dt_tensor(dtype, ishape)
+    input_tensor = torch.from_numpy(inp).float()
+    input_quant_tensor = pack_quant_tensor(
+        tensor=input_tensor, scale=output_scale, bit_width=ibw_tensor
+    )
+    b_avgpool.eval()
+    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
+    
+    # finn execution
+    idict = {model.graph.input[0].name : inp}
+    odict = oxe.execute_onnx(model, idict, True)
+    produced = odict[model.graph.output[0].name]
+
+    os.remove(export_onnx_path)
-- 
GitLab