From 05c17ac6f29f1286896252ff795b9046a6863470 Mon Sep 17 00:00:00 2001
From: Manuel Weberndorfer <manuel.weberndorfer@id.ethz.ch>
Date: Fri, 16 Jul 2021 13:15:02 +0000
Subject: [PATCH] compare multiple numbers per line for 'numeric'

---
 tests/run.py | 50 +++++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 43 insertions(+), 7 deletions(-)

diff --git a/tests/run.py b/tests/run.py
index c4b3e92e..693eda6f 100755
--- a/tests/run.py
+++ b/tests/run.py
@@ -29,6 +29,7 @@ import subprocess
 import tempfile
 import typing
 import unittest
+import re
 
 
 ROOT_DIRECTORY = pathlib.Path(__file__).parent.parent
@@ -241,6 +242,39 @@ def float_or_nan(value: str) -> float:
         return math.nan
 
 
+def extract_numbers(text: typing.Iterable[str]) -> typing.Iterable[float]:
+    """
+    Extracts the numbers (separated by ',', '(', ')', ']', '[', '=', or whitespace) from text
+    and returns an iterator.
+
+    >>> list(extract_numbers(["1.0, 2.0"]))
+    [1.0, 2.0]
+    >>> list(extract_numbers(["x=1.0"]))
+    [1.0]
+    >>> list(extract_numbers(["[1.0]"]))
+    [1.0]
+    >>> list(extract_numbers(["1.e1"]))
+    [10.0]
+    >>> list(extract_numbers(["NaN"]))
+    []
+    >>> list(extract_numbers(["1.0 (2.0)"]))
+    [1.0, 2.0]
+    >>> list(extract_numbers(["1.0 a 2.0"]))
+    [1.0, 2.0]
+    >>> list(extract_numbers(["1.0 a 2.0", "b 3.0"]))
+    [1.0, 2.0, 3.0]
+    """
+    return filter(
+        lambda x: not math.isnan(x),
+        map(
+            float_or_nan,
+            itertools.chain.from_iterable(
+                map(lambda x: re.split(r"[,\(\)\[\]=\s]", x), text)
+            ),
+        ),
+    )
+
+
 def diff_numeric_string(
     value: str, reference: str, case: unittest.TestCase = unittest.TestCase()
 ):
@@ -250,10 +284,15 @@ def diff_numeric_string(
 
     >>> diff_numeric_string("a", "a")
     >>> diff_numeric_string("1", "1")
+    >>> diff_numeric_string("1 2", "1 2")
     >>> diff_numeric_string("1", "2")
     Traceback (most recent call last):
     ...
     AssertionError: 1.0 != 2.0 within 7 places (1.0 difference)
+    >>> diff_numeric_string("a 1", "b 2")
+    Traceback (most recent call last):
+    ...
+    AssertionError: 1.0 != 2.0 within 7 places (1.0 difference)
     >>> diff_numeric_string("a", "1")
     Traceback (most recent call last):
     ...
@@ -261,17 +300,14 @@ def diff_numeric_string(
     >>> diff_numeric_string("1", "a")
     Traceback (most recent call last):
     ...
-    AssertionError: False is not true
+    AssertionError: 1.0 != nan within 7 places (nan difference)
     """
     for float_value, float_reference in itertools.zip_longest(
-        map(float_or_nan, iter(value.splitlines())),
-        map(float_or_nan, iter(reference.splitlines())),
+        extract_numbers(iter(value.splitlines())),
+        extract_numbers(iter(reference.splitlines())),
         fillvalue=math.nan,
     ):
-        if math.isnan(float_reference):
-            case.assertTrue(math.isnan(float_value))
-        else:
-            case.assertAlmostEqual(float_value, float_reference)
+        case.assertAlmostEqual(float_value, float_reference)
 
 
 def load_tests(
-- 
GitLab