diff --git a/tests/core/test_modelwrapper.py b/tests/core/test_modelwrapper.py index c04deb616625d45358011db9f2b78b4d245f514b..d1da6934a5db07aabe41a9ca40b5de497b6460a1 100644 --- a/tests/core/test_modelwrapper.py +++ b/tests/core/test_modelwrapper.py @@ -31,7 +31,7 @@ import onnx from collections import Counter import brevitas.onnx as bo import numpy as np -import finn.core.data_layout as data_layout +import finn.core.data_layout as DataLayout from finn.core.modelwrapper import ModelWrapper from finn.util.test import get_test_model_trained @@ -70,7 +70,7 @@ def test_modelwrapper(): assert out_prod.op_type == "MultiThreshold" inp_layout = model.get_tensor_layout(inp_name) assert inp_layout is None - inp_layout = data_layout.NCHW + inp_layout = DataLayout.NCHW model.set_tensor_layout(inp_name, inp_layout) assert model.get_tensor_layout(inp_name) == inp_layout os.remove(export_onnx_path)