diff --git a/tests/transformation/test_infer_datatypes.py b/tests/transformation/test_infer_datatypes.py index 77b6a94f8ed891a4fe761fe864a6e18d35e84382..e3db40289c4318894cf5ad41c2f67b3bff501db9 100644 --- a/tests/transformation/test_infer_datatypes.py +++ b/tests/transformation/test_infer_datatypes.py @@ -54,8 +54,8 @@ def test_infer_datatypes(): assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32 assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32 assert model.get_tensor_datatype("MatMul_3_out0") == DataType.INT32 - assert model.get_tensor_datatype("Sign_0_out0") == DataType.BIPOLAR - assert model.get_tensor_datatype("Sign_1_out0") == DataType.BIPOLAR - assert model.get_tensor_datatype("Sign_2_out0") == DataType.BIPOLAR - assert model.get_tensor_datatype("Sign_3_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("MultiThreshold_0_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("MultiThreshold_1_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("MultiThreshold_2_out0") == DataType.BIPOLAR + assert model.get_tensor_datatype("MultiThreshold_3_out0") == DataType.BIPOLAR os.remove(export_onnx_path)