From 261cc105bb14ff02cc518dc4c7cb88bf1ba4b86b Mon Sep 17 00:00:00 2001
From: Manuel Weberndorfer <manuel.weberndorfer@id.ethz.ch>
Date: Wed, 25 Nov 2020 18:25:56 +0000
Subject: [PATCH] add setCoordinates to Mesh to avoid copy

---
 cpppetsc/src/include/ae108/cpppetsc/Mesh.h    |  29 +++++
 .../include/ae108/cpppetsc/writeToViewer.h    | 112 +++++++-----------
 cpppetsc/src/writeToViewer.cc                 |  16 +--
 cpppetsc/test/Mesh_Test.cc                    |  28 ++++-
 4 files changed, 103 insertions(+), 82 deletions(-)

diff --git a/cpppetsc/src/include/ae108/cpppetsc/Mesh.h b/cpppetsc/src/include/ae108/cpppetsc/Mesh.h
index df317804..c177ecbc 100644
--- a/cpppetsc/src/include/ae108/cpppetsc/Mesh.h
+++ b/cpppetsc/src/include/ae108/cpppetsc/Mesh.h
@@ -15,6 +15,7 @@
 
 #pragma once
 
+#include "ae108/cpppetsc/InvalidParametersException.h"
 #include "ae108/cpppetsc/IteratorRange.h"
 #include "ae108/cpppetsc/LocalElementIterator.h"
 #include "ae108/cpppetsc/LocalElementView.h"
@@ -83,6 +84,17 @@ public:
   Mesh cloneWithDofs(const size_type dofPerVertex,
                      const size_type dofPerElement) const;
 
+  /**
+   * @brief Set the coordinates of the vertices of the mesh.
+   *
+   * @param coordinates A vector with dimension number of degrees of freedom per
+   * vertex.
+   *
+   * @throw InvalidParametersException The vector was not created from a mesh
+   * describing its layout.
+   */
+  void setCoordinates(distributed<vector_type> coordinates);
+
   /**
    * @brief Makes it possible to iterate over (views of) the local elements.
    */
@@ -483,6 +495,23 @@ template <class Policy> UniqueEntity<DM> Mesh<Policy>::createDM() {
   return makeUniqueEntity<Policy>(dm);
 }
 
