Program Listing for File TensorShapeInfo.hpp

Return to documentation for file (nvcv_types/include/nvcv/TensorShapeInfo.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_TENSORSHAPEINFO_HPP
#define NVCV_TENSORSHAPEINFO_HPP

#include "Size.hpp"
#include "TensorLayoutInfo.hpp"
#include "TensorShape.hpp"

#include <cassert>

namespace nvcv {

// The criteria followed for the design of the TensorShapeInfo hierarchy is as follows:
// - no virtual dispatches
// - no heap allocation
// - minimal memory footprint
// - fast!
//
// To achieve that, the tensor layout info class for each TensorShapeInfo child
// had to be passed to the parent's constructor, instead of using virtual
// method. This allows for the parent ctor to also use the layout info, which
// wouldn't be the case if using virtual method to get it.
// Care was taken to handle object copies, as the parent of the new class must
// use the tensor layout info object of the new object, not the old.

namespace detail {

template<typename LAYOUT_INFO>
class TensorShapeInfoImpl
{
public:
    using LayoutInfo = LAYOUT_INFO;

    TensorShapeInfoImpl(const TensorShape &shape, const LayoutInfo &infoLayout)
        : m_shape(shape)
        , m_infoLayout(infoLayout)
    {
        // idxSample
        int idx = m_infoLayout.idxSample();
        if (idx >= 0)
        {
            m_cacheNumSamples = m_shape[idx];
        }
        else if (m_shape.layout() != TENSOR_NONE)
        {
            m_cacheNumSamples = 1;
        }
        else
        {
            m_cacheNumSamples = 0;
        }
    }

    const TensorShape &shape() const
    {
        return m_shape;
    }

    const TensorLayout &layout() const
    {
        return m_shape.layout();
    }

    const LayoutInfo &infoLayout() const
    {
        return m_infoLayout;
    }

    TensorShape::DimType numSamples() const
    {
        return m_cacheNumSamples;
    }

    bool isImage() const
    {
        return m_infoLayout.isImage();
    }

protected:
    TensorShape m_shape;
    LayoutInfo  m_infoLayout;
    int         m_cacheNumSamples;
};

} // namespace detail

class TensorShapeInfo : public detail::TensorShapeInfoImpl<TensorLayoutInfo>
{
    using Base = detail::TensorShapeInfoImpl<TensorLayoutInfo>;

public:
    static bool IsCompatible(const TensorShape &tshape)
    {
        (void)tshape;
        return true;
    }

    static Optional<TensorShapeInfo> Create(const TensorShape &tshape)
    {
        return TensorShapeInfo(tshape);
    }

private:
    TensorShapeInfo(const TensorShape &tshape)
        : Base(tshape, *TensorLayoutInfo::Create(tshape.layout()))
    {
    }

    Optional<TensorLayoutInfo> m_infoLayout;
};

class TensorShapeInfoImage : public detail::TensorShapeInfoImpl<TensorLayoutInfoImage>
{
    using Base = detail::TensorShapeInfoImpl<TensorLayoutInfoImage>;

public:
    static bool IsCompatible(const TensorShape &tshape)
    {
        return TensorShapeInfo::IsCompatible(tshape) && TensorLayoutInfo::IsCompatible(tshape.layout());
    }

    static Optional<TensorShapeInfoImage> Create(const TensorShape &tshape)
    {
        if (IsCompatible(tshape))
        {
            return TensorShapeInfoImage(tshape);
        }
        else
        {
            return NullOpt;
        }
    }

    int32_t numChannels() const
    {
        return m_cacheNumChannels;
    }

    int32_t numCols() const
    {
        return m_cacheSize.w;
    }

    int32_t numRows() const
    {
        return m_cacheSize.h;
    }

    const Size2D &size() const
    {
        return m_cacheSize;
    }

protected:
    TensorShapeInfoImage(const TensorShape &tshape)
        : TensorShapeInfoImage(tshape, *TensorLayoutInfoImage::Create(tshape.layout()))
    {
    }

    TensorShapeInfoImage(const TensorShape &shape, const TensorLayoutInfoImage &infoLayout)
        : Base(shape, infoLayout)
    {
        // idxChannel
        int idx = this->infoLayout().idxChannel();
        if (idx >= 0)
        {
            m_cacheNumChannels = m_shape[idx];
        }
        else
        {
            m_cacheNumChannels = 1;
        }

        // idxWidth
        idx = this->infoLayout().idxWidth();
        if (idx < 0)
        {
            throw Exception(Status::ERROR_INVALID_ARGUMENT, "Image shape must have a Width dimension");
        }
        m_cacheSize.w = m_shape[idx];

        // idxHeight
        idx = this->infoLayout().idxHeight();
        if (idx >= 0)
        {
            m_cacheSize.h = m_shape[idx];
        }
        else
        {
            m_cacheSize.h = 1;
        }
    }

    Size2D m_cacheSize;
    int    m_cacheNumChannels;
};

class TensorShapeInfoImagePlanar : public TensorShapeInfoImage
{
public:
    static bool IsCompatible(const TensorShape &tshape)
    {
        if (auto infoLayout = TensorLayoutInfoImage::Create(tshape.layout()))
        {
            const TensorLayout &layout = tshape.layout();

            if (infoLayout->isRowMajor() && (infoLayout->isChannelFirst() || infoLayout->isChannelLast()))
            {
                int iheight = infoLayout->idxHeight();
                // Has explicit height?
                if (iheight >= 0)
                {
                    assert(iheight + 1 < layout.rank());
                    // *HWC, [^C]*HW, *CHW
                    return layout[iheight + 1] == LABEL_WIDTH
                        && (iheight == 0 || infoLayout->isChannelLast() || layout[iheight - 1] == LABEL_CHANNEL);
                }
                else
                {
                    int ichannel = infoLayout->idxChannel();

                    // [^HC]*W, [^H]*CW, [^H]*WC
                    return ichannel == -1 || ichannel >= layout.rank() - 2;
                }
            }
        }
        return false;
    }

    static Optional<TensorShapeInfoImagePlanar> Create(const TensorShape &tshape)
    {
        if (IsCompatible(tshape))
        {
            return TensorShapeInfoImagePlanar(tshape);
        }
        else
        {
            return NullOpt;
        }
    }

    int32_t numPlanes() const
    {
        return m_cacheNumPlanes;
    }

private:
    int m_cacheNumPlanes;

    TensorShapeInfoImagePlanar(const TensorShape &tshape)
        : TensorShapeInfoImage(tshape)
    {
        // numPlanes
        if (this->infoLayout().isChannelLast())
        {
            m_cacheNumPlanes = 1;
        }
        else
        {
            int ichannel = this->infoLayout().idxChannel();
            if (ichannel >= 0)
            {
                m_cacheNumPlanes = m_shape[ichannel];
            }
            else
            {
                m_cacheNumPlanes = 1;
            }
        }
    }
};

} // namespace nvcv

#endif // NVCV_TENSORSHAPEINFO_HPP