Program Listing for File ArrayDataAccess.hpp
↰ Return to documentation for file (nvcv_types/include/nvcv/ArrayDataAccess.hpp
)
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef NVCV_ARRAYDATAACESSOR_HPP
#define NVCV_ARRAYDATAACESSOR_HPP
#include "ArrayData.hpp"
#include <type_traits>
namespace nvcv {
namespace detail {
#ifdef __cpp_lib_is_invocable
template<typename TypeExpression>
struct invoke_result : public std::invoke_result<TypeExpression>
{
};
#else // __cpp_lib_is_invocable
template<typename TypeExpression>
struct invoke_result : public std::result_of<TypeExpression>
{
};
#endif // __cpp_lib_is_invocable
template<typename ArrayDataType,
typename = typename std::enable_if<std::is_base_of<ArrayData, ArrayDataType>::value>::type>
class ArrayDataAccessImpl
{
using traits = std::pointer_traits<Byte *>;
public:
using ArrayType = ArrayDataType;
using pointer = typename traits::pointer;
using difference_type = typename traits::difference_type;
ArrayDataAccessImpl() = delete;
int64_t length() const
{
return m_length;
}
DataType dtype() const
{
return m_data.dtype();
}
int64_t stride() const
{
return m_data.stride();
}
NVCVArrayBufferType kind() const
{
return m_data.kind();
}
pointer sampleData(int64_t n) const
{
auto result = m_data.basePtr();
if ((n + m_idxShift) >= m_length)
{
throw Exception(Status::ERROR_INVALID_ARGUMENT, "Requested index is out of bounds.");
}
result += m_data.stride() * m_idxShift;
result += m_memShift;
return result;
}
pointer ptr() const
{
auto result = m_data.basePtr();
if (m_idxShift > 0)
{
if (m_idxShift >= m_length)
{
throw Exception(Status::ERROR_INVALID_ARGUMENT, "Requested index is out of bounds.");
}
result += m_data.stride() * m_idxShift;
result += m_memShift;
}
return result;
}
protected:
ArrayType m_data;
ArrayDataAccessImpl(const ArrayType &data)
: m_data{data}
, m_length{data.length()}
, m_idxShift{0}
, m_memShift{0}
{
}
ArrayDataAccessImpl(const ArrayType &data, int64_t _length, const pointer _start)
: ArrayDataAccessImpl{data}
{
auto length = _length == 0 ? m_data.length() : _length;
auto start = _start == nullptr ? m_data.basePtr() : _start;
auto memLineRange = start - m_data.basePtr();
auto itrEnd = m_data.basePtr();
itrEnd += m_data.stride() * m_data.capacity();
if (start && m_data.basePtr() <= start && start < itrEnd)
{
m_memShift = memLineRange % m_data.stride();
}
else
{
throw Exception(Status::ERROR_INVALID_ARGUMENT, "Requested start address is out of bounds.");
}
auto itrAt = memLineRange / m_data.stride();
if ((itrAt + length) <= m_data.capacity())
{
m_length = length;
m_idxShift = itrAt;
}
else
{
throw Exception(Status::ERROR_INVALID_ARGUMENT, "Requested array length is out of bounds.");
}
}
private:
int64_t m_length;
int64_t m_idxShift;
difference_type m_memShift;
};
} // namespace detail
class ArrayDataAccess : public detail::ArrayDataAccessImpl<ArrayData>
{
using Base = detail::ArrayDataAccessImpl<ArrayData>;
public:
static bool IsCompatible(const ArrayData &data)
{
return data.IsCompatible<ArrayData>();
}
static Optional<ArrayDataAccess> Create(const ArrayData &data, int64_t length = 0, const pointer start = nullptr)
{
auto castData = data.cast<ArrayData>();
if (castData)
{
return ArrayDataAccess{castData.value(), length, start};
}
else
{
return NullOpt;
}
}
private:
ArrayDataAccess(const ArrayData &data, int64_t length, const pointer start)
: Base{data, length, start}
{
}
};
class ArrayDataAccessHost : public detail::ArrayDataAccessImpl<ArrayDataHost>
{
using Base = detail::ArrayDataAccessImpl<ArrayDataHost>;
public:
static bool IsCompatible(const ArrayData &data)
{
return data.IsCompatible<ArrayDataHost>();
}
static Optional<ArrayDataAccessHost> Create(const ArrayData &data, int64_t length = 0,
const pointer start = nullptr)
{
auto castData = data.cast<ArrayDataHost>();
if (castData)
{
return ArrayDataAccessHost{castData.value(), length, start};
}
else
{
return NullOpt;
}
}
private:
ArrayDataAccessHost(const ArrayDataHost &data, int64_t length, const pointer start)
: Base{data, length, start}
{
}
};
class ArrayDataAccessHostPinned : public detail::ArrayDataAccessImpl<ArrayDataHostPinned>
{
using Base = detail::ArrayDataAccessImpl<ArrayDataHostPinned>;
public:
static bool IsCompatible(const ArrayData &data)
{
return data.IsCompatible<ArrayDataHostPinned>();
}
static Optional<ArrayDataAccessHostPinned> Create(const ArrayData &data, int64_t length = 0,
const pointer start = nullptr)
{
auto castData = data.cast<ArrayDataHostPinned>();
if (castData)
{
return ArrayDataAccessHostPinned{castData.value(), length, start};
}
else
{
return NullOpt;
}
}
private:
ArrayDataAccessHostPinned(const ArrayDataHostPinned &data, int64_t length, const pointer start)
: Base{data, length, start}
{
}
};
class ArrayDataAccessCuda : public detail::ArrayDataAccessImpl<ArrayDataCuda>
{
using Base = detail::ArrayDataAccessImpl<ArrayDataCuda>;
public:
static bool IsCompatible(const ArrayData &data)
{
return data.IsCompatible<ArrayDataCuda>();
}
static Optional<ArrayDataAccessCuda> Create(const ArrayData &data, int64_t length = 0,
const pointer start = nullptr)
{
auto castData = data.cast<ArrayDataCuda>();
if (castData)
{
return ArrayDataAccessCuda{castData.value(), length, start};
}
else
{
return NullOpt;
}
}
private:
ArrayDataAccessCuda(const ArrayDataCuda &data, int64_t length, const pointer start)
: Base{data, length, start}
{
}
};
} // namespace nvcv
#endif // NVCV_ARRAYDATAACESSOR_HPP