Program Listing for File TensorWrap.hpp

Return to documentation for file (nvcv_types/include/nvcv/cuda/TensorWrap.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_CUDA_TENSOR_WRAP_HPP
#define NVCV_CUDA_TENSOR_WRAP_HPP

#include "TypeTraits.hpp" // for HasTypeTraits, etc.

#include <nvcv/ImageData.hpp>        // for ImageDataStridedCuda, etc.
#include <nvcv/TensorData.hpp>       // for TensorDataStridedCuda, etc.
#include <nvcv/TensorDataAccess.hpp> // for TensorDataAccessStridedImagePlanar, etc.

#include <cassert> // for assert, etc.
#include <utility> // for forward, etc.

namespace nvcv::cuda {

template<typename T, int... Strides>
class TensorWrap;

template<typename T, int... Strides>
class TensorWrap<const T, Strides...>
{
    static_assert(HasTypeTraits<T>, "TensorWrap<T> can only be used if T has type traits");

public:
    using ValueType = const T;

    static constexpr int kNumDimensions   = sizeof...(Strides);
    static constexpr int kVariableStrides = ((Strides == -1) + ...);
    static constexpr int kConstantStrides = kNumDimensions - kVariableStrides;

    TensorWrap() = default;

    template<typename DataType, typename... Args>
    explicit __host__ __device__ TensorWrap(const DataType *data, Args... strides)
        : m_data(reinterpret_cast<const std::byte *>(data))
        , m_strides{std::forward<int>(strides)...}
    {
        static_assert(std::conjunction_v<std::is_same<int, Args>...>);
        static_assert(sizeof...(Args) == kVariableStrides);
    }

    template<typename DataType, typename StrideType>
    explicit __host__ __device__ TensorWrap(const DataType *data, StrideType *strides)
        : m_data(reinterpret_cast<const std::byte *>(data))
    {
        for (int i = 0; i < kVariableStrides; ++i)
        {
            m_strides[i] = strides[i];
        }
    }

    __host__ TensorWrap(const ImageDataStridedCuda &image)
    {
        static_assert(kVariableStrides == 1 && kNumDimensions == 2);

        m_data = reinterpret_cast<const std::byte *>(image.plane(0).basePtr);

        m_strides[0] = image.plane(0).rowStride;
    }

    __host__ TensorWrap(const TensorDataStridedCuda &tensor)
    {
        constexpr int kStride[] = {std::forward<int>(Strides)...};

        assert(tensor.rank() >= kNumDimensions);

        m_data = reinterpret_cast<const std::byte *>(tensor.basePtr());

#pragma unroll
        for (int i = 0; i < kNumDimensions; ++i)
        {
            if (kStride[i] != -1)
            {
                assert(tensor.stride(i) == kStride[i]);
            }
            else if (i < kVariableStrides)
            {
                assert(tensor.stride(i) <= TypeTraits<int>::max);

                m_strides[i] = tensor.stride(i);
            }
        }
    }

    __host__ __device__ const int *strides() const
    {
        return m_strides;
    }

    template<typename DimType, class = Require<std::is_same_v<int, BaseType<DimType>>>>
    inline const __host__ __device__ T &operator[](DimType c) const
    {
        if constexpr (NumElements<DimType> == 1)
        {
            if constexpr (NumComponents<DimType> == 0)
            {
                return *doGetPtr(c);
            }
            else
            {
                return *doGetPtr(c.x);
            }
        }
        else if constexpr (NumElements<DimType> == 2)
        {
            return *doGetPtr(c.y, c.x);
        }
        else if constexpr (NumElements<DimType> == 3)
        {
            return *doGetPtr(c.z, c.y, c.x);
        }
        else if constexpr (NumElements<DimType> == 4)
        {
            return *doGetPtr(c.w, c.z, c.y, c.x);
        }
    }

    template<typename... Args>
    inline const __host__ __device__ T *ptr(Args... c) const
    {
        return doGetPtr(c...);
    }

protected:
    template<typename... Args>
    inline const __host__ __device__ T *doGetPtr(Args... c) const
    {
        static_assert(std::conjunction_v<std::is_same<int, Args>...>);
        static_assert(sizeof...(Args) <= kNumDimensions);

        constexpr int kArgSize  = sizeof...(Args);
        constexpr int kVarSize  = kArgSize < kVariableStrides ? kArgSize : kVariableStrides;
        constexpr int kDimSize  = kArgSize < kNumDimensions ? kArgSize : kNumDimensions;
        constexpr int kStride[] = {std::forward<int>(Strides)...};

        int coords[] = {std::forward<int>(c)...};

        // Computing offset first potentially postpones or avoids 64-bit math during addressing
        int offset = 0;
#pragma unroll
        for (int i = 0; i < kVarSize; ++i)
        {
            offset += coords[i] * m_strides[i];
        }
#pragma unroll
        for (int i = kVariableStrides; i < kDimSize; ++i)
        {
            offset += coords[i] * kStride[i];
        }

        return reinterpret_cast<const T *>(m_data + offset);
    }

private:
    const std::byte *m_data                      = nullptr;
    int              m_strides[kVariableStrides] = {};
};

template<typename T, int... Strides>
class TensorWrap : public TensorWrap<const T, Strides...>
{
    using Base = TensorWrap<const T, Strides...>;

public:
    using ValueType = T;

    using Base::kConstantStrides;
    using Base::kNumDimensions;
    using Base::kVariableStrides;

    TensorWrap() = default;

    template<typename DataType, typename... Args>
    explicit __host__ __device__ TensorWrap(DataType *data, Args... strides)
        : Base(data, strides...)
    {
    }

    template<typename DataType, typename StrideType>
    explicit __host__ __device__ TensorWrap(DataType *data, StrideType *strides)
        : Base(data, strides)
    {
    }

    __host__ TensorWrap(const ImageDataStridedCuda &image)
        : Base(image)
    {
    }

    __host__ TensorWrap(const TensorDataStridedCuda &tensor)
        : Base(tensor)
    {
    }

    template<typename DimType, class = Require<std::is_same_v<int, BaseType<DimType>>>>
    inline __host__ __device__ T &operator[](DimType c) const
    {
        if constexpr (NumElements<DimType> == 1)
        {
            if constexpr (NumComponents<DimType> == 0)
            {
                return *doGetPtr(c);
            }
            else
            {
                return *doGetPtr(c.x);
            }
        }
        else if constexpr (NumElements<DimType> == 2)
        {
            return *doGetPtr(c.y, c.x);
        }
        else if constexpr (NumElements<DimType> == 3)
        {
            return *doGetPtr(c.z, c.y, c.x);
        }
        else if constexpr (NumElements<DimType> == 4)
        {
            return *doGetPtr(c.w, c.z, c.y, c.x);
        }
    }

    template<typename... Args>
    inline __host__ __device__ T *ptr(Args... c) const
    {
        return doGetPtr(c...);
    }

protected:
    template<typename... Args>
    inline __host__ __device__ T *doGetPtr(Args... c) const
    {
        // The const_cast here is the *only* place where it is used to remove the base pointer constness
        return const_cast<T *>(Base::doGetPtr(c...));
    }
};

template<typename T>
using Tensor1DWrap = TensorWrap<T, sizeof(T)>;

template<typename T>
using Tensor2DWrap = TensorWrap<T, -1, sizeof(T)>;

template<typename T>
using Tensor3DWrap = TensorWrap<T, -1, -1, sizeof(T)>;

template<typename T>
using Tensor4DWrap = TensorWrap<T, -1, -1, -1, sizeof(T)>;

template<typename T>
using Tensor5DWrap = TensorWrap<T, -1, -1, -1, -1, sizeof(T)>;

template<typename T, int N>
using TensorNDWrap = std::conditional_t<
    N == 1, Tensor1DWrap<T>,
    std::conditional_t<N == 2, Tensor2DWrap<T>,
                       std::conditional_t<N == 3, Tensor3DWrap<T>,
                                          std::conditional_t<N == 4, Tensor4DWrap<T>,
                                                             std::conditional_t<N == 5, Tensor5DWrap<T>, void>>>>>;

template<typename T, class = Require<HasTypeTraits<T>>>
__host__ auto CreateTensorWrapNHW(const TensorDataStridedCuda &tensor)
{
    auto tensorAccess = TensorDataAccessStridedImagePlanar::Create(tensor);
    assert(tensorAccess);
    assert(tensorAccess->sampleStride() <= TypeTraits<int>::max);
    assert(tensorAccess->rowStride() <= TypeTraits<int>::max);

    return Tensor3DWrap<T>(tensor.basePtr(), static_cast<int>(tensorAccess->sampleStride()),
                           static_cast<int>(tensorAccess->rowStride()));
}

template<typename T, class = Require<HasTypeTraits<T>>>
__host__ auto CreateTensorWrapNHWC(const TensorDataStridedCuda &tensor)
{
    auto tensorAccess = TensorDataAccessStridedImagePlanar::Create(tensor);
    assert(tensorAccess);
    assert(tensorAccess->sampleStride() <= TypeTraits<int>::max);
    assert(tensorAccess->rowStride() <= TypeTraits<int>::max);
    assert(tensorAccess->colStride() <= TypeTraits<int>::max);

    return Tensor4DWrap<T>(tensor.basePtr(), static_cast<int>(tensorAccess->sampleStride()),
                           static_cast<int>(tensorAccess->rowStride()), static_cast<int>(tensorAccess->colStride()));
}

} // namespace nvcv::cuda

#endif // NVCV_CUDA_TENSOR_WRAP_HPP