diff --git a/tests/test_streamline.py b/tests/test_streamline.py index 6e9f14647b72fe2fc90d4fa0f54234bcd561f9c3..7b74fddc058b82b6688974fe73d04b67c3de20fa 100644 --- a/tests/test_streamline.py +++ b/tests/test_streamline.py @@ -43,20 +43,24 @@ def test_streamline_lfc_w1a1(): transforms = [ tg.convert_sub_to_add, ba.batchnorm_to_affine, + sl.convert_sign_to_thres, sl.move_scalar_add_past_matmul, sl.move_scalar_mul_past_matmul, sl.move_add_past_mul, sl.collapse_repeated_add, sl.collapse_repeated_mul, - sl.convert_sign_to_thres, sl.absorb_add_into_multi_threshold, + sl.factor_out_mul_sign_magnitude, sl.absorb_mul_into_multi_threshold, ] + trn_ind = 0 for trn in transforms: model = model.transform_repeated(trn) model = model.transform_single(tg.give_unique_node_names) model = model.transform_single(tg.give_readable_tensor_names) produced_ctx = oxe.execute_onnx(model, input_dict, True) produced = produced_ctx[model.graph.output[0].name] + # model.save("%d-%s.onnx" % (trn_ind, trn.__name__)) assert np.isclose(expected, produced, atol=1e-3).all() + trn_ind += 1 os.remove(export_onnx_path)