diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 23d8af3d1a8e466e21be64a2d46b50241dcb8850..b3aa98793302b10fce7ab4b086e46229bbbf7958 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -74,6 +74,7 @@ def test_brevitas_to_onnx_export(): export_onnx_path = "test_output_lfc.onnx" with torch.no_grad(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + lfc = lfc.eval() bo.prepare_for_onnx_export(lfc, True) torch.onnx.export( lfc, torch.empty(784, dtype=torch.float), export_onnx_path, verbose=True