Semantic Segmentation Inference Using TensorRT
The semantic segmentation sample in CVCUDA uses the fcn_resnet101
deep learning model from the torchvision
library. Since the model does not come with the softmax layer at the end, we are going to add one. The following code snippet shows how the model is setup for inference use case with TensorRT.
TensorRT requires a serialized TensorRT engine to run the inference. One can generate such an engine by first converting an existing PyTorch model to ONNX and then converting the ONNX to a TensorRT engine. The serialized TensorRT engine is good to work on the specific GPU with the maximum batch size it was given at the creation time. Since ONNX and TensorRT model generation is a time consuming operation, we avoid doing this every-time by first checking if one of those already exists (most likely due to a previous run of this sample.) If so, we simply use those models rather than generating a new one.
Finally we take care of setting up the I/O bindings. We allocate the output Tensors in advance for TensorRT. Helper methods such as convert_onnx_to_tensorrt
and setup_tensort_bindings
are defined in the helper script file samples/common/python/trt_utils.py
1class SegmentationTensorRT:
2 def __init__(
3 self,
4 output_dir,
5 seg_class_name,
6 batch_size,
7 image_size,
8 device_id,
9 cvcuda_perf,
10 ):
11 self.logger = logging.getLogger(__name__)
12 self.output_dir = output_dir
13 self.device_id = device_id
14 self.cvcuda_perf = cvcuda_perf
15 # For TensorRT, the process is the following:
16 # We check if there already exists a TensorRT engine generated
17 # previously. If not, we check if there exists an ONNX model generated
18 # previously. If not, we will generate both of the one by one
19 # and then use those.
20 # The underlying pytorch model that we use in case of TensorRT
21 # inference is the FCN model from torchvision. It is only used during
22 # the conversion process and not during the inference.
23 onnx_file_path = os.path.join(
24 self.output_dir,
25 "model.%d.%d.%d.onnx"
26 % (
27 batch_size,
28 image_size[1],
29 image_size[0],
30 ),
31 )
32 trt_engine_file_path = os.path.join(
33 self.output_dir,
34 "model.%d.%d.%d.trtmodel"
35 % (
36 batch_size,
37 image_size[1],
38 image_size[0],
39 ),
40 )
41
42 with torch.cuda.stream(torch.cuda.ExternalStream(cvcuda.Stream.current.handle)):
43
44 torch_model = segmentation_models.fcn_resnet101
45 weights = segmentation_models.FCN_ResNet101_Weights.DEFAULT
46
47 try:
48 self.class_index = weights.meta["categories"].index(seg_class_name)
49 except ValueError:
50 raise ValueError(
51 "Requested segmentation class '%s' is not supported by the "
52 "fcn_resnet101 model. All supported class names are: %s"
53 % (seg_class_name, ", ".join(weights.meta["categories"]))
54 )
55
56 # Check if we have a previously generated model.
57 if not os.path.isfile(trt_engine_file_path):
58 if not os.path.isfile(onnx_file_path):
59 # First we use PyTorch to create a segmentation model.
60 with torch.no_grad():
61 fcn_base = torch_model(weights=weights)
62
63 class FCN_Softmax(torch.nn.Module):
64 def __init__(self, fcn):
65 super(FCN_Softmax, self).__init__()
66 self.fcn = fcn
67
68 def forward(self, x):
69 infer_output = self.fcn(x)["out"]
70 return torch.nn.functional.softmax(infer_output, dim=1)
71
72 fcn_base.eval()
73 pyt_model = FCN_Softmax(fcn_base)
74 pyt_model.cuda(self.device_id)
75 pyt_model.eval()
76
77 # Allocate a dummy input to help generate an ONNX model.
78 dummy_x_in = torch.randn(
79 batch_size,
80 3,
81 image_size[1],
82 image_size[0],
83 requires_grad=False,
84 ).cuda(self.device_id)
85
86 # Generate an ONNX model using the PyTorch's onnx export.
87 torch.onnx.export(
88 pyt_model,
89 args=dummy_x_in,
90 f=onnx_file_path,
91 export_params=True,
92 opset_version=15,
93 do_constant_folding=True,
94 input_names=["input"],
95 output_names=["output"],
96 dynamic_axes={
97 "input": {0: "batch_size"},
98 "output": {0: "batch_size"},
99 },
100 )
101
102 # Remove the tensors and model after this.
103 del pyt_model
104 del dummy_x_in
105 torch.cuda.empty_cache()
106
107 # Now that we have an ONNX model, we will continue generating a
108 # serialized TensorRT engine from it.
109 convert_onnx_to_tensorrt(
110 onnx_file_path,
111 trt_engine_file_path,
112 max_batch_size=batch_size,
113 max_workspace_size=1,
114 )
115
116 # Once the TensorRT engine generation is all done, we load it.
117 trt_logger = trt.Logger(trt.Logger.ERROR)
118 with open(trt_engine_file_path, "rb") as f, trt.Runtime(
119 trt_logger
120 ) as runtime:
121 trt_model = runtime.deserialize_cuda_engine(f.read())
122
123 # Create execution context.
124 self.model = trt_model.create_execution_context()
125
126 # Allocate the output bindings.
127 self.output_tensors, self.output_idx = setup_tensort_bindings(
128 trt_model,
129 batch_size,
130 self.device_id,
131 self.logger,
132 )
133
134 self.logger.info("Using TensorRT as the inference engine.")
To run the inference the __call__
method is used. It uses the correct I/O bindings and makes sure to use the CUDA stream to perform the forward inference pass. In passing the inputs, we are directly going to pass the data from the CVCUDA tensor without further conversions. The API to do so does involve accessing an internal member named __cuda_array_interface__
as shown in the code below.
1def __call__(self, tensor):
2 self.cvcuda_perf.push_range("inference.tensorrt")
3
4 # Grab the data directly from the pre-allocated tensor.
5 input_bindings = [tensor.cuda().__cuda_array_interface__["data"][0]]
6 output_bindings = []
7 for t in self.output_tensors:
8 output_bindings.append(t.data_ptr())
9 io_bindings = input_bindings + output_bindings
10
11 # Must call this before inference
12 binding_i = self.model.engine.get_binding_index("input")
13 assert self.model.set_binding_shape(binding_i, tensor.shape)
14
15 self.model.execute_async_v2(
16 bindings=io_bindings, stream_handle=cvcuda.Stream.current.handle
17 )
18
19 segmented = self.output_tensors[self.output_idx]
20
21 self.cvcuda_perf.pop_range()
22 return segmented
23