From 261cc105bb14ff02cc518dc4c7cb88bf1ba4b86b Mon Sep 17 00:00:00 2001 From: Manuel Weberndorfer <manuel.weberndorfer@id.ethz.ch> Date: Wed, 25 Nov 2020 18:25:56 +0000 Subject: [PATCH] add setCoordinates to Mesh to avoid copy --- cpppetsc/src/include/ae108/cpppetsc/Mesh.h | 29 +++++ .../include/ae108/cpppetsc/writeToViewer.h | 112 +++++++----------- cpppetsc/src/writeToViewer.cc | 16 +-- cpppetsc/test/Mesh_Test.cc | 28 ++++- 4 files changed, 103 insertions(+), 82 deletions(-) diff --git a/cpppetsc/src/include/ae108/cpppetsc/Mesh.h b/cpppetsc/src/include/ae108/cpppetsc/Mesh.h index df317804..c177ecbc 100644 --- a/cpppetsc/src/include/ae108/cpppetsc/Mesh.h +++ b/cpppetsc/src/include/ae108/cpppetsc/Mesh.h @@ -15,6 +15,7 @@ #pragma once +#include "ae108/cpppetsc/InvalidParametersException.h" #include "ae108/cpppetsc/IteratorRange.h" #include "ae108/cpppetsc/LocalElementIterator.h" #include "ae108/cpppetsc/LocalElementView.h" @@ -83,6 +84,17 @@ public: Mesh cloneWithDofs(const size_type dofPerVertex, const size_type dofPerElement) const; + /** + * @brief Set the coordinates of the vertices of the mesh. + * + * @param coordinates A vector with dimension number of degrees of freedom per + * vertex. + * + * @throw InvalidParametersException The vector was not created from a mesh + * describing its layout. + */ + void setCoordinates(distributed<vector_type> coordinates); + /** * @brief Makes it possible to iterate over (views of) the local elements. */ @@ -483,6 +495,23 @@ template <class Policy> UniqueEntity<DM> Mesh<Policy>::createDM() { return makeUniqueEntity<Policy>(dm); } +template <class Policy> +void Mesh<Policy>::setCoordinates(distributed<vector_type> coordinates) { + const auto coordinateDM = [&]() { + auto dm = DM{}; + Policy::handleError(VecGetDM(coordinates.unwrap().data(), &dm)); + return dm; + }(); + + if (!coordinateDM) { + throw InvalidParametersException(); + } + + Policy::handleError(DMSetCoordinateDM(_mesh.get(), coordinateDM)); + Policy::handleError( + DMSetCoordinates(_mesh.get(), coordinates.unwrap().data())); +} + template <class Policy> typename Mesh<Policy>::size_type Mesh<Policy>::totalNumberOfVertices() const { return _totalNumberOfVertices; diff --git a/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h b/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h index 6ad9f0d1..10e4e82a 100644 --- a/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h +++ b/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h @@ -15,6 +15,7 @@ #pragma once +#include "ae108/cpppetsc/InvalidParametersException.h" #include "ae108/cpppetsc/Mesh_fwd.h" #include "ae108/cpppetsc/ParallelComputePolicy_fwd.h" #include "ae108/cpppetsc/SequentialComputePolicy_fwd.h" @@ -28,35 +29,32 @@ namespace cpppetsc { /** * @brief Writes a mesh to a Viewer. + * + * @throw InvalidParametersException if the coordinates of `mesh` are not set. */ template <class Policy> -void writeToViewer(const Mesh<Policy> &mesh, - const distributed<Vector<Policy>> &coordinates, - const Viewer<Policy> &viewer); +void writeToViewer(const Mesh<Policy> &mesh, const Viewer<Policy> &viewer); /** - * @brief Writes a vector to a Viewer. The concrete format depends on the file - * extension (vtk/vtu). + * @brief Writes a vector to a Viewer. + * + * @throw InvalidParametersException if the `data` vector has no associated mesh + * or if the coordinates of this mesh are not set. */ template <class Policy> void writeToViewer(const distributed<Vector<Policy>> &data, - const distributed<Vector<Policy>> &coordinates, const Viewer<Policy> &viewer); +extern template void +writeToViewer<SequentialComputePolicy>(const Mesh<SequentialComputePolicy> &, + const Viewer<SequentialComputePolicy> &); +extern template void +writeToViewer<ParallelComputePolicy>(const Mesh<ParallelComputePolicy> &, + const Viewer<ParallelComputePolicy> &); extern template void writeToViewer<SequentialComputePolicy>( - const Mesh<SequentialComputePolicy> &, - const distributed<Vector<SequentialComputePolicy>> &, - const Viewer<SequentialComputePolicy> &); -extern template void writeToViewer<ParallelComputePolicy>( - const Mesh<ParallelComputePolicy> &, - const distributed<Vector<ParallelComputePolicy>> &, - const Viewer<ParallelComputePolicy> &); -extern template void writeToViewer<SequentialComputePolicy>( - const distributed<Vector<SequentialComputePolicy>> &, const distributed<Vector<SequentialComputePolicy>> &, const Viewer<SequentialComputePolicy> &); extern template void writeToViewer<ParallelComputePolicy>( - const distributed<Vector<ParallelComputePolicy>> &, const distributed<Vector<ParallelComputePolicy>> &, const Viewer<ParallelComputePolicy> &); @@ -68,71 +66,47 @@ extern template void writeToViewer<ParallelComputePolicy>( namespace ae108 { namespace cpppetsc { +namespace detail { +template <class Policy> void throwIfNoCoordinateDM(const DM dm) { + auto coordinateDM = DM{}; + Policy::handleError(DMGetCoordinateDM(dm, &coordinateDM)); + if (!coordinateDM) { + throw InvalidParametersException(); + } +} -template <class Policy> -void writeToViewer(const Mesh<Policy> &mesh, - const distributed<Vector<Policy>> &coordinates, - const Viewer<Policy> &viewer) { - const auto coordinateDM = [&]() { - auto dm = DM{}; - Policy::handleError(VecGetDM(coordinates.unwrap().data(), &dm)); - assert(dm); - return UniqueEntity<DM>(dm, [](DM) {}); - }(); - - // copy the DM to be able to set the coordinate DM, which is - // required by DMView - auto dm = [&]() { - auto dm = DM{}; - Policy::handleError(DMClone(mesh.data(), &dm)); - return makeUniqueEntity<Policy>(dm); - }(); - Policy::handleError(DMSetCoordinateDM(dm.get(), coordinateDM.get())); - Policy::handleError(DMSetCoordinates(dm.get(), coordinates.unwrap().data())); +template <class Policy> void throwIfNoCoordinates(const DM dm) { + auto coordinates = Vec{}; + Policy::handleError(DMGetCoordinates(dm, &coordinates)); + if (!coordinates) { + throw InvalidParametersException(); + } +} +} // namespace detail - Policy::handleError(DMView(dm.get(), viewer.data())); +template <class Policy> +void writeToViewer(const Mesh<Policy> &mesh, const Viewer<Policy> &viewer) { + detail::throwIfNoCoordinateDM<Policy>(mesh.data()); + detail::throwIfNoCoordinates<Policy>(mesh.data()); + Policy::handleError(DMView(mesh.data(), viewer.data())); } template <class Policy> void writeToViewer(const distributed<Vector<Policy>> &data, - const distributed<Vector<Policy>> &coordinates, const Viewer<Policy> &viewer) { - const auto getDM = [](const Vector<Policy> &x) { - auto dm = DM{}; - Policy::handleError(VecGetDM(x.data(), &dm)); - assert(dm); - return UniqueEntity<DM>(dm, [](DM) {}); - }; - - // clone DM that defines the layout of the data vector - // to be able to set the coordinate DM const auto dm = [&]() { auto dm = DM{}; - const auto reference = getDM(data).get(); - Policy::handleError(DMClone(reference, &dm)); - return makeUniqueEntity<Policy>(dm); + Policy::handleError(VecGetDM(data.unwrap().data(), &dm)); + return dm; }(); - // configure the default section to be able to create - // a copy of the data vector - auto section = PetscSection{}; - Policy::handleError(DMGetDefaultSection(getDM(data).get(), §ion)); - Policy::handleError(DMSetDefaultSection(dm.get(), section)); - - // set the coordinate DM which is required by VecView - const auto coordinateDM = getDM(coordinates); - Policy::handleError(DMSetCoordinateDM(dm.get(), coordinateDM.get())); - Policy::handleError(DMSetCoordinates(dm.get(), coordinates.unwrap().data())); - - // copy the data to the vector with the correct DM - const auto vec = [&]() { - auto vec = Vec{}; - Policy::handleError(DMCreateGlobalVector(dm.get(), &vec)); - return makeUniqueEntity<Policy>(vec); - }(); - Policy::handleError(VecCopy(data.unwrap().data(), vec.get())); + if (!dm) { + throw InvalidParametersException(); + } + detail::throwIfNoCoordinateDM<Policy>(dm); + detail::throwIfNoCoordinates<Policy>(dm); - Policy::handleError(VecView(vec.get(), viewer.data())); + Policy::handleError(VecView(data.unwrap().data(), viewer.data())); } } // namespace cpppetsc diff --git a/cpppetsc/src/writeToViewer.cc b/cpppetsc/src/writeToViewer.cc index 4ba82311..54df94b4 100644 --- a/cpppetsc/src/writeToViewer.cc +++ b/cpppetsc/src/writeToViewer.cc @@ -20,20 +20,16 @@ namespace ae108 { namespace cpppetsc { +template void +writeToViewer<SequentialComputePolicy>(const Mesh<SequentialComputePolicy> &, + const Viewer<SequentialComputePolicy> &); +template void +writeToViewer<ParallelComputePolicy>(const Mesh<ParallelComputePolicy> &, + const Viewer<ParallelComputePolicy> &); template void writeToViewer<SequentialComputePolicy>( - const Mesh<SequentialComputePolicy> &, const distributed<Vector<SequentialComputePolicy>> &, const Viewer<SequentialComputePolicy> &); template void writeToViewer<ParallelComputePolicy>( - const Mesh<ParallelComputePolicy> &, - const distributed<Vector<ParallelComputePolicy>> &, - const Viewer<ParallelComputePolicy> &); -template void writeToViewer<SequentialComputePolicy>( - const distributed<Vector<SequentialComputePolicy>> &, - const distributed<Vector<SequentialComputePolicy>> &, - const Viewer<SequentialComputePolicy> &); -template void writeToViewer<ParallelComputePolicy>( - const distributed<Vector<ParallelComputePolicy>> &, const distributed<Vector<ParallelComputePolicy>> &, const Viewer<ParallelComputePolicy> &); diff --git a/cpppetsc/test/Mesh_Test.cc b/cpppetsc/test/Mesh_Test.cc index 06d4d646..b49925b3 100644 --- a/cpppetsc/test/Mesh_Test.cc +++ b/cpppetsc/test/Mesh_Test.cc @@ -113,9 +113,9 @@ template <typename T> struct Mesh_Test : Test { using Connectivity = std::vector<std::array<size_type, verticesPerElement>>; const Connectivity connectivity = {{0, 2}, {1, 2}, {0, 2}}; - const mesh_type mesh = mesh_type::fromConnectivity( - dimension, connectivity, totalNumberOfVertices, dofPerVertex, - dofPerElement); + mesh_type mesh = mesh_type::fromConnectivity(dimension, connectivity, + totalNumberOfVertices, + dofPerVertex, dofPerElement); }; template <class Policy, int VertexDof, int ElementDof> struct TestCase { @@ -139,6 +139,28 @@ TYPED_TEST(Mesh_Test, from_connectivity_generates_correct_number_of_entities) { Eq(this->totalNumberOfVertices)); } +TYPED_TEST(Mesh_Test, coordinates_can_be_set) { + using size_type = typename TestFixture::size_type; + using vector_type = typename TestFixture::vector_type; + + const auto coordinateMesh = + this->mesh.cloneWithDofs(this->dimension, size_type{0}); + auto coordinates = vector_type::fromGlobalMesh(coordinateMesh); + + EXPECT_NO_THROW(this->mesh.setCoordinates(std::move(coordinates))); +} + +TYPED_TEST(Mesh_Test, setting_coordinates_throws_if_not_associated_with_mesh) { + using size_type = typename TestFixture::size_type; + using vector_type = typename TestFixture::vector_type; + + auto coordinates = tag<DistributedTag>( + vector_type(this->totalNumberOfVertices * this->dofPerVertex)); + + EXPECT_THROW(this->mesh.setCoordinates(std::move(coordinates)), + InvalidParametersException); +} + TYPED_TEST(Mesh_Test, cloned_mesh_has_correct_number_of_entities) { using size_type = typename TestFixture::size_type; -- GitLab