+template <class Policy>
+void Mesh<Policy>::setCoordinates(distributed<vector_type> coordinates) {
+  const auto coordinateDM = [&]() {
+    auto dm = DM{};
+    Policy::handleError(VecGetDM(coordinates.unwrap().data(), &dm));
+    return dm;
+  }();
+
+  if (!coordinateDM) {
+    throw InvalidParametersException();
+  }
+
+  Policy::handleError(DMSetCoordinateDM(_mesh.get(), coordinateDM));
+  Policy::handleError(
+      DMSetCoordinates(_mesh.get(), coordinates.unwrap().data()));
+}
+
 template <class Policy>
 typename Mesh<Policy>::size_type Mesh<Policy>::totalNumberOfVertices() const {
   return _totalNumberOfVertices;
diff --git a/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h b/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h
index 6ad9f0d1..10e4e82a 100644
--- a/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h
+++ b/cpppetsc/src/include/ae108/cpppetsc/writeToViewer.h
@@ -15,6 +15,7 @@
 
 #pragma once
 
+#include "ae108/cpppetsc/InvalidParametersException.h"
 #include "ae108/cpppetsc/Mesh_fwd.h"
 #include "ae108/cpppetsc/ParallelComputePolicy_fwd.h"
 #include "ae108/cpppetsc/SequentialComputePolicy_fwd.h"
@@ -28,35 +29,32 @@ namespace cpppetsc {
 
 /**
  * @brief Writes a mesh to a Viewer.
+ *
+ * @throw InvalidParametersException if the coordinates of `mesh` are not set.
  */
 template <class Policy>
-void writeToViewer(const Mesh<Policy> &mesh,
-                   const distributed<Vector<Policy>> &coordinates,
-                   const Viewer<Policy> &viewer);
+void writeToViewer(const Mesh<Policy> &mesh, const Viewer<Policy> &viewer);
 
 /**
- * @brief Writes a vector to a Viewer. The concrete format depends on the file
- * extension (vtk/vtu).
+ * @brief Writes a vector to a Viewer.
+ *
+ * @throw InvalidParametersException if the `data` vector has no associated mesh
+ * or if the coordinates of this mesh are not set.
  */
 template <class Policy>
 void writeToViewer(const distributed<Vector<Policy>> &data,
-                   const distributed<Vector<Policy>> &coordinates,
                    const Viewer<Policy> &viewer);
 
+extern template void
+writeToViewer<SequentialComputePolicy>(const Mesh<SequentialComputePolicy> &,
+                                       const Viewer<SequentialComputePolicy> &);
+extern template void
+writeToViewer<ParallelComputePolicy>(const Mesh<ParallelComputePolicy> &,
+                                     const Viewer<ParallelComputePolicy> &);
 extern template void writeToViewer<SequentialComputePolicy>(
-    const Mesh<SequentialComputePolicy> &,
-    const distributed<Vector<SequentialComputePolicy>> &,
-    const Viewer<SequentialComputePolicy> &);
-extern template void writeToViewer<ParallelComputePolicy>(
-    const Mesh<ParallelComputePolicy> &,
-    const distributed<Vector<ParallelComputePolicy>> &,
-    const Viewer<ParallelComputePolicy> &);
-extern template void writeToViewer<SequentialComputePolicy>(
-    const distributed<Vector<SequentialComputePolicy>> &,
     const distributed<Vector<SequentialComputePolicy>> &,
     const Viewer<SequentialComputePolicy> &);
 extern template void writeToViewer<ParallelComputePolicy>(
-    const distributed<Vector<ParallelComputePolicy>> &,
     const distributed<Vector<ParallelComputePolicy>> &,
     const Viewer<ParallelComputePolicy> &);
 
@@ -68,71 +66,47 @@ extern template void writeToViewer<ParallelComputePolicy>(
 
 namespace ae108 {
 namespace cpppetsc {
+namespace detail {
+template <class Policy> void throwIfNoCoordinateDM(const DM dm) {
+  auto coordinateDM = DM{};
+  Policy::handleError(DMGetCoordinateDM(dm, &coordinateDM));
+  if (!coordinateDM) {
+    throw InvalidParametersException();
+  }
+}
 
-template <class Policy>
-void writeToViewer(const Mesh<Policy> &mesh,
-                   const distributed<Vector<Policy>> &coordinates,
-                   const Viewer<Policy> &viewer) {
-  const auto coordinateDM = [&]() {
-    auto dm = DM{};
-    Policy::handleError(VecGetDM(coordinates.unwrap().data(), &dm));
-    assert(dm);
-    return UniqueEntity<DM>(dm, [](DM) {});
-  }();
-
-  // copy the DM to be able to set the coordinate DM, which is
-  // required by DMView
-  auto dm = [&]() {
-    auto dm = DM{};
-    Policy::handleError(DMClone(mesh.data(), &dm));
-    return makeUniqueEntity<Policy>(dm);
-  }();
-  Policy::handleError(DMSetCoordinateDM(dm.get(), coordinateDM.get()));
-  Policy::handleError(DMSetCoordinates(dm.get(), coordinates.unwrap().data()));
+template <class Policy> void throwIfNoCoordinates(const DM dm) {
+  auto coordinates = Vec{};
+  Policy::handleError(DMGetCoordinates(dm, &coordinates));
+  if (!coordinates) {
+    throw InvalidParametersException();
+  }
+}
+} // namespace detail
 
-  Policy::handleError(DMView(dm.get(), viewer.data()));
+template <class Policy>
+void writeToViewer(const Mesh<Policy> &mesh, const Viewer<Policy> &viewer) {
+  detail::throwIfNoCoordinateDM<Policy>(mesh.data());
+  detail::throwIfNoCoordinates<Policy>(mesh.data());
+  Policy::handleError(DMView(mesh.data(), viewer.data()));
 }
 
 template <class Policy>
 void writeToViewer(const distributed<Vector<Policy>> &data,
-                   const distributed<Vector<Policy>> &coordinates,
                    const Viewer<Policy> &viewer) {
-  const auto getDM = [](const Vector<Policy> &x) {
-    auto dm = DM{};
-    Policy::handleError(VecGetDM(x.data(), &dm));
-    assert(dm);
-    return UniqueEntity<DM>(dm, [](DM) {});
-  };
-
-  // clone DM that defines the layout of the data vector
-  // to be able to set the coordinate DM
   const auto dm = [&]() {
     auto dm = DM{};
-    const auto reference = getDM(data).get();
-    Policy::handleError(DMClone(reference, &dm));
-    return makeUniqueEntity<Policy>(dm);
+    Policy::handleError(VecGetDM(data.unwrap().data(), &dm));
+    return dm;
   }();
 
-  // configure the default section to be able to create
-  // a copy of the data vector
-  auto section = PetscSection{};
-  Policy::handleError(DMGetDefaultSection(getDM(data).get(), &section));
-  Policy::handleError(DMSetDefaultSection(dm.get(), section));
-
-  // set the coordinate DM which is required by VecView
-  const auto coordinateDM = getDM(coordinates);
-  Policy::handleError(DMSetCoordinateDM(dm.get(), coordinateDM.get()));
-  Policy::handleError(DMSetCoordinates(dm.get(), coordinates.unwrap().data()));
-
-  // copy the data to the vector with the correct DM
-  const auto vec = [&]() {
-    auto vec = Vec{};
-    Policy::handleError(DMCreateGlobalVector(dm.get(), &vec));
-    return makeUniqueEntity<Policy>(vec);
-  }();
-  Policy::handleError(VecCopy(data.unwrap().data(), vec.get()));
+  if (!dm) {
+    throw InvalidParametersException();
+  }
+  detail::throwIfNoCoordinateDM<Policy>(dm);
+  detail::throwIfNoCoordinates<Policy>(dm);
 
-  Policy::handleError(VecView(vec.get(), viewer.data()));
+  Policy::handleError(VecView(data.unwrap().data(), viewer.data()));
 }
 
 } // namespace cpppetsc
diff --git a/cpppetsc/src/writeToViewer.cc b/cpppetsc/src/writeToViewer.cc
index 4ba82311..54df94b4 100644
--- a/cpppetsc/src/writeToViewer.cc
+++ b/cpppetsc/src/writeToViewer.cc
@@ -20,20 +20,16 @@
 namespace ae108 {
 namespace cpppetsc {
 
+template void
+writeToViewer<SequentialComputePolicy>(const Mesh<SequentialComputePolicy> &,
+                                       const Viewer<SequentialComputePolicy> &);
+template void
+writeToViewer<ParallelComputePolicy>(const Mesh<ParallelComputePolicy> &,
+                                     const Viewer<ParallelComputePolicy> &);
 template void writeToViewer<SequentialComputePolicy>(
-    const Mesh<SequentialComputePolicy> &,
     const distributed<Vector<SequentialComputePolicy>> &,
     const Viewer<SequentialComputePolicy> &);
 template void writeToViewer<ParallelComputePolicy>(
-    const Mesh<ParallelComputePolicy> &,
-    const distributed<Vector<ParallelComputePolicy>> &,
-    const Viewer<ParallelComputePolicy> &);
-template void writeToViewer<SequentialComputePolicy>(
-    const distributed<Vector<SequentialComputePolicy>> &,
-    const distributed<Vector<SequentialComputePolicy>> &,
-    const Viewer<SequentialComputePolicy> &);
-template void writeToViewer<ParallelComputePolicy>(
-    const distributed<Vector<ParallelComputePolicy>> &,
     const distributed<Vector<ParallelComputePolicy>> &,
     const Viewer<ParallelComputePolicy> &);
 
diff --git a/cpppetsc/test/Mesh_Test.cc b/cpppetsc/test/Mesh_Test.cc
index 06d4d646..b49925b3 100644
--- a/cpppetsc/test/Mesh_Test.cc
+++ b/cpppetsc/test/Mesh_Test.cc
@@ -113,9 +113,9 @@ template <typename T> struct Mesh_Test : Test {
   using Connectivity = std::vector<std::array<size_type, verticesPerElement>>;
   const Connectivity connectivity = {{0, 2}, {1, 2}, {0, 2}};
 
-  const mesh_type mesh = mesh_type::fromConnectivity(
-      dimension, connectivity, totalNumberOfVertices, dofPerVertex,
-      dofPerElement);
+  mesh_type mesh = mesh_type::fromConnectivity(dimension, connectivity,
+                                               totalNumberOfVertices,
+                                               dofPerVertex, dofPerElement);
 };
 
 template <class Policy, int VertexDof, int ElementDof> struct TestCase {
@@ -139,6 +139,28 @@ TYPED_TEST(Mesh_Test, from_connectivity_generates_correct_number_of_entities) {
               Eq(this->totalNumberOfVertices));
 }
 
+TYPED_TEST(Mesh_Test, coordinates_can_be_set) {
+  using size_type = typename TestFixture::size_type;
+  using vector_type = typename TestFixture::vector_type;
+
+  const auto coordinateMesh =
+      this->mesh.cloneWithDofs(this->dimension, size_type{0});
+  auto coordinates = vector_type::fromGlobalMesh(coordinateMesh);
+
+  EXPECT_NO_THROW(this->mesh.setCoordinates(std::move(coordinates)));
+}
+
+TYPED_TEST(Mesh_Test, setting_coordinates_throws_if_not_associated_with_mesh) {
+  using size_type = typename TestFixture::size_type;
+  using vector_type = typename TestFixture::vector_type;
+
+  auto coordinates = tag<DistributedTag>(
+      vector_type(this->totalNumberOfVertices * this->dofPerVertex));
+
+  EXPECT_THROW(this->mesh.setCoordinates(std::move(coordinates)),
+               InvalidParametersException);
+}
+
 TYPED_TEST(Mesh_Test, cloned_mesh_has_correct_number_of_entities) {
   using size_type = typename TestFixture::size_type;
 
-- 
GitLab