Classification Inference Using TensorRT

The classification sample in CVCUDA uses the ResNet50 deep learning model from the torchvision library. Since the model does not come with the softmax layer at the end, we are going to add one. The following code snippet shows how the model is setup for inference use case with TensorRT.

TensorRT requires a serialized TensorRT engine to run the inference. One can generate such an engine by first converting an existing PyTorch model to ONNX and then converting the ONNX to a TensorRT engine. The serialized TensorRT engine is good to work on the specific GPU with the maximum batch size it was given at the creation time. Since ONNX and TensorRT model generation is a time consuming operation, we avoid doing this every-time by first checking if one of those already exists (most likely due to a previous run of this sample.) If so, we simply use those models rather than generating a new one.

Finally we take care of setting up the I/O bindings. We allocate the output Tensors in advance for TensorRT. Helper methods such as convert_onnx_to_tensorrt and setup_tensort_bindings are defined in the helper script file samples/common/python/trt_utils.py

  1class ClassificationTensorRT:
  2    def __init__(
  3        self,
  4        output_dir,
  5        batch_size,
  6        image_size,
  7        device_id,
  8        cvcuda_perf,
  9    ):
 10        self.logger = logging.getLogger(__name__)
 11        self.output_dir = output_dir
 12        self.device_id = device_id
 13        self.cvcuda_perf = cvcuda_perf
 14        # For TensorRT, the process is the following:
 15        # We check if there already exists a TensorRT engine generated
 16        # previously. If not, we check if there exists an ONNX model generated
 17        # previously. If not, we will generate both of the one by one
 18        # and then use those.
 19        # The underlying PyTorch model that we use in case of TensorRT
 20        # inference is the ResNet50 model from torchvision. It is only used during
 21        # the conversion process and not during the inference.
 22        onnx_file_path = os.path.join(
 23            self.output_dir,
 24            "model.%d.%d.%d.onnx"
 25            % (
 26                batch_size,
 27                image_size[1],
 28                image_size[0],
 29            ),
 30        )
 31        trt_engine_file_path = os.path.join(
 32            self.output_dir,
 33            "model.%d.%d.%d.trtmodel"
 34            % (
 35                batch_size,
 36                image_size[1],
 37                image_size[0],
 38            ),
 39        )
 40
 41        with torch.cuda.stream(torch.cuda.ExternalStream(cvcuda.Stream.current.handle)):
 42
 43            torch_model = torchvision_models.resnet50
 44            weights = torchvision_models.ResNet50_Weights.DEFAULT
 45            self.labels = weights.meta["categories"]
 46            # Save the list of labels so that the C++ sample can read it.
 47            with open(os.path.join(output_dir, "labels.txt"), "w") as f:
 48                for line in self.labels:
 49                    f.write("%s\n" % line)
 50
 51            # Check if we have a previously generated model.
 52            if not os.path.isfile(trt_engine_file_path):
 53                if not os.path.isfile(onnx_file_path):
 54                    # First we use PyTorch to create a classification model.
 55                    with torch.no_grad():
 56
 57                        class Resnet50_Softmax(torch.nn.Module):
 58                            def __init__(self, resnet50):
 59                                super(Resnet50_Softmax, self).__init__()
 60                                self.resnet50 = resnet50
 61
 62                            def forward(self, x):
 63                                infer_output = self.resnet50(x)
 64                                return torch.nn.functional.softmax(infer_output, dim=1)
 65
 66                        resnet_base = torch_model(weights=weights)
 67                        resnet_base.eval()
 68                        pyt_model = Resnet50_Softmax(resnet_base)
 69                        pyt_model.cuda(self.device_id)
 70                        pyt_model.eval()
 71
 72                        # Allocate a dummy input to help generate an ONNX model.
 73                        dummy_x_in = torch.randn(
 74                            batch_size,
 75                            3,
 76                            image_size[1],
 77                            image_size[0],
 78                            requires_grad=False,
 79                        ).cuda(self.device_id)
 80
 81                        # Generate an ONNX model using the PyTorch's onnx export.
 82                        torch.onnx.export(
 83                            pyt_model,
 84                            args=dummy_x_in,
 85                            f=onnx_file_path,
 86                            export_params=True,
 87                            opset_version=15,
 88                            do_constant_folding=True,
 89                            input_names=["input"],
 90                            output_names=["output"],
 91                            dynamic_axes={
 92                                "input": {0: "batch_size"},
 93                                "output": {0: "batch_size"},
 94                            },
 95                        )
 96
 97                        # Remove the tensors and model after this.
 98                        del pyt_model
 99                        del dummy_x_in
100                        torch.cuda.empty_cache()
101
102                # Now that we have an ONNX model, we will continue generating a
103                # serialized TensorRT engine from it.
104                convert_onnx_to_tensorrt(
105                    onnx_file_path,
106                    trt_engine_file_path,
107                    max_batch_size=batch_size,
108                    max_workspace_size=1,
109                )
110
111            # Once the TensorRT engine generation is all done, we load it.
112            trt_logger = trt.Logger(trt.Logger.ERROR)
113            with open(trt_engine_file_path, "rb") as f, trt.Runtime(
114                trt_logger
115            ) as runtime:
116                trt_model = runtime.deserialize_cuda_engine(f.read())
117
118            # Create execution context.
119            self.model = trt_model.create_execution_context()
120
121            # Allocate the output bindings.
122            self.output_tensors, self.output_idx = setup_tensort_bindings(
123                trt_model,
124                batch_size,
125                self.device_id,
126                self.logger,
127            )
128
129            self.logger.info("Using TensorRT as the inference engine.")

To run the inference the __call__ method is used. It uses the correct I/O bindings and makes sure to use the CUDA stream to perform the forward inference pass. In passing the inputs, we are directly going to pass the data from the CVCUDA tensor without further conversions. The API to do so does involve accessing an internal member named __cuda_array_interface__ as shown in the code below.

 1def __call__(self, tensor):
 2    self.cvcuda_perf.push_range("inference.tensorrt")
 3
 4    # Grab the data directly from the pre-allocated tensor.
 5    input_bindings = [tensor.cuda().__cuda_array_interface__["data"][0]]
 6    output_bindings = []
 7    for t in self.output_tensors:
 8        output_bindings.append(t.data_ptr())
 9    io_bindings = input_bindings + output_bindings
10
11    # Must call this before inference
12    binding_i = self.model.engine.get_binding_index("input")
13    assert self.model.set_binding_shape(binding_i, tensor.shape)
14
15    self.model.execute_async_v2(
16        bindings=io_bindings, stream_handle=cvcuda.Stream.current.handle
17    )
18
19    # Since this model produces only 1 output, we can grab it now.
20    classification_scores = self.output_tensors[0]
21
22    self.cvcuda_perf.pop_range()
23    return classification_scores
24