From 2e016d8973ede957978a5544aa78fc75ec37e565 Mon Sep 17 00:00:00 2001 From: Manuel Weberndorfer <manuel.weberndorfer@id.ethz.ch> Date: Tue, 26 Oct 2021 16:32:44 +0000 Subject: [PATCH] support setting/getting local size in Vector --- cpppetsc/src/include/ae108/cpppetsc/Vector.h | 45 ++++++++++++++++---- cpppetsc/test/Vector_Test.cc | 13 +++++- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/cpppetsc/src/include/ae108/cpppetsc/Vector.h b/cpppetsc/src/include/ae108/cpppetsc/Vector.h index 7f4837fe..85536949 100644 --- a/cpppetsc/src/include/ae108/cpppetsc/Vector.h +++ b/cpppetsc/src/include/ae108/cpppetsc/Vector.h @@ -53,11 +53,21 @@ public: */ explicit Vector(const size_type global_size); + using LocalRows = TaggedEntity<size_type, struct LocalRowsTag>; + using GlobalRows = TaggedEntity<size_type, struct GlobalRowsTag>; + + /** + * @brief Allocates a vector with the provided layout. + * + * @remark Does not initialize the vector. + */ + explicit Vector(const LocalRows localRows, GlobalRows globalRows); + /** - * @brief Allocates a vector of length global_size and initializes the vector + * @brief Allocates a vector of length globalRows and initializes the vector * to value. */ - explicit Vector(const size_type global_size, const value_type value); + explicit Vector(const size_type globalRows, const value_type value); /** * @brief Constructs a vector with the provided entries. @@ -194,6 +204,11 @@ public: */ size_type size() const; + /** + * @brief Returns the local size of the vector. + */ + size_type localSize() const; + /** * @brief Returns the local row indices in the format: * @@ -224,7 +239,7 @@ public: Vec data() const; private: - static Vec createVec(const size_type global_size); + static Vec createVec(const LocalRows localRows, const GlobalRows globalRows); UniqueEntity<Vec> _vec; }; @@ -261,22 +276,27 @@ namespace ae108 { namespace cpppetsc { template <class Policy> -Vec Vector<Policy>::createVec(const size_type global_size) { +Vec Vector<Policy>::createVec(const LocalRows localRows, + const GlobalRows globalRows) { Vec vec; Policy::handleError( - VecCreateMPI(Policy::communicator(), PETSC_DECIDE, global_size, &vec)); + VecCreateMPI(Policy::communicator(), localRows, globalRows, &vec)); return vec; } template <class Policy> -Vector<Policy>::Vector(const size_type global_size) - : _vec(makeUniqueEntity<Policy>(createVec(global_size))) { +Vector<Policy>::Vector(const size_type globalRows) + : Vector(LocalRows{PETSC_DECIDE}, GlobalRows{globalRows}) {} + +template <class Policy> +Vector<Policy>::Vector(const LocalRows localRows, const GlobalRows globalRows) + : _vec(makeUniqueEntity<Policy>(createVec(localRows, globalRows))) { Policy::handleError(VecSetFromOptions(_vec.get())); } template <class Policy> -Vector<Policy>::Vector(const size_type global_size, const value_type value) - : Vector(global_size) { +Vector<Policy>::Vector(const size_type globalRows, const value_type value) + : Vector(globalRows) { fill(value); } @@ -468,6 +488,13 @@ typename Vector<Policy>::size_type Vector<Policy>::size() const { return size; } +template <class Policy> +typename Vector<Policy>::size_type Vector<Policy>::localSize() const { + size_type size = 0; + Policy::handleError(VecGetLocalSize(_vec.get(), &size)); + return size; +} + template <class Policy> void Vector<Policy>::setZero() { Policy::handleError(VecZeroEntries(_vec.get())); } diff --git a/cpppetsc/test/Vector_Test.cc b/cpppetsc/test/Vector_Test.cc index 0be060c0..874b3350 100644 --- a/cpppetsc/test/Vector_Test.cc +++ b/cpppetsc/test/Vector_Test.cc @@ -45,11 +45,20 @@ template <class Policy> struct Vector_Test : Test { using Policies = Types<SequentialComputePolicy, ParallelComputePolicy>; TYPED_TEST_CASE(Vector_Test, Policies); -TYPED_TEST(Vector_Test, create_vector_of_length_1) { - typename TestFixture::vector_type vec(1); +TYPED_TEST(Vector_Test, vector_with_1_global_row_has_global_size_1) { + const typename TestFixture::vector_type vec(1); EXPECT_THAT(vec.size(), Eq(1)); } +TYPED_TEST(Vector_Test, vector_with_2_local_rows_has_local_size_2) { + using vector_type = typename TestFixture::vector_type; + + const typename TestFixture::vector_type vec( + typename vector_type::LocalRows{2}, + typename vector_type::GlobalRows{PETSC_DETERMINE}); + EXPECT_THAT(vec.localSize(), Eq(2)); +} + TYPED_TEST(Vector_Test, local_range_works) { typename TestFixture::vector_type vec(3); -- GitLab