Skip to content
Snippets Groups Projects
Commit 6bae6d04 authored by webmanue's avatar webmanue
Browse files

use proximity to infer permutation

- sorting may not work as expected
- e.g. for small changes in first coordinate
parent fa1a6fc4
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@ RUN apt-get update && \
ninja-build \
python3 python3-pip \
python3-h5py \
python3-scipy \
gdb
RUN pip3 install \
......
......@@ -25,6 +25,7 @@ import argparse
import sys
from functools import partial
import vtk
import scipy.spatial
RESULT_GRID = vtk.vtkUnstructuredGrid()
REFERENCE_GRID = vtk.vtkUnstructuredGrid()
......@@ -39,6 +40,26 @@ def point_coordinates(
return list(grid.GetPoint(point_id))
def all_point_coordinates(
grid: vtk.vtkUnstructuredGrid,
) -> typing.List[typing.List[float]]:
"""
Returns a list of point coordinates.
"""
return list(map(partial(point_coordinates, grid), range(grid.GetNumberOfPoints())))
def point_index_permutation() -> typing.List[int]:
"""
Infers the result index of a point based on its coordinates.
"""
result = all_point_coordinates(RESULT_GRID)
reference = all_point_coordinates(REFERENCE_GRID)
tree = scipy.spatial.KDTree(result)
return list(tree.query(reference)[1])
def cell_connectivity(grid: vtk.vtkUnstructuredGrid, cell_id: int) -> typing.List[int]:
"""
Returns the connectivity for the cell with index `cell_id`.
......@@ -48,34 +69,48 @@ def cell_connectivity(grid: vtk.vtkUnstructuredGrid, cell_id: int) -> typing.Lis
return list(map(ids.GetId, range(ids.GetNumberOfIds())))
def all_point_coordinates(
def all_cell_connectivities(
grid: vtk.vtkUnstructuredGrid,
) -> typing.List[typing.List[float]]:
) -> typing.List[typing.List[int]]:
"""
Returns a list of point coordinates.
Returns the connectivity for all cells in `grid`.
"""
return list(map(partial(point_coordinates, grid), range(grid.GetNumberOfPoints())))
return list(map(partial(cell_connectivity, grid), range(grid.GetNumberOfCells())))
def all_cell_coordinates(
grid: vtk.vtkUnstructuredGrid,
) -> typing.List[typing.List[typing.List[float]]]:
def cell_index_permutation() -> typing.List[int]:
"""
Returns a list of point coordinates for every element.
Infers the result index of a cell based on its connectivity.
"""
connectivities = (
cell_connectivity(grid, i) for i in range(grid.GetCells().GetNumberOfCells())
point_permutation = point_index_permutation()
reference = map(
lambda c: tuple(map(point_permutation.__getitem__, c)),
all_cell_connectivities(REFERENCE_GRID),
)
result = {tuple(t): i for i, t in enumerate(all_cell_connectivities(RESULT_GRID))}
return list(map(result.__getitem__, reference))
T = typing.TypeVar("T")
def id_to_coordinates(i: int) -> typing.List[float]:
return point_coordinates(grid, i)
def connectivity_to_list_of_coordinates(
connectivity: typing.List[int],
) -> typing.List[typing.List[float]]:
return list(map(id_to_coordinates, connectivity))
def permute_point_data(data: typing.List[T]) -> typing.Iterable[T]:
"""
Permutes the items of `data` into the reference order inferred from the coordinates.
"""
permutation = point_index_permutation()
for i in permutation:
yield data[i]
return list(map(connectivity_to_list_of_coordinates, connectivities))
def permute_cell_data(data: typing.List[T]) -> typing.Iterable[T]:
"""
Permutes the items of `data` into the reference order inferred from the connectivities.
"""
permutation = cell_index_permutation()
for i in permutation:
yield data[i]
def all_point_data(
......@@ -150,22 +185,25 @@ class GridComparison(unittest.TestCase):
"""
The points of the grids have the same coordinates.
"""
for result_coordinates, reference_coordinates in zip(
sorted(all_point_coordinates(RESULT_GRID)),
sorted(all_point_coordinates(REFERENCE_GRID)),
for result, reference in zip(
permute_point_data(all_point_coordinates(RESULT_GRID)),
all_point_coordinates(REFERENCE_GRID),
):
for result_x, reference_x in zip(result_coordinates, reference_coordinates):
for result_x, reference_x in zip(result, reference):
self.assertAlmostEqual(result_x, reference_x)
def test_same_cell_coordinates(self):
def test_same_cell_connectivities(self):
"""
The cells of the grids have the vertices at the same coordinates.
The cells of the grids have the same vertex indices (up to permutation).
"""
for result_coordinates, reference_coordinates in zip(
sorted(all_cell_coordinates(RESULT_GRID)),
sorted(all_cell_coordinates(REFERENCE_GRID)),
permutation = point_index_permutation()
for result, reference in zip(
permute_cell_data(all_cell_connectivities(RESULT_GRID)),
all_cell_connectivities(REFERENCE_GRID),
):
for result_x, reference_x in zip(result_coordinates, reference_coordinates):
for result_x, reference_x in zip(
result, map(permutation.__getitem__, reference)
):
self.assertAlmostEqual(result_x, reference_x)
def test_same_cell_types(self):
......@@ -173,21 +211,10 @@ class GridComparison(unittest.TestCase):
The cells of the grid have the same types.
"""
for result, reference in zip(
sorted(
list(
zip(all_cell_coordinates(RESULT_GRID), all_cell_types(RESULT_GRID))
)
),
sorted(
list(
zip(
all_cell_coordinates(REFERENCE_GRID),
all_cell_types(REFERENCE_GRID),
)
)
),
permute_cell_data(all_cell_types(RESULT_GRID)),
all_cell_types(REFERENCE_GRID),
):
self.assertEqual(result[1], reference[1])
self.assertEqual(result, reference)
def test_same_number_of_point_data(self):
"""
......@@ -207,15 +234,10 @@ class GridComparison(unittest.TestCase):
sorted(all_point_data(REFERENCE_GRID), key=lambda x: x["name"]),
):
self.assertEqual(result["name"], reference["name"])
result_array = sorted(
list(zip(all_point_coordinates(RESULT_GRID), result["array"]))
)
reference_array = sorted(
list(zip(all_point_coordinates(REFERENCE_GRID), reference["array"]))
)
for result_item, reference_item in zip(result_array, reference_array):
for result_x, reference_x in zip(result_item[1], reference_item[1]):
for result_item, reference_item in zip(
permute_point_data(result["array"]), reference["array"]
):
for result_x, reference_x in zip(result_item, reference_item):
self.assertAlmostEqual(result_x, reference_x)
def test_same_number_of_cell_data(self):
......@@ -236,15 +258,10 @@ class GridComparison(unittest.TestCase):
sorted(all_cell_data(REFERENCE_GRID), key=lambda x: x["name"]),
):
self.assertEqual(result["name"], reference["name"])
result_array = sorted(
list(zip(all_cell_coordinates(RESULT_GRID), result["array"]))
)
reference_array = sorted(
list(zip(all_cell_coordinates(REFERENCE_GRID), reference["array"]))
)
for result_item, reference_item in zip(result_array, reference_array):
for result_x, reference_x in zip(result_item[1], reference_item[1]):
for result_item, reference_item in zip(
permute_cell_data(result["array"]), reference["array"]
):
for result_x, reference_x in zip(result_item, reference_item):
self.assertAlmostEqual(result_x, reference_x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment