Commit 24cb42f6 authored by Nicolas Winkler's avatar Nicolas Winkler
Browse files

implementing for new matrix class

parent 860d277c
......@@ -2,6 +2,7 @@
#include <random>
#include <chrono>
#include <omp.h>
#include <iomanip>
#include "matrix_operations.h"
......@@ -28,6 +29,18 @@ void print_matrix(int n, Scalar *A)
}
}
void print_matrix(const Matrix& A)
{
for (int i = 0; i < A.n; i++)
{
for (int j = 0; j < A.m; j++)
{
std::cout << A(i, j) << " ";
}
std::cout << std::endl;
}
}
void test_matrix(int n, Scalar *A, Scalar *original)
{
std::vector<Scalar> L(n * n, 0);
......@@ -49,7 +62,7 @@ void test_matrix(int n, Scalar *A, Scalar *original)
}
std::vector<Scalar> result(n * n, 0);
mat_mult(n, n, n, L.data(), U.data(), result.data());
mat_mult(n, L.data(), U.data(), result.data());
for (int i = 0; i < n; i++)
{
......@@ -57,15 +70,17 @@ void test_matrix(int n, Scalar *A, Scalar *original)
{
if (abs(result[i * n + k] - original[i * n + k]) > 0.1)
{
std::cout << "NOT CORRECT!" << std::endl;
return;
std::cout << "\033[1;31m" << std::setw(7) << result[i*n + k] << "\033[0m ";
std::cout << "\033[1;32m" << std::setw(7) << original[i*n + k] << "\033[0m ";
//std::cout << "NOT CORRECT!" << std::endl;
//return;
}
else
{
// std::cout << result[i*n + k] << " ";
std::cout << result[i*n + k] << " ";
}
}
// std::cout << std::endl;
std::cout << std::endl;
}
std::cout << "CORRECT!" << std::endl;
}
......@@ -115,8 +130,16 @@ void subtract_submatrix(int i, int j, int n, int m, int original_m, Scalar *A, c
}
}
void blocked_lu(int n, Scalar *A)
void blocked_lu(Matrix& A)
{
int n = A.n;
// assume quadratic
if (n != A.m) {
std::cout << "matrix not quadratic" << std::endl;
exit(1);
}
int already_done = 0;
// std::cout << "starting blocked lu" << std::endl;
......@@ -128,55 +151,30 @@ void blocked_lu(int n, Scalar *A)
bs = n - already_done;
// std::cout << "doing the rest sequentially; remaining n: " << bs << std::endl;
if (bs == 1)
{
if (bs == 1) { // LU-decomposition of 1x1 matrix trivially is itself
break;
}
std::vector<Scalar> block(bs * bs);
extract_submatrix(already_done, already_done, bs, bs, n, A, block.data());
lu_simple(bs, block.data());
insert_submatrix(already_done, already_done, bs, bs, n, A, block.data());
Matrix block = A.submatrix(already_done, already_done, bs, bs);
lu_simple(block);
break;
}
int A_n = n - already_done;
std::vector<Scalar> block(bs * bs);
std::vector<Scalar> A01(bs * (A_n - bs));
std::vector<Scalar> A10(bs * (A_n - bs));
extract_submatrix(already_done, already_done, bs, bs, n, A, block.data());
extract_submatrix(already_done, already_done + bs, bs, A_n - bs, n, A, A01.data());
extract_submatrix(already_done + bs, already_done, A_n - bs, bs, n, A, A10.data());
lu_simple(bs, block.data());
insert_submatrix(already_done, already_done, bs, bs, n, A, block.data());
std::vector<Scalar> ul(bs * bs);
simple_transpose(bs, bs, block.data(), ul.data());
Matrix A00 = A.submatrix(already_done, already_done, bs, bs);
Matrix A01 = A.submatrix(already_done, already_done + bs, bs, A_n - bs);
Matrix A10 = A.submatrix(already_done + bs, already_done, A_n - bs, bs);
Matrix A11 = A.submatrix(already_done + bs, already_done + bs, A_n - bs, A_n - bs);
std::vector<Scalar> A10_transposed(bs * (A_n - bs));
simple_transpose(A_n - bs, bs, A10.data(), A10_transposed.data());
// set diagonal to 1 for l
for (int i = 0; i < bs; i++)
{
block[i * bs + i] = 1;
}
lu_simple(A00);
#pragma omp parallel
{
simple_trsm(bs, A_n - bs, block.data(), A01.data());
simple_trsm(bs, A_n - bs, ul.data(), A10_transposed.data());
trsm(A00, A01);
trsm_trans(A00, A10);
}
simple_transpose(bs, A_n - bs, A10_transposed.data(), A10.data());
insert_submatrix(already_done, already_done + bs, bs, A_n - bs, n, A, A01.data());
insert_submatrix(already_done + bs, already_done, A_n - bs, bs, n, A, A10.data());
std::vector<Scalar> to_subtract((A_n - bs) * (A_n - bs), 0.0);
mat_mult((A_n - bs), bs, (A_n - bs), A10.data(), A01.data(), to_subtract.data());
subtract_submatrix(already_done + bs, already_done + bs, A_n - bs, A_n - bs, n, A, to_subtract.data());
mat_mult_subtract(A10, A01, A11);
already_done += bs;
// std::cout << "done one step" << std::endl;
......
......@@ -3,6 +3,7 @@
#include <vector>
#include <omp.h>
#include "matrix_operations.h"
#include "lu.cpp"
#include "bench_util.h"
......@@ -19,8 +20,10 @@ int bench_mpi_omp(int run_id, BenchUtil &bench, int num_runs, int matrix_size, i
printf(".");
fflush(stdout);
// generate matrix
std::vector<Scalar> A(matrix_size * matrix_size);
randomize_matrix(matrix_size, A.data());
std::vector<Scalar> matrix_data(matrix_size * matrix_size);
randomize_matrix(matrix_size, matrix_data.data());
std::vector<Scalar> original = matrix_data;
Matrix A(matrix_size, matrix_data.data());
bench.bench_param(run_id, "type", "OMP");
bench.bench_param(run_id, "num_threads", to_string(num_threads));
bench.bench_param(run_id, "run", to_string(i));
......@@ -28,7 +31,7 @@ int bench_mpi_omp(int run_id, BenchUtil &bench, int num_runs, int matrix_size, i
bench.bench_param(run_id, "block_size", to_string(block_size));
bench.bench_start(run_id, "execution_time");
blocked_lu(matrix_size, A.data());
blocked_lu(A);
/* Record and complete the current measure */
bench.bench_stop();
......
#include <immintrin.h>
#include <omp.h>
#include <iostream>
#include "matrix_operations.h"
Matrix Matrix::submatrix(size_t i, size_t j, size_t n, size_t m)
{
Matrix m;
m.data = this->operator()(i, j);
m.stride = this->stride;
return m;
Matrix mat;
mat.data = &this->operator()(i, j);
mat.stride = this->stride;
mat.n = n;
mat.m = m;
return mat;
}
......@@ -34,8 +38,7 @@ void mat_mulrt(int n, int m, int p, const double* A, const double* B, double* re
}
}
#include <iostream>
void mat_mult(int n, int m, int p, const double* A, const double* B, double* result)
void mat_mult_subtract(const Matrix& A, const Matrix& B, Matrix& result)
{
// block size for larger "L1" blocks
// these can be tweaked, but need to be multiples of 4
......@@ -49,151 +52,76 @@ void mat_mult(int n, int m, int p, const double* A, const double* B, double* res
const int bsj2 = 4;
const int bsk2 = 4;
auto is_aligned = [](const double* ptr, uintptr_t alignment) -> bool { return ((uintptr_t)(void*) ptr) % alignment == 0; };
bool all_aligned =
n % 4 == 0
&& m % 4 == 0
&& p % 4 == 0
&& is_aligned(A, 32)
&& is_aligned(B, 32)
&& is_aligned(result, 32);
int n = A.n;
int p = B.m;
int m = A.m;
// parallel for needs to be over loop of i or k to avoid race conditions
#pragma omp parallel for
for (int bi1 = 0; bi1 < n; bi1 += bsi1) {
std::cout << "in matmul " << omp_get_thread_num() << std::endl;
//std::cout << "in matmul " << omp_get_thread_num() << std::endl;
for (int bj1 = 0; bj1 < m; bj1 += bsj1) {
for (int bk1 = 0; bk1 < p; bk1 += bsk1) {
if (all_aligned) {
for (int bj2 = bj1; bj2 < std::min(bj1+bsj1, m); bj2 += bsj2) {
for (int bi2 = bi1; bi2 < std::min(bi1+bsi1, n); bi2 += bsi2) {
for (int bk2 = bk1; bk2 < std::min(bk1+bsk1, p); bk2 += bsk2) {
if (bj2+bsj2 > m || bi2+bsi2 > n || bk2+bsk2 > p) {
for (int j = bj2; j < std::min(bj2+bsj2, m); j += 1) {
for (int i = bi2; i < std::min(bi2+bsi2, n); i += 1) {
//std::cout << "i: " << i << std::endl;
for (int k = bk2; k < std::min(bk2+bsk2, p); k += 1) {
result[i*p + k] += A[i*m + j] * B[j*p + k];
}
}
}
}
else {
__m256d a00 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 0)]);
__m256d a01 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 1)]);
__m256d a02 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 2)]);
__m256d a03 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 3)]);
__m256d a10 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 0)]);
__m256d a11 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 1)]);
__m256d a12 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 2)]);
__m256d a13 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 3)]);
__m256d a20 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 0)]);
__m256d a21 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 1)]);
__m256d a22 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 2)]);
__m256d a23 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 3)]);
__m256d a30 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 0)]);
__m256d a31 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 1)]);
__m256d a32 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 2)]);
__m256d a33 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 3)]);
__m256d b0 = _mm256_load_pd(&B[(bj2 + 0)*p + bk2]);
__m256d b1 = _mm256_load_pd(&B[(bj2 + 1)*p + bk2]);
__m256d b2 = _mm256_load_pd(&B[(bj2 + 2)*p + bk2]);
__m256d b3 = _mm256_load_pd(&B[(bj2 + 3)*p + bk2]);
__m256d r0 = _mm256_load_pd(&result[(bi2 + 0)*p + bk2]);
__m256d r1 = _mm256_load_pd(&result[(bi2 + 1)*p + bk2]);
__m256d r2 = _mm256_load_pd(&result[(bi2 + 2)*p + bk2]);
__m256d r3 = _mm256_load_pd(&result[(bi2 + 3)*p + bk2]);
r0 = _mm256_add_pd(r0, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a00, b0), _mm256_mul_pd(a01, b1)), _mm256_add_pd(_mm256_mul_pd(a02, b2), _mm256_mul_pd(a03, b3))));
r1 = _mm256_add_pd(r1, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a10, b0), _mm256_mul_pd(a11, b1)), _mm256_add_pd(_mm256_mul_pd(a12, b2), _mm256_mul_pd(a13, b3))));
r2 = _mm256_add_pd(r2, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a20, b0), _mm256_mul_pd(a21, b1)), _mm256_add_pd(_mm256_mul_pd(a22, b2), _mm256_mul_pd(a23, b3))));
r3 = _mm256_add_pd(r3, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a30, b0), _mm256_mul_pd(a31, b1)), _mm256_add_pd(_mm256_mul_pd(a32, b2), _mm256_mul_pd(a33, b3))));
_mm256_store_pd(&result[(bi2 + 0)*p + bk2], r0);
_mm256_store_pd(&result[(bi2 + 1)*p + bk2], r1);
_mm256_store_pd(&result[(bi2 + 2)*p + bk2], r2);
_mm256_store_pd(&result[(bi2 + 3)*p + bk2], r3);
}
for (int bj2 = bj1; bj2 < std::min(bj1+bsj1, m); bj2 += bsj2) {
for (int bi2 = bi1; bi2 < std::min(bi1+bsi1, n); bi2 += bsi2) {
for (int bk2 = bk1; bk2 < std::min(bk1+bsk1, p); bk2 += bsk2) {
}
}
}
}
else {
for (int bj2 = bj1; bj2 < std::min(bj1+bsj1, m); bj2 += bsj2) {
for (int bi2 = bi1; bi2 < std::min(bi1+bsi1, n); bi2 += bsi2) {
for (int bk2 = bk1; bk2 < std::min(bk1+bsk1, p); bk2 += bsk2) {
if (bj2+bsj2 > m || bi2+bsi2 > n || bk2+bsk2 > p) {
for (int j = bj2; j < std::min(bj2+bsj2, m); j += 1) {
for (int i = bi2; i < std::min(bi2+bsi2, n); i += 1) {
//std::cout << "i: " << i << std::endl;
for (int k = bk2; k < std::min(bk2+bsk2, p); k += 1) {
result[i*p + k] += A[i*m + j] * B[j*p + k];
}
if (bj2+bsj2 > m || bi2+bsi2 > n || bk2+bsk2 > p) {
for (int j = bj2; j < std::min(bj2+bsj2, m); j += 1) {
for (int i = bi2; i < std::min(bi2+bsi2, n); i += 1) {
//std::cout << "i: " << i << std::endl;
for (int k = bk2; k < std::min(bk2+bsk2, p); k += 1) {
result(i, k) -= A(i, j) * B(j, k);
}
}
}
else {
__m256d a00 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 0)]);
__m256d a01 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 1)]);
__m256d a02 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 2)]);
__m256d a03 = _mm256_broadcast_sd(&A[(bi2 + 0)*m + (bj2 + 3)]);
__m256d a10 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 0)]);
__m256d a11 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 1)]);
__m256d a12 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 2)]);
__m256d a13 = _mm256_broadcast_sd(&A[(bi2 + 1)*m + (bj2 + 3)]);
__m256d a20 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 0)]);
__m256d a21 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 1)]);
__m256d a22 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 2)]);
__m256d a23 = _mm256_broadcast_sd(&A[(bi2 + 2)*m + (bj2 + 3)]);
__m256d a30 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 0)]);
__m256d a31 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 1)]);
__m256d a32 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 2)]);
__m256d a33 = _mm256_broadcast_sd(&A[(bi2 + 3)*m + (bj2 + 3)]);
__m256d b0 = _mm256_loadu_pd(&B[(bj2 + 0)*p + bk2]);
__m256d b1 = _mm256_loadu_pd(&B[(bj2 + 1)*p + bk2]);
__m256d b2 = _mm256_loadu_pd(&B[(bj2 + 2)*p + bk2]);
__m256d b3 = _mm256_loadu_pd(&B[(bj2 + 3)*p + bk2]);
__m256d r0 = _mm256_loadu_pd(&result[(bi2 + 0)*p + bk2]);
__m256d r1 = _mm256_loadu_pd(&result[(bi2 + 1)*p + bk2]);
__m256d r2 = _mm256_loadu_pd(&result[(bi2 + 2)*p + bk2]);
__m256d r3 = _mm256_loadu_pd(&result[(bi2 + 3)*p + bk2]);
r0 = _mm256_add_pd(r0, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a00, b0), _mm256_mul_pd(a01, b1)), _mm256_add_pd(_mm256_mul_pd(a02, b2), _mm256_mul_pd(a03, b3))));
r1 = _mm256_add_pd(r1, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a10, b0), _mm256_mul_pd(a11, b1)), _mm256_add_pd(_mm256_mul_pd(a12, b2), _mm256_mul_pd(a13, b3))));
r2 = _mm256_add_pd(r2, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a20, b0), _mm256_mul_pd(a21, b1)), _mm256_add_pd(_mm256_mul_pd(a22, b2), _mm256_mul_pd(a23, b3))));
r3 = _mm256_add_pd(r3, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a30, b0), _mm256_mul_pd(a31, b1)), _mm256_add_pd(_mm256_mul_pd(a32, b2), _mm256_mul_pd(a33, b3))));
_mm256_storeu_pd(&result[(bi2 + 0)*p + bk2], r0);
_mm256_storeu_pd(&result[(bi2 + 1)*p + bk2], r1);
_mm256_storeu_pd(&result[(bi2 + 2)*p + bk2], r2);
_mm256_storeu_pd(&result[(bi2 + 3)*p + bk2], r3);
}
}
else {
__m256d a00 = _mm256_broadcast_sd(&A(bi2 + 0, bj2 + 0));
__m256d a01 = _mm256_broadcast_sd(&A(bi2 + 0, bj2 + 1));
__m256d a02 = _mm256_broadcast_sd(&A(bi2 + 0, bj2 + 2));
__m256d a03 = _mm256_broadcast_sd(&A(bi2 + 0, bj2 + 3));
__m256d a10 = _mm256_broadcast_sd(&A(bi2 + 1, bj2 + 0));
__m256d a11 = _mm256_broadcast_sd(&A(bi2 + 1, bj2 + 1));
__m256d a12 = _mm256_broadcast_sd(&A(bi2 + 1, bj2 + 2));
__m256d a13 = _mm256_broadcast_sd(&A(bi2 + 1, bj2 + 3));
__m256d a20 = _mm256_broadcast_sd(&A(bi2 + 2, bj2 + 0));
__m256d a21 = _mm256_broadcast_sd(&A(bi2 + 2, bj2 + 1));
__m256d a22 = _mm256_broadcast_sd(&A(bi2 + 2, bj2 + 2));
__m256d a23 = _mm256_broadcast_sd(&A(bi2 + 2, bj2 + 3));
__m256d a30 = _mm256_broadcast_sd(&A(bi2 + 3, bj2 + 0));
__m256d a31 = _mm256_broadcast_sd(&A(bi2 + 3, bj2 + 1));
__m256d a32 = _mm256_broadcast_sd(&A(bi2 + 3, bj2 + 2));
__m256d a33 = _mm256_broadcast_sd(&A(bi2 + 3, bj2 + 3));
__m256d b0 = _mm256_loadu_pd(&B(bj2 + 0, bk2));
__m256d b1 = _mm256_loadu_pd(&B(bj2 + 1, bk2));
__m256d b2 = _mm256_loadu_pd(&B(bj2 + 2, bk2));
__m256d b3 = _mm256_loadu_pd(&B(bj2 + 3, bk2));
__m256d r0 = _mm256_loadu_pd(&result(bi2 + 0, bk2));
__m256d r1 = _mm256_loadu_pd(&result(bi2 + 1, bk2));
__m256d r2 = _mm256_loadu_pd(&result(bi2 + 2, bk2));
__m256d r3 = _mm256_loadu_pd(&result(bi2 + 3, bk2));
r0 = _mm256_sub_pd(r0, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a00, b0), _mm256_mul_pd(a01, b1)), _mm256_add_pd(_mm256_mul_pd(a02, b2), _mm256_mul_pd(a03, b3))));
r1 = _mm256_sub_pd(r1, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a10, b0), _mm256_mul_pd(a11, b1)), _mm256_add_pd(_mm256_mul_pd(a12, b2), _mm256_mul_pd(a13, b3))));
r2 = _mm256_sub_pd(r2, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a20, b0), _mm256_mul_pd(a21, b1)), _mm256_add_pd(_mm256_mul_pd(a22, b2), _mm256_mul_pd(a23, b3))));
r3 = _mm256_sub_pd(r3, _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a30, b0), _mm256_mul_pd(a31, b1)), _mm256_add_pd(_mm256_mul_pd(a32, b2), _mm256_mul_pd(a33, b3))));
_mm256_storeu_pd(&result(bi2 + 0, bk2), r0);
_mm256_storeu_pd(&result(bi2 + 1, bk2), r1);
_mm256_storeu_pd(&result(bi2 + 2, bk2), r2);
_mm256_storeu_pd(&result(bi2 + 3, bk2), r3);
}
}
}
}
}
}
}
......@@ -301,11 +229,13 @@ void mat_mult(int n, int m, int p, const float* A, const float* B, float* result
}
void simple_transpose(int n, int m, const double* A, double* A_trans)
void simple_transpose(const Matrix& A, Matrix& A_trans)
{
int n = A.n;
int m = A.m;
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
A_trans[j*n + i] = A[i*m + j];
A_trans(j, i) = A(i, j);
}
}
}
......@@ -353,22 +283,73 @@ void simple_trsm(int n, int m, const float* L, float* A)
}
///
/// does triangular matrix solve assuming L is a lower triangular
/// matrix with diagonal = 1
///
void trsm(const Matrix& L, Matrix& A)
{
int n = L.n;
int m = A.m;
#pragma omp parallel for
for (int k = 0; k < m; k++) {
// for each row
for (int i = 0; i < n; i++) {
double val = 0.0;
for (int j = 0; j < i; j++) {
val += L(i, j) * A(j, k);
}
A(i, k) = (A(i, k) - val);
}
}
}
///
/// does transposed triangular matrix solve assuming U is an upper triangular
/// matrix
///
void trsm_trans(const Matrix& U, Matrix& A)
{
int n = U.n;
int m = A.n;
#pragma omp parallel for
for (int k = 0; k < m; k++) {
// for each row
for (int i = 0; i < n; i++) {
double val = 0.0;
for (int j = 0; j < i; j++) {
val += U(j, i) * A(k, j);
}
A(k, i) = (A(k, i) - val) / U(i, i);
}
}
}
// matrices in rowmajor order
//
// 1/2n(n+1) divisions
// 1/6n(2n^2+3n+1) FMA
void lu_simple(int n, double* A)
void lu_simple(Matrix& A)
{
int n = A.n;
// assume quadratic
if (n != A.m) {
std::cout << "matrix not quadratic" << std::endl;
exit(1);
}
// for each column i = column
for (int i = 0; i < n - 1; i++) {
// j = row
#pragma omp parallel for
for (int j = i + 1; j < n; j++) {
double mult = A[j*n + i] / A[i*n + i];
double mult = A(j, i) / A(i, i);
for (int k = i + 1; k < n; k++) {
A[j*n + k] -= mult * A[i*n + k];
A(j, k) -= mult * A(i, k);
}
A[j*n + i] = mult;
A(j, i) = mult;
}
}
......
......@@ -6,10 +6,12 @@
///
struct Matrix {
double* data;
size_t n, m;
size_t stride;
Matrix(void) = default;
inline Matrix(size_t n, double* data) :
data{ data }, stride{ n } {}
data{ data }, n{ n }, m{ n }, stride{ n } {}
inline const double& operator() (size_t i, size_t j) const {
return data[i*stride + j];
......@@ -28,7 +30,7 @@ void mat_mult(int n, const double* A, const double* B, double* result);
/// multiplies the n x m matrix A with the m x p matrix B and stores the result
/// in result
///
void mat_mult(int n, int m, int p, const double* A, const double* B, double* result);
void mat_mult_subtract(const Matrix& A, const Matrix& B, Matrix& result);
void mat_mult(int n, int m, int p, const float* A, const float* B, float* result);
void mat_subtract(int n, int m, double* A_result, const double* B);
......@@ -48,11 +50,14 @@ void mat_subtract(int n, int m, float* A_result, const float* B);
void simple_trsm(int n, int m, const double* L, double* A);
void simple_trsm(int n, int m, const float* L, float* A);
void trsm(const Matrix& L, Matrix& A);
void trsm_trans(const Matrix& U, Matrix& A);
/// transposes A into A_trans
void simple_transpose(int n, int m, const double* A, double* A_trans);
void simple_transpose(const Matrix& A, Matrix& A_trans);
void simple_transpose(int n, int m, const float* A, float* A_trans);
void lu_simple(int n, double* A);
void lu_simple(Matrix& A);
void lu_simple(int n, float* A);
#endif // MATRIX_OPERATIONS_H_
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment