Image Classification

Overview

The Image Classification sample demonstrates an end-to-end GPU-accelerated deep learning inference pipeline using CV-CUDA for preprocessing and TensorRT for model inference. This sample showcases:

  • Loading and preprocessing images entirely on GPU

  • ImageNet normalization with mean and standard deviation

  • Integration with TensorRT for high-performance inference

  • Processing images through a ResNet50 classification model

  • Extracting top-K predictions

Usage

Process an image with default ResNet50 model:

python3 classification.py -i image.jpg

The sample will:

  1. Download ResNet50 model weights (first run only)

  2. Export to ONNX format (first run only)

  3. Build TensorRT engine (first run only)

  4. Process the image and display top 5 predictions

Command-Line Arguments

Argument

Short Form

Default

Description

--input

-i

tabby_tiger_cat.jpg

Input image file path

--output

-o

cvcuda/.cache/cat_classified.jpg

Output file path for predictions

--width

224

Target width for model input

--height

224

Target height for model input

Implementation Details

The classification pipeline consists of:

  1. Model setup (ONNX export and TensorRT engine building, cached after first run)

  2. Image loading into GPU memory

  3. Preprocessing (resize, normalize, reformat)

  4. TensorRT inference

  5. Extracting and displaying top-K predictions

Code Walkthrough

Model Setup

# 1. Download the onnx model (if not already downloaded)
onnx_model_path = get_cache_dir() / f"resnet50_{args.height}x{args.width}.onnx"
if not onnx_model_path.exists():
    import torchvision  # noqa: E402

    resnet50 = torchvision.models.resnet50(
        weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
    )
    export_classifier_onnx(
        resnet50, onnx_model_path, (3, args.height, args.width), verbose=False
    )

# 2. Build the TensorRT engine (if not already built)
trt_model_path = get_cache_dir() / f"resnet50_{args.height}x{args.width}.trtmodel"
if not trt_model_path.exists():
    engine_from_onnx(onnx_model_path, trt_model_path)
model = TRT(trt_model_path)

The model setup process:

  • ONNX Export: Exports PyTorch ResNet50 to ONNX format

  • TensorRT Build: Compiles ONNX to optimized TensorRT engine

  • Caching: Models are cached in cvcuda/.cache/ directory

  • Automatic: Only runs on first execution

Loading Input Image

# 3. Read the image
input_image: cvcuda.Tensor = read_image(args.input)

The image is loaded:

  • Directly into GPU memory

  • As a CV-CUDA tensor

  • In HWC (Height-Width-Channels) layout

Preprocessing Pipeline

# 4. Preprocess the image
# 4.1 Allocate the static imagenet mean and std tensors
#     This is only needed once and can be reused for all images
scale: np.ndarray = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(
    (1, 1, 1, 3)
)
scale_tensor: cvcuda.Tensor = cvcuda.Tensor((1, 1, 1, 3), np.float32, "NHWC")
cuda_memcpy_h2d(scale, scale_tensor.cuda())

std: np.ndarray = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(
    (1, 1, 1, 3)
)
std_tensor: cvcuda.Tensor = cvcuda.Tensor((1, 1, 1, 3), np.float32, "NHWC")
cuda_memcpy_h2d(std, std_tensor.cuda())

# 4.2 Add a batch dimension
input_tensor: cvcuda.Tensor = cvcuda.stack([input_image])

# 4.3 Resize the image
resized_tensor: cvcuda.Tensor = cvcuda.resize(
    input_tensor, (1, args.height, args.width, 3), cvcuda.Interp.LINEAR
)

# 4.4 Convert to float32
float_tensor: cvcuda.Tensor = cvcuda.convertto(
    resized_tensor, np.float32, scale=1 / 255
)

# 4.5 Normalize the image using imagenet mean and std
normalized_tensor: cvcuda.Tensor = cvcuda.normalize(
    float_tensor,
    scale_tensor,
    std_tensor,
    cvcuda.NormalizeFlags.SCALE_IS_STDDEV,
)

# 4.6 Convert to NCHW layout
tensor: cvcuda.Tensor = cvcuda.reformat(normalized_tensor, "NCHW")

The preprocessing steps:

  1. Setup Normalization Parameters: ImageNet mean and std deviation

  2. Add Batch Dimension: Convert HWC → NHWC using cvcuda.stack()

  3. Resize: Scale to target size (default 224×224)

  4. Convert to Float: Convert uint8 [0,255] → float32 [0.0,1.0]

  5. Normalize: Apply ImageNet normalization: (x - mean) / std

  6. Reformat: Convert NHWC → NCHW for PyTorch-style models

ImageNet Normalization

The standard ImageNet normalization uses:

  • Mean: [0.485, 0.456, 0.406] for RGB channels

  • Std: [0.229, 0.224, 0.225] for RGB channels

This normalization is critical for pretrained models and is applied as:

normalized = (image / 255.0 - mean) / std

Running Inference

# 5. Run the inference
# TRT takes list of tensors and outputs list of tensors
input_tensors: list[cvcuda.Tensor] = [tensor]
output_tensors: list[cvcuda.Tensor] = model(input_tensors)
output_tensor: cvcuda.Tensor = output_tensors[0]

TensorRT inference:

  • Takes CV-CUDA tensors as input via __cuda_array_interface__

  • Returns CV-CUDA tensors as output

  • Runs entirely on GPU

  • Output shape: [1, 1000] (1 batch, 1000 ImageNet classes)

Postprocessing Results

# 6. Postprocess the inference results
output: np.ndarray = np.zeros((1, 1000), dtype=np.float32)
cuda_memcpy_d2h(output_tensor.cuda(), output)

# 7. Print the top 5 predictions
indices = np.argsort(output)[0][::-1]
for i, index in enumerate(indices[:5]):
    print(f"  {i+1}. Class {index}: {output[0][index]}")

The postprocessing:

  • Copies results to CPU memory

  • Sorts predictions by confidence

  • Displays top 5 most confident classes

Expected Output

Example Output

Processing image: tabby_tiger_cat.jpg
Top 5 Predictions (placeholder values):
  1. Class 282: 0.615 (Tiger Cat, if using tabby_tiger_cat)
  2. Class 281: 0.254
  3. Class 283: 0.024
  4. Class 284: 0.001
  5. Class 285: 0.000

Note

Class indices correspond to ImageNet-1K classes. Class 281-293 represent various cat breeds.

Interpreting Results

The output shows:

  • Class Index: ImageNet class ID (0-999)

  • Confidence Score: Higher values indicate higher confidence

  • Relative Ranking: Sorted from most to least confident

For ImageNet class names, refer to the ImageNet class list.

CV-CUDA Operators Used

Operator

Purpose

cvcuda.stack()

Add batch dimension to single image

cvcuda.resize()

Resize image to model input size (configurable, default 224×224)

cvcuda.convertto()

Convert uint8 to float32 and scale to [0,1]

cvcuda.normalize()

Apply ImageNet mean and standard deviation normalization

cvcuda.reformat()

Convert NHWC layout to NCHW for model input

Common Utilities Used

See Also

References