diff --git a/tests/transformation/streamline/test_minimize_bit_width.py b/tests/transformation/streamline/test_minimize_bit_width.py index 658481cc6d170bc8bc855e12722361928f060549..4995f45ebae840a1df6ffa7bbebfb85534158657 100644 --- a/tests/transformation/streamline/test_minimize_bit_width.py +++ b/tests/transformation/streamline/test_minimize_bit_width.py @@ -161,9 +161,9 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool): @pytest.mark.parametrize("wdt", weight_data_types) -@pytest.mark.parametrize("adt", input_data_types) +@pytest.mark.parametrize("idt", input_data_types) @pytest.mark.parametrize("rww", [True, False]) -def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): +def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool): """Testing MinimizeAccumulatorWidth for VVAU and MVAU. :param wdt: (DataType) The data type that we are testing for the weights @@ -191,7 +191,7 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): inst = getCustomOp(node) if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): cur_adt = DataType[inst.get_nodeattr("accDataType")] - cur_odt = DataType[inst.get_nodeattr("accDataType")] + cur_odt = DataType[inst.get_nodeattr("outputDataType")] # TODO - figure out how to calculate expected accDataType # exp_wdt = def_wdt if rww else wdt # assert cur_adt.bitwidth() == exp_adt.bitwidth(), "Mismatched data types" @@ -199,4 +199,4 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8" assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal" else: - assert cur_adt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed" \ No newline at end of file + assert cur_odt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed" \ No newline at end of file