/*
 * This software is governed by the CeCILL-B license under French law and
 * abiding by the rules of distribution of free software.  You can  use, 
 * modify and/ or redistribute the software under the terms of the CeCILL-B
 * license as circulated by CEA, CNRS and INRIA at the following URL
 * "http://www.cecill.info" or the LICENCE.txt file present in this project.
*/

#ifndef CUDA_UTILS_THRUST_HPP__
#define CUDA_UTILS_THRUST_HPP__

#include <thrust/scan.h>
#include <thrust/device_ptr.h>
#include <thrust/reduce.h>
#include <thrust/functional.h>
#include <cassert>

#include "cuda_utils.hpp"

/**

  @file cuda_utils_thrust.hpp
  @brief interoperability between cuda utils and thrust

  To prevent re-inventing the wheel we use thrust to launch standard parallel
  algorithms on device arrays


*/

// TODO: change the interface of the functions, here 'last' index define an index
// to be treated. But people use more often sizes and will have to do a tedious -1
// each time they call the procedure.

// =============================================================================
namespace Cuda_utils{
// =============================================================================

/// inclusive scan is a prefix sum which skips the first element of the array
/// Computation is done in place. A prefix sum consist in summing every elements
/// before the element.
/// example:
/// array input  : 2 1 1 2 3 -1
/// array output : 2 3 4 6 9  8
/// @param start first index to treat
/// @param end : last index to treat
template<class T>
void inclusive_scan(int start,
                    int end,
                    Cuda_utils::Device::Array<T>& array)
{
    assert(start >= 0           );
    assert(end   <  array.size());
    thrust::device_ptr<T> d_ptr = thrust::device_pointer_cast( array.ptr() );

    thrust::inclusive_scan(d_ptr+start, d_ptr+end+1, d_ptr);
}

// -----------------------------------------------------------------------------

template<class T>
void inclusive_scan(int start, int end, T* d_array)
{
    thrust::device_ptr<T> d_ptr = thrust::device_pointer_cast( d_array );
    thrust::inclusive_scan(d_ptr+start, d_ptr+end+1, d_ptr);
}

// -----------------------------------------------------------------------------

template<class T>
void pack(Cuda_utils::Device::Array<T>& array)
{

}

// -----------------------------------------------------------------------------

/// Reduce a device array using a bin operator
/// @tparam T : type of the array to be reduced
/// @tparam BinOp : type of the bin operator used to reduce the array. It must
/// implement the operator (x,y) tagged as host device function like below:
/// @code
/// template <typename T>
/// struct plus
/// {
///     __host__ __device__
///     T operator()(const T& x, const T& y) const { return x + y; }
/// };
/// @endcode
/// You can find predefined bin op in <thrust/functional.h> here is the
/// (not exhaustive) list of operators you can find:
/// thrust::plus<T>, thrust::minus<T>, thrust::multiplies<T>,
/// thrust::divides<T>, thrust::modulus<T>, thrust::equal_to<T>,
/// thrust::not_equal_to<T>, thrust::greater<T>, thrust::less<T>,
/// thrust::greater_equal<T>, thrust::less_equal<T>, thrust::logical_and<T>
/// thrust::logical_or<T>, thrust::logical_not<T>, thrust::bit_and<T>,
/// thrust::bit_or<T>, thrust::bit_xor<T>, thrust::maximum<T>,
/// thrust::minimum<T>, thrust::project1st<T>, thrust::project2nd<T>
/// @param start first index to treat
/// @param end : last index to treat
/// @param init : initial value of the for the bin operator using to reduce the
/// array
template<typename T, typename BinOp>
T reduce(int start,
         int end,
         Cuda_utils::Device::Array<T>& array,
         T init,
         BinOp op = thrust::plus<T>() )
{
    thrust::device_ptr<T> d_ptr = thrust::device_pointer_cast( array.ptr() );
    return thrust::reduce(d_ptr+start, d_ptr+end+1, init, op);
}

// -----------------------------------------------------------------------------

/// Shortcut to reduce the whole 'array' with the sum operator
template<typename T>
T reduce_sum(Cuda_utils::Device::Array<T>& array)
{
    return reduce(0, array.size()-1, array, T(0), thrust::plus<T>());
}

}// END Cuda_utils =============================================================

#endif // CUDA_UTILS_THRUST_HPP__



