From 967aef7d802c9ab09afbd5a8a97b6bb223115036 Mon Sep 17 00:00:00 2001
From: Manuel Weberndorfer <manuel.weberndorfer@id.ethz.ch>
Date: Tue, 1 Mar 2022 19:01:07 +0000
Subject: [PATCH] add script to compare xdmf files

- compares properties of unstructured grids
---
 .gitlab-ci.yml     |   2 +-
 tests/xdmf_diff.py | 254 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 255 insertions(+), 1 deletion(-)
 create mode 100755 tests/xdmf_diff.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index e3b825f0..502b8adb 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -93,7 +93,7 @@ check-script:
     image: $CI_REGISTRY_IMAGE/dev-real:$CI_COMMIT_REF_SLUG
     script:
         - find . -name "*.py" -print0 | xargs -0 python3 -m black --check
-        - find . -name "*.py" -print0 | xargs -0 python3 -m pylint
+        - find . -name "*.py" -print0 | xargs -0 python3 -m pylint --generated-members=vtk.*
         - find . -name "*.py" -print0 | xargs -0 python3 -m mypy --python-version 3.7 --ignore-missing
         - find . -name "*.py" -print0 | xargs -0 python3 -m doctest
 
diff --git a/tests/xdmf_diff.py b/tests/xdmf_diff.py
new file mode 100755
index 00000000..5fa1c14b
--- /dev/null
+++ b/tests/xdmf_diff.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python3
+
+# © 2022 ETH Zurich, Mechanics and Materials Lab
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Compares the contents of two XDMF files.
+"""
+
+import unittest
+import typing
+import pathlib
+import argparse
+import sys
+from functools import partial
+import vtk
+
+RESULT_GRID = vtk.vtkUnstructuredGrid()
+REFERENCE_GRID = vtk.vtkUnstructuredGrid()
+
+
+def point_coordinates(
+    grid: vtk.vtkUnstructuredGrid, point_id: int
+) -> typing.List[float]:
+    """
+    Returns the coordinates of the cell with index `point_id`.
+    """
+    return list(grid.GetPoint(point_id))
+
+
+def cell_connectivity(grid: vtk.vtkUnstructuredGrid, cell_id: int) -> typing.List[int]:
+    """
+    Returns the connectivity for the cell with index `cell_id`.
+    """
+    ids = vtk.vtkIdList()
+    grid.GetCells().GetCellAtId(cell_id, ids)
+    return list(map(ids.GetId, range(ids.GetNumberOfIds())))
+
+
+def all_point_coordinates(
+    grid: vtk.vtkUnstructuredGrid,
+) -> typing.List[typing.List[float]]:
+    """
+    Returns a list of point coordinates.
+    """
+    return list(map(partial(point_coordinates, grid), range(grid.GetNumberOfPoints())))
+
+
+def all_cell_coordinates(
+    grid: vtk.vtkUnstructuredGrid,
+) -> typing.List[typing.List[typing.List[float]]]:
+    """
+    Returns a list of point coordinates for every element.
+    """
+    connectivities = (
+        cell_connectivity(grid, i) for i in range(grid.GetCells().GetNumberOfCells())
+    )
+
+    def id_to_coordinates(i: int) -> typing.List[float]:
+        return point_coordinates(grid, i)
+
+    def connectivity_to_list_of_coordinates(
+        connectivity: typing.List[int],
+    ) -> typing.List[typing.List[float]]:
+        return list(map(id_to_coordinates, connectivity))
+
+    return list(map(connectivity_to_list_of_coordinates, connectivities))
+
+
+def all_point_data(
+    grid: vtk.vtkUnstructuredGrid,
+) -> typing.List[typing.Dict[str, typing.Any]]:
+    """
+    Returns a list of point data.
+    """
+    data = grid.GetPointData()
+    return [
+        {
+            "name": data.GetArrayName(i),
+            "array": [
+                list(data.GetArray(i).GetTuple(j))
+                for j in range(data.GetArray(i).GetNumberOfTuples())
+            ],
+        }
+        for i in range(data.GetNumberOfArrays())
+    ]
+
+
+def all_cell_data(
+    grid: vtk.vtkUnstructuredGrid,
+) -> typing.List[typing.Dict[str, typing.Any]]:
+    """
+    Returns a list of cell data.
+    """
+    data = grid.GetCellData()
+    return [
+        {
+            "name": data.GetArrayName(i),
+            "array": [
+                list(data.GetArray(i).GetTuple(j))
+                for j in range(data.GetArray(i).GetNumberOfTuples())
+            ],
+        }
+        for i in range(data.GetNumberOfArrays())
+    ]
+
+
+class GridComparison(unittest.TestCase):
+    """
+    Compares two grids.
+    """
+
+    def test_same_number_of_points(self):
+        """
+        The grids have the same number of points.
+        """
+        self.assertEqual(
+            RESULT_GRID.GetNumberOfPoints(), REFERENCE_GRID.GetNumberOfPoints()
+        )
+
+    def test_same_number_of_cells(self):
+        """
+        The grids have the same number of cells.
+        """
+        self.assertEqual(
+            RESULT_GRID.GetNumberOfCells(), REFERENCE_GRID.GetNumberOfCells()
+        )
+
+    def test_same_point_coordinates(self):
+        """
+        The points of the grids have the same coordinates.
+        """
+        for result_coordinates, reference_coordinates in zip(
+            sorted(all_point_coordinates(RESULT_GRID)),
+            sorted(all_point_coordinates(REFERENCE_GRID)),
+        ):
+            for result_x, reference_x in zip(result_coordinates, reference_coordinates):
+                self.assertAlmostEqual(result_x, reference_x)
+
+    def test_same_cell_coordinates(self):
+        """
+        The cells of the grids have the vertices at the same coordinates.
+        """
+        for result_coordinates, reference_coordinates in zip(
+            sorted(all_cell_coordinates(RESULT_GRID)),
+            sorted(all_cell_coordinates(REFERENCE_GRID)),
+        ):
+            for result_x, reference_x in zip(result_coordinates, reference_coordinates):
+                self.assertAlmostEqual(result_x, reference_x)
+
+    def test_same_number_of_point_data(self):
+        """
+        The points of the grid have the same number of data.
+        """
+        self.assertEqual(
+            len(all_point_data(RESULT_GRID)), len(all_point_data(REFERENCE_GRID))
+        )
+
+    def test_same_point_data_arrays(self):
+        """
+        The points of the grid have the same point data at the same coordinates.
+        """
+
+        for result, reference in zip(
+            sorted(all_point_data(RESULT_GRID), key=lambda x: x["name"]),
+            sorted(all_point_data(REFERENCE_GRID), key=lambda x: x["name"]),
+        ):
+            self.assertEqual(result["name"], reference["name"])
+            result_array = sorted(
+                list(zip(all_point_coordinates(RESULT_GRID), result["array"]))
+            )
+            reference_array = sorted(
+                list(zip(all_point_coordinates(REFERENCE_GRID), reference["array"]))
+            )
+
+            for result_item, reference_item in zip(result_array, reference_array):
+                for result_x, reference_x in zip(result_item[1], reference_item[1]):
+                    self.assertAlmostEqual(result_x, reference_x)
+
+    def test_same_number_of_cell_data(self):
+        """
+        The cells of the grid have the same number of data.
+        """
+        self.assertEqual(
+            len(all_cell_data(RESULT_GRID)), len(all_cell_data(REFERENCE_GRID))
+        )
+
+    def test_same_cell_data_arrays(self):
+        """
+        The cells of the grid have the same cell data.
+        """
+
+        for result, reference in zip(
+            sorted(all_cell_data(RESULT_GRID), key=lambda x: x["name"]),
+            sorted(all_cell_data(REFERENCE_GRID), key=lambda x: x["name"]),
+        ):
+            self.assertEqual(result["name"], reference["name"])
+            result_array = sorted(
+                list(zip(all_cell_coordinates(RESULT_GRID), result["array"]))
+            )
+            reference_array = sorted(
+                list(zip(all_cell_coordinates(REFERENCE_GRID), reference["array"]))
+            )
+
+            for result_item, reference_item in zip(result_array, reference_array):
+                for result_x, reference_x in zip(result_item[1], reference_item[1]):
+                    self.assertAlmostEqual(result_x, reference_x)
+
+
+def read_xdmf(path: pathlib.Path) -> vtk.vtkUnstructuredGrid:
+    """
+    Reads the XDMF file at path and returns the result.
+    """
+    reader = vtk.vtkXdmfReader()
+    reader.SetFileName(str(path))
+    reader.Update()
+    return reader.GetOutput()
+
+
+def main():
+    """
+    Reads the file names from the command options and runs the tests.
+    """
+    parser = argparse.ArgumentParser(
+        description="Compares the contents of two XDMF files."
+    )
+    parser.add_argument("result", type=pathlib.Path, help="the result *.xdmf file")
+    parser.add_argument(
+        "reference", type=pathlib.Path, help="the reference *.xdmf file"
+    )
+    arguments, other = parser.parse_known_args()
+
+    global RESULT_GRID  # pylint: disable=W0603
+    RESULT_GRID = read_xdmf(arguments.result)
+
+    global REFERENCE_GRID  # pylint: disable=W0603
+    REFERENCE_GRID = read_xdmf(arguments.reference)
+
+    return unittest.main(argv=sys.argv[:1] + other)
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab