diff --git a/assembly/test/AssemblerGroup_Test.cc b/assembly/test/AssemblerGroup_Test.cc index 93d404c343e1295e1f3d41ee2c162873493f826b..f664ff61d4e834f56c90e59c52427d0086940ba2 100644 --- a/assembly/test/AssemblerGroup_Test.cc +++ b/assembly/test/AssemblerGroup_Test.cc @@ -19,6 +19,7 @@ #include "ae108/assembly/FeaturePlugin.h" #include "ae108/assembly/FeaturePlugins.h" #include "ae108/assembly/test/Element.h" +#include "ae108/cpppetsc/MeshDataProvider.h" #include <array> #include <gmock/gmock.h> @@ -117,8 +118,10 @@ TEST_F(AssemblerGroup_Test, accessing_elements_in_plugin_works) { const auto values = std::array<double, 2>({.7, .9}); - group.get<0>().emplaceElement(AssemblerWithoutPlugin::ElementView{&mesh, 0}); - group.get<1>().emplaceElement(AssemblerWithoutPlugin::ElementView{&mesh, 0}); + group.get<0>().emplaceElement(AssemblerWithoutPlugin::ElementView{ + createDataProviderFromMesh(&mesh), 0}); + group.get<1>().emplaceElement(AssemblerWithoutPlugin::ElementView{ + createDataProviderFromMesh(&mesh), 0}); EXPECT_CALL(group.get<0>().meshElements().begin()->instance(), computeEnergy(_, _)) diff --git a/cpppetsc/src/CMakeLists.txt b/cpppetsc/src/CMakeLists.txt index 96b388bca96b58d736a465594ecea340c53fd33b..c12d7bc2208f0c039c033a4bc3dee466ecd995fc 100644 --- a/cpppetsc/src/CMakeLists.txt +++ b/cpppetsc/src/CMakeLists.txt @@ -29,6 +29,7 @@ add_library(${PROJECT_NAME}-cpppetsc ./Matrix_fwd.cc ./Mesh.cc ./MeshBoundaryCondition.cc + ./MeshDataProvider.cc ./Mesh_fwd.cc ./NonlinearSolver.cc ./NonlinearSolverDivergedException.cc diff --git a/cpppetsc/src/MeshDataProvider.cc b/cpppetsc/src/MeshDataProvider.cc new file mode 100644 index 0000000000000000000000000000000000000000..b8546e327c37b17e16924d7c3ec7e197c1739470 --- /dev/null +++ b/cpppetsc/src/MeshDataProvider.cc @@ -0,0 +1,15 @@ +// © 2021 ETH Zurich, Mechanics and Materials Lab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ae108/cpppetsc/MeshDataProvider.h" diff --git a/cpppetsc/src/include/ae108/cpppetsc/LocalElementView.h b/cpppetsc/src/include/ae108/cpppetsc/LocalElementView.h index e548592ad7c87e8925fdd5fcefd230551c09f44a..2f7ae867f65cf3faed137754c355490b6f6ef8dd 100644 --- a/cpppetsc/src/include/ae108/cpppetsc/LocalElementView.h +++ b/cpppetsc/src/include/ae108/cpppetsc/LocalElementView.h @@ -1,4 +1,4 @@ -// © 2020 ETH Zurich, Mechanics and Materials Lab +// © 2020, 2021 ETH Zurich, Mechanics and Materials Lab // © 2020 California Institute of Technology // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +15,7 @@ #pragma once +#include "ae108/cpppetsc/MeshDataProvider.h" #include "ae108/cpppetsc/TaggedVector.h" #include <vector> @@ -30,13 +31,13 @@ public: using vector_type = typename mesh_type::vector_type; using matrix_type = typename mesh_type::matrix_type; - explicit LocalElementView(); + using DataProvider = MeshDataProvider<mesh_type>; /** - * @param mesh A valid pointer to a mesh_type instance. + * @param data Provides mesh information to the view. * @param elementPointIndex A valid element point index. */ - explicit LocalElementView(const mesh_type *const mesh, + explicit LocalElementView(DataProvider data, const size_type elementPointIndex); /** @@ -105,14 +106,9 @@ public: void addElementMatrix(const std::vector<value_type> &data, matrix_type *const matrix) const; - friend bool operator==(const LocalElementView &lhs, - const LocalElementView &rhs) { - return lhs._mesh == rhs._mesh && lhs._entityIndex == rhs._entityIndex; - } - private: - const mesh_type *_mesh = nullptr; - size_type _entityIndex = 0; + DataProvider _data; + size_type _entityIndex; }; } // namespace cpppetsc } // namespace ae108 @@ -121,85 +117,85 @@ private: * implementations *******************************************************************/ +#include <cassert> + namespace ae108 { namespace cpppetsc { -template <class IteratorType> -LocalElementView<IteratorType>::LocalElementView() = default; - template <class IteratorType> LocalElementView<IteratorType>::LocalElementView( - const mesh_type *const mesh, const size_type elementPointIndex) - : _mesh(mesh), _entityIndex(elementPointIndex) {} + DataProvider data, const size_type elementPointIndex) + : _data(std::move(data)), _entityIndex(elementPointIndex) {} template <class IteratorType> LocalElementView<IteratorType> LocalElementView<IteratorType>::forClonedMesh( const mesh_type *const mesh) const { - return LocalElementView{mesh, _entityIndex}; + assert(mesh); + return LocalElementView{createDataProviderFromMesh(mesh), _entityIndex}; } template <class IteratorType> std::vector<typename LocalElementView<IteratorType>::size_type> LocalElementView<IteratorType>::vertexIndices() const { - return _mesh->vertices(_entityIndex); + return _data.vertices(_entityIndex); } template <class IteratorType> typename LocalElementView<IteratorType>::size_type LocalElementView<IteratorType>::index() const { - return _mesh->elementPointIndexToGlobalIndex(_entityIndex); + return _data.elementPointIndexToGlobalIndex(_entityIndex); } template <class IteratorType> typename LocalElementView<IteratorType>::size_type LocalElementView<IteratorType>::numberOfVertices() const { - return _mesh->numberOfVertices(_entityIndex); + return _data.numberOfVertices(_entityIndex); } template <class IteratorType> typename LocalElementView<IteratorType>::size_type LocalElementView<IteratorType>::numberOfDofs() const { - return _mesh->numberOfDofs(_entityIndex); + return _data.numberOfDofs(_entityIndex); } template <class IteratorType> std::pair<typename LocalElementView<IteratorType>::size_type, typename LocalElementView<IteratorType>::size_type> LocalElementView<IteratorType>::localDofLineRange() const { - return _mesh->localDofLineRange(_entityIndex); + return _data.localDofLineRange(_entityIndex); } template <class IteratorType> std::pair<typename LocalElementView<IteratorType>::size_type, typename LocalElementView<IteratorType>::size_type> LocalElementView<IteratorType>::globalDofLineRange() const { - return _mesh->globalDofLineRange(_entityIndex); + return _data.globalDofLineRange(_entityIndex); } template <class IteratorType> void LocalElementView<IteratorType>::copyElementData( const local<vector_type> &vector, std::vector<value_type> *data) const { - _mesh->copyEntityData(_entityIndex, vector, data); + _data.copyEntityData(_entityIndex, vector, data); } template <class IteratorType> void LocalElementView<IteratorType>::addElementData( const std::vector<value_type> &data, local<vector_type> *const vector) const { - _mesh->addEntityData(_entityIndex, data, vector); + _data.addEntityData(_entityIndex, data, vector); } template <class IteratorType> void LocalElementView<IteratorType>::setElementData( const std::vector<value_type> &data, local<vector_type> *const vector) const { - _mesh->setEntityData(_entityIndex, data, vector); + _data.setEntityData(_entityIndex, data, vector); } template <class IteratorType> void LocalElementView<IteratorType>::addElementMatrix( const std::vector<value_type> &data, matrix_type *const matrix) const { - _mesh->addEntityMatrix(_entityIndex, data, matrix); + _data.addEntityMatrix(_entityIndex, data, matrix); } } // namespace cpppetsc } // namespace ae108 \ No newline at end of file diff --git a/cpppetsc/src/include/ae108/cpppetsc/LocalVertexView.h b/cpppetsc/src/include/ae108/cpppetsc/LocalVertexView.h index 7331ad83f57c5f9986a2f7011aa01e97e72cebcb..38a5147c3e91df80ec232e6601525033fa1d010a 100644 --- a/cpppetsc/src/include/ae108/cpppetsc/LocalVertexView.h +++ b/cpppetsc/src/include/ae108/cpppetsc/LocalVertexView.h @@ -15,6 +15,7 @@ #pragma once +#include "ae108/cpppetsc/MeshDataProvider.h" #include "ae108/cpppetsc/TaggedVector.h" #include <utility> #include <vector> @@ -31,14 +32,13 @@ public: using vector_type = typename mesh_type::vector_type; using matrix_type = typename mesh_type::matrix_type; - explicit LocalVertexView(); + using DataProvider = MeshDataProvider<mesh_type>; /** - * @param mesh A valid pointer to a mesh_type instance. + * @param data Provides mesh information to the view. * @param vertexPointIndex A valid vertex point index. */ - explicit LocalVertexView(const mesh_type *const mesh, - const size_type vertexPointIndex); + explicit LocalVertexView(DataProvider data, const size_type vertexPointIndex); /** * @brief Converts the vertex view to a vertex view for a cloned mesh. @@ -98,14 +98,9 @@ public: void addVertexMatrix(const std::vector<value_type> &data, matrix_type *const matrix) const; - friend bool operator==(const LocalVertexView &lhs, - const LocalVertexView &rhs) { - return lhs._mesh == rhs._mesh && lhs._entityIndex == rhs._entityIndex; - } - private: - const mesh_type *_mesh = nullptr; - size_type _entityIndex = 0; + DataProvider _data; + size_type _entityIndex; }; } // namespace cpppetsc } // namespace ae108 @@ -114,72 +109,72 @@ private: * implementations *******************************************************************/ +#include <cassert> + namespace ae108 { namespace cpppetsc { template <class IteratorType> -LocalVertexView<IteratorType>::LocalVertexView() = default; - -template <class IteratorType> -LocalVertexView<IteratorType>::LocalVertexView(const mesh_type *const mesh, +LocalVertexView<IteratorType>::LocalVertexView(DataProvider data, const size_type vertexPointIndex) - : _mesh(mesh), _entityIndex(vertexPointIndex) {} + : _data(std::move(data)), _entityIndex(vertexPointIndex) {} template <class IteratorType> LocalVertexView<IteratorType> LocalVertexView<IteratorType>::forClonedMesh( const mesh_type *const mesh) const { - return LocalVertexView{mesh, _entityIndex}; + assert(mesh); + return LocalVertexView{createDataProviderFromMesh(mesh), _entityIndex}; } template <class IteratorType> typename LocalVertexView<IteratorType>::size_type LocalVertexView<IteratorType>::index() const { - return _mesh->vertexPointIndexToGlobalIndex(_entityIndex); + return _data.vertexPointIndexToGlobalIndex(_entityIndex); } template <class IteratorType> typename LocalVertexView<IteratorType>::size_type LocalVertexView<IteratorType>::numberOfDofs() const { - return _mesh->numberOfDofs(_entityIndex); + return _data.numberOfDofs(_entityIndex); } template <class IteratorType> std::pair<typename LocalVertexView<IteratorType>::size_type, typename LocalVertexView<IteratorType>::size_type> LocalVertexView<IteratorType>::localDofLineRange() const { - return _mesh->localDofLineRange(_entityIndex); + return _data.localDofLineRange(_entityIndex); } template <class IteratorType> std::pair<typename LocalVertexView<IteratorType>::size_type, typename LocalVertexView<IteratorType>::size_type> LocalVertexView<IteratorType>::globalDofLineRange() const { - return _mesh->globalDofLineRange(_entityIndex); + return _data.globalDofLineRange(_entityIndex); } template <class IteratorType> void LocalVertexView<IteratorType>::copyVertexData( const local<vector_type> &vector, std::vector<value_type> *data) const { - return _mesh->copyEntityData(_entityIndex, vector, data); + return _data.copyEntityData(_entityIndex, vector, data); } template <class IteratorType> void LocalVertexView<IteratorType>::addVertexData( const std::vector<value_type> &data, local<vector_type> *const vector) const { - return _mesh->addEntityData(_entityIndex, data, vector); + return _data.addEntityData(_entityIndex, data, vector); } template <class IteratorType> void LocalVertexView<IteratorType>::setVertexData( const std::vector<value_type> &data, local<vector_type> *const vector) const { - return _mesh->setEntityData(_entityIndex, data, vector); + return _data.setEntityData(_entityIndex, data, vector); } template <class IteratorType> void LocalVertexView<IteratorType>::addVertexMatrix( const std::vector<value_type> &data, matrix_type *const matrix) const { - return _mesh->addEntityMatrix(_entityIndex, data, matrix); + return _data.addEntityMatrix(_entityIndex, data, matrix); } } // namespace cpppetsc } // namespace ae108 \ No newline at end of file diff --git a/cpppetsc/src/include/ae108/cpppetsc/Mesh.h b/cpppetsc/src/include/ae108/cpppetsc/Mesh.h index 5dbecaac0af2e0419fd51fb998af0d28ba18225a..92ad5349da6d72e29785314e4181082b6576e2b3 100644 --- a/cpppetsc/src/include/ae108/cpppetsc/Mesh.h +++ b/cpppetsc/src/include/ae108/cpppetsc/Mesh.h @@ -43,10 +43,9 @@ namespace ae108 { namespace cpppetsc { template <class Policy> class Mesh { - friend LocalElementView<Mesh>; - friend LocalVertexView<Mesh>; - public: + using policy_type = Policy; + using size_type = PetscInt; using value_type = PetscScalar; using vector_type = Vector<Policy>; @@ -218,92 +217,6 @@ private: static void distributeMesh(Mesh *const); /** - * @brief Convert an internal element index to a 'canonical' element index. - */ - size_type elementPointIndexToGlobalIndex(const size_type pointIndex) const; - - /** - * @brief Convert an internal vertex index to a 'canonical' vertex index. - */ - size_type vertexPointIndexToGlobalIndex(const size_type pointIndex) const; - - /** - * @brief Returns the number of vertices of the element. - */ - size_type numberOfVertices(const size_type elementPointIndex) const; - - /** - * @brief Returns the vertex indices for the element. - */ - std::vector<size_type> vertices(const size_type elementPointIndex) const; - - /** - * @brief Returns the global line range that corresponds to the degrees of - * freedom of the vertex or element. - */ - std::pair<size_type, size_type> - localDofLineRange(const size_type entityPointIndex) const; - - /** - * @brief Returns the global line range that corresponds to the degrees of - * freedom of the vertex or element. - * - * @remark Returns negative line numbers (-begin - 1, -end - 1) if the point - * data is not owned locally (e.g. for ghost points). In particular, first >= - * end in this case. - */ - std::pair<size_type, size_type> - globalDofLineRange(const size_type entityPointIndex) const; - - size_type numberOfDofs(const size_type entityPointIndex) const; - - /** - * @brief Copies data corresponding to the specified entity from the vector - * to the data buffer. - * - * @param entityPointIndex Specifies the entity. - * @param localVector A mesh-local vector. - * @param data A valid pointer to a buffer object. - */ - void copyEntityData(const size_type entityPointIndex, - const local<vector_type> &localVector, - std::vector<value_type> *const data) const; - - /** - * @brief Add data corresponding to the specified entity from the data buffer - * to the vector. - * - * @param entityPointIndex Specifies the entity. - * @param data The data buffer. - * @param localVector A valid pointer to a mesh-local vector. - */ - void addEntityData(const size_type entityPointIndex, - const std::vector<value_type> &data, - local<vector_type> *const localVector) const; - - /** - * @brief Set data corresponding to the specified entity from the data buffer - * in the vector. - * - * @param entityPointIndex Specifies the entity. - * @param data The data buffer. - * @param localVector A valid pointer to a mesh-local vector. - */ - void setEntityData(const size_type entityPointIndex, - const std::vector<value_type> &data, - local<vector_type> *const localVector) const; - - /** - * @brief Add data corresponding to the specified entity from the data buffer - * to the matrix. - * - * @param entityPointIndex Specifies the entity. - * @param data The data buffer. - * @param matrix A valid pointer to a matrix. - */ - void addEntityMatrix(const size_type entityPointIndex, - const std::vector<value_type> &data, - matrix_type *const matrix) const; /** * @brief Returns the layout of a global vector associated with this mesh. @@ -354,6 +267,7 @@ extern template class Mesh<ParallelComputePolicy>; *******************************************************************/ #include "ae108/cpppetsc/Matrix.h" +#include "ae108/cpppetsc/MeshDataProvider.h" #include "ae108/cpppetsc/Vector.h" namespace ae108 { @@ -365,8 +279,9 @@ template <class Policy> auto Mesh<Policy>::localElements() const { Policy::handleError(DMPlexGetHeightStratum(_mesh.get(), 0, &start, &stop)); namespace rv = ranges::cpp20::views; - return rv::iota(start, stop) | rv::transform([&](const size_type id) { - return LocalElementView<Mesh>(this, id); + const auto provider = createDataProviderFromMesh(this); + return rv::iota(start, stop) | rv::transform([provider](const size_type id) { + return LocalElementView<Mesh>(provider, id); }); } @@ -376,8 +291,9 @@ template <class Policy> auto Mesh<Policy>::localVertices() const { Policy::handleError(DMPlexGetDepthStratum(_mesh.get(), 0, &start, &stop)); namespace rv = ranges::cpp20::views; - return rv::iota(start, stop) | rv::transform([&](const size_type id) { - return LocalVertexView<Mesh>(this, id); + const auto provider = createDataProviderFromMesh(this); + return rv::iota(start, stop) | rv::transform([provider](const size_type id) { + return LocalVertexView<Mesh>(provider, id); }); } @@ -756,127 +672,6 @@ Mesh<Policy>::fromCanonicalOrder(const distributed<vector_type> &vector) const { return result; } -template <class Policy> -typename Mesh<Policy>::size_type -Mesh<Policy>::elementPointIndexToGlobalIndex(const size_type pointIndex) const { - auto migration = PetscSF(); - Policy::handleError(DMPlexGetMigrationSF(_mesh.get(), &migration)); - if (migration) { - const PetscSFNode *nodes = nullptr; - Policy::handleError( - PetscSFGetGraph(migration, nullptr, nullptr, nullptr, &nodes)); - return nodes[pointIndex].index; - } - - return pointIndex; -} - -template <class Policy> -typename Mesh<Policy>::size_type -Mesh<Policy>::vertexPointIndexToGlobalIndex(const size_type pointIndex) const { - return elementPointIndexToGlobalIndex(pointIndex) - _totalNumberOfElements; -} - -template <class Policy> -typename Mesh<Policy>::size_type -Mesh<Policy>::numberOfVertices(const size_type elementPointIndex) const { - size_type size = 0; - Policy::handleError(DMPlexGetConeSize(_mesh.get(), elementPointIndex, &size)); - return size; -} - -template <class Policy> -std::vector<typename Mesh<Policy>::size_type> -Mesh<Policy>::vertices(const size_type elementPointIndex) const { - const auto size = numberOfVertices(elementPointIndex); - const size_type *cone = nullptr; - Policy::handleError(DMPlexGetCone(_mesh.get(), elementPointIndex, &cone)); - std::vector<size_type> output(size); - std::transform(cone, cone + size, output.begin(), - [this](const size_type entityPointIndex) { - return vertexPointIndexToGlobalIndex(entityPointIndex); - }); - return output; -} - -template <class Policy> -std::pair<typename Mesh<Policy>::size_type, typename Mesh<Policy>::size_type> -Mesh<Policy>::localDofLineRange(const size_type entityPointIndex) const { - auto output = std::pair<size_type, size_type>(); - Policy::handleError(DMPlexGetPointLocal(_mesh.get(), entityPointIndex, - &output.first, &output.second)); - return output; -} - -template <class Policy> -std::pair<typename Mesh<Policy>::size_type, typename Mesh<Policy>::size_type> -Mesh<Policy>::globalDofLineRange(const size_type entityPointIndex) const { - auto output = std::pair<size_type, size_type>(); - Policy::handleError(DMPlexGetPointGlobal(_mesh.get(), entityPointIndex, - &output.first, &output.second)); - return output; -} - -template <class Policy> -typename Mesh<Policy>::size_type -Mesh<Policy>::numberOfDofs(const size_type entityPointIndex) const { - auto output = size_type{0}; - auto section = PetscSection(); - Policy::handleError(DMGetSection(_mesh.get(), §ion)); - - Policy::handleError(PetscSectionGetDof(section, entityPointIndex, &output)); - - return output; -} - -template <class Policy> -void Mesh<Policy>::copyEntityData(const size_type entityPointIndex, - const local<vector_type> &localVector, - std::vector<value_type> *const data) const { - assertCorrectBaseMesh(localVector.unwrap()); - auto size = size_type{0}; - Policy::handleError(DMPlexVecGetClosure(_mesh.get(), nullptr, - localVector.unwrap().data(), - entityPointIndex, &size, nullptr)); - data->resize(size); - if (size == 0) - return; - - auto dataPtr = data->data(); - Policy::handleError(DMPlexVecGetClosure(_mesh.get(), nullptr, - localVector.unwrap().data(), - entityPointIndex, &size, &dataPtr)); -} - -template <class Policy> -void Mesh<Policy>::addEntityData(const size_type entityPointIndex, - const std::vector<value_type> &data, - local<vector_type> *const localVector) const { - assertCorrectBaseMesh(localVector->unwrap()); - Policy::handleError( - DMPlexVecSetClosure(_mesh.get(), nullptr, localVector->unwrap().data(), - entityPointIndex, data.data(), ADD_VALUES)); -} - -template <class Policy> -void Mesh<Policy>::setEntityData(const size_type entityPointIndex, - const std::vector<value_type> &data, - local<vector_type> *const localVector) const { - assertCorrectBaseMesh(localVector->unwrap()); - Policy::handleError( - DMPlexVecSetClosure(_mesh.get(), nullptr, localVector->unwrap().data(), - entityPointIndex, data.data(), INSERT_VALUES)); -} - -template <class Policy> -void Mesh<Policy>::addEntityMatrix(const size_type entityPointIndex, - const std::vector<value_type> &data, - matrix_type *const matrix) const { - Policy::handleError(DMPlexMatSetClosure(_mesh.get(), nullptr, nullptr, - matrix->data(), entityPointIndex, - data.data(), ADD_VALUES)); -} - template <class Policy> void Mesh<Policy>::assertCorrectBaseMesh(const vector_type &vector) const { #ifndef NDEBUG diff --git a/cpppetsc/src/include/ae108/cpppetsc/MeshDataProvider.h b/cpppetsc/src/include/ae108/cpppetsc/MeshDataProvider.h new file mode 100644 index 0000000000000000000000000000000000000000..6e12ca9ebb22489e4a8fdfa656ab3a977796590c --- /dev/null +++ b/cpppetsc/src/include/ae108/cpppetsc/MeshDataProvider.h @@ -0,0 +1,329 @@ +// © 2020, 2021 ETH Zurich, Mechanics and Materials Lab +// © 2020 California Institute of Technology +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "ae108/cpppetsc/SharedEntity.h" +#include "ae108/cpppetsc/TaggedVector.h" +#include <petscdm.h> +#include <utility> +#include <vector> + +namespace ae108 { +namespace cpppetsc { + +/** + * brief Provides the necessary data for local entity views. + */ +template <class MeshType_> class MeshDataProvider { +public: + using mesh_type = MeshType_; + + using policy_type = typename mesh_type::policy_type; + using size_type = typename mesh_type::size_type; + using value_type = typename mesh_type::value_type; + using vector_type = typename mesh_type::vector_type; + using matrix_type = typename mesh_type::matrix_type; + + explicit MeshDataProvider(SharedEntity<DM, policy_type> mesh, + const size_type totalNumberOfElements); + + /** + * @brief Convert an internal element index to a 'canonical' element index. + */ + size_type elementPointIndexToGlobalIndex(const size_type pointIndex) const; + + /** + * @brief Convert an internal vertex index to a 'canonical' vertex index. + */ + size_type vertexPointIndexToGlobalIndex(const size_type pointIndex) const; + + /** + * @brief Returns the number of vertices of the element. + */ + size_type numberOfVertices(const size_type elementPointIndex) const; + + /** + * @brief Returns the vertex indices for the element. + */ + std::vector<size_type> vertices(const size_type elementPointIndex) const; + + /** + * @brief Returns the global line range that corresponds to the degrees of + * freedom of the vertex or element. + */ + std::pair<size_type, size_type> + localDofLineRange(const size_type entityPointIndex) const; + + /** + * @brief Returns the global line range that corresponds to the degrees of + * freedom of the vertex or element. + * + * @remark Returns negative line numbers (-begin - 1, -end - 1) if the point + * data is not owned locally (e.g. for ghost points). In particular, first >= + * end in this case. + */ + std::pair<size_type, size_type> + globalDofLineRange(const size_type entityPointIndex) const; + + size_type numberOfDofs(const size_type entityPointIndex) const; + + /** + * @brief Copies data corresponding to the specified entity from the vector + * to the data buffer. + * + * @param entityPointIndex Specifies the entity. + * @param localVector A mesh-local vector. + * @param data A valid pointer to a buffer object. + */ + void copyEntityData(const size_type entityPointIndex, + const local<vector_type> &localVector, + std::vector<value_type> *const data) const; + + /** + * @brief Add data corresponding to the specified entity from the data buffer + * to the vector. + * + * @param entityPointIndex Specifies the entity. + * @param data The data buffer. + * @param localVector A valid pointer to a mesh-local vector. + */ + void addEntityData(const size_type entityPointIndex, + const std::vector<value_type> &data, + local<vector_type> *const localVector) const; + + /** + * @brief Set data corresponding to the specified entity from the data buffer + * in the vector. + * + * @param entityPointIndex Specifies the entity. + * @param data The data buffer. + * @param localVector A valid pointer to a mesh-local vector. + */ + void setEntityData(const size_type entityPointIndex, + const std::vector<value_type> &data, + local<vector_type> *const localVector) const; + + /** + * @brief Add data corresponding to the specified entity from the data buffer + * to the matrix. + * + * @param entityPointIndex Specifies the entity. + * @param data The data buffer. + * @param matrix A valid pointer to a matrix. + */ + void addEntityMatrix(const size_type entityPointIndex, + const std::vector<value_type> &data, + matrix_type *const matrix) const; + + /** + * @brief Asserts that the vector has been created with the associated mesh. + */ + void assertCorrectBaseMesh(const vector_type &vector) const; + +private: + SharedEntity<DM, policy_type> _mesh; + size_type _totalNumberOfElements; +}; + +/** + * @brief Creates a data provider from the given mesh. + * @param mesh Valid nonzero pointer. + */ +template <class MeshType> +MeshDataProvider<MeshType> +createDataProviderFromMesh(const MeshType *const mesh); + +} // namespace cpppetsc +} // namespace ae108 + +/******************************************************************** + * implementations + *******************************************************************/ + +#include <algorithm> +#include <cassert> +#include <petscdmplex.h> +#include <petscsf.h> + +namespace ae108 { +namespace cpppetsc { + +template <class MeshType_> +MeshDataProvider<MeshType_>::MeshDataProvider( + SharedEntity<DM, policy_type> mesh, const size_type totalNumberOfElements) + : _mesh(std::move(mesh)), _totalNumberOfElements(totalNumberOfElements) {} + +template <class MeshType_> +typename MeshDataProvider<MeshType_>::size_type +MeshDataProvider<MeshType_>::elementPointIndexToGlobalIndex( + const size_type pointIndex) const { + auto migration = PetscSF(); + policy_type::handleError(DMPlexGetMigrationSF(_mesh.get(), &migration)); + if (migration) { + const PetscSFNode *nodes = nullptr; + policy_type::handleError( + PetscSFGetGraph(migration, nullptr, nullptr, nullptr, &nodes)); + return nodes[pointIndex].index; + } + + return pointIndex; +} + +template <class MeshType_> +typename MeshDataProvider<MeshType_>::size_type +MeshDataProvider<MeshType_>::vertexPointIndexToGlobalIndex( + const size_type pointIndex) const { + return elementPointIndexToGlobalIndex(pointIndex) - _totalNumberOfElements; +} + +template <class MeshType_> +typename MeshDataProvider<MeshType_>::size_type +MeshDataProvider<MeshType_>::numberOfVertices( + const size_type elementPointIndex) const { + size_type size = 0; + policy_type::handleError( + DMPlexGetConeSize(_mesh.get(), elementPointIndex, &size)); + return size; +} + +template <class MeshType_> +std::vector<typename MeshDataProvider<MeshType_>::size_type> +MeshDataProvider<MeshType_>::vertices(const size_type elementPointIndex) const { + const auto size = numberOfVertices(elementPointIndex); + const size_type *cone = nullptr; + policy_type::handleError( + DMPlexGetCone(_mesh.get(), elementPointIndex, &cone)); + std::vector<size_type> output(size); + std::transform(cone, cone + size, output.begin(), + [this](const size_type entityPointIndex) { + return vertexPointIndexToGlobalIndex(entityPointIndex); + }); + return output; +} + +template <class MeshType_> +std::pair<typename MeshDataProvider<MeshType_>::size_type, + typename MeshDataProvider<MeshType_>::size_type> +MeshDataProvider<MeshType_>::localDofLineRange( + const size_type entityPointIndex) const { + auto output = std::pair<size_type, size_type>(); + policy_type::handleError(DMPlexGetPointLocal(_mesh.get(), entityPointIndex, + &output.first, &output.second)); + return output; +} + +template <class MeshType_> +std::pair<typename MeshDataProvider<MeshType_>::size_type, + typename MeshDataProvider<MeshType_>::size_type> +MeshDataProvider<MeshType_>::globalDofLineRange( + const size_type entityPointIndex) const { + auto output = std::pair<size_type, size_type>(); + policy_type::handleError(DMPlexGetPointGlobal(_mesh.get(), entityPointIndex, + &output.first, &output.second)); + return output; +} + +template <class MeshType_> +typename MeshDataProvider<MeshType_>::size_type +MeshDataProvider<MeshType_>::numberOfDofs( + const size_type entityPointIndex) const { + auto output = size_type{0}; + auto section = PetscSection(); + policy_type::handleError(DMGetSection(_mesh.get(), §ion)); + + policy_type::handleError( + PetscSectionGetDof(section, entityPointIndex, &output)); + + return output; +} + +template <class MeshType_> +void MeshDataProvider<MeshType_>::copyEntityData( + const size_type entityPointIndex, const local<vector_type> &localVector, + std::vector<value_type> *const data) const { + assertCorrectBaseMesh(localVector.unwrap()); + auto size = size_type{0}; + policy_type::handleError( + DMPlexVecGetClosure(_mesh.get(), nullptr, localVector.unwrap().data(), + entityPointIndex, &size, nullptr)); + data->resize(size); + if (size == 0) + return; + + auto dataPtr = data->data(); + policy_type::handleError( + DMPlexVecGetClosure(_mesh.get(), nullptr, localVector.unwrap().data(), + entityPointIndex, &size, &dataPtr)); +} + +template <class MeshType_> +void MeshDataProvider<MeshType_>::addEntityData( + const size_type entityPointIndex, const std::vector<value_type> &data, + local<vector_type> *const localVector) const { + assertCorrectBaseMesh(localVector->unwrap()); + policy_type::handleError( + DMPlexVecSetClosure(_mesh.get(), nullptr, localVector->unwrap().data(), + entityPointIndex, data.data(), ADD_VALUES)); +} + +template <class MeshType_> +void MeshDataProvider<MeshType_>::setEntityData( + const size_type entityPointIndex, const std::vector<value_type> &data, + local<vector_type> *const localVector) const { + assertCorrectBaseMesh(localVector->unwrap()); + policy_type::handleError( + DMPlexVecSetClosure(_mesh.get(), nullptr, localVector->unwrap().data(), + entityPointIndex, data.data(), INSERT_VALUES)); +} + +template <class MeshType_> +void MeshDataProvider<MeshType_>::addEntityMatrix( + const size_type entityPointIndex, const std::vector<value_type> &data, + matrix_type *const matrix) const { + policy_type::handleError(DMPlexMatSetClosure(_mesh.get(), nullptr, nullptr, + matrix->data(), entityPointIndex, + data.data(), ADD_VALUES)); +} + +template <class MeshType_> +void MeshDataProvider<MeshType_>::assertCorrectBaseMesh( + const vector_type &vector) const { +#ifndef NDEBUG + const auto extractDM = [](const vector_type &vector) { + auto dm = DM{}; + policy_type::handleError(VecGetDM(vector.data(), &dm)); + return dm; + }; + const auto dm = extractDM(vector); + assert((!dm || dm == _mesh.get()) && + "The vector was created using a different mesh."); +#else + static_cast<void>(vector); +#endif +} + +template <class MeshType> +MeshDataProvider<MeshType> +createDataProviderFromMesh(const MeshType *const mesh) { + assert(mesh); + return MeshDataProvider<MeshType>{ + makeSharedEntity<typename MeshType::policy_type>(mesh->data(), + OwnershipType::Share), + mesh->totalNumberOfElements(), + }; +} +} // namespace cpppetsc +} // namespace ae108 \ No newline at end of file diff --git a/cpppetsc/test/Mesh_Test.cc b/cpppetsc/test/Mesh_Test.cc index 7cae66a7455eeb6b2e8ff8ad86fcf648a251d552..b633c5148653dfb3a743b8ef6d9b49aa99f3fe96 100644 --- a/cpppetsc/test/Mesh_Test.cc +++ b/cpppetsc/test/Mesh_Test.cc @@ -493,11 +493,12 @@ TYPED_TEST(Mesh_Test, copying_to_global_vector_works) { TYPED_TEST(Mesh_Test, begin_element_iterator_starts_at_index_0) { using size_type = typename TestFixture::size_type; - auto iterator = this->mesh.localElements().begin(); + const auto range = this->mesh.localElements(); + auto iterator = range.begin(); auto result = std::numeric_limits<size_type>::max(); - if (iterator != this->mesh.localElements().end()) { + if (iterator != range.end()) { result = (*iterator).index(); } ASSERT_THAT(MPI_Allreduce(MPI_IN_PLACE, &result, 1, MPI_INT, MPI_MIN, @@ -509,11 +510,12 @@ TYPED_TEST(Mesh_Test, begin_element_iterator_starts_at_index_0) { TYPED_TEST(Mesh_Test, end_element_iterator_is_past_final_index) { using size_type = typename TestFixture::size_type; - auto iterator = this->mesh.localElements().end(); + const auto range = this->mesh.localElements(); + auto iterator = range.end(); auto result = std::numeric_limits<size_type>::min(); - if (iterator != this->mesh.localElements().begin()) { + if (iterator != range.begin()) { result = (*(--iterator)).index(); } ASSERT_THAT(MPI_Allreduce(MPI_IN_PLACE, &result, 1, MPI_INT, MPI_MAX, @@ -525,11 +527,12 @@ TYPED_TEST(Mesh_Test, end_element_iterator_is_past_final_index) { TYPED_TEST(Mesh_Test, begin_vertex_iterator_starts_at_index_0) { using size_type = typename TestFixture::size_type; - auto iterator = this->mesh.localVertices().begin(); + const auto range = this->mesh.localVertices(); + auto iterator = range.begin(); auto result = std::numeric_limits<size_type>::max(); - if (iterator != this->mesh.localVertices().end()) { + if (iterator != range.end()) { result = (*iterator).index(); } ASSERT_THAT(MPI_Allreduce(MPI_IN_PLACE, &result, 1, MPI_INT, MPI_MIN, @@ -541,11 +544,12 @@ TYPED_TEST(Mesh_Test, begin_vertex_iterator_starts_at_index_0) { TYPED_TEST(Mesh_Test, end_vertex_iterator_is_past_final_index) { using size_type = typename TestFixture::size_type; - auto iterator = this->mesh.localVertices().end(); + const auto range = this->mesh.localVertices(); + auto iterator = range.end(); auto result = std::numeric_limits<size_type>::min(); - if (iterator != this->mesh.localVertices().begin()) { + if (iterator != range.begin()) { result = (*(--iterator)).index(); } ASSERT_THAT(MPI_Allreduce(MPI_IN_PLACE, &result, 1, MPI_INT, MPI_MAX, @@ -918,6 +922,21 @@ TYPED_TEST(Mesh_Test, adding_vertex_0_to_matrix_works) { } } +TYPED_TEST(Mesh_Test, element_views_are_valid_after_moving_mesh) { + using mesh_type = typename TestFixture::mesh_type; + + auto mesh = mesh_type::fromConnectivity( + this->topologicalDimension, this->coordinateDimension, this->connectivity, + this->totalNumberOfVertices, this->dofPerVertex, this->dofPerElement); + const auto range = mesh.localElements(); + const auto movedMesh = std::move(mesh); + + auto begin = range.begin(); + for (auto &&element : movedMesh.localElements()) { + EXPECT_THAT((*begin++).index(), Eq(element.index())); + } +} + template <typename T> struct Mesh_IgnoreVertexTest : Test { using mesh_type = Mesh<T>;