From 880f0f4f587c0fdf8e96890b730e09b2928871f4 Mon Sep 17 00:00:00 2001 From: icolbert <Ian.Colbert@amd.com> Date: Mon, 27 Feb 2023 16:55:11 -0800 Subject: [PATCH] Fixing test_minimize_accumulator_width() --- .../transformation/streamline/test_minimize_bit_width.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/transformation/streamline/test_minimize_bit_width.py b/tests/transformation/streamline/test_minimize_bit_width.py index 658481cc6..4995f45eb 100644 --- a/tests/transformation/streamline/test_minimize_bit_width.py +++ b/tests/transformation/streamline/test_minimize_bit_width.py @@ -161,9 +161,9 @@ def test_minimize_weight_bit_width(wdt: DataType, rww: bool): @pytest.mark.parametrize("wdt", weight_data_types) -@pytest.mark.parametrize("adt", input_data_types) +@pytest.mark.parametrize("idt", input_data_types) @pytest.mark.parametrize("rww", [True, False]) -def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): +def test_minimize_accumulator_width(wdt: DataType, idt:DataType, rww: bool): """Testing MinimizeAccumulatorWidth for VVAU and MVAU. :param wdt: (DataType) The data type that we are testing for the weights @@ -191,7 +191,7 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): inst = getCustomOp(node) if isinstance(inst, (MatrixVectorActivation, VectorVectorActivation)): cur_adt = DataType[inst.get_nodeattr("accDataType")] - cur_odt = DataType[inst.get_nodeattr("accDataType")] + cur_odt = DataType[inst.get_nodeattr("outputDataType")] # TODO - figure out how to calculate expected accDataType # exp_wdt = def_wdt if rww else wdt # assert cur_adt.bitwidth() == exp_adt.bitwidth(), "Mismatched data types" @@ -199,4 +199,4 @@ def test_minimize_weight_bit_width(wdt: DataType, idt:DataType, rww: bool): assert (cur_adt.bitwidth() % 8) == 0, "bit width of last node needs to be divisible by 8" assert cur_adt.bitwidth() == cur_odt.bitwidth(), "outputDataType and accDataType should be equal" else: - assert cur_adt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed" \ No newline at end of file + assert cur_odt.bitwidth() == idt.bitwidth(), "outputDataType should not be changed" \ No newline at end of file -- GitLab