/*
 * This software is governed by the CeCILL-B license under French law and
 * abiding by the rules of distribution of free software.  You can  use, 
 * modify and/ or redistribute the software under the terms of the CeCILL-B
 * license as circulated by CEA, CNRS and INRIA at the following URL
 * "http://www.cecill.info" or the LICENCE.txt file present in this project.
*/

#ifndef CUDA_ASSERT_HPP__
#define CUDA_ASSERT_HPP__

#include "memory_debug.hpp"

/**
 * @file cuda_assert.hpp
 * @brief Utilities to check errors with cuda API/Driver functions and kernels
 * define the symbol TRACE_MEMORY before this header will make
 * error reports print the device memory stack.
 *
 * @note Macros use the do{}while(0) trick -> it enable to put easily a
 * semicolon at the end of the macro call whithout producing illegal code:
 * @code
 *  {
 *      if( cond )
 *          MACRO_CALL( args ); // <- valid code even with semi colonn!
 *  }
 * @endcode
 */

// =============================================================================
// Implemention helper not meant for external usage
// =============================================================================

// This is a shameless copy paste of thrust lib include:
// #include <thrust/detail/static_assert.h>
// Which is also a shameless copy paste from boost lib! see http://www.boost.org
// Don't be puzzled by the meaning of the code its just ugly metaprogramming

// Helper macro joins the two arguments X Y together
#define CUDA_UTILS_JOIN( X, Y ) _CUDA_UTILS_DO_JOIN_( X, Y )
#define _CUDA_UTILS_DO_JOIN_( X, Y ) _CUDA_UTILS_DO_JOIN2_(X,Y)
#define _CUDA_UTILS_DO_JOIN2_( X, Y ) X##Y

namespace Cuda_utils {
namespace Details {
// HP aCC cannot deal with missing names for template value parameters
template <bool x> struct STATIC_ASSERTION_FAILURE;
// Only specialise for true in order to false to produce an error
template <> struct STATIC_ASSERTION_FAILURE<true> { enum { value = 1 }; };
// HP aCC cannot deal with missing names for template value parameters
template<int x> struct static_assert_test{};
template<typename, bool x>
struct depend_on_instantiation
{
    static const bool value = x;
};
}
}
// =============================================================================
// END OF : Implemention helper not meant for external usage
// =============================================================================

#ifndef NDEBUG

/// @def CUDA_STATIC_ASSERT
/// @brief Static assertion (Checking done at compile time)
/// This macro will define invalid C++ code if called with a false boolean:
/// @code
/// {
///     // Next line won't compile if the result of sizeof is greater than 3
///     CUDA_STATIC_ASSERT( sizeof(a_type) <= 3 );
/// }
/// @warning static assertion is a hack. As stated above it produces code
/// which can't compile when argument is false. This mean the compiler error
/// will not be explicit. At least it will give the line of static assertion but
/// nvcc will say for instance: "error: incomplete type is not allowed". Always
/// comment static assertion otherwise people won't understand the "weird error"
/// is part of the static assertion mechanism
/// @endcode
#define CUDA_STATIC_ASSERT( B )                                                \
   typedef ::Cuda_utils::Details::static_assert_test<                          \
      sizeof(::Cuda_utils::Details::STATIC_ASSERTION_FAILURE< (bool)( B ) >)>  \
         CUDA_UTILS_JOIN(cuda_utils_static_assert_typedef_, __LINE__)

// -----------------------------------------------------------------------------

/// @def CUDA_SAFE_CALL
/// @brief Check a cuda API call.
/// assert(false) if a cudaXXX() does not return cudaSuccess:
/// usage:
/// @code
/// CUDA_SAFE_CALL(cudaMemcpy(dst, src, sizeof, cudaMemcpyDeviceToHost) );
/// // Will assert if cudaMemcpy fails.
/// @endcode
#define CUDA_SAFE_CALL(x) do{                              \
    cudaError_t code = x;                                  \
    if(code != cudaSuccess){                               \
        fprintf(stderr,"CUDA error: %s at %s, line %d\n",  \
        cudaGetErrorString(code), __FILE__, __LINE__);     \
        fflush(stderr);                                    \
        Mem_debug::cuda_print_memory_trace();              \
        Mem_debug::cuda_print_rusage();                    \
        assert(false);                                     \
    }                                                      \
} while(0)

// -----------------------------------------------------------------------------

/// @def CUDA_CHECK_ERRORS
/// Use this macro to check the latest cuda error. assert(false) when the test
/// fails. This is usefull when you want to check kernels errors. You just
/// need to put the CUDA_CHECK_ERRORS(); right after the kernel call:
/// @code
/// a_kernel<<<block_size, grid_size>>>( args, ... );
/// CUDA_CHECK_ERRORS(); // Will assert if the kernel fails
/// @endcode
#define CUDA_CHECK_ERRORS()                                    \
    do{                                                        \
        cudaThreadSynchronize();                               \
        cudaError_t code = cudaGetLastError();                 \
        if(code != cudaSuccess){                               \
            fprintf(stderr,"CUDA error: %s at %s, line %d\n",  \
            cudaGetErrorString(code), __FILE__, __LINE__);     \
            fflush(stderr);                                    \
            Mem_debug::cuda_print_memory_trace();              \
            Mem_debug::cuda_print_rusage();                    \
            assert(false);                                     \
        }                                                      \
    } while(0)

// -----------------------------------------------------------------------------

/// @def CU_SAFE_CALL
/// @brief Check a cuda Driver call.
/// assert(false) if a cuXXX() does not return CUDA_SUCCESS:
/// usage:
/// @code
/// CU_SAFE_CALL( cuMemcpy(dst, src, sizeof) );
/// // Will assert if cuMemcpy fails.
/// @endcode
#define CU_SAFE_CALL(call) do{                                               \
    CUresult err = call;                                                     \
    if( CUDA_SUCCESS != err) {                                               \
        fprintf(stderr, "Cuda driver error %x in file '%s' in line %i.\n",   \
                err, __FILE__, __LINE__ );                                   \
        assert(false);                                                       \
    } }while(0)

// -----------------------------------------------------------------------------

#else

    // Disable macros in release mode
    #define CUDA_CHECK_ERRORS() do{}while(0)
    #define CUDA_STATIC_ASSERT( B ) do{ }while( false && (B) )
    #define CUDA_SAFE_CALL(code) code
    #define CU_SAFE_CALL( call ) call

#endif

#endif // CUDA_ASSERT_HPP__
