Program Listing for File TensorImpl.hpp

Return to documentation for file (nvcv_types/include/nvcv/detail/TensorImpl.hpp)

/*
 * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef NVCV_TENSOR_IMPL_HPP
#define NVCV_TENSOR_IMPL_HPP

#ifndef NVCV_TENSOR_HPP
#    error "You must not include this header directly"
#endif

namespace nvcv {

// Tensor implementation -------------------------------------

inline TensorShape Tensor::shape() const
{
    NVCVTensorHandle htensor = this->handle();

    int32_t rank = 0;
    detail::CheckThrow(nvcvTensorGetShape(htensor, &rank, nullptr));

    NVCVTensorLayout layout;
    detail::CheckThrow(nvcvTensorGetLayout(htensor, &layout));

    TensorShape::ShapeType shape(rank);
    detail::CheckThrow(nvcvTensorGetShape(htensor, &rank, shape.begin()));
    return {shape, layout};
}

inline int Tensor::rank() const
{
    int32_t rank = 0;
    detail::CheckThrow(nvcvTensorGetShape(this->handle(), &rank, nullptr));
    return rank;
}

inline TensorLayout Tensor::layout() const
{
    NVCVTensorLayout layout;
    detail::CheckThrow(nvcvTensorGetLayout(this->handle(), &layout));
    return static_cast<TensorLayout>(layout);
}

inline DataType Tensor::dtype() const
{
    NVCVDataType out;
    detail::CheckThrow(nvcvTensorGetDataType(this->handle(), &out));
    return DataType{out};
}

inline TensorData Tensor::exportData() const
{
    auto h = this->handle();
    if (h == nullptr)
        throw Exception(Status::ERROR_INVALID_OPERATION, "The tensor handle is null.");

    NVCVTensorData data;
    detail::CheckThrow(nvcvTensorExportData(this->handle(), &data));

    if (data.bufferType != NVCV_TENSOR_BUFFER_STRIDED_CUDA)
    {
        throw Exception(Status::ERROR_INVALID_OPERATION, "Tensor data cannot be exported, buffer type not supported");
    }

    return TensorData(data);
}

inline void Tensor::setUserPointer(void *ptr)
{
    detail::CheckThrow(nvcvTensorSetUserPointer(this->handle(), ptr));
}

inline void *Tensor::userPointer() const
{
    void *ptr;
    detail::CheckThrow(nvcvTensorGetUserPointer(this->handle(), &ptr));
    return ptr;
}

inline Tensor Tensor::reshape(const TensorShape &new_shape)
{
    NVCVTensorHandle out_handle;
    detail::CheckThrow(
        nvcvTensorReshape(this->handle(), new_shape.rank(), &new_shape.shape()[0], new_shape.layout(), &out_handle));
    Tensor out_tensor(std::move(out_handle));
    return out_tensor;
}

inline auto Tensor::CalcRequirements(const TensorShape &shape, DataType dtype, const MemAlignment &bufAlign)
    -> Requirements
{
    Requirements reqs;
    detail::CheckThrow(nvcvTensorCalcRequirements(shape.size(), &shape[0], dtype,
                                                  static_cast<NVCVTensorLayout>(shape.layout()), bufAlign.baseAddr(),
                                                  bufAlign.rowAddr(), &reqs));
    return reqs;
}

inline auto Tensor::CalcRequirements(int numImages, Size2D imgSize, ImageFormat fmt, const MemAlignment &bufAlign)
    -> Requirements
{
    Requirements reqs;
    detail::CheckThrow(nvcvTensorCalcRequirementsForImages(numImages, imgSize.w, imgSize.h, fmt, bufAlign.baseAddr(),
                                                           bufAlign.rowAddr(), &reqs));
    return reqs;
}

inline Tensor::Tensor(const Requirements &reqs, const Allocator &alloc)
{
    NVCVTensorHandle handle;
    detail::CheckThrow(nvcvTensorConstruct(&reqs, alloc.handle(), &handle));
    reset(std::move(handle));
}

inline Tensor::Tensor(int numImages, Size2D imgSize, ImageFormat fmt, const MemAlignment &bufAlign,
                      const Allocator &alloc)
    : Tensor(CalcRequirements(numImages, imgSize, fmt, bufAlign), alloc)
{
}

inline Tensor::Tensor(const TensorShape &shape, DataType dtype, const MemAlignment &bufAlign, const Allocator &alloc)
    : Tensor(CalcRequirements(shape, dtype, bufAlign), alloc)
{
}

// Factory functions --------------------------------------------------

inline Tensor TensorWrapData(const TensorData &data, TensorDataCleanupCallback &&cleanup)
{
    NVCVTensorHandle handle;
    detail::CheckThrow(
        nvcvTensorWrapDataConstruct(&data.cdata(), cleanup.targetFunc(), cleanup.targetHandle(), &handle));
    cleanup.release(); // already owned by the tensor
    return Tensor(std::move(handle));
}

inline Tensor TensorWrapImage(const Image &img)
{
    NVCVTensorHandle handle;
    detail::CheckThrow(nvcvTensorWrapImageConstruct(img.handle(), &handle));
    return Tensor(std::move(handle));
}

} // namespace nvcv

#endif // NVCV_TENSOR_IMPL_HPP