Program Listing for File TensorLayoutInfo.hpp

Return to documentation for file (nvcv_types/include/nvcv/TensorLayoutInfo.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_TENSOR_LAYOUT_INFO_HPP
#define NVCV_TENSOR_LAYOUT_INFO_HPP

#include "Optional.hpp"
#include "TensorLayout.hpp"

namespace nvcv {

class TensorLayoutInfo
{
public:
    static bool IsCompatible(const TensorLayout &)
    {
        return true;
    }

    static Optional<TensorLayoutInfo> Create(const TensorLayout &layout)
    {
        return TensorLayoutInfo{layout};
    }

    constexpr const TensorLayout &layout() const

    {
        return m_layout;
    }

    constexpr bool isBatch() const
    {
        return m_cacheIsBatch;
    }

    int idxSample() const
    {
        return m_cacheIdxSample;
    }

    bool isImage() const
    {
        return m_cacheIsImage;
    }

protected:
    TensorLayoutInfo(const TensorLayout &layout)
        : m_layout(layout)
    {
        // isBatch ----------------
        m_cacheIsBatch = m_layout.rank() > 0 && m_layout[0] == LABEL_BATCH;

        // isImage ----------------
        if (m_layout != TENSOR_NONE)
        {
            m_cacheIsImage = m_layout.find(LABEL_WIDTH) >= 0;
        }
        else
        {
            m_cacheIsImage = false;
        }

        // idxSample ----------------
        m_cacheIdxSample = m_cacheIsBatch ? 0 : -1;
    }

private:
    TensorLayout m_layout;
    bool         m_cacheIsBatch;
    bool         m_cacheIsImage;
    int          m_cacheIdxSample;
};

class TensorLayoutInfoImage : public TensorLayoutInfo
{
public:
    static bool IsCompatible(const TensorLayout &layout)
    {
        if (auto info = TensorLayoutInfo::Create(layout))
        {
            return info->isImage();
        }
        else
        {
            return false;
        }
    }

    static Optional<TensorLayoutInfoImage> Create(const TensorLayout &layout)
    {
        if (IsCompatible(layout))
        {
            return TensorLayoutInfoImage{layout};
        }
        else
        {
            return NullOpt;
        }
    }

    int numSpatialDims() const
    {
        return m_cacheNumSpatialDims;
    }

    bool isRowMajor() const
    {
        return m_cacheIsRowMajor;
    }

    int idxChannel() const
    {
        return m_cacheIdxChannel;
    }

    int idxWidth() const
    {
        return m_cacheIdxWidth;
    }

    int idxHeight() const
    {
        return m_cacheIdxHeight;
    }

    int idxDepth() const
    {
        return m_cacheIdxDepth;
    }

    bool hasChannel() const
    {
        return m_cacheHasChannel;
    }

    bool isChannelFirst() const
    {
        return m_cacheIsChannelFirst;
    }

    bool isChannelLast() const
    {
        return m_cacheIsChannelLast;
    }

protected:
    TensorLayoutInfoImage(const TensorLayout &layout)
        : TensorLayoutInfo(layout)
    {
        m_cacheNumSpatialDims = std::count_if(layout.begin(), layout.end(),
                                              [](char v)
                                              {
                                                  switch (v)
                                                  {
                                                  case LABEL_WIDTH:
                                                  case LABEL_HEIGHT:
                                                  case LABEL_DEPTH:
                                                      return true;
                                                  default:
                                                      return false;
                                                  }
                                              });

        m_cacheIsRowMajor = layout.endsWith(TENSOR_W) || layout.endsWith(TENSOR_WC);
        m_cacheIdxChannel = layout.find(LABEL_CHANNEL);
        m_cacheIdxWidth   = layout.find(LABEL_WIDTH);
        m_cacheIdxHeight  = layout.find(LABEL_HEIGHT);
        m_cacheIdxDepth   = layout.find(LABEL_DEPTH);
        m_cacheHasChannel = m_cacheIdxChannel >= 0;

        // isChannelFirst --------------
        if (layout != TENSOR_NONE)
        {
            if (this->isBatch())
            {
                m_cacheIsChannelFirst = layout[1] == LABEL_CHANNEL;
            }
            else
            {
                m_cacheIsChannelFirst = layout[0] == LABEL_CHANNEL;
            }
        }
        else
        {
            m_cacheIsChannelFirst = false;
        }

        // isChannelLast --------------
        if (layout != TENSOR_NONE)
        {
            m_cacheIsChannelLast = layout[layout.rank() - 1] == LABEL_CHANNEL || !this->hasChannel();
        }
        else
        {
            m_cacheIsChannelLast = false;
        }
    }

private:
    int  m_cacheNumSpatialDims;
    bool m_cacheIsRowMajor;
    int  m_cacheIdxChannel;
    int  m_cacheIdxWidth;
    int  m_cacheIdxHeight;
    int  m_cacheIdxDepth;
    bool m_cacheHasChannel;
    bool m_cacheIsChannelFirst;
    bool m_cacheIsChannelLast;
};

} // namespace nvcv

#endif // NVCV_TENSOR_LAYOUT_INFO_HPP