Object Detection

Overview

The Object Detection sample demonstrates GPU-accelerated object detection using CV-CUDA for preprocessing and TensorRT with EfficientNMS for inference. This sample showcases:

  • End-to-end object detection pipeline on GPU

  • RetinaNet model with ResNet50-FPN backbone

  • EfficientNMS TensorRT plugin for fast post-processing

  • Bounding box drawing on detected objects

  • Integration between CV-CUDA, TensorRT, and visualization

Usage

Detect objects in an image:

python3 object_detection.py -i image.jpg

The sample will:

  1. Download RetinaNet model weights (first run only)

  2. Export model with EfficientNMS to ONNX (first run only)

  3. Build TensorRT engine (first run only)

  4. Detect objects and draw bounding boxes

  5. Save output as cvcuda/.cache/cat_detections.jpg

Specify custom output path:

python3 object_detection.py -i street.jpg -o detections.jpg

Command-Line Arguments

Argument

Short Form

Default

Description

--input

-i

tabby_tiger_cat.jpg

Input image file path

--output

-o

cvcuda/.cache/cat_detections.jpg

Output image file path with drawn boxes

--width

224

Target width for model input

--height

224

Target height for model input

Implementation Details

The object detection pipeline consists of:

  1. Model setup (RetinaNet+EfficientNMS export and TensorRT engine building)

  2. Image loading into GPU

  3. Preprocessing (resize and normalize)

  4. TensorRT detection

  5. Drawing bounding boxes

  6. Saving annotated image

Code Walkthrough

Model Setup and Export

# 1. Export the ONNX model (RetinaNet backbone + head + EfficientNMS plugin)
onnx_model_path = get_cache_dir() / f"retinanet_{args.height}x{args.width}.onnx"
if not onnx_model_path.exists():
    import torchvision  # noqa: E402

    retinanet = torchvision.models.detection.retinanet_resnet50_fpn(
        weights=torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights.DEFAULT,
    )
    export_retinanet_onnx(
        retinanet,
        onnx_model_path,
        (3, args.height, args.width),
        verbose=False,
    )

# 2. Build the TensorRT engine
trt_model_path = get_cache_dir() / f"retinanet_{args.height}x{args.width}.trtmodel"
if not trt_model_path.exists():
    engine_from_onnx(onnx_model_path, trt_model_path, use_fp16=False)
model = TRT(trt_model_path)

The model export process:

  • RetinaNet: Loads pretrained RetinaNet with ResNet50-FPN backbone

  • EfficientNMS: Adds TensorRT EfficientNMS plugin to model graph

  • ONNX Export: Exports complete detection pipeline

  • TensorRT Build: Compiles to optimized engine

Note

EfficientNMS performs Non-Maximum Suppression (NMS) on GPU, eliminating the need for CPU post-processing.

Loading Input Image

# 3. Read the image
input_image: cvcuda.Tensor = read_image(args.input)

Image is loaded directly into GPU memory with original dimensions preserved for later bbox scaling.

Preprocessing Pipeline

# 4. Preprocess the image
# 4.1 Add a batch dimension
input_tensor: cvcuda.Tensor = cvcuda.stack([input_image])

# 4.2 Resize the image
resized_tensor: cvcuda.Tensor = cvcuda.resize(
    input_tensor, (1, args.height, args.width, 3), cvcuda.Interp.LINEAR
)

# 4.3 Convert to float32
float_tensor: cvcuda.Tensor = cvcuda.convertto(
    resized_tensor, np.float32, scale=1 / 255
)

# 4.4 Convert to NCHW layout
tensor: cvcuda.Tensor = cvcuda.reformat(float_tensor, "NCHW")

Preprocessing steps:

  1. Add Batch Dimension: HWC → NHWC using cvcuda.stack()

  2. Resize: Scale to target model input size (default 224×224)

  3. Normalize: Convert to float32 [0,1] range

  4. Reformat: NHWC → NCHW for model input

Running Inference

# 5. Run the inference
input_tensors: list[cvcuda.Tensor] = [tensor]
output_tensors: list[cvcuda.Tensor] = model(input_tensors)

# EfficientNMS outputs: [num_detections, boxes, scores, classes]
num_detections_tensor = output_tensors[0]  # [1, 1] int32
boxes_tensor = output_tensors[1]  # [1, max_detections, 4] float32
scores_tensor = output_tensors[2]  # [1, max_detections] float32
classes_tensor = output_tensors[3]  # [1, max_detections] int32

Inference outputs from EfficientNMS:

  • num_detections: [1, 1] - Number of valid detections

  • boxes: [1, max_detections, 4] - Bounding boxes [x1, y1, x2, y2]

  • scores: [1, max_detections] - Confidence scores

  • classes: [1, max_detections] - Class indices

All outputs are already filtered and sorted by EfficientNMS.

Postprocessing and Visualization

# 6. Copy results to host
num_detections = np.zeros((1, 1), dtype=np.int32)
boxes = np.zeros((1, 100, 4), dtype=np.float32)
scores = np.zeros((1, 100), dtype=np.float32)
classes = np.zeros((1, 100), dtype=np.int32)

cuda_memcpy_d2h(num_detections_tensor.cuda(), num_detections)
cuda_memcpy_d2h(boxes_tensor.cuda(), boxes)
cuda_memcpy_d2h(scores_tensor.cuda(), scores)
cuda_memcpy_d2h(classes_tensor.cuda(), classes)

# 7. Draw the detections on the image
n = num_detections[0, 0]
orig_h, orig_w = input_image.shape[:2]
scale_x = orig_w / float(args.width)
scale_y = orig_h / float(args.height)

# Create list of bounding boxes
bboxes: list[cvcuda.BndBoxI] = []
for idx, box in enumerate(boxes[0]):
    # only assess boxes from the top n detections
    if idx >= n:
        break
    x1 = int(box[0] * scale_x)
    y1 = int(box[1] * scale_y)
    x2 = int(box[2] * scale_x)
    y2 = int(box[3] * scale_y)
    # CVCUDA bbox are (x, y, width, height)
    bbox = (
        x1,
        y1,
        x2 - x1,
        y2 - y1,
    )
    print(f"Box {idx}: {bbox}")

    # create each cvcuda bounding box
    cvcuda_box = cvcuda.BndBoxI(
        box=bbox,
        thickness=2,
        borderColor=(255, 0, 0),
        fillColor=(0, 0, 0, 0),
    )
    bboxes.append(cvcuda_box)

bndboxes = cvcuda.BndBoxesI(boxes=[bboxes])
output_image = cvcuda.bndbox(input_image, bndboxes)
write_image(output_image, args.output)

# 8. Verify output file exists
assert args.output.exists()

Postprocessing:

  1. Copy to Host: Transfer detection results to CPU

  2. Scale Boxes: Scale from model input size to original image size

  3. Create Bounding Boxes: Build CV-CUDA bounding box objects

  4. Draw Boxes: Use cvcuda.bndbox() to draw on GPU

  5. Save Result: Write annotated image

Expected Output

Console output shows detected bounding boxes:

Box 0: (45, 67, 312, 389)
Box 1: (150, 200, 280, 350)
...

Each box shows (x1, y1, x2, y2) coordinates in the original image space.

The output image will have red bounding boxes drawn around detected objects (e.g., cat).

../../_images/tabby_tiger_cat.jpg

Original Input Image

../../_images/cat_detections.jpg

Output with Detected Objects

Understanding Detection Output

  • Bounding Box Format: Corner format with (x1, y1) top-left and (x2, y2) bottom-right

  • Confidence Scores: Range [0, 1] where 1 is highest confidence

  • Class Labels: RetinaNet is trained on COCO dataset with 80 classes

CV-CUDA Operators Used

Operator

Purpose

cvcuda.stack()

Add batch dimension

cvcuda.resize()

Resize to model input size (configurable, default 224×224)

cvcuda.convertto()

Convert to float32 and normalize to [0,1]

cvcuda.reformat()

Convert NHWC to NCHW

cvcuda.bndbox()

Draw bounding boxes on GPU

Common Utilities Used

See Also

References