Skip to content
Snippets Groups Projects
Unverified Commit 396cb0ad authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #339 from Xilinx/feature/brevitas_bias_fix

Upgrade Brevitas to get bias export fixes
parents 23ca60b8 d732bbd8
No related branches found
No related tags found
No related merge requests found
...@@ -11,7 +11,7 @@ jobs: ...@@ -11,7 +11,7 @@ jobs:
test: test:
name: Run quicktest on PR branch name: Run quicktest on PR branch
runs-on: ubuntu-16.04 runs-on: ubuntu-18.04
steps: steps:
- name: checkout - name: checkout
......
...@@ -14,7 +14,7 @@ gecho () { ...@@ -14,7 +14,7 @@ gecho () {
# the repos themselves are cloned in the Dockerfile # the repos themselves are cloned in the Dockerfile
FINN_BASE_COMMIT=ac0b86a63eb937b869bfa453a996a8a8b8506546 FINN_BASE_COMMIT=ac0b86a63eb937b869bfa453a996a8a8b8506546
FINN_EXP_COMMIT=e9f97dcdb4db2f889b0f36af079a6a1792b7d4de FINN_EXP_COMMIT=e9f97dcdb4db2f889b0f36af079a6a1792b7d4de
BREVITAS_COMMIT=14abbe1e7ef82485d79415871fcf5766b0a40a00 BREVITAS_COMMIT=d7ded80fa9557da2998ea310669edee7fb2d9526
CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
HLSLIB_COMMIT=4d74baefa79df48b5a0348d63f39a26df075de51 HLSLIB_COMMIT=4d74baefa79df48b5a0348d63f39a26df075de51
PYVERILATOR_COMMIT=e2ff74030de3992dcac54bf1b6aad2915946e8cb PYVERILATOR_COMMIT=e2ff74030de3992dcac54bf1b6aad2915946e8cb
......
...@@ -6,6 +6,7 @@ future==0.18.2 ...@@ -6,6 +6,7 @@ future==0.18.2
gspread==3.6.0 gspread==3.6.0
numpy==1.18.0 numpy==1.18.0
onnx==1.7.0 onnx==1.7.0
onnxoptimizer==0.2.6
onnxruntime==1.4.0 onnxruntime==1.4.0
pre-commit==2.6.0 pre-commit==2.6.0
scipy==1.5.2 scipy==1.5.2
......
...@@ -50,8 +50,6 @@ export_onnx_path = "test_brevitas_conv.onnx" ...@@ -50,8 +50,6 @@ export_onnx_path = "test_brevitas_conv.onnx"
@pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("in_channels", [32]) @pytest.mark.parametrize("in_channels", [32])
def test_brevitas_QConv2d(dw, bias, in_channels): def test_brevitas_QConv2d(dw, bias, in_channels):
if bias:
pytest.xfail("bias export bug")
ishape = (1, 32, 111, 111) ishape = (1, 32, 111, 111)
if dw is True: if dw is True:
groups = in_channels groups = in_channels
......
...@@ -48,8 +48,6 @@ export_onnx_path = "test_brevitas_qlinear.onnx" ...@@ -48,8 +48,6 @@ export_onnx_path = "test_brevitas_qlinear.onnx"
@pytest.mark.parametrize("w_bits", [4]) @pytest.mark.parametrize("w_bits", [4])
@pytest.mark.parametrize("i_dtype", [DataType.UINT4]) @pytest.mark.parametrize("i_dtype", [DataType.UINT4])
def test_brevitas_qlinear(bias, out_features, in_features, w_bits, i_dtype): def test_brevitas_qlinear(bias, out_features, in_features, w_bits, i_dtype):
if bias:
pytest.xfail("bias export bug")
i_shape = (1, in_features) i_shape = (1, in_features)
w_shape = (out_features, in_features) w_shape = (out_features, in_features)
b_linear = QuantLinear( b_linear = QuantLinear(
......
...@@ -347,6 +347,8 @@ class TestEnd2End: ...@@ -347,6 +347,8 @@ class TestEnd2End:
assert os.path.isfile(chkpt_preproc_name) assert os.path.isfile(chkpt_preproc_name)
# join preprocessing and core model # join preprocessing and core model
pre_model = ModelWrapper(chkpt_preproc_name) pre_model = ModelWrapper(chkpt_preproc_name)
pre_model = pre_model.transform(InferShapes())
pre_model = pre_model.transform(FoldConstants())
model = model.transform(MergeONNXModels(pre_model)) model = model.transform(MergeONNXModels(pre_model))
# add input quantization annotation: UINT8 for all BNN-PYNQ models # add input quantization annotation: UINT8 for all BNN-PYNQ models
global_inp_name = model.graph.input[0].name global_inp_name = model.graph.input[0].name
......
...@@ -101,6 +101,7 @@ def test_end2end_mobilenet_export(): ...@@ -101,6 +101,7 @@ def test_end2end_mobilenet_export():
# set input finn datatype to UINT8 # set input finn datatype to UINT8
preproc_model.set_tensor_datatype(preproc_model.graph.input[0].name, DataType.UINT8) preproc_model.set_tensor_datatype(preproc_model.graph.input[0].name, DataType.UINT8)
preproc_model = preproc_model.transform(InferShapes()) preproc_model = preproc_model.transform(InferShapes())
preproc_model = preproc_model.transform(FoldConstants())
preproc_model = preproc_model.transform(GiveUniqueNodeNames()) preproc_model = preproc_model.transform(GiveUniqueNodeNames())
preproc_model = preproc_model.transform(GiveUniqueParameterTensors()) preproc_model = preproc_model.transform(GiveUniqueParameterTensors())
preproc_model = preproc_model.transform(GiveReadableTensorNames()) preproc_model = preproc_model.transform(GiveReadableTensorNames())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment