Template Class TensorWrap
Defined in File TensorWrap.hpp
Class Documentation
-
template<typename T, int... Strides>
class TensorWrap TensorWrap class is a non-owning wrap of a N-D tensor used for easy access of its elements in CUDA device.
TensorWrap is a wrapper of a multi-dimensional tensor that can have one or more of its N dimension strides, or pitches, defined either at compile-time or at run-time. Each pitch in
Strides
represents the offset in bytes as a compile-time template parameter that will be applied from the first (slowest changing) dimension to the last (fastest changing) dimension of the tensor, in that order. Each dimension with run-time pitch is specified as -1 in theStrides
template parameter.Template arguments:
T type of the values inside the tensor
Strides sequence of compile- or run-time pitches (-1 indicates run-time)
Y compile-time pitches
X run-time pitches
N dimensions, where N = X + Y
For example, in the code below a wrap is defined for an NHWC 4D tensor where each sample image in N has a run-time image pitch (first -1 in template argument), and each row in H has a run-time row pitch (second -1), a pixel in W has a compile-time constant pitch as the size of the pixel type and a channel in C has also a compile-time constant pitch as the size of the channel type.
using DataType = ...; using ChannelType = BaseType<DataType>; using TensorWrap = TensorWrap<ChannelType, -1, -1, sizeof(DataType), sizeof(ChannelType)>; std::byte *imageData = ...; int imgStride = ...; int rowStride = ...; TensorWrap tensorWrap(imageData, imgStride, rowStride); // Elements may be accessed via operator[] using an int4 argument. They can also be accessed via pointer using // the ptr method with up to 4 integer arguments.
Tensor wrapper class specialized for non-constant value type.
See also
- Template Parameters:
T – Type (it can be const) of each element inside the tensor wrapper.
Strides – Each compile-time (use -1 for run-time) pitch in bytes from first to last dimension.
T – Type (non-const) of each element inside the tensor wrapper.
Strides – Each compile-time (use -1 for run-time) pitch in bytes from first to last dimension.
Public Functions
-
TensorWrap() = default
- template<typename DataType, typename... Args> inline explicit __host__ __device__ TensorWrap (DataType *data, Args... strides)
Constructs a TensorWrap by wrapping a
data
pointer argument.- Parameters:
data – [in] Pointer to the data that will be wrapped.
strides0..N – [in] Each run-time pitch in bytes from first to last dimension.
- template<typename DataType, typename StrideType> inline explicit __host__ __device__ TensorWrap (DataType *data, StrideType *strides)
Constructs a TensorWrap by wrapping a const
data
pointer argument and copying the dyncamic strides from a given buffer.- Parameters:
data – [in] Pointer to the data that will be wrapped.
strides – [in] Pointer to stride data
-
inline __host__ TensorWrap(const ImageDataStridedCuda &image)
Constructs a TensorWrap by wrapping an
image
argument.- Parameters:
image – [in] Image reference to the image that will be wrapped.
-
inline __host__ TensorWrap(const TensorDataStridedCuda &tensor)
Constructs a TensorWrap by wrapping a
tensor
argument.- Parameters:
tensor – [in] Tensor reference to the tensor that will be wrapped.
- template<typename DimType, class = Require<std::is_same_v<int, BaseType<DimType>>>> inline __host__ __device__ T & operator[] (DimType c) const
Subscript operator for read-and-write access.
- Parameters:
c – [in] N-D coordinate (from last to first dimension) to be accessed.
- Returns:
Accessed reference.
- template<typename... Args> inline __host__ __device__ T * ptr (Args... c) const
Get a read-and-write proxy (as pointer) at the Dth dimension.
- Parameters:
c0..D – [in] Each coordinate from first to last dimension.
- Returns:
The pointer to the beginning of the Dth dimension.
Public Static Attributes
-
static constexpr int kConstantStrides
-
static constexpr int kNumDimensions
-
static constexpr int kVariableStrides
Protected Functions
- template<typename... Args> inline __host__ __device__ T * doGetPtr (Args... c) const