Skip to content
Snippets Groups Projects
Commit 205528b6 authored by icolbert's avatar icolbert
Browse files

Adding unit test

parent 7c92c0f4
No related branches found
No related tags found
No related merge requests found
...@@ -89,6 +89,8 @@ from finn.transformation.streamline.reorder import ( ...@@ -89,6 +89,8 @@ from finn.transformation.streamline.reorder import (
MakeMaxPoolNHWC, MakeMaxPoolNHWC,
MoveScalarLinearPastInvariants, MoveScalarLinearPastInvariants,
) )
from finn.transformation.fpgadataflow.minimize_accumulator_width import MinimizeAccumulatorWidth
from finn.transformation.fpgadataflow.minimize_weight_bit_width import MinimizeWeightBitWidth
from finn.util.basic import get_finn_root from finn.util.basic import get_finn_root
from finn.util.gdrive import upload_to_end2end_dashboard from finn.util.gdrive import upload_to_end2end_dashboard
from finn.util.pytorch import ToTensor from finn.util.pytorch import ToTensor
...@@ -511,11 +513,23 @@ class TestEnd2End: ...@@ -511,11 +513,23 @@ class TestEnd2End:
model = folding_fxn(model) model = folding_fxn(model)
model.save(get_checkpoint_name(topology, wbits, abits, QONNX_export, "fold")) model.save(get_checkpoint_name(topology, wbits, abits, QONNX_export, "fold"))
def test_minimize_bit_width(self, topology, wbits, abits, QONNX_export):
prev_chkpt_name = get_checkpoint_name(
topology, wbits, abits, QONNX_export, "fold"
)
model = load_test_checkpoint_or_skip(prev_chkpt_name)
model = model.transform(MinimizeAccumulatorWidth())
model = model.transform(MinimizeWeightBitWidth())
curr_chkpt_name = get_checkpoint_name(
topology, wbits, abits, QONNX_export, "minimize_bit_width"
)
model.save(curr_chkpt_name)
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.vivado @pytest.mark.vivado
def test_cppsim(self, topology, wbits, abits, QONNX_export): def test_cppsim(self, topology, wbits, abits, QONNX_export):
prev_chkpt_name = get_checkpoint_name( prev_chkpt_name = get_checkpoint_name(
topology, wbits, abits, QONNX_export, "fold" topology, wbits, abits, QONNX_export, "minimize_bit_width"
) )
model = load_test_checkpoint_or_skip(prev_chkpt_name) model = load_test_checkpoint_or_skip(prev_chkpt_name)
model = model.transform(PrepareCppSim()) model = model.transform(PrepareCppSim())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment