From 2e41ac8c2c046e325e50de825a525d1030b3f6d9 Mon Sep 17 00:00:00 2001 From: auphelia <jakobapk@web.de> Date: Thu, 18 Jun 2020 14:28:36 +0100 Subject: [PATCH] [Test] Change input tensor to use values of smaller finn dtypes --- tests/brevitas/test_brevitas_avg_pool_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/test_brevitas_avg_pool_export.py b/tests/brevitas/test_brevitas_avg_pool_export.py index 34b154c4e..24854a215 100644 --- a/tests/brevitas/test_brevitas_avg_pool_export.py +++ b/tests/brevitas/test_brevitas_avg_pool_export.py @@ -55,10 +55,9 @@ def test_brevitas_avg_pool_export( prefix = "INT" else: prefix = "UINT" - dt_name = prefix + str(input_bit_width) + dt_name = prefix + str(input_bit_width // 2) dtype = DataType[dt_name] model = model.transform(InferShapes()) - model.set_tensor_datatype(model.graph.input[0].name, dtype) model = model.transform(InferDataTypes()) # execution with input tensor using integers and scale = 1 @@ -91,6 +90,7 @@ def test_brevitas_avg_pool_export( bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution -- GitLab