Program Listing for File TensorLayout.hpp
↰ Return to documentation for file (nvcv_types/include/nvcv/TensorLayout.hpp
)
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef NVCV_TENSOR_LAYOUT_HPP
#define NVCV_TENSOR_LAYOUT_HPP
#include "TensorLayout.h"
#include "detail/CheckError.hpp"
#include "detail/Concepts.hpp"
#include <cassert>
#include <iostream>
inline bool operator==(const NVCVTensorLayout &lhs, const NVCVTensorLayout &rhs)
{
return nvcvTensorLayoutCompare(lhs, rhs) == 0;
}
inline bool operator!=(const NVCVTensorLayout &lhs, const NVCVTensorLayout &rhs)
{
return !operator==(lhs, rhs);
}
inline bool operator<(const NVCVTensorLayout &lhs, const NVCVTensorLayout &rhs)
{
return nvcvTensorLayoutCompare(lhs, rhs) < 0;
}
inline std::ostream &operator<<(std::ostream &out, const NVCVTensorLayout &layout)
{
return out << nvcvTensorLayoutGetName(&layout);
}
namespace nvcv {
enum TensorLabel : char
{
LABEL_BATCH = NVCV_TLABEL_BATCH,
LABEL_CHANNEL = NVCV_TLABEL_CHANNEL,
LABEL_FRAME = NVCV_TLABEL_FRAME,
LABEL_DEPTH = NVCV_TLABEL_DEPTH,
LABEL_HEIGHT = NVCV_TLABEL_HEIGHT,
LABEL_WIDTH = NVCV_TLABEL_WIDTH
};
class TensorLayout final
{
public:
using const_iterator = const char *;
using iterator = const_iterator;
using value_type = char;
TensorLayout() = default;
constexpr TensorLayout(const NVCVTensorLayout &layout)
: m_layout(layout)
{
}
explicit TensorLayout(const char *descr)
{
detail::CheckThrow(nvcvTensorLayoutMake(descr, &m_layout));
}
template<class IT, class = detail::IsRandomAccessIterator<IT>>
explicit TensorLayout(IT itbeg, IT itend)
{
detail::CheckThrow(nvcvTensorLayoutMakeRange(&*itbeg, &*itend, &m_layout));
}
constexpr char operator[](int idx) const;
constexpr int rank() const;
int find(char dimLabel, int start = 0) const;
bool startsWith(const TensorLayout &test) const
{
return nvcvTensorLayoutStartsWith(m_layout, test.m_layout) != 0;
}
bool endsWith(const TensorLayout &test) const
{
return nvcvTensorLayoutEndsWith(m_layout, test.m_layout) != 0;
}
TensorLayout subRange(int beg, int end) const
{
TensorLayout out;
detail::CheckThrow(nvcvTensorLayoutMakeSubRange(m_layout, beg, end, &out.m_layout));
return out;
}
TensorLayout first(int n) const
{
TensorLayout out;
detail::CheckThrow(nvcvTensorLayoutMakeFirst(m_layout, n, &out.m_layout));
return out;
}
TensorLayout last(int n) const
{
TensorLayout out;
detail::CheckThrow(nvcvTensorLayoutMakeLast(m_layout, n, &out.m_layout));
return out;
}
friend bool operator==(const TensorLayout &a, const TensorLayout &b);
bool operator!=(const TensorLayout &that) const;
bool operator<(const TensorLayout &that) const;
constexpr const_iterator begin() const;
constexpr const_iterator end() const;
constexpr const_iterator cbegin() const;
constexpr const_iterator cend() const;
constexpr operator const NVCVTensorLayout &() const;
friend std::ostream &operator<<(std::ostream &out, const TensorLayout &that);
// Public so that class is trivial but still the
// implicit ctors do the right thing
NVCVTensorLayout m_layout;
};
#define NVCV_DETAIL_DEF_TLAYOUT(LAYOUT) constexpr const TensorLayout TENSOR_##LAYOUT{NVCV_TENSOR_##LAYOUT};
NVCV_DETAIL_DEF_TLAYOUT(NONE)
#include "TensorLayoutDef.inc"
#undef NVCV_DETAIL_DEF_TLAYOUT
constexpr const TensorLayout &GetImplicitTensorLayout(int rank)
{
// clang-format off
return rank == 1
? TENSOR_W
: (rank == 2
? TENSOR_HW
: (rank == 3
? TENSOR_NHW
: (rank == 4
? TENSOR_NCHW
: (rank == 5
? TENSOR_NCDHW
: (rank == 6
? TENSOR_NCFDHW
: TENSOR_NONE
)
)
)
)
);
// clang-format on
}
constexpr char TensorLayout::operator[](int idx) const
{
return nvcvTensorLayoutGetLabel(m_layout, idx);
}
constexpr int TensorLayout::rank() const
{
return nvcvTensorLayoutGetNumDim(m_layout);
}
inline int TensorLayout::find(char dimLabel, int start) const
{
return nvcvTensorLayoutFindDimIndex(m_layout, dimLabel, start);
}
constexpr TensorLayout::operator const NVCVTensorLayout &() const
{
return m_layout;
}
inline bool operator==(const TensorLayout &a, const TensorLayout &b)
{
return a.m_layout == b.m_layout;
}
inline bool TensorLayout::operator!=(const TensorLayout &that) const
{
return !(*this == that);
}
inline bool TensorLayout::operator<(const TensorLayout &that) const
{
return m_layout < that.m_layout;
}
constexpr auto TensorLayout::begin() const -> const_iterator
{
return nvcvTensorLayoutGetName(&m_layout);
}
constexpr inline auto TensorLayout::end() const -> const_iterator
{
return this->begin() + this->rank();
}
constexpr auto TensorLayout::cbegin() const -> const_iterator
{
return this->begin();
}
constexpr auto TensorLayout::cend() const -> const_iterator
{
return this->end();
}
inline std::ostream &operator<<(std::ostream &out, const TensorLayout &that)
{
return out << that.m_layout;
}
// For disambiguation
inline bool operator==(const TensorLayout &lhs, const NVCVTensorLayout &rhs)
{
return nvcvTensorLayoutCompare(lhs.m_layout, rhs) == 0;
}
inline bool operator!=(const TensorLayout &lhs, const NVCVTensorLayout &rhs)
{
return !operator==(lhs, rhs);
}
inline bool operator<(const TensorLayout &lhs, const NVCVTensorLayout &rhs)
{
return nvcvTensorLayoutCompare(lhs.m_layout, rhs) < 0;
}
} // namespace nvcv
#endif // NVCV_TENSOR_LAYOUT_HPP