Skip to content
Snippets Groups Projects
Commit 261cc105 authored by webmanue's avatar webmanue
Browse files

add setCoordinates to Mesh to avoid copy

parent 39f01861
No related branches found
No related tags found
No related merge requests found
Pipeline #79460 passed
......@@ -15,6 +15,7 @@
#pragma once
#include "ae108/cpppetsc/InvalidParametersException.h"
#include "ae108/cpppetsc/IteratorRange.h"
#include "ae108/cpppetsc/LocalElementIterator.h"
#include "ae108/cpppetsc/LocalElementView.h"
......@@ -83,6 +84,17 @@ public:
Mesh cloneWithDofs(const size_type dofPerVertex,
const size_type dofPerElement) const;
/**
* @brief Set the coordinates of the vertices of the mesh.
*
* @param coordinates A vector with dimension number of degrees of freedom per
* vertex.
*
* @throw InvalidParametersException The vector was not created from a mesh
* describing its layout.
*/
void setCoordinates(distributed<vector_type> coordinates);
/**
* @brief Makes it possible to iterate over (views of) the local elements.
*/
......@@ -483,6 +495,23 @@ template <class Policy> UniqueEntity<DM> Mesh<Policy>::createDM() {
return makeUniqueEntity<Policy>(dm);
}
template <class Policy>
void Mesh<Policy>::setCoordinates(distributed<vector_type> coordinates) {
const auto coordinateDM = [&]() {
auto dm = DM{};
Policy::handleError(VecGetDM(coordinates.unwrap().data(), &dm));
return dm;
}();
if (!coordinateDM) {
throw InvalidParametersException();
}
Policy::handleError(DMSetCoordinateDM(_mesh.get(), coordinateDM));
Policy::handleError(
DMSetCoordinates(_mesh.get(), coordinates.unwrap().data()));
}
template <class Policy>
typename Mesh<Policy>::size_type Mesh<Policy>::totalNumberOfVertices() const {
return _totalNumberOfVertices;
......
......@@ -15,6 +15,7 @@
#pragma once
#include "ae108/cpppetsc/InvalidParametersException.h"
#include "ae108/cpppetsc/Mesh_fwd.h"
#include "ae108/cpppetsc/ParallelComputePolicy_fwd.h"
#include "ae108/cpppetsc/SequentialComputePolicy_fwd.h"
......@@ -28,35 +29,32 @@ namespace cpppetsc {
/**
* @brief Writes a mesh to a Viewer.
*
* @throw InvalidParametersException if the coordinates of `mesh` are not set.
*/
template <class Policy>
void writeToViewer(const Mesh<Policy> &mesh,
const distributed<Vector<Policy>> &coordinates,
const Viewer<Policy> &viewer);
void writeToViewer(const Mesh<Policy> &mesh, const Viewer<Policy> &viewer);
/**
* @brief Writes a vector to a Viewer. The concrete format depends on the file
* extension (vtk/vtu).
* @brief Writes a vector to a Viewer.
*
* @throw InvalidParametersException if the `data` vector has no associated mesh
* or if the coordinates of this mesh are not set.
*/
template <class Policy>
void writeToViewer(const distributed<Vector<Policy>> &data,
const distributed<Vector<Policy>> &coordinates,
const Viewer<Policy> &viewer);
extern template void
writeToViewer<SequentialComputePolicy>(const Mesh<SequentialComputePolicy> &,
const Viewer<SequentialComputePolicy> &);
extern template void
writeToViewer<ParallelComputePolicy>(const Mesh<ParallelComputePolicy> &,
const Viewer<ParallelComputePolicy> &);
extern template void writeToViewer<SequentialComputePolicy>(
const Mesh<SequentialComputePolicy> &,
const distributed<Vector<SequentialComputePolicy>> &,
const Viewer<SequentialComputePolicy> &);
extern template void writeToViewer<ParallelComputePolicy>(
const Mesh<ParallelComputePolicy> &,
const distributed<Vector<ParallelComputePolicy>> &,
const Viewer<ParallelComputePolicy> &);
extern template void writeToViewer<SequentialComputePolicy>(
const distributed<Vector<SequentialComputePolicy>> &,
const distributed<Vector<SequentialComputePolicy>> &,
const Viewer<SequentialComputePolicy> &);
extern template void writeToViewer<ParallelComputePolicy>(
const distributed<Vector<ParallelComputePolicy>> &,
const distributed<Vector<ParallelComputePolicy>> &,
const Viewer<ParallelComputePolicy> &);
......@@ -68,71 +66,47 @@ extern template void writeToViewer<ParallelComputePolicy>(
namespace ae108 {
namespace cpppetsc {
namespace detail {
template <class Policy> void throwIfNoCoordinateDM(const DM dm) {
auto coordinateDM = DM{};
Policy::handleError(DMGetCoordinateDM(dm, &coordinateDM));
if (!coordinateDM) {
throw InvalidParametersException();
}
}
template <class Policy>
void writeToViewer(const Mesh<Policy> &mesh,
const distributed<Vector<Policy>> &coordinates,
const Viewer<Policy> &viewer) {
const auto coordinateDM = [&]() {
auto dm = DM{};
Policy::handleError(VecGetDM(coordinates.unwrap().data(), &dm));
assert(dm);
return UniqueEntity<DM>(dm, [](DM) {});
}();
// copy the DM to be able to set the coordinate DM, which is
// required by DMView
auto dm = [&]() {
auto dm = DM{};
Policy::handleError(DMClone(mesh.data(), &dm));
return makeUniqueEntity<Policy>(dm);
}();
Policy::handleError(DMSetCoordinateDM(dm.get(), coordinateDM.get()));
Policy::handleError(DMSetCoordinates(dm.get(), coordinates.unwrap().data()));
template <class Policy> void throwIfNoCoordinates(const DM dm) {
auto coordinates = Vec{};
Policy::handleError(DMGetCoordinates(dm, &coordinates));
if (!coordinates) {
throw InvalidParametersException();
}
}
} // namespace detail
Policy::handleError(DMView(dm.get(), viewer.data()));
template <class Policy>
void writeToViewer(const Mesh<Policy> &mesh, const Viewer<Policy> &viewer) {
detail::throwIfNoCoordinateDM<Policy>(mesh.data());
detail::throwIfNoCoordinates<Policy>(mesh.data());
Policy::handleError(DMView(mesh.data(), viewer.data()));
}
template <class Policy>
void writeToViewer(const distributed<Vector<Policy>> &data,
const distributed<Vector<Policy>> &coordinates,
const Viewer<Policy> &viewer) {
const auto getDM = [](const Vector<Policy> &x) {
auto dm = DM{};
Policy::handleError(VecGetDM(x.data(), &dm));
assert(dm);
return UniqueEntity<DM>(dm, [](DM) {});
};
// clone DM that defines the layout of the data vector
// to be able to set the coordinate DM
const auto dm = [&]() {
auto dm = DM{};
const auto reference = getDM(data).get();
Policy::handleError(DMClone(reference, &dm));
return makeUniqueEntity<Policy>(dm);
Policy::handleError(VecGetDM(data.unwrap().data(), &dm));
return dm;
}();
// configure the default section to be able to create
// a copy of the data vector
auto section = PetscSection{};
Policy::handleError(DMGetDefaultSection(getDM(data).get(), &section));
Policy::handleError(DMSetDefaultSection(dm.get(), section));
// set the coordinate DM which is required by VecView
const auto coordinateDM = getDM(coordinates);
Policy::handleError(DMSetCoordinateDM(dm.get(), coordinateDM.get()));
Policy::handleError(DMSetCoordinates(dm.get(), coordinates.unwrap().data()));
// copy the data to the vector with the correct DM
const auto vec = [&]() {
auto vec = Vec{};
Policy::handleError(DMCreateGlobalVector(dm.get(), &vec));
return makeUniqueEntity<Policy>(vec);
}();
Policy::handleError(VecCopy(data.unwrap().data(), vec.get()));
if (!dm) {
throw InvalidParametersException();
}
detail::throwIfNoCoordinateDM<Policy>(dm);
detail::throwIfNoCoordinates<Policy>(dm);
Policy::handleError(VecView(vec.get(), viewer.data()));
Policy::handleError(VecView(data.unwrap().data(), viewer.data()));
}
} // namespace cpppetsc
......
......@@ -20,20 +20,16 @@
namespace ae108 {
namespace cpppetsc {
template void
writeToViewer<SequentialComputePolicy>(const Mesh<SequentialComputePolicy> &,
const Viewer<SequentialComputePolicy> &);
template void
writeToViewer<ParallelComputePolicy>(const Mesh<ParallelComputePolicy> &,
const Viewer<ParallelComputePolicy> &);
template void writeToViewer<SequentialComputePolicy>(
const Mesh<SequentialComputePolicy> &,
const distributed<Vector<SequentialComputePolicy>> &,
const Viewer<SequentialComputePolicy> &);
template void writeToViewer<ParallelComputePolicy>(
const Mesh<ParallelComputePolicy> &,
const distributed<Vector<ParallelComputePolicy>> &,
const Viewer<ParallelComputePolicy> &);
template void writeToViewer<SequentialComputePolicy>(
const distributed<Vector<SequentialComputePolicy>> &,
const distributed<Vector<SequentialComputePolicy>> &,
const Viewer<SequentialComputePolicy> &);
template void writeToViewer<ParallelComputePolicy>(
const distributed<Vector<ParallelComputePolicy>> &,
const distributed<Vector<ParallelComputePolicy>> &,
const Viewer<ParallelComputePolicy> &);
......
......@@ -113,9 +113,9 @@ template <typename T> struct Mesh_Test : Test {
using Connectivity = std::vector<std::array<size_type, verticesPerElement>>;
const Connectivity connectivity = {{0, 2}, {1, 2}, {0, 2}};
const mesh_type mesh = mesh_type::fromConnectivity(
dimension, connectivity, totalNumberOfVertices, dofPerVertex,
dofPerElement);
mesh_type mesh = mesh_type::fromConnectivity(dimension, connectivity,
totalNumberOfVertices,
dofPerVertex, dofPerElement);
};
template <class Policy, int VertexDof, int ElementDof> struct TestCase {
......@@ -139,6 +139,28 @@ TYPED_TEST(Mesh_Test, from_connectivity_generates_correct_number_of_entities) {
Eq(this->totalNumberOfVertices));
}
TYPED_TEST(Mesh_Test, coordinates_can_be_set) {
using size_type = typename TestFixture::size_type;
using vector_type = typename TestFixture::vector_type;
const auto coordinateMesh =
this->mesh.cloneWithDofs(this->dimension, size_type{0});
auto coordinates = vector_type::fromGlobalMesh(coordinateMesh);
EXPECT_NO_THROW(this->mesh.setCoordinates(std::move(coordinates)));
}
TYPED_TEST(Mesh_Test, setting_coordinates_throws_if_not_associated_with_mesh) {
using size_type = typename TestFixture::size_type;
using vector_type = typename TestFixture::vector_type;
auto coordinates = tag<DistributedTag>(
vector_type(this->totalNumberOfVertices * this->dofPerVertex));
EXPECT_THROW(this->mesh.setCoordinates(std::move(coordinates)),
InvalidParametersException);
}
TYPED_TEST(Mesh_Test, cloned_mesh_has_correct_number_of_entities) {
using size_type = typename TestFixture::size_type;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment