Program Listing for File TensorBatchWrap.hpp

Return to documentation for file (nvcv_types/include/nvcv/cuda/TensorBatchWrap.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_CUDA_TENSOR_BATCH_WRAP_HPP
#define NVCV_CUDA_TENSOR_BATCH_WRAP_HPP

#include "TypeTraits.hpp" // for HasTypeTraits, etc
#include "nvcv/TensorBatchData.hpp"
#include "nvcv/cuda/TensorWrap.hpp"

#include <type_traits>

namespace nvcv::cuda {

template<typename T, int... Strides>
class TensorBatchWrap;

template<typename T, int... Strides>
class TensorBatchWrap<const T, Strides...>
{
    static_assert(HasTypeTraits<T>, "TensorBatchWrap<T> can only be used if T has type traits");

public:
    // The type provided as template parameter is the value type, i.e. the type of each element inside this wrapper.
    using ValueType = const T;

    static constexpr int kNumDimensions   = sizeof...(Strides);
    static constexpr int kVariableStrides = ((Strides == -1) + ...);
    static constexpr int kConstantStrides = kNumDimensions - kVariableStrides;

    TensorBatchWrap() = default;

    __host__ TensorBatchWrap(const TensorBatchDataStridedCuda &data)
        : TensorBatchWrap(data.cdata())
    {
    }

    __host__ __device__ TensorBatchWrap(const NVCVTensorBatchData &data)
        : m_numTensors(data.numTensors)
        , m_tensors(data.buffer.strided.tensors)
    {
    }

    template<typename... Coords>
    inline const __host__ __device__ T *ptr(int t, Coords... c) const
    {
        return doGetPtr(t, c...);
    }

    template<typename DimType, class = Require<std::is_same_v<int, BaseType<DimType>>>>
    inline const __host__ __device__ T &operator[](DimType c) const
    {
        static_assert(NumElements<DimType> == kNumDimensions + 1,
                      "Coordinates in the subscript operator must be (N+1)-dimensional, "
                      "where N is a dimensionality of a single tensor in the batch.");
        if constexpr (NumElements<DimType> == 1)
        {
            return *doGetPtr(c.x);
        }
        if constexpr (NumElements<DimType> == 2)
        {
            return *doGetPtr(c.x, c.y);
        }
        else if constexpr (NumElements<DimType> == 3)
        {
            return *doGetPtr(c.x, c.z, c.y);
        }
        else if constexpr (NumElements<DimType> == 4)
        {
            return *doGetPtr(c.x, c.w, c.z, c.y);
        }
    }

    inline const __host__ __device__ auto tensor(int t) const
    {
        return TensorWrap<ValueType, Strides...>(doGetPtr(t), strides(t));
    }

    inline __host__ __device__ int32_t numTensors() const
    {
        return m_numTensors;
    }

    inline const __host__ __device__ int64_t *shape(int t) const
    {
        assert(t >= 0 && t < m_numTensors);
        return m_tensors[t].shape;
    }

    inline const __host__ __device__ int64_t *strides(int t) const
    {
        assert(t >= 0 && t < m_numTensors);
        return m_tensors[t].stride;
    }

protected:
    template<typename... Args>
    inline __host__ __device__ T *doGetPtr(int t, Args... c) const
    {
        static_assert(std::conjunction_v<std::is_same<int, Args>...>);
        static_assert(sizeof...(Args) <= kNumDimensions);

        constexpr int kArgSize  = sizeof...(Args);
        constexpr int kVarSize  = kArgSize < kVariableStrides ? kArgSize : kVariableStrides;
        constexpr int kDimSize  = kArgSize < kNumDimensions ? kArgSize : kNumDimensions;
        constexpr int kStride[] = {std::forward<int>(Strides)...};

        // Computing offset first potentially postpones or avoids 64-bit math during addressing
        int offset = 0;
        if constexpr (kArgSize > 0)
        {
            int            coords[] = {std::forward<int>(c)...};
            const int64_t *strides  = m_tensors[t].stride;

#pragma unroll
            for (int i = 0; i < kVarSize; ++i)
            {
                offset += coords[i] * strides[i];
            }
#pragma unroll
            for (int i = kVariableStrides; i < kDimSize; ++i)
            {
                offset += coords[i] * kStride[i];
            }
        }

        NVCVByte *dataPtr = m_tensors[t].data;
        return reinterpret_cast<T *>(dataPtr + offset);
    }

    int32_t                           m_numTensors;
    NVCVTensorBatchElementStridedRec *m_tensors;
};

template<typename T, int... Strides>
class TensorBatchWrap : public TensorBatchWrap<const T, Strides...>
{
    using Base = TensorBatchWrap<const T, Strides...>;

public:
    using ValueType = T;
    using Base::doGetPtr;
    using Base::kNumDimensions;
    using Base::m_tensors;
    using Base::strides;

    __host__ TensorBatchWrap(const TensorBatchDataStridedCuda &data)
        : Base(data)
    {
    }

    __host__ __device__ TensorBatchWrap(NVCVTensorBatchData &data)
        : Base(data)
    {
    }

    template<typename... Coords>
    inline __host__ __device__ T *ptr(int t, Coords... c) const
    {
        return doGetPtr(t, c...);
    }

    inline __host__ __device__ auto tensor(int t) const
    {
        return TensorWrap<ValueType, Strides...>(doGetPtr(t), strides(t));
    }

    template<typename DimType, class = Require<std::is_same_v<int, BaseType<DimType>>>>
    inline __host__ __device__ T &operator[](DimType c) const
    {
        static_assert(NumElements<DimType> == kNumDimensions + 1,
                      "Coordinates in the subscript operator must be (N+1)-dimensional, "
                      "where N is a dimensionality of a single tensor in the batch.");
        if constexpr (NumElements<DimType> == 1)
        {
            return *doGetPtr(c.x);
        }
        if constexpr (NumElements<DimType> == 2)
        {
            return *doGetPtr(c.x, c.y);
        }
        else if constexpr (NumElements<DimType> == 3)
        {
            return *doGetPtr(c.x, c.z, c.y);
        }
        else if constexpr (NumElements<DimType> == 4)
        {
            return *doGetPtr(c.x, c.w, c.z, c.y);
        }
    }
};

template<typename T>
using TensorBatch1DWrap = TensorBatchWrap<T, sizeof(T)>;

template<typename T>
using TensorBatch2DWrap = TensorBatchWrap<T, -1, sizeof(T)>;

template<typename T>
using TensorBatch3DWrap = TensorBatchWrap<T, -1, -1, sizeof(T)>;

template<typename T>
using TensorBatch4DWrap = TensorBatchWrap<T, -1, -1, -1, sizeof(T)>;

template<typename T>
using TensorBatch5DWrap = TensorBatchWrap<T, -1, -1, -1, -1, sizeof(T)>;

template<typename T, int N>
using TensorBatchNDWrap = std::conditional_t<
    N == 1, TensorBatch1DWrap<T>,
    std::conditional_t<N == 2, TensorBatch2DWrap<T>,
                       std::conditional_t<N == 3, TensorBatch3DWrap<T>,
                                          std::conditional_t<N == 4, TensorBatch4DWrap<T>,
                                                             std::conditional_t<N == 5, TensorBatch5DWrap<T>, void>>>>>;
} // namespace nvcv::cuda

#endif // NVCV_CUDA_TENSOR_BATCH_WRAP_HPP