Commit f5960db7 authored by stefanow's avatar stefanow
Browse files

MCMC fix

parent 4ab21302
......@@ -2,6 +2,9 @@
#include <iostream>
#include <cassert>
#ifdef NDEBUG
#define DEBUG_MESSAGE(msg)
#else
......
/**************************************************************************************************
* DISCLAIMER AND LICENSE *
**************************************************************************************************/
/* MacOS implementations originally taken from
* http://www-personal.umich.edu/~williams/archive/computation/fe-handling-example.c (public domain)
* by David N. Williams
*/
/*
* Copyright 2017 Stefano Weidmann
* Author: Weidmann, Stefano
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**************************************************************************************************
* DESCRIPTION *
**************************************************************************************************/
/* Ports the GNU floating point exception API to MacOS and Windows:
* int feenableexcept(int excepts);
* int fedisableexcept(int excepts);
* int fegetexcept(void);
*/
/**************************************************************************************************
* USAGE *
**************************************************************************************************/
/* int excepts is a ORed composition of:
* FE_DIVBYZERO (Pole error: division by zero, or some other asymptotically infinite result (from finite arguments)),
* FE_INEXACT (Inexact: the result is not exact),
* FE_INVALID (Domain error: At least one of the arguments is a value for which the function is not defined)
* FE_OVERFLOW (Overflow range error: The result is too large in magnitude to be represented as a value of the return type)
* FE_UNDERFLOW (Underflow range error: The result is too small in magnitude to be represented as a value of the return type)
*
* FE_ALL_EXCEPT is a combination of all flags above
*
* int feenableexcept(int excepts) makes a unixoid OS send the signal SIGFPE to the process if any of the exceptions specified happen.
* The default action of a process is to abort, at least on MacOS and Linux.
* On Windows it throws an exception. You must enable SEH (structured exception handling) for it to work.
*
* int fedisableexcept(int excepts) disables the specified exceptions (counterpart to feenableexcept)
*
* int fegetexcept(void) returns the currently enables exceptions (format like int excepts)
*
*
* Example
* -----------
* feenableexcept(FE_ALL_EXCEPT & ~FE_INEXACT) // crash on everything except inexact math
*/
/**************************************************************************************************
* INCLUDES AND MACROS *
**************************************************************************************************/
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#ifdef _GNU_SOURCE
#define STEF_FENV_WAS_GNU_SOURCE 1
#else
#define _GNU_SOURCE
#define STEF_FENV_WAS_GNU_SOURCE 0
#endif
#include <stdlib.h>
#include <fenv.h>
#ifdef BSD
#pragma STDC FENV_ACCESS ON
#endif
/**************************************************************************************************
* MACOS PORT *
**************************************************************************************************/
#if defined(__APPLE__) && defined(__MACH__)
#if !defined(__i386__) && !defined(__x86_64__)
#error "Doesn't work on non-intel macs!"
#endif
static inline
int
fegetexcept(void) {
static fenv_t fenv;
return fegetenv(&fenv) ? -1 : (fenv.__control & FE_ALL_EXCEPT);
}
static inline
int
feenableexcept(int excepts) {
static fenv_t fenv;
int new_excepts = excepts & FE_ALL_EXCEPT, old_excepts; // previous masks
if (fegetenv(&fenv)) {
return -1;
}
old_excepts = fenv.__control & FE_ALL_EXCEPT;
// unmask
fenv.__control &= ~new_excepts;
fenv.__mxcsr &= ~(new_excepts << 7);
return (fesetenv(&fenv) ? -1 : old_excepts);
}
static inline
int
fedisableexcept(int excepts) {
static fenv_t fenv;
int new_excepts = excepts & FE_ALL_EXCEPT,
old_excepts; // all previous masks
if (fegetenv(&fenv)) {
return -1;
}
old_excepts = fenv.__control & FE_ALL_EXCEPT;
// mask
fenv.__control |= new_excepts;
fenv.__mxcsr |= new_excepts << 7;
return (fesetenv(&fenv) ? -1 : old_excepts);
}
#endif
/**************************************************************************************************
* WINDOWS PORT *
**************************************************************************************************/
// TODO Test it
#if defined(_WIN32) || defined(_WIN64)
#include <float.h>
#include <stdint.h>
// _EM_DENORMAL is left out because it doesn't have a fenv counterpart
const static int StefFenv_WindowsExceptionConstants_C[] = {_EM_INVALID, _EM_ZERODIVIDE, _EM_OVERFLOW, _EM_UNDERFLOW, _EM_INEXACT};
const static int StefFenv_FenvExceptionConstants_C[] = {FE_INVALID, FE_DIVBYZERO, FE_OVERFLOW, FE_UNDERFLOW, FE_INEXACT};
const static int StefFenv_NumberOfExceptionConstants_C = sizeof(StefFenv_WindowsExceptionConstants_C) / sizeof(int);
/* windows exception masks work the other way around than fenv masks
* the exception is enabled if the constant is NOT set
*/
static inline
int
StefFenv_TranslateWindowsExceptionMaskToFenv(const int windowsExceptionMask) {
int fenvExceptionMask = 0;
for (int i = 0; i < StefFenv_NumberOfExceptionConstants_C; ++i) {
if (windowsExceptionMask & ~StefFenv_WindowsExceptionConstants_C[i]) {
fenvExceptionMask |= StefFenv_FenvExceptionConstants_C[i];
}
}
return fenvExceptionMask;
}
static inline
int
StefFenv_TranslateFenvExceptionMaskToWindows(const int fenvExceptionMask) {
int windowsExceptionMask = -1;
for (int i = 0; i < StefFenv_NumberOfExceptionConstants_C; ++i) {
if (fenvExceptionMask & StefFenv_FenvExceptionConstants_C[i]) {
windowsExceptionMask &= ~StefFenv_WindowsExceptionConstants_C[i];
}
}
return windowsExceptionMask;
}
static inline
int
fegetexcept(void) {
_clearfp(); // Clearing enables exceptions afterwards
const int currentState = _controlfp(0, 0);
return StefFenv_TranslateWindowsExceptionMaskToFenv(currentState);
}
static inline
int
feenableexcept(int excepts) {
_clearfp(); // Clearing enables exceptions afterwards
const int stateBefore = fegetexcept();
const int windowsExceptionMask = StefFenv_TranslateFenvExceptionMaskToWindows(excepts);
_controlfp(windowsExceptionMask, _MCW_EM);
return stateBefore;
}
static inline
int
fedisableexcept(int excepts) {
_clearfp(); // Clearing enables exceptions afterwards
const int stateBefore = fegetexcept();
const int negatedWindowsExceptionMask = ~StefFenv_TranslateFenvExceptionMaskToWindows(excepts);
_controlfp(negatedWindowsExceptionMask, _MCW_EM);
return stateBefore;
}
#endif
/**************************************************************************************************
* PRETTIER CRASHING WITH POSIX *
**************************************************************************************************/
// Signal handler adapted from http://www-personal.umich.edu/~williams/archive/computation/fe-handling-example.c
static inline
const char*
StefFenv_GetFloatingPointExceptionDescription(
const int signalExceptionCode
) {
const static int signalExceptionCodes[] = {
FPE_INTDIV,
FPE_INTOVF,
FPE_FLTDIV,
FPE_FLTOVF,
FPE_FLTUND,
FPE_FLTRES,
FPE_FLTINV,
FPE_FLTSUB
};
const static char* exceptionDescriptions[] = {
"integer division by zero",
"integer overflow",
"floating point divide by zero",
"floating point overflow",
"floating point underflow",
"floating point inexact result",
"floating point invalid operation",
"subscript out of range"
};
const int numberOfExceptions = sizeof(signalExceptionCodes) / sizeof(signalExceptionCodes[0]);
for (int i = 0; i < numberOfExceptions; ++i) {
if (signalExceptionCode & signalExceptionCodes[i]) {
return exceptionDescriptions[i];
}
}
return "no floating point exception";
}
// TODO: Find better way to test if on unix
#if !(defined(_WIN32) || defined(_WIN64))
#define _POSIX_C_SOURCE 200112L
#include <signal.h>
#include <unistd.h>
#include <stdio.h>
#include <stdbool.h>
#if _POSIX_VERSION >= 200112L
static inline
void
StefFenv_FloatingPointExceptionHandler(
int signal,
siginfo_t* signalInfos,
void* somethingIdontCareAbout
) {
if (signal == SIGFPE) {
const char* exceptionDescription = StefFenv_GetFloatingPointExceptionDescription(signalInfos->si_code);
fputs("**************************************************************************************************\n", stderr);
fprintf(stderr, "SIGNAL SIGFPE CAUGHT\n%s\n", exceptionDescription);
fputs("**************************************************************************************************\n", stderr);
} else {
fputs("Should handle a floating point exception signal, but the signal sent wasn't SIGFPE!", stderr);
}
fflush(stderr);
abort();
}
static inline
void
StefFenv_RegisterFloatingPointExceptionHandler(
) {
static bool alreadyRegistered = false;
if (alreadyRegistered) {
return;
}
struct sigaction signalAction;
sigemptyset(&signalAction.sa_mask);
/* The SA_SIGINFO flag tells sigaction() to use the sa_sigaction field, not sa_handler. */
signalAction.sa_flags = SA_SIGINFO;
signalAction.sa_sigaction = StefFenv_FloatingPointExceptionHandler;
if (sigaction(SIGFPE, &signalAction, NULL) < 0) {
perror("error registering with sigaction");
}
}
#endif // POSIX_VERSION
#else
// not unix
static inline
void
StefFenv_RegisterFloatingPointExceptionHandler(
) {
#ifndef NDEBUG
fputs("Signal handling not supported on this platform!\n", stderr);
#endif
}
#endif // not windows
/**************************************************************************************************
* MACRO CLEANUP *
**************************************************************************************************/
#ifdef __cplusplus
}
#endif
#if !STEF_FENV_WAS_GNU_SOURCE
#undef _GNU_SOURCE
#endif
#undef STEF_FENV_WAS_GNU_SOURCE
......@@ -9,7 +9,16 @@
#include "likelihood.hpp"
#include <omp.h>
#ifndef NDEBUG
#include "stefFenv.h"
#endif
int main(){
#ifndef NDEBUG
//feenableexcept(FE_ALL_EXCEPT & ~FE_INEXACT & ~FE_UNDERFLOW);
//StefFenv_RegisterFloatingPointExceptionHandler();
#endif
/// Initialization
std::cout << "Using the " << PotentialName<POTENTIAL>::value << " potential\n";
std::cout << "Prior stddev = " << NORMAL_PRIOR_STDDEV << "\n";
......
......@@ -54,13 +54,6 @@ typedef Eigen::Matrix<numeric_t, N_DIM, -1> PopMatrix;
typedef double (*PotentialFunction)(const double r, const ThetaVector& theta);
/// Box size in each dimension
constexpr numeric_t LOW_BD = -10;
constexpr numeric_t UPPER_BD = 10;
constexpr numeric_t NORMAL_PRIOR_STDDEV = 10;
/// Beta
......
......@@ -10,6 +10,9 @@ template <typename T>
T normal_pdf_1D(T x, T m, T s){
static const T inv_sqrt_2pi = 0.3989422804014327;
T a = (x - m) / s;
if (std::fabs(a) > 10) {
return 1e-16;
}
return inv_sqrt_2pi / s * std::exp(-T(0.5) * a * a);
}
/// PDF of a Gamma distribution with \alpha = 'a' and \beta = 'b'
......@@ -126,39 +129,29 @@ inline static PopMatrix mcmc(const ThetaVector & theta, const index_t rck, const
PopMatrix new_thetas(N_DIM, rck);
ThetaVector particle, particle_prop;
std::uniform_real_distribution<numeric_t> acc_dist(0,1);
particle = theta;
numeric_t particle_likelihood = likelihood_theta;
for(index_t i = 0; i < rck; ++i){
particle = theta;
numeric_t particle_likelihood = likelihood_theta;
numeric_t particle_prop_likelihood;
for(index_t j = 0; j < rck; ++j){
while(true){
particle_prop = sample_mvn(1, particle, sample_cov_mat);
/*
std::cout << "Cov mat " << sample_cov_mat << "\n"
<< "cur part " << particle << "\n"
<< "prop part " << particle_prop << "\n"
<< "Like prop " << likelihood(particle_prop) << "\n"
<< "Prior prop " << prior(particle_prop) << "\n"
<< "Like cur " << likelihood(particle) << "\n"
<< "Prior cur " << prior(particle) << "\n"
<< "Acc prob " << std::pow(likelihood(particle_prop) / likelihood(particle),rho_curr) * prior(particle_prop) / prior(particle) << "\n";
*/
//if(particle_prop(N_DIM - 1) > 0) break;
//std::cout << "Retrying\n";
break;
};
particle_prop_likelihood = likelihood(particle_prop);
if(acc_dist(GEN) <= std::pow(particle_prop_likelihood / particle_likelihood, rho_curr) * prior(particle_prop) / prior(particle)){
particle = particle_prop;
particle_likelihood = particle_prop_likelihood;
//std::cout << "Accepted.\n";
} else {
//std::cout << "Not accepted.\n";
}
particle_prop = sample_mvn(1, particle, sample_cov_mat);
particle_prop_likelihood = likelihood(particle_prop);
if(acc_dist(GEN) <= std::pow(particle_prop_likelihood / particle_likelihood, rho_curr) * prior(particle_prop) / prior(particle)){
particle = particle_prop;
particle_likelihood = particle_prop_likelihood;
}
new_thetas.col(i) = particle;
/*
std::cout << "Cov mat " << sample_cov_mat << "\n"
<< "cur part " << particle << "\n"
<< "prop part " << particle_prop << "\n"
<< "Like prop " << likelihood(particle_prop) << "\n"
<< "Prior prop " << prior(particle_prop) << "\n"
<< "Like cur " << likelihood(particle) << "\n"
<< "Prior cur " << prior(particle) << "\n"
<< "Acc prob " << std::pow(likelihood(particle_prop) / likelihood(particle),rho_curr) * prior(particle_prop) / prior(particle) << "\n";
*/
}
return new_thetas;
};
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment