Program Listing for File TensorDataAccess.hpp

Return to documentation for file (nvcv_types/include/nvcv/TensorDataAccess.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_TENSORDATAACESSOR_HPP
#define NVCV_TENSORDATAACESSOR_HPP

#include "TensorData.hpp"
#include "TensorShapeInfo.hpp"

#include <cstddef>

namespace nvcv {

// Design is similar to TensorShapeInfo hierarchy

namespace detail {

template<typename ShapeInfo, typename LayoutInfo = typename ShapeInfo::LayoutInfo>
class TensorDataAccessStridedImpl
{
public:
    TensorDataAccessStridedImpl(const TensorDataStrided &tdata, const ShapeInfo &infoShape)
        : m_tdata(tdata)
        , m_infoShape(infoShape)
    {
    }

    TensorShape::DimType numSamples() const
    {
        return m_infoShape.numSamples();
    }

    DataType dtype() const
    {
        return m_tdata.dtype();
    }

    const TensorLayout &layout() const
    {
        return m_tdata.layout();
    }

    const TensorShape &shape() const
    {
        return m_tdata.shape();
    }

    int64_t sampleStride() const
    {
        int idx = this->infoLayout().idxSample();
        if (idx >= 0)
        {
            return m_tdata.stride(idx);
        }
        else
        {
            return 0;
        }
    }

    Byte *sampleData(int n) const
    {
        return sampleData(n, m_tdata.basePtr());
    }

    Byte *sampleData(int n, Byte *base) const
    {
        assert(0 <= n && n < this->numSamples());
        return base + this->sampleStride() * n;
    }

    bool isImage() const
    {
        return m_infoShape.isImage();
    }

    const ShapeInfo &infoShape() const
    {
        return m_infoShape;
    }

    const LayoutInfo &infoLayout() const
    {
        return m_infoShape.infoLayout();
    }

protected:
    TensorDataStrided m_tdata;

    TensorDataAccessStridedImpl(const TensorDataAccessStridedImpl &that, const TensorShapeInfo &infoShape)
        : m_tdata(that.m_tdata)
        , m_infoShape(infoShape)
    {
    }

private:
    ShapeInfo m_infoShape;
};

template<typename ShapeInfo>
class TensorDataAccessStridedImageImpl : public TensorDataAccessStridedImpl<ShapeInfo>
{
    using Base = detail::TensorDataAccessStridedImpl<ShapeInfo>;

public:
    TensorDataAccessStridedImageImpl(const TensorDataStrided &tdata, const ShapeInfo &infoShape)
        : Base(tdata, infoShape)
    {
    }

    int32_t numCols() const
    {
        return this->infoShape().numCols();
    }

    int32_t numRows() const
    {
        return this->infoShape().numRows();
    }

    int32_t numChannels() const
    {
        return this->infoShape().numChannels();
    }

    Size2D size() const
    {
        return this->infoShape().size();
    }

    int64_t chStride() const
    {
        int idx = this->infoLayout().idxChannel();
        if (idx >= 0)
        {
            return this->m_tdata.stride(idx);
        }
        else
        {
            return 0;
        }
    }

    int64_t colStride() const
    {
        int idx = this->infoLayout().idxWidth();
        if (idx >= 0)
        {
            return this->m_tdata.stride(idx);
        }
        else
        {
            return 0;
        }
    }

    int64_t rowStride() const
    {
        int idx = this->infoLayout().idxHeight();
        if (idx >= 0)
        {
            return this->m_tdata.stride(idx);
        }
        else
        {
            return 0;
        }
    }

    int64_t depthStride() const
    {
        int idx = this->infoLayout().idxDepth();
        if (idx >= 0)
        {
            return this->m_tdata.stride(idx);
        }
        else
        {
            return 0;
        }
    }

    Byte *rowData(int y) const
    {
        return rowData(y, this->m_tdata.basePtr());
    }

    Byte *rowData(int y, Byte *base) const
    {
        assert(0 <= y && y < this->numRows());
        return base + this->rowStride() * y;
    }

    Byte *chData(int c) const
    {
        return chData(c, this->m_tdata.basePtr());
    }

    Byte *chData(int c, Byte *base) const
    {
        assert(0 <= c && c < this->numChannels());
        return base + this->chStride() * c;
    }

protected:
    TensorDataAccessStridedImageImpl(const TensorDataAccessStridedImageImpl &that, const ShapeInfo &infoShape)
        : Base(that, infoShape)
    {
    }
};

template<typename ShapeInfo>
class TensorDataAccessStridedImagePlanarImpl : public TensorDataAccessStridedImageImpl<ShapeInfo>
{
    using Base = TensorDataAccessStridedImageImpl<ShapeInfo>;

public:
    TensorDataAccessStridedImagePlanarImpl(const TensorDataStrided &tdata, const ShapeInfo &infoShape)
        : Base(tdata, infoShape)
    {
    }

    int32_t numPlanes() const
    {
        return this->infoShape().numPlanes();
    }

    int64_t planeStride() const
    {
        if (this->infoLayout().isChannelFirst())
        {
            int ichannel = this->infoLayout().idxChannel();
            assert(ichannel >= 0);
            return this->m_tdata.stride(ichannel);
        }
        else
        {
            return 0;
        }
    }

    Byte *planeData(int p) const
    {
        return planeData(p, this->m_tdata.basePtr());
    }

    Byte *planeData(int p, Byte *base) const
    {
        assert(0 <= p && p < this->numPlanes());
        return base + this->planeStride() * p;
    }
};

} // namespace detail

class TensorDataAccessStrided : public detail::TensorDataAccessStridedImpl<TensorShapeInfo>
{
    using Base = detail::TensorDataAccessStridedImpl<TensorShapeInfo>;

public:
    static bool IsCompatible(const TensorData &data)
    {
        return data.IsCompatible<TensorDataStrided>();
    }

    static Optional<TensorDataAccessStrided> Create(const TensorData &data)
    {
        if (Optional<TensorDataStrided> dataStrided = data.cast<TensorDataStrided>())
        {
            return TensorDataAccessStrided(dataStrided.value());
        }
        else
        {
            return NullOpt;
        }
    }

private:
    TensorDataAccessStrided(const TensorDataStrided &data)
        : Base(data, *TensorShapeInfo::Create(data.shape()))
    {
    }
};

class TensorDataAccessStridedImage : public detail::TensorDataAccessStridedImageImpl<TensorShapeInfoImage>
{
    using Base = detail::TensorDataAccessStridedImageImpl<TensorShapeInfoImage>;

public:
    static bool IsCompatible(const TensorData &data)
    {
        return TensorDataAccessStrided::IsCompatible(data) && TensorShapeInfoImage::IsCompatible(data.shape());
    }

    static Optional<TensorDataAccessStridedImage> Create(const TensorData &data)
    {
        if (IsCompatible(data))
        {
            return TensorDataAccessStridedImage(data.cast<TensorDataStrided>().value());
        }
        else
        {
            return NullOpt;
        }
    }

protected:
    TensorDataAccessStridedImage(const TensorDataStrided &data)
        : Base(data, *TensorShapeInfoImage::Create(data.shape()))
    {
    }
};

class TensorDataAccessStridedImagePlanar
    : public detail::TensorDataAccessStridedImagePlanarImpl<TensorShapeInfoImagePlanar>
{
    using Base = detail::TensorDataAccessStridedImagePlanarImpl<TensorShapeInfoImagePlanar>;

public:
    static bool IsCompatible(const TensorData &data)
    {
        return TensorDataAccessStridedImage::IsCompatible(data)
            && TensorShapeInfoImagePlanar::IsCompatible(data.shape());
    }

    static Optional<TensorDataAccessStridedImagePlanar> Create(const TensorData &data)
    {
        if (IsCompatible(data))
        {
            return TensorDataAccessStridedImagePlanar(data.cast<TensorDataStrided>().value());
        }
        else
        {
            return NullOpt;
        }
    }

protected:
    TensorDataAccessStridedImagePlanar(const TensorDataStrided &data)
        : Base(data, *TensorShapeInfoImagePlanar::Create(data.shape()))
    {
    }
};

} // namespace nvcv

#endif // NVCV_TENSORDATAACESSOR_HPP