diff --git a/tests/transformation/test_move_transpose_past_scalar_mul.py b/tests/transformation/test_move_transpose_past_scalar_mul.py index 42b6f22d0fca50f0c9f55b63c89d35413c2d8e31..2fef09a78cc2f1b532c767a6dfd360b34779ff34 100644 --- a/tests/transformation/test_move_transpose_past_scalar_mul.py +++ b/tests/transformation/test_move_transpose_past_scalar_mul.py @@ -51,7 +51,7 @@ def test_move_transpose_past_scalar_mul(perm, scalar): # compare execution before and after transformation inp_values = np.random.uniform(low=0, high=1, size=(1, 2, 3, 4)).astype(np.float32) - idict = {"inp": inp_values} + idict = {model.graph.input[0].name: inp_values} model_transformed = model.transform(MoveTransposePastScalarMul()) assert oxe.compare_execution(model, model_transformed, idict)