From 2e016d8973ede957978a5544aa78fc75ec37e565 Mon Sep 17 00:00:00 2001
From: Manuel Weberndorfer <manuel.weberndorfer@id.ethz.ch>
Date: Tue, 26 Oct 2021 16:32:44 +0000
Subject: [PATCH] support setting/getting local size in Vector

---
 cpppetsc/src/include/ae108/cpppetsc/Vector.h | 45 ++++++++++++++++----
 cpppetsc/test/Vector_Test.cc                 | 13 +++++-
 2 files changed, 47 insertions(+), 11 deletions(-)

diff --git a/cpppetsc/src/include/ae108/cpppetsc/Vector.h b/cpppetsc/src/include/ae108/cpppetsc/Vector.h
index 7f4837fe..85536949 100644
--- a/cpppetsc/src/include/ae108/cpppetsc/Vector.h
+++ b/cpppetsc/src/include/ae108/cpppetsc/Vector.h
@@ -53,11 +53,21 @@ public:
    */
   explicit Vector(const size_type global_size);
 
+  using LocalRows = TaggedEntity<size_type, struct LocalRowsTag>;
+  using GlobalRows = TaggedEntity<size_type, struct GlobalRowsTag>;
+
+  /**
+   * @brief Allocates a vector with the provided layout.
+   *
+   * @remark Does not initialize the vector.
+   */
+  explicit Vector(const LocalRows localRows, GlobalRows globalRows);
+
   /**
-   * @brief Allocates a vector of length global_size and initializes the vector
+   * @brief Allocates a vector of length globalRows and initializes the vector
    * to value.
    */
-  explicit Vector(const size_type global_size, const value_type value);
+  explicit Vector(const size_type globalRows, const value_type value);
 
   /**
    * @brief Constructs a vector with the provided entries.
@@ -194,6 +204,11 @@ public:
    */
   size_type size() const;
 
+  /**
+   * @brief Returns the local size of the vector.
+   */
+  size_type localSize() const;
+
   /**
    * @brief Returns the local row indices in the format:
    *
@@ -224,7 +239,7 @@ public:
   Vec data() const;
 
 private:
-  static Vec createVec(const size_type global_size);
+  static Vec createVec(const LocalRows localRows, const GlobalRows globalRows);
 
   UniqueEntity<Vec> _vec;
 };
@@ -261,22 +276,27 @@ namespace ae108 {
 namespace cpppetsc {
 
 template <class Policy>
-Vec Vector<Policy>::createVec(const size_type global_size) {
+Vec Vector<Policy>::createVec(const LocalRows localRows,
+                              const GlobalRows globalRows) {
   Vec vec;
   Policy::handleError(
-      VecCreateMPI(Policy::communicator(), PETSC_DECIDE, global_size, &vec));
+      VecCreateMPI(Policy::communicator(), localRows, globalRows, &vec));
   return vec;
 }
 
 template <class Policy>
-Vector<Policy>::Vector(const size_type global_size)
-    : _vec(makeUniqueEntity<Policy>(createVec(global_size))) {
+Vector<Policy>::Vector(const size_type globalRows)
+    : Vector(LocalRows{PETSC_DECIDE}, GlobalRows{globalRows}) {}
+
+template <class Policy>
+Vector<Policy>::Vector(const LocalRows localRows, const GlobalRows globalRows)
+    : _vec(makeUniqueEntity<Policy>(createVec(localRows, globalRows))) {
   Policy::handleError(VecSetFromOptions(_vec.get()));
 }
 
 template <class Policy>
-Vector<Policy>::Vector(const size_type global_size, const value_type value)
-    : Vector(global_size) {
+Vector<Policy>::Vector(const size_type globalRows, const value_type value)
+    : Vector(globalRows) {
   fill(value);
 }
 
@@ -468,6 +488,13 @@ typename Vector<Policy>::size_type Vector<Policy>::size() const {
   return size;
 }
 
+template <class Policy>
+typename Vector<Policy>::size_type Vector<Policy>::localSize() const {
+  size_type size = 0;
+  Policy::handleError(VecGetLocalSize(_vec.get(), &size));
+  return size;
+}
+
 template <class Policy> void Vector<Policy>::setZero() {
   Policy::handleError(VecZeroEntries(_vec.get()));
 }
diff --git a/cpppetsc/test/Vector_Test.cc b/cpppetsc/test/Vector_Test.cc
index 0be060c0..874b3350 100644
--- a/cpppetsc/test/Vector_Test.cc
+++ b/cpppetsc/test/Vector_Test.cc
@@ -45,11 +45,20 @@ template <class Policy> struct Vector_Test : Test {
 using Policies = Types<SequentialComputePolicy, ParallelComputePolicy>;
 TYPED_TEST_CASE(Vector_Test, Policies);
 
-TYPED_TEST(Vector_Test, create_vector_of_length_1) {
-  typename TestFixture::vector_type vec(1);
+TYPED_TEST(Vector_Test, vector_with_1_global_row_has_global_size_1) {
+  const typename TestFixture::vector_type vec(1);
   EXPECT_THAT(vec.size(), Eq(1));
 }
 
+TYPED_TEST(Vector_Test, vector_with_2_local_rows_has_local_size_2) {
+  using vector_type = typename TestFixture::vector_type;
+
+  const typename TestFixture::vector_type vec(
+      typename vector_type::LocalRows{2},
+      typename vector_type::GlobalRows{PETSC_DETERMINE});
+  EXPECT_THAT(vec.localSize(), Eq(2));
+}
+
 TYPED_TEST(Vector_Test, local_range_works) {
   typename TestFixture::vector_type vec(3);
 
-- 
GitLab