diff --git a/tests/run.py b/tests/run.py index c4b3e92e36540c28a06690436e09078073bdf8ad..693eda6f2ff8271a3afbb414513340ab78a8aa51 100755 --- a/tests/run.py +++ b/tests/run.py @@ -29,6 +29,7 @@ import subprocess import tempfile import typing import unittest +import re ROOT_DIRECTORY = pathlib.Path(__file__).parent.parent @@ -241,6 +242,39 @@ def float_or_nan(value: str) -> float: return math.nan +def extract_numbers(text: typing.Iterable[str]) -> typing.Iterable[float]: + """ + Extracts the numbers (separated by ',', '(', ')', ']', '[', '=', or whitespace) from text + and returns an iterator. + + >>> list(extract_numbers(["1.0, 2.0"])) + [1.0, 2.0] + >>> list(extract_numbers(["x=1.0"])) + [1.0] + >>> list(extract_numbers(["[1.0]"])) + [1.0] + >>> list(extract_numbers(["1.e1"])) + [10.0] + >>> list(extract_numbers(["NaN"])) + [] + >>> list(extract_numbers(["1.0 (2.0)"])) + [1.0, 2.0] + >>> list(extract_numbers(["1.0 a 2.0"])) + [1.0, 2.0] + >>> list(extract_numbers(["1.0 a 2.0", "b 3.0"])) + [1.0, 2.0, 3.0] + """ + return filter( + lambda x: not math.isnan(x), + map( + float_or_nan, + itertools.chain.from_iterable( + map(lambda x: re.split(r"[,\(\)\[\]=\s]", x), text) + ), + ), + ) + + def diff_numeric_string( value: str, reference: str, case: unittest.TestCase = unittest.TestCase() ): @@ -250,10 +284,15 @@ def diff_numeric_string( >>> diff_numeric_string("a", "a") >>> diff_numeric_string("1", "1") + >>> diff_numeric_string("1 2", "1 2") >>> diff_numeric_string("1", "2") Traceback (most recent call last): ... AssertionError: 1.0 != 2.0 within 7 places (1.0 difference) + >>> diff_numeric_string("a 1", "b 2") + Traceback (most recent call last): + ... + AssertionError: 1.0 != 2.0 within 7 places (1.0 difference) >>> diff_numeric_string("a", "1") Traceback (most recent call last): ... @@ -261,17 +300,14 @@ def diff_numeric_string( >>> diff_numeric_string("1", "a") Traceback (most recent call last): ... - AssertionError: False is not true + AssertionError: 1.0 != nan within 7 places (nan difference) """ for float_value, float_reference in itertools.zip_longest( - map(float_or_nan, iter(value.splitlines())), - map(float_or_nan, iter(reference.splitlines())), + extract_numbers(iter(value.splitlines())), + extract_numbers(iter(reference.splitlines())), fillvalue=math.nan, ): - if math.isnan(float_reference): - case.assertTrue(math.isnan(float_value)) - else: - case.assertAlmostEqual(float_value, float_reference) + case.assertAlmostEqual(float_value, float_reference) def load_tests(