Skip to content
Snippets Groups Projects
Commit 880f0f4f authored by icolbert's avatar icolbert
Browse files

Fixing test_minimize_accumulator_width()

parent 5fc80734
No related branches found
No related tags found
No related merge requests found
...@@ -161,9 +161,9 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool): ...@@ -161,9 +161,9 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool):
@pytest.mark.parametrize("wdt", weight_data_types) @pytest.mark.parametrize("wdt", weight_data_types)
@pytest.mark.parametrize("adt", input_data_types) @pytest.mark.parametrize("idt", input_data_types)
@pytest.mark.parametrize("rww", [True, False]) @pytest.mark.parametrize("rww", [True, False])
def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool):
"""Testing MinimizeAccumulatorWidth for VVAU and MVAU. """Testing MinimizeAccumulatorWidth for VVAU and MVAU.
:param wdt: (DataType) The data type that we are testing for the weights :param wdt: (DataType) The data type that we are testing for the weights
...@@ -191,7 +191,7 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): ...@@ -191,7 +191,7 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool):
inst = getCustomOp(node) inst = getCustomOp(node)
if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)):
cur_adt = DataType[inst.get_nodeattr("accDataType")] cur_adt = DataType[inst.get_nodeattr("accDataType")]
cur_odt = DataType[inst.get_nodeattr("accDataType")] cur_odt = DataType[inst.get_nodeattr("outputDataType")]
# TODO - figure out how to calculate expected accDataType # TODO - figure out how to calculate expected accDataType
# exp_wdt = def_wdt if rww else wdt # exp_wdt = def_wdt if rww else wdt
# assert cur_adt.bitwidth() == exp_adt.bitwidth(), "Mismatched data types" # assert cur_adt.bitwidth() == exp_adt.bitwidth(), "Mismatched data types"
...@@ -199,4 +199,4 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): ...@@ -199,4 +199,4 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool):
assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8" assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8"
assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal" assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal"
else: else:
assert cur_adt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed" assert cur_odt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed"
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment