From 2e41ac8c2c046e325e50de825a525d1030b3f6d9 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Thu, 18 Jun 2020 14:28:36 +0100
Subject: [PATCH] [Test] Change input tensor to use values of smaller finn
 dtypes

---
 tests/brevitas/test_brevitas_avg_pool_export.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py
index 34b154c4e..24854a215 100644
--- a/tests/brevitas/test_brevitas_avg_pool_export.py
+++ b/tests/brevitas/test_brevitas_avg_pool_export.py
@@ -55,10 +55,9 @@ def test_brevitas_avg_pool_export(
         prefix = "INT"
     else:
         prefix = "UINT"
-    dt_name = prefix + str(input_bit_width)
+    dt_name = prefix + str(input_bit_width // 2)
     dtype = DataType[dt_name]
     model = model.transform(InferShapes())
-    model.set_tensor_datatype(model.graph.input[0].name, dtype)
     model = model.transform(InferDataTypes())
 
     # execution with input tensor using integers and scale = 1
@@ -91,6 +90,7 @@ def test_brevitas_avg_pool_export(
     bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor)
     model = ModelWrapper(export_onnx_path)
     model = model.transform(InferShapes())
+    model = model.transform(InferDataTypes())
     b_avgpool.eval()
     expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
     # finn execution
-- 
GitLab