Program Listing for File TensorDataAccess.hpp
↰ Return to documentation for file (nvcv_types/include/nvcv/TensorDataAccess.hpp
)
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef NVCV_TENSORDATAACESSOR_HPP
#define NVCV_TENSORDATAACESSOR_HPP
#include "TensorData.hpp"
#include "TensorShapeInfo.hpp"
#include <cstddef>
namespace nvcv {
// Design is similar to TensorShapeInfo hierarchy
namespace detail {
template<typename ShapeInfo, typename LayoutInfo = typename ShapeInfo::LayoutInfo>
class TensorDataAccessStridedImpl
{
public:
TensorDataAccessStridedImpl(const TensorDataStrided &tdata, const ShapeInfo &infoShape)
: m_tdata(tdata)
, m_infoShape(infoShape)
{
}
TensorShape::DimType numSamples() const
{
return m_infoShape.numSamples();
}
DataType dtype() const
{
return m_tdata.dtype();
}
const TensorLayout &layout() const
{
return m_tdata.layout();
}
const TensorShape &shape() const
{
return m_tdata.shape();
}
int64_t sampleStride() const
{
int idx = this->infoLayout().idxSample();
if (idx >= 0)
{
return m_tdata.stride(idx);
}
else
{
return 0;
}
}
Byte *sampleData(int n) const
{
return sampleData(n, m_tdata.basePtr());
}
Byte *sampleData(int n, Byte *base) const
{
assert(0 <= n && n < this->numSamples());
return base + this->sampleStride() * n;
}
bool isImage() const
{
return m_infoShape.isImage();
}
const ShapeInfo &infoShape() const
{
return m_infoShape;
}
const LayoutInfo &infoLayout() const
{
return m_infoShape.infoLayout();
}
protected:
TensorDataStrided m_tdata;
TensorDataAccessStridedImpl(const TensorDataAccessStridedImpl &that, const TensorShapeInfo &infoShape)
: m_tdata(that.m_tdata)
, m_infoShape(infoShape)
{
}
private:
ShapeInfo m_infoShape;
};
template<typename ShapeInfo>
class TensorDataAccessStridedImageImpl : public TensorDataAccessStridedImpl<ShapeInfo>
{
using Base = detail::TensorDataAccessStridedImpl<ShapeInfo>;
public:
TensorDataAccessStridedImageImpl(const TensorDataStrided &tdata, const ShapeInfo &infoShape)
: Base(tdata, infoShape)
{
}
int32_t numCols() const
{
return this->infoShape().numCols();
}
int32_t numRows() const
{
return this->infoShape().numRows();
}
int32_t numChannels() const
{
return this->infoShape().numChannels();
}
Size2D size() const
{
return this->infoShape().size();
}
int64_t chStride() const
{
int idx = this->infoLayout().idxChannel();
if (idx >= 0)
{
return this->m_tdata.stride(idx);
}
else
{
return 0;
}
}
int64_t colStride() const
{
int idx = this->infoLayout().idxWidth();
if (idx >= 0)
{
return this->m_tdata.stride(idx);
}
else
{
return 0;
}
}
int64_t rowStride() const
{
int idx = this->infoLayout().idxHeight();
if (idx >= 0)
{
return this->m_tdata.stride(idx);
}
else
{
return 0;
}
}
int64_t depthStride() const
{
int idx = this->infoLayout().idxDepth();
if (idx >= 0)
{
return this->m_tdata.stride(idx);
}
else
{
return 0;
}
}
Byte *rowData(int y) const
{
return rowData(y, this->m_tdata.basePtr());
}
Byte *rowData(int y, Byte *base) const
{
assert(0 <= y && y < this->numRows());
return base + this->rowStride() * y;
}
Byte *chData(int c) const
{
return chData(c, this->m_tdata.basePtr());
}
Byte *chData(int c, Byte *base) const
{
assert(0 <= c && c < this->numChannels());
return base + this->chStride() * c;
}
protected:
TensorDataAccessStridedImageImpl(const TensorDataAccessStridedImageImpl &that, const ShapeInfo &infoShape)
: Base(that, infoShape)
{
}
};
template<typename ShapeInfo>
class TensorDataAccessStridedImagePlanarImpl : public TensorDataAccessStridedImageImpl<ShapeInfo>
{
using Base = TensorDataAccessStridedImageImpl<ShapeInfo>;
public:
TensorDataAccessStridedImagePlanarImpl(const TensorDataStrided &tdata, const ShapeInfo &infoShape)
: Base(tdata, infoShape)
{
}
int32_t numPlanes() const
{
return this->infoShape().numPlanes();
}
int64_t planeStride() const
{
if (this->infoLayout().isChannelFirst())
{
int ichannel = this->infoLayout().idxChannel();
assert(ichannel >= 0);
return this->m_tdata.stride(ichannel);
}
else
{
return 0;
}
}
Byte *planeData(int p) const
{
return planeData(p, this->m_tdata.basePtr());
}
Byte *planeData(int p, Byte *base) const
{
assert(0 <= p && p < this->numPlanes());
return base + this->planeStride() * p;
}
};
} // namespace detail
class TensorDataAccessStrided : public detail::TensorDataAccessStridedImpl<TensorShapeInfo>
{
using Base = detail::TensorDataAccessStridedImpl<TensorShapeInfo>;
public:
static bool IsCompatible(const TensorData &data)
{
return data.IsCompatible<TensorDataStrided>();
}
static Optional<TensorDataAccessStrided> Create(const TensorData &data)
{
if (Optional<TensorDataStrided> dataStrided = data.cast<TensorDataStrided>())
{
return TensorDataAccessStrided(dataStrided.value());
}
else
{
return NullOpt;
}
}
private:
TensorDataAccessStrided(const TensorDataStrided &data)
: Base(data, *TensorShapeInfo::Create(data.shape()))
{
}
};
class TensorDataAccessStridedImage : public detail::TensorDataAccessStridedImageImpl<TensorShapeInfoImage>
{
using Base = detail::TensorDataAccessStridedImageImpl<TensorShapeInfoImage>;
public:
static bool IsCompatible(const TensorData &data)
{
return TensorDataAccessStrided::IsCompatible(data) && TensorShapeInfoImage::IsCompatible(data.shape());
}
static Optional<TensorDataAccessStridedImage> Create(const TensorData &data)
{
if (IsCompatible(data))
{
return TensorDataAccessStridedImage(data.cast<TensorDataStrided>().value());
}
else
{
return NullOpt;
}
}
protected:
TensorDataAccessStridedImage(const TensorDataStrided &data)
: Base(data, *TensorShapeInfoImage::Create(data.shape()))
{
}
};
class TensorDataAccessStridedImagePlanar
: public detail::TensorDataAccessStridedImagePlanarImpl<TensorShapeInfoImagePlanar>
{
using Base = detail::TensorDataAccessStridedImagePlanarImpl<TensorShapeInfoImagePlanar>;
public:
static bool IsCompatible(const TensorData &data)
{
return TensorDataAccessStridedImage::IsCompatible(data)
&& TensorShapeInfoImagePlanar::IsCompatible(data.shape());
}
static Optional<TensorDataAccessStridedImagePlanar> Create(const TensorData &data)
{
if (IsCompatible(data))
{
return TensorDataAccessStridedImagePlanar(data.cast<TensorDataStrided>().value());
}
else
{
return NullOpt;
}
}
protected:
TensorDataAccessStridedImagePlanar(const TensorDataStrided &data)
: Base(data, *TensorShapeInfoImagePlanar::Create(data.shape()))
{
}
};
} // namespace nvcv
#endif // NVCV_TENSORDATAACESSOR_HPP