Program Listing for File TensorDataImpl.hpp

Return to documentation for file (nvcv_types/include/nvcv/detail/TensorDataImpl.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_TENSORDATA_IMPL_HPP
#define NVCV_TENSORDATA_IMPL_HPP

#ifndef NVCV_TENSORDATA_HPP
#    error "You must not include this header directly"
#endif

#include <algorithm>

namespace nvcv {

// Implementation - TensorData -----------------------------

inline TensorData::TensorData(const NVCVTensorData &data)
    : m_data(data)
{
}

inline int TensorData::rank() const
{
    return this->cdata().rank;
}

inline const TensorShape &TensorData::shape() const &
{
    if (!m_cacheShape)
    {
        const NVCVTensorData &data = this->cdata();
        m_cacheShape.emplace(data.shape, data.rank, data.layout);
    }

    return *m_cacheShape;
}

inline const TensorShape::DimType &TensorData::shape(int d) const &
{
    const NVCVTensorData &data = this->cdata();

    if (d < 0 || d >= data.rank)
    {
        throw Exception(Status::ERROR_INVALID_ARGUMENT, "Index of shape dimension %d is out of bounds [0;%d]", d,
                        data.rank - 1);
    }
    return data.shape[d];
}

inline const TensorLayout &TensorData::layout() const &
{
    return this->shape().layout();
}

inline DataType TensorData::dtype() const
{
    const NVCVTensorData &data = this->cdata();
    return DataType{data.dtype};
}

inline const NVCVTensorData &TensorData::cdata() const &
{
    return m_data;
}

inline NVCVTensorData &TensorData::data() &
{
    // data contents might be modified, must reset cache
    m_cacheShape.reset();
    return m_data;
}

template<typename Derived>
bool TensorData::IsCompatible() const
{
    return Derived::IsCompatibleKind(m_data.bufferType);
}

template<typename Derived>
inline Optional<Derived> TensorData::cast() const
{
    static_assert(std::is_base_of<TensorData, Derived>::value, "Cannot cast TensorData to an unrelated type");

    static_assert(sizeof(Derived) == sizeof(TensorData), "The derived type must not add new data members.");

    if (IsCompatible<Derived>())
    {
        return Derived(m_data);
    }
    else
    {
        return NullOpt;
    }
}

// Implementation - TensorDataStrided ----------------------------

inline Byte *TensorDataStrided::basePtr() const
{
    const NVCVTensorBufferStrided &buffer = this->cdata().buffer.strided;
    return reinterpret_cast<Byte *>(buffer.basePtr);
}

inline const int64_t &TensorDataStrided::stride(int d) const
{
    const NVCVTensorData &data = this->cdata();
    if (d < 0 || d >= data.rank)
    {
        throw Exception(Status::ERROR_INVALID_ARGUMENT, "Index of pitch %d is out of bounds [0;%d]", d, data.rank - 1);
    }

    return data.buffer.strided.strides[d];
}

// TensorDataStridedCuda implementation -----------------------

inline TensorDataStridedCuda::TensorDataStridedCuda(const TensorShape &tshape, const DataType &dtype,
                                                    const Buffer &buffer)
{
    NVCVTensorData &data = this->data();

    std::copy(tshape.shape().begin(), tshape.shape().end(), data.shape);
    data.rank   = tshape.rank();
    data.dtype  = dtype;
    data.layout = tshape.layout();

    data.bufferType     = NVCV_TENSOR_BUFFER_STRIDED_CUDA;
    data.buffer.strided = buffer;
}

inline TensorDataStridedCuda::TensorDataStridedCuda(const NVCVTensorData &data)
    : TensorDataStrided(data)
{
    if (!IsCompatibleKind(data.bufferType))
    {
        throw Exception(Status::ERROR_INVALID_ARGUMENT, "Incompatible buffer type.");
    }
}

} // namespace nvcv

#endif // NVCV_TENSORDATA_IMPL_HPP