diff --git a/tests/transformation/streamline/test_streamline_cnv.py b/tests/transformation/streamline/test_streamline_cnv.py index 56dcd26076ec0a5fba6e9be6acac7f5e13572c3d..103967dfb6b86cc6e2ce2bc9ab78249d8945d47d 100644 --- a/tests/transformation/streamline/test_streamline_cnv.py +++ b/tests/transformation/streamline/test_streamline_cnv.py @@ -44,9 +44,9 @@ from finn.transformation.double_to_single_float import DoubleToSingleFloat export_onnx_path = make_build_dir("test_streamline_cnv_") # act bits -@pytest.mark.parametrize("abits", [1]) +@pytest.mark.parametrize("abits", [1, 2]) # weight bits -@pytest.mark.parametrize("wbits", [1]) +@pytest.mark.parametrize("wbits", [1, 2]) # network topology / size @pytest.mark.parametrize("size", ["CNV"]) def test_streamline_cnv(size, wbits, abits): @@ -74,6 +74,7 @@ def test_streamline_cnv(size, wbits, abits): # model.save("orig_cnv.onnx") model = model.transform(Streamline()) # model.save("streamlined_cnv.onnx") + assert len(model.graph.node) == 23 produced_ctx = oxe.execute_onnx(model, input_dict, True) produced = produced_ctx[model.graph.output[0].name] assert np.isclose(expected, produced, atol=1e-3).all()