diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index d7e1c5e732e29be300e17620d3ca5ea792c5c477..d0df91f74350bd11f9a1d2309e7aa6c6fe16b161 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -2,6 +2,7 @@ import os from functools import reduce from operator import mul +import brevitas.onnx as bo import onnx import onnx.numpy_helper as nph import torch @@ -73,8 +74,6 @@ def test_brevitas_to_onnx_export(): export_onnx_path = "test_output_lfc.onnx" with torch.no_grad(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) - import brevitas.onnx as bo - bo.prepare_for_onnx_export(lfc, True) torch.onnx.export( lfc, torch.empty(784, dtype=torch.float), export_onnx_path, verbose=True