Semantic Segmentation deployed using Triton
NVIDIA’s Triton Inference Server enables teams to deploy, run, and scale trained AI models from any framework on any GPU- or CPU-based infrastructure. It offers backend support for most machine learning ML frameworks, as well as custom C++ and python backend In this tutorial, we will go over an example of taking a CVCUDA accelerated inference workload and deploy it using Triton’s Custom Python backend. Note that for video modality, our Triton implementation supports a non-streamed mode and a streamed processing mode. Non-streamed mode will send decoded/uncompressed frames over the Triton network, where video-to-frame decoding/encoding are processed on the client side. In streamed mode, raw compressed frames are sent over the network and offloads the entire decoding-preprocessing-inference-postprocessing-encoding pipeline to the server side. Performance benchmark indicates the streamed mode has great advantages over the non-streamed for video workload, thus it is highly recommended to turn on with –stream_video (-sv) flag. Refer to the Segmentation sample documentation to understand the details of the pipeline.
Terminologies
Triton Server
Manages and deploys model at scale. Refer the Triton documentation to review all the features Triton has to offer.
Triton model repository
Triton model represents a inference workload that needs to be deployed. The triton server loads the model repository when started.
Triton Client
Triton client libraries facilitates communication with Triton using Python or C++ API. In this example we will demonstrate how to to the Python API to communicate with Triton using GRPC requests.
Tutorial
Download the Triton server and client dockers. To download the dockers from NGC, the following is required
nvidia-docker v2.11.0
Working NVIDIA NGC account (visit https://ngc.nvidia.com/setup to get started using NGC) and follow through the NGC documentation here https://docs.nvidia.com/ngc/ngc-catalog-user-guide/index.html#ngc-image-prerequisites
docker CLI logged into nvcr.io (NGC’s docker registry) to be able to pull docker images.
docker pull nvcr.io/nvidia/tritonserver:<xx.yy>-py3 docker pull nvcr.io/nvidia/tritonserver:<xx.yy>-py3-sdk
where xx.yy refers to the Triton version
Create the model repository
Triton loads the model repository using the following command:
tritonserver --model-repository=<model-repository-path>
The model repository paths needs to conform to a layout specified below:
<model-repository-path>/ <model-name>/ <version>/ <model-definition-file> config.pbtxt
For the segmentation sample, we will create a model.py which creates a TritonPythonModel that runs the preprocess, inference and post process workloads. We will copy the necessary files and modules from the segmentation sample for preprocess, inference and postprocess stages and create the following folder structure:
triton_models/ fcn_resnet101/ 1/ model.py config.pbtxt
Each model in the model repository must include a model configuration that provides the required and optional information about the model. Typically, this configuration is provided in a config.pbtxt
The segmentation config is shown below for non-streamed mode
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: "fcn_resnet101" backend: "python" max_batch_size: 32 input [ { name: "inputrgb" data_type: TYPE_UINT8 dims: [ -1, -1, -1] } ] output[ { name: "outputrgb" data_type: TYPE_FP32 dims: [ -1, -1, -1 ] } ] parameters: { key: "network_width" value: {string_value:"224"} } parameters: { key: "network_height" value: {string_value:"224"} } parameters: { key: "device_id" value: {string_value:"0"} } parameters: { key: "visualization_class_name" value: {string_value:"cat"} } parameters: { key: "inference_backend" value: {string_value:"tensorrt"} } instance_group { kind: KIND_GPU count: 1 }
And the following for streamed mode
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Packet Metadata Description # META0: #- batch size, frame width, frame height # - FPS, total frames, pixel format # - codec, color space, color range # META1: # - keyframe flag, presentation timestamp, decode timestamp # META2: # - offset, bitstream length, duration name: "fcn_resnet101_streaming" backend: "python" max_batch_size: 0 model_transaction_policy { decoupled: True } input [ { name: "PACKET_IN" data_type: TYPE_UINT8 dims: [ -1 ] }, { name: "META0" data_type: TYPE_INT32 dims: [ 9 ] optional: true }, { name: "META1" data_type: TYPE_INT64 dims: [ 3 ] }, { name: "META2" data_type: TYPE_UINT64 dims: [ 3 ] }, { name: "FIRST_PACKET" data_type: TYPE_BOOL dims: [ 1 ] }, { name: "LAST_PACKET" data_type: TYPE_BOOL dims: [ 1 ] } ] output[ { name: "PACKET_OUT" data_type: TYPE_UINT8 dims: [ -1 ] }, { name: "FRAME_SIZE" data_type: TYPE_UINT64 dims: [ 2 ] }, { name: "LAST_PACKET" data_type: TYPE_BOOL dims: [ 1 ] } ] parameters: { key: "network_width" value: {string_value:"224"} } parameters: { key: "network_height" value: {string_value:"224"} } parameters: { key: "device_id" value: {string_value:"0"} } parameters: { key: "visualization_class_name" value: {string_value:"cat"} } parameters: { key: "inference_backend" value: {string_value:"tensorrt"} } parameters: { key: "max_batch_size_trt_engine" value: {string_value:"32"} } instance_group { kind: KIND_GPU count: 1 }
Triton client (non-streamed mode)
Triton receives as input the frames (in batches) and returns the segmentation output These are represented as input and output layers of the Triton model. Additional parameters for initialization of the model can be specified as well.
We will use the Triton Python API using GRPC protocol to communicate with triton. Below is an example on how to create a python Triton client for the segmentation sample
Create the Triton GRPC client and set the input and output layer names of the model
1# Create GRPC Triton Client 2try: 3 triton_client = tritongrpcclient.InferenceServerClient(url=url) 4except Exception as e: 5 raise Exception("Unable to create Triton GRPC Client: " + str(e)) 6 7# Set input and output Triton buffer names 8input_name = "inputrgb" 9output_name = "outputrgb"
The client takes as input a set of images or video and decodes the input image or video into a batched tensor. We will first initialize the data loader based on the data modality
1if os.path.splitext(input_path)[1] == ".jpg" or os.path.isdir(input_path): 2 # Treat this as data modality of images 3 decoder = ImageBatchDecoder( 4 input_path, 5 batch_size, 6 device_id, 7 cuda_ctx, 8 cvcuda_stream, 9 cvcuda_perf, 10 ) 11 12 encoder = ImageBatchEncoder( 13 output_dir, 14 device_id=device_id, 15 cvcuda_perf=cvcuda_perf, 16 ) 17else: 18 # Treat this as data modality of videos. 19 # Check if the user wanted to use streaming video or not. 20 if should_stream_video: 21 decoder = VideoBatchStreamingDecoderVPF( 22 "client", 23 None, 24 cvcuda_perf, 25 input_path, 26 model_name, 27 model_version, 28 ) 29 file_name = os.path.splitext(os.path.basename(input_path))[0] 30 video_id = uuid.uuid4() # unique video id 31 video_output_path = os.path.join( 32 output_dir, f"{file_name}_output_{video_id}.mp4" 33 ) 34 encoder = VideoBatchStreamingEncoderVPF( 35 "client", 36 None, 37 cvcuda_perf, 38 video_output_path, 39 decoder.decoder.fps, 40 ) 41 else: 42 decoder = VideoBatchDecoder( 43 input_path, 44 batch_size, 45 device_id, 46 cuda_ctx, 47 cvcuda_stream, 48 cvcuda_perf, 49 ) 50 51 encoder = VideoBatchEncoder( 52 output_dir, decoder.fps, device_id, cuda_ctx, cvcuda_stream, cvcuda_perf 53 ) 54 55# Fire up encoder/decoder 56decoder.start() 57encoder.start()
We are now finished with the initialization steps and will iterate over the video or images and run the pipeline. The decoder will return a batch of frames
1# Stage 1: decode 2batch = decoder() 3if batch is None: 4 cvcuda_perf.pop_range(total_items=0) # for batch 5 break # No more frames to decode 6assert batch_idx == batch.batch_idx
Create a Triton Inference Request by setting the layer name, data, data shape and data type of the input. We will also create an InferResponse to receive the output from the server
1# Stage 2: Create Triton Input request 2inputs = [] 3outputs = [] 4 5cvcuda_perf.push_range("io_prep") 6torch_arr = torch.as_tensor( 7 batch.data.cuda(), device="cuda:%d" % device_id 8) 9numpy_arr = torch_arr.cpu().numpy() 10inputs.append( 11 tritongrpcclient.InferInput( 12 input_name, numpy_arr.shape, "UINT8" 13 ) 14) 15outputs.append(tritongrpcclient.InferRequestedOutput(output_name)) 16inputs[0].set_data_from_numpy(numpy_arr) 17cvcuda_perf.pop_range()
Create an Async Infer Request to the server
1# Stage 3 : Run async Inference 2cvcuda_perf.push_range("async_infer") 3response = [] 4triton_client.async_infer( 5 model_name=model_name, 6 inputs=inputs, 7 callback=partial(callback, response), 8 model_version=model_version, 9 outputs=outputs, 10)
Wait for the response from the server. Verify no exception is returned from the server. Parse the output data from the InferResponse returned by the server
1# Stage 4 : Wait until the results are available 2while len(response) == 0: 3 time.sleep(0.001) 4cvcuda_perf.pop_range() 5 6# Stage 5 : Parse received Infer response 7cvcuda_perf.push_range("parse_response") 8if len(response) == 1: 9 # Check for the errors 10 if type(response[0]) == InferenceServerException: 11 cuda_ctx.pop() 12 raise response[0] 13 else: 14 seg_output = response[0].as_numpy(output_name) 15 16cvcuda_perf.pop_range()
Encode the output based on the data modality
1# Stage 6: encode output data 2cvcuda_perf.push_range("encode_output") 3seg_output = torch.as_tensor(seg_output) 4batch.data = seg_output.cuda() 5encoder(batch) 6cvcuda_perf.pop_range() 7
Triton client (streamed mode)
For streamed mode of video modality, the workflow is further simplified as the GPU workloads are all offloaded to the server side. Triton receives raw video packets with metadata instead decompressed frame data, and sends output frames as compressed data as well.
Demux and stream the input data to server for decoding
1# Stage 1: Begin streaming input data to the server for decoding 2decoder(triton_client, tritongrpcclient, video_id, batch_size)
Asynchronously receive output data from server
1# Stage 2: Begin receiving the output from server 2packet_count = 0 3packet_size = 0 4while True: 5 response = user_data._completed_requests.get() 6 7 # If the response was an error, we must raise it. 8 if isinstance(response, Exception): 9 raise response 10 else: 11 # The response was a data item. 12 data_item = response 13 14 packet = data_item.as_numpy("PACKET_OUT") 15 height, width = data_item.as_numpy("FRAME_SIZE") 16 17 packet_count += 1 18 packet_size += len(packet) 19 if packet_count % 50 == 0: 20 logger.debug( 21 "Received packet No. %d, size %d" 22 % (packet_count, len(packet)) 23 ) 24 25 # Check if this was the last packet by checking the terminal flag 26 if data_item.as_numpy("LAST_PACKET")[0]: 27 break
Stream output data from server for muxing
1# Stage 3: Begin streaming output data from server for encoding. 2encoder(packet, height, width)
Running the Sample
Follow the instructions in the README.md for the setup and instructions to run the sample
[//]: # "SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved."
[//]: # "SPDX-License-Identifier: Apache-2.0"
[//]: # ""
[//]: # "Licensed under the Apache License, Version 2.0 (the 'License');"
[//]: # "you may not use this file except in compliance with the License."
[//]: # "You may obtain a copy of the License at"
[//]: # "http://www.apache.org/licenses/LICENSE-2.0"
[//]: # ""
[//]: # "Unless required by applicable law or agreed to in writing, software"
[//]: # "distributed under the License is distributed on an 'AS IS' BASIS"
[//]: # "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied."
[//]: # "See the License for the specific language governing permissions and"
[//]: # "limitations under the License."
# Semantic Segmentation : Locally and using Triton
## Pre-requisites
- Recommended Linux distros:
- Ubuntu >= 20.04 (tested with 20.04 and 22.04)
- WSL2 with Ubuntu >= 20.04 (tested with 20.04)
- CUDA driver >= 11.7
- Triton server and client docker >= 22.07
- Refer to the [Samples README](../README.md) for Pre-requisites to run the segmentation pipeline
# Instructions to run the sample without Triton
1. Launch the docker
```bash
docker run -ti --gpus=all -v <local mount path>:/cvcuda -w /cvcuda nvcr.io/nvidia/tensorrt:22.09-py3
```
2. Install the dependencies
```bash
./samples/scripts/install_dependencies.sh
```
3. Run the segmentation sample for different data modalities
a. Run segmentation on a single image
```bash
python3 ./samples/segmentation/python/main.py -i ./samples/assets/images/tabby_tiger_cat.jpg -b 1
```
b. Run segmentation on folder containing images with pytorch backend
```bash
python3 ./samples/segmentation/python/main.py -i ./samples/assets/images -b 2 -bk pytorch
```
c. Run segmentation on a video file with tensorrt backend
```bash
python3 ./samples/segmentation/python/main.py -i ./samples/assets/videos/pexels-ilimdar-avgezer-7081456.mp4 -b 4 -bk tensorrt
```
4. To benchmark this run, we can use the benchmark.py in the following way. It should launch 1 process, ignore 1 batch from front and end as warmup batches, save per process and overall numbers as JSON files in /tmp directory. To understand more about performance benchmarking in CV-CUDA, please refer to [Performance Benchmarking README](../../scripts/README.md)
```bash
python3 ./samples/scripts/benchmark.py -np 1 -w 1 -o /tmp ./samples/segmentation/python/main.py -b 4 -i ./samples/assets/videos/pexels-ilimdar-avgezer-7081456.mp4
```
# Instructions to run the sample with Triton
## Triton Server instructions
Triton has different public [Docker images](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver): `-py3-sdk` for Triton client libraries, `-py3` for Triton server libraries with TensorRT, ONNX, Pytorch, TensorFlow, `-pyt-python-py3` for Triton server libraries with PyTorch and Python backend only.
1. Launch the triton server
```bash
docker run --shm-size=1g --ulimit memlock=-1 -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -ti --gpus=all -v <local mount path>:/cvcuda -w /cvcuda nvcr.io/nvidia/tritonserver:22.12-py3
```
2. Install the dependencies
```bash
./samples/scripts/install_dependencies.sh
pip3 install tensorrt
```
3. Install the CV-CUDA packages. Pre-built packages `.deb`, `.tar.xz`, `.whl` are only available on Github, so need to download from there. Otherwise, please build from source. Please note that since the above container comes with Python 3.8.10, we will install cvcuda-python3.8-0 package as mentioned below. If you have any other Python distributions, you would need to use the appropriate cvcuda-python packages below.
```bash
wget https://github.com/CVCUDA/CV-CUDA/releases/download/v0.6.0-beta/cvcuda-lib-0.6.0_beta-cuda11-x86_64-linux.deb \
https://github.com/CVCUDA/CV-CUDA/releases/download/v0.6.0-beta/cvcuda-python3.8-0.6.0_beta-cuda11-x86_64-linux.deb \
https://github.com/CVCUDA/CV-CUDA/releases/download/v0.6.0-beta/cvcuda_cu11-0.6.0b0-cp310-cp310-linux_x86_64.whl \
-P /tmp/cvcuda && \
apt-get install -y /tmp/cvcuda/*.deb && \
pip3 install /tmp/cvcuda/*.whl
```
4. Start the triton server.
Update the `inference_backend` parameter in config.pbtxt to "pytorch" or "tensorrt". Default backend is "tensorrt"
```bash
tritonserver --model-repository `pwd`/samples/segmentation/python/triton_models [--log-info=1]
```
## Triton Client instructions
1. Launch the triton client docker
```bash
docker run -ti --net host --gpus=all -v <local_mount_path>:/cvcuda -w /cvcuda nvcr.io/nvidia/tritonserver:22.12-py3-sdk /bin/bash
```
In case the client and server are on the same machine in a local-server setup, we can simply reuse the server image (and even docker exec into the same container) by installing the Triton client utilities:
```bash
pip3 install tritonclient[all]
```
Convert local video file to stream data, we need [PyAV](https://github.com/PyAV-Org/PyAV), the Pythonic bindings for FFmpeg libraries, which is already in the `install_dependencies.sh`:
```bash
pip3 install av
```
2. Install the dependencies
```bash
cd /cvcuda
./samples/scripts/install_dependencies.sh
```
3. Run client script for different data modalities
a. Run segmentation on a single image
```bash
python3 ./samples/segmentation/python/triton_client.py -i ./samples/assets/images/tabby_tiger_cat.jpg -b 1
```
b. Run segmentation on folder containing images
```bash
python3 ./samples/segmentation/python/triton_client.py -i ./samples/assets/images -b 2
```
c. Run segmentation on a video file
```bash
python3 ./samples/segmentation/python/triton_client.py -i ./samples/assets/videos/pexels-ilimdar-avgezer-7081456.mp4 -b 4
```
d. Run segmentation on a video file with streamed encoding/decoding (highly recommended as performance is greatly improved in this mode), use --stream_video or -sv
```bash
python3 ./samples/segmentation/python/triton_client.py -i ./samples/assets/videos/pexels-ilimdar-avgezer-7081456.mp4 -o /tmp -b 4 -sv [--log_level=debug]
```
4. To benchmark this client run, we can use the benchmark.py in the following way. It should launch 1 process, ignore 1 batch from front and end as warmup batches, save per process and overall numbers as JSON files in /tmp directory. To understand more about performance benchmarking in CV-CUDA, please refer to [Performance Benchmarking README](../../scripts/README.md)
```bash
python3 ./samples/scripts/benchmark.py -np 1 -w 1 -o /tmp ./samples/segmentation/python/triton_client.py -i ./samples/assets/videos/pexels-ilimdar-avgezer-7081456.mp4 -b 4 -sv
```