Program Listing for File ImageBatchVarShapeWrap.hpp

Return to documentation for file (nvcv_types/include/nvcv/cuda/ImageBatchVarShapeWrap.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_CUDA_IMAGE_BATCH_VAR_SHAPE_WRAP_HPP
#define NVCV_CUDA_IMAGE_BATCH_VAR_SHAPE_WRAP_HPP

#include "TypeTraits.hpp" // for HasTypeTraits, etc.

#include <nvcv/ImageBatchData.hpp> // for ImageBatchVarShapeDataStridedCuda, etc.

namespace nvcv::cuda {

template<typename T>
class ImageBatchVarShapeWrap;

template<typename T>
class ImageBatchVarShapeWrap<const T>
{
    // It is a requirement of this class that its type has type traits.
    static_assert(HasTypeTraits<T>, "ImageBatchVarShapeWrap<T> can only be used if T has type traits");

public:
    // The type provided as template parameter is the value type, i.e. the type of each element inside this wrapper.
    using ValueType = const T;

    // The number of dimensions is fixed as 4, one for each image in the batch and one for each plane of a 2D image.
    static constexpr int kNumDimensions = 4;
    // The number of variable strides is fixed as 3, meaning it uses 3 run-time strides.
    static constexpr int kVariableStrides = 3;
    // The number of constant strides is fixed as 1, meaning it uses 1 compile-time stride.
    static constexpr int kConstantStrides = 1;

    ImageBatchVarShapeWrap() = default;

    __host__ ImageBatchVarShapeWrap(const ImageBatchVarShapeDataStridedCuda &images)
        : m_imageList(images.imageList())
    {
    }

    inline const __host__ __device__ NVCVImagePlaneStrided plane(int s, int p = 0) const
    {
        return m_imageList[s].planes[p];
    }

    inline __host__ __device__ int width(int s, int p = 0) const
    {
        return m_imageList[s].planes[p].width;
    }

    inline __host__ __device__ int height(int s, int p = 0) const
    {
        return m_imageList[s].planes[p].height;
    }

    inline __host__ __device__ int rowStride(int s, int p = 0) const
    {
        return m_imageList[s].planes[p].rowStride;
    }

    inline const __host__ __device__ T &operator[](int4 c) const
    {
        return *doGetPtr(c.w, c.z, c.y, c.x);
    }

    inline const __host__ __device__ T &operator[](int3 c) const
    {
        return *doGetPtr(c.z, 0, c.y, c.x);
    }

    inline const __host__ __device__ T *ptr(int s, int p, int y, int x) const
    {
        return doGetPtr(s, p, y, x);
    }

    inline const __host__ __device__ T *ptr(int s, int y, int x) const
    {
        return doGetPtr(s, 0, y, x);
    }

    inline const __host__ __device__ T *ptr(int s, int y) const
    {
        return doGetPtr(s, 0, y, 0);
    }

protected:
    inline const __host__ __device__ T *doGetPtr(int s, int p, int y, int x) const
    {
        int offset = y * this->rowStride(s, p) + x * sizeof(T);

        return reinterpret_cast<const T *>(m_imageList[s].planes[p].basePtr + offset);
    }

private:
    const NVCVImageBufferStrided *m_imageList = nullptr;
};

template<typename T>
class ImageBatchVarShapeWrap : public ImageBatchVarShapeWrap<const T>
{
    using Base = ImageBatchVarShapeWrap<const T>;

public:
    using ValueType = T;

    using Base::kConstantStrides;
    using Base::kNumDimensions;
    using Base::kVariableStrides;

    ImageBatchVarShapeWrap() = default;

    __host__ ImageBatchVarShapeWrap(const ImageBatchVarShapeDataStridedCuda &images)
        : Base(images)
    {
    }

    using Base::height;
    using Base::plane;
    using Base::rowStride;
    using Base::width;

    inline __host__ __device__ T &operator[](int4 c) const
    {
        return *doGetPtr(c.w, c.z, c.y, c.x);
    }

    inline __host__ __device__ T &operator[](int3 c) const
    {
        return *doGetPtr(c.z, 0, c.y, c.x);
    }

    inline __host__ __device__ T *ptr(int s, int p, int y, int x) const
    {
        return doGetPtr(s, p, y, x);
    }

    inline __host__ __device__ T *ptr(int s, int y, int x) const
    {
        return doGetPtr(s, 0, y, x);
    }

    inline __host__ __device__ T *ptr(int s, int y) const
    {
        return doGetPtr(s, 0, y, 0);
    }

protected:
    inline __host__ __device__ T *doGetPtr(int s, int p, int y, int x) const
    {
        return const_cast<T *>(Base::doGetPtr(s, p, y, x));
    }
};

template<typename T>
class ImageBatchVarShapeWrapNHWC : ImageBatchVarShapeWrap<T>
{
    using Base = ImageBatchVarShapeWrap<T>;

public:
    using ValueType = T;

    using Base::kConstantStrides;
    using Base::kNumDimensions;
    using Base::kVariableStrides;

    ImageBatchVarShapeWrapNHWC() = default;

    __host__ ImageBatchVarShapeWrapNHWC(const ImageBatchVarShapeDataStridedCuda &images, int numChannels)
        : Base(images)
        , m_numChannels(numChannels)
    {
#ifndef NDEBUG
        auto formats = images.hostFormatList();
        for (int i = 0; i < images.numImages(); ++i)
        {
            assert(1 == nvcv::ImageFormat{formats[i]}.numPlanes()
                   && "This wrapper class is only for single-plane images in batch");
        }
#endif
    }

    // Get the number of channels.
    inline __host__ __device__ const int &numChannels() const
    {
        return m_numChannels;
    }

    using Base::height;
    using Base::plane;
    using Base::rowStride;
    using Base::width;

    inline __host__ __device__ T &operator[](int4 c) const
    {
        return *doGetPtr(c.x, c.y, c.z, c.w);
    }

    inline __host__ __device__ T &operator[](int3 c) const
    {
        return *doGetPtr(c.x, c.y, c.z, 0);
    }

    inline __host__ __device__ T *ptr(int s, int y, int x, int c) const
    {
        return doGetPtr(s, y, x, c);
    }

    inline __host__ __device__ T *ptr(int s, int y, int x) const
    {
        return doGetPtr(s, y, x, 0);
    }

    inline __host__ __device__ T *ptr(int s, int y) const
    {
        return doGetPtr(s, y, 0, 0);
    }

private:
    inline __host__ __device__ T *doGetPtr(int s, int y, int x, int c) const
    {
        return Base::doGetPtr(s, 0, y, 0) + x * m_numChannels + c;
    }

    int m_numChannels = 1;
};

} // namespace nvcv::cuda

#endif // NVCV_CUDA_IMAGE_BATCH_VAR_SHAPE_HPP