Skip to content
Snippets Groups Projects
Commit 880f0f4f authored by icolbert's avatar icolbert
Browse files

Fixing test_minimize_accumulator_width()

parent 5fc80734
No related branches found
No related tags found
No related merge requests found
......@@ -161,9 +161,9 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool):
@pytest.mark.parametrize("wdt", weight_data_types)
@pytest.mark.parametrize("adt", input_data_types)
@pytest.mark.parametrize("idt", input_data_types)
@pytest.mark.parametrize("rww", [True, False])
def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool):
def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool):
"""Testing MinimizeAccumulatorWidth for VVAU and MVAU.
:param wdt: (DataType) The data type that we are testing for the weights
......@@ -191,7 +191,7 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool):
inst = getCustomOp(node)
if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)):
cur_adt = DataType[inst.get_nodeattr("accDataType")]
cur_odt = DataType[inst.get_nodeattr("accDataType")]
cur_odt = DataType[inst.get_nodeattr("outputDataType")]
# TODO - figure out how to calculate expected accDataType
# exp_wdt = def_wdt if rww else wdt
# assert cur_adt.bitwidth() == exp_adt.bitwidth(), "Mismatched data types"
......@@ -199,4 +199,4 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool):
assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8"
assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal"
else:
assert cur_adt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed"
\ No newline at end of file
assert cur_odt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed"
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment