#include "StaticCast.hpp" // for StaticCast, etc.
#include "TypeTraits.hpp" // for Require, etc.

#include <utility> // for std::declval, etc.

namespace nvcv::cuda::detail {

// clang-format off

// Metavariable to check if two types are compound and have the same number of components.
template<class T, class U, class = Require<HasTypeTraits<T, U>>>
constexpr bool IsSameCompound = IsCompound<T> && TypeTraits<T>::components == TypeTraits<U>::components;

// Metavariable to check that at least one type is of compound type out of two types.
// If both are compound type, then it is checked that both have the same number of components.
template<typename T, typename U, class = Require<HasTypeTraits<T, U>>>
constexpr bool OneIsCompound =
    (TypeTraits<T>::components == 0 && TypeTraits<U>::components >= 1) ||
    (TypeTraits<T>::components >= 1 && TypeTraits<U>::components == 0) ||
    IsSameCompound<T, U>;

// Metavariable to check if a type is of integral type.
template<typename T, class = Require<HasTypeTraits<T>>>
constexpr bool IsIntegral = std::is_integral_v<typename TypeTraits<T>::base_type>;

// Metavariable to require that at least one type is of compound type out of two integral types.
// If both are compound type, then it is required that both have the same number of components.
template<typename T, typename U, class = Require<HasTypeTraits<T, U>>>
constexpr bool OneIsCompoundAndBothAreIntegral = OneIsCompound<T, U> && IsIntegral<T> && IsIntegral<U>;

// Metavariable to require that a type is a CUDA compound of integral type.
template<typename T, class = Require<HasTypeTraits<T>>>
constexpr bool IsIntegralCompound = IsIntegral<T> && IsCompound<T>;

// clang-format on

} // namespace nvcv::cuda::detail

#define NVCV_CUDA_UNARY_OPERATOR(OPERATOR, REQUIREMENT)                                                          \
    template<typename T, class = nvcv::cuda::Require<REQUIREMENT<T>>>                                            \
    inline __host__ __device__ auto operator OPERATOR(T a)                                                       \
    {                                                                                                            \
        using RT = nvcv::cuda::ConvertBaseTypeTo<decltype(OPERATOR std::declval<nvcv::cuda::BaseType<T>>()), T>; \
        if constexpr (nvcv::cuda::NumElements<RT> == 1)                                                          \
            return RT{OPERATOR a.x};                                                                             \
        else if constexpr (nvcv::cuda::NumElements<RT> == 2)                                                     \
            return RT{OPERATOR a.x, OPERATOR a.y};                                                               \
        else if constexpr (nvcv::cuda::NumElements<RT> == 3)                                                     \
            return RT{OPERATOR a.x, OPERATOR a.y, OPERATOR a.z};                                                 \
        else if constexpr (nvcv::cuda::NumElements<RT> == 4)                                                     \
            return RT{OPERATOR a.x, OPERATOR a.y, OPERATOR a.z, OPERATOR a.w};                                   \

NVCV_CUDA_UNARY_OPERATOR(-, nvcv::cuda::IsCompound)
NVCV_CUDA_UNARY_OPERATOR(+, nvcv::cuda::IsCompound)
NVCV_CUDA_UNARY_OPERATOR(~, nvcv::cuda::detail::IsIntegralCompound)


#define NVCV_CUDA_BINARY_OPERATOR(OPERATOR, REQUIREMENT)                                                        \
    template<typename T, typename U, class = nvcv::cuda::Require<REQUIREMENT<T, U>>>                            \
    inline __host__ __device__ auto operator OPERATOR(T a, U b)                                                 \
    {                                                                                                           \
        using RT = nvcv::cuda::MakeType<                                                                        \
            decltype(std::declval<nvcv::cuda::BaseType<T>>() OPERATOR std::declval<nvcv::cuda::BaseType<U>>()), \
            nvcv::cuda::NumComponents<T> == 0 ? nvcv::cuda::NumComponents<U> : nvcv::cuda::NumComponents<T>>;   \
        if constexpr (nvcv::cuda::NumComponents<T> == 0)                                                        \
        {                                                                                                       \
            if constexpr (nvcv::cuda::NumElements<RT> == 1)                                                     \
                return RT{a OPERATOR b.x};                                                                      \
            else if constexpr (nvcv::cuda::NumElements<RT> == 2)                                                \
                return RT{a OPERATOR b.x, a OPERATOR b.y};                                                      \
            else if constexpr (nvcv::cuda::NumElements<RT> == 3)                                                \
                return RT{a OPERATOR b.x, a OPERATOR b.y, a OPERATOR b.z};                                      \
            else if constexpr (nvcv::cuda::NumElements<RT> == 4)                                                \
                return RT{a OPERATOR b.x, a OPERATOR b.y, a OPERATOR b.z, a OPERATOR b.w};                      \
        }                                                                                                       \
        else if constexpr (nvcv::cuda::NumComponents<U> == 0)                                                   \
        {                                                                                                       \
            if constexpr (nvcv::cuda::NumElements<RT> == 1)                                                     \
                return RT{a.x OPERATOR b};                                                                      \
            else if constexpr (nvcv::cuda::NumElements<RT> == 2)                                                \
                return RT{a.x OPERATOR b, a.y OPERATOR b};                                                      \
            else if constexpr (nvcv::cuda::NumElements<RT> == 3)                                                \
                return RT{a.x OPERATOR b, a.y OPERATOR b, a.z OPERATOR b};                                      \
            else if constexpr (nvcv::cuda::NumElements<RT> == 4)                                                \
                return RT{a.x OPERATOR b, a.y OPERATOR b, a.z OPERATOR b, a.w OPERATOR b};                      \
        }                                                                                                       \
        else                                                                                                    \
        {                                                                                                       \
            if constexpr (nvcv::cuda::NumElements<RT> == 1)                                                     \
                return RT{a.x OPERATOR b.x};                                                                    \
            else if constexpr (nvcv::cuda::NumElements<RT> == 2)                                                \
                return RT{a.x OPERATOR b.x, a.y OPERATOR b.y};                                                  \
            else if constexpr (nvcv::cuda::NumElements<RT> == 3)                                                \
                return RT{a.x OPERATOR b.x, a.y OPERATOR b.y, a.z OPERATOR b.z};                                \
            else if constexpr (nvcv::cuda::NumElements<RT> == 4)                                                \
                return RT{a.x OPERATOR b.x, a.y OPERATOR b.y, a.z OPERATOR b.z, a.w OPERATOR b.w};              \
        }                                                                                                       \
    }                                                                                                           \
    template<typename T, typename U, class = nvcv::cuda::Require<nvcv::cuda::IsCompound<T>>>                    \
    inline __host__ __device__ T &operator OPERATOR##=(T &a, U b)                                               \
    {                                                                                                           \
        return a = nvcv::cuda::StaticCast<nvcv::cuda::BaseType<T>>(a OPERATOR b);                               \

NVCV_CUDA_BINARY_OPERATOR(-, nvcv::cuda::detail::OneIsCompound)
NVCV_CUDA_BINARY_OPERATOR(+, nvcv::cuda::detail::OneIsCompound)
NVCV_CUDA_BINARY_OPERATOR(*, nvcv::cuda::detail::OneIsCompound)
NVCV_CUDA_BINARY_OPERATOR(/, nvcv::cuda::detail::OneIsCompound)
NVCV_CUDA_BINARY_OPERATOR(%, nvcv::cuda::detail::OneIsCompoundAndBothAreIntegral)
NVCV_CUDA_BINARY_OPERATOR(&, nvcv::cuda::detail::OneIsCompoundAndBothAreIntegral)
NVCV_CUDA_BINARY_OPERATOR(|, nvcv::cuda::detail::OneIsCompoundAndBothAreIntegral)
NVCV_CUDA_BINARY_OPERATOR(^, nvcv::cuda::detail::OneIsCompoundAndBothAreIntegral)
NVCV_CUDA_BINARY_OPERATOR(<<, nvcv::cuda::detail::OneIsCompoundAndBothAreIntegral)
NVCV_CUDA_BINARY_OPERATOR(>>, nvcv::cuda::detail::OneIsCompoundAndBothAreIntegral)


template<typename T, typename U, class = nvcv::cuda::Require<nvcv::cuda::detail::IsSameCompound<T, U>>>
inline __host__ __device__ bool operator==(T a, U b)
    if constexpr (nvcv::cuda::NumElements<T> >= 1)
        if (a.x != b.x)
            return false;
    if constexpr (nvcv::cuda::NumElements<T> >= 2)
        if (a.y != b.y)
            return false;
    if constexpr (nvcv::cuda::NumElements<T> >= 3)
        if (a.z != b.z)
            return false;
    if constexpr (nvcv::cuda::NumElements<T> == 4)
        if (a.w != b.w)
            return false;
    return true;

template<typename T, typename U, class = nvcv::cuda::Require<nvcv::cuda::detail::IsSameCompound<T, U>>>
inline __host__ __device__ bool operator!=(T a, U b)
    return !(a == b);

namespace nvcv::cuda {

template<typename T, typename U, class = nvcv::cuda::Require<nvcv::cuda::detail::IsSameCompound<T, U>>>
inline __host__ __device__ auto dot(T a, U b)
    using PT = decltype(std::declval<nvcv::cuda::BaseType<T>>() * std::declval<nvcv::cuda::BaseType<U>>());
    using RT = decltype(std::declval<PT>() + std::declval<PT>());

    if constexpr (nvcv::cuda::NumComponents<T> == 1)
        return RT{a.x * b.x};
    else if constexpr (nvcv::cuda::NumComponents<T> == 2)
        return RT{a.x * b.x + a.y * b.y};
    else if constexpr (nvcv::cuda::NumComponents<T> == 3)
        return RT{a.x * b.x + a.y * b.y + a.z * b.z};
    else if constexpr (nvcv::cuda::NumComponents<T> == 4)
        return RT{a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w};

} // namespace nvcv::cuda