demo_trt.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import sys
  2. import os
  3. import time
  4. import argparse
  5. import numpy as np
  6. import cv2
  7. # from PIL import Image
  8. import tensorrt as trt
  9. import pycuda.driver as cuda
  10. import pycuda.autoinit
  11. from tool.utils import *
  12. try:
  13. # Sometimes python2 does not understand FileNotFoundError
  14. FileNotFoundError
  15. except NameError:
  16. FileNotFoundError = IOError
  17. def GiB(val):
  18. return val * 1 << 30
  19. def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
  20. '''
  21. Parses sample arguments.
  22. Args:
  23. description (str): Description of the sample.
  24. subfolder (str): The subfolder containing data relevant to this sample
  25. find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
  26. Returns:
  27. str: Path of data directory.
  28. Raises:
  29. FileNotFoundError
  30. '''
  31. # Standard command-line arguments for all samples.
  32. kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
  33. parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  34. parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory.", default=kDEFAULT_DATA_ROOT)
  35. args, unknown_args = parser.parse_known_args()
  36. # If data directory is not specified, use the default.
  37. data_root = args.datadir
  38. # If the subfolder exists, append it to the path, otherwise use the provided path as-is.
  39. subfolder_path = os.path.join(data_root, subfolder)
  40. data_path = subfolder_path
  41. if not os.path.exists(subfolder_path):
  42. print("WARNING: " + subfolder_path + " does not exist. Trying " + data_root + " instead.")
  43. data_path = data_root
  44. # Make sure data directory exists.
  45. if not (os.path.exists(data_path)):
  46. raise FileNotFoundError(data_path + " does not exist. Please provide the correct data path with the -d option.")
  47. # Find all requested files.
  48. for index, f in enumerate(find_files):
  49. find_files[index] = os.path.abspath(os.path.join(data_path, f))
  50. if not os.path.exists(find_files[index]):
  51. raise FileNotFoundError(find_files[index] + " does not exist. Please provide the correct data path with the -d option.")
  52. return data_path, find_files
  53. # Simple helper data class that's a little nicer to use than a 2-tuple.
  54. class HostDeviceMem(object):
  55. def __init__(self, host_mem, device_mem):
  56. self.host = host_mem
  57. self.device = device_mem
  58. def __str__(self):
  59. return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
  60. def __repr__(self):
  61. return self.__str__()
  62. # Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
  63. def allocate_buffers(engine, batch_size):
  64. inputs = []
  65. outputs = []
  66. bindings = []
  67. stream = cuda.Stream()
  68. for binding in engine:
  69. size = trt.volume(engine.get_binding_shape(binding)) * batch_size
  70. dims = engine.get_binding_shape(binding)
  71. # in case batch dimension is -1 (dynamic)
  72. if dims[0] < 0:
  73. size *= -1
  74. dtype = trt.nptype(engine.get_binding_dtype(binding))
  75. # Allocate host and device buffers
  76. host_mem = cuda.pagelocked_empty(size, dtype)
  77. device_mem = cuda.mem_alloc(host_mem.nbytes)
  78. # Append the device buffer to device bindings.
  79. bindings.append(int(device_mem))
  80. # Append to the appropriate list.
  81. if engine.binding_is_input(binding):
  82. inputs.append(HostDeviceMem(host_mem, device_mem))
  83. else:
  84. outputs.append(HostDeviceMem(host_mem, device_mem))
  85. return inputs, outputs, bindings, stream
  86. # This function is generalized for multiple inputs/outputs.
  87. # inputs and outputs are expected to be lists of HostDeviceMem objects.
  88. def do_inference(context, bindings, inputs, outputs, stream):
  89. # Transfer input data to the GPU.
  90. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
  91. # Run inference.
  92. context.execute_async(bindings=bindings, stream_handle=stream.handle)
  93. # Transfer predictions back from the GPU.
  94. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
  95. # Synchronize the stream
  96. stream.synchronize()
  97. # Return only the host outputs.
  98. return [out.host for out in outputs]
  99. TRT_LOGGER = trt.Logger()
  100. def main(engine_path, image_path, image_size):
  101. with get_engine(engine_path) as engine, engine.create_execution_context() as context:
  102. buffers = allocate_buffers(engine, 1)
  103. IN_IMAGE_H, IN_IMAGE_W = image_size
  104. context.set_binding_shape(0, (1, 3, IN_IMAGE_H, IN_IMAGE_W))
  105. image_src = cv2.imread(image_path)
  106. num_classes = 80
  107. for i in range(2): # This 'for' loop is for speed check
  108. # Because the first iteration is usually longer
  109. boxes = detect(context, buffers, image_src, image_size, num_classes)
  110. if num_classes == 20:
  111. namesfile = 'data/voc.names'
  112. elif num_classes == 80:
  113. namesfile = 'data/coco.names'
  114. else:
  115. namesfile = 'data/names'
  116. class_names = load_class_names(namesfile)
  117. plot_boxes_cv2(image_src, boxes[0], savename='predictions_trt.jpg', class_names=class_names)
  118. def get_engine(engine_path):
  119. # If a serialized engine exists, use it instead of building an engine.
  120. print("Reading engine from file {}".format(engine_path))
  121. with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
  122. return runtime.deserialize_cuda_engine(f.read())
  123. def detect(context, buffers, image_src, image_size, num_classes):
  124. IN_IMAGE_H, IN_IMAGE_W = image_size
  125. ta = time.time()
  126. # Input
  127. resized = cv2.resize(image_src, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR)
  128. img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
  129. img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
  130. img_in = np.expand_dims(img_in, axis=0)
  131. img_in /= 255.0
  132. img_in = np.ascontiguousarray(img_in)
  133. print("Shape of the network input: ", img_in.shape)
  134. # print(img_in)
  135. inputs, outputs, bindings, stream = buffers
  136. print('Length of inputs: ', len(inputs))
  137. inputs[0].host = img_in
  138. trt_outputs = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
  139. print('Len of outputs: ', len(trt_outputs))
  140. trt_outputs[0] = trt_outputs[0].reshape(1, -1, 1, 4)
  141. trt_outputs[1] = trt_outputs[1].reshape(1, -1, num_classes)
  142. tb = time.time()
  143. print('-----------------------------------')
  144. print(' TRT inference time: %f' % (tb - ta))
  145. print('-----------------------------------')
  146. boxes = post_processing(img_in, 0.4, 0.6, trt_outputs)
  147. return boxes
  148. if __name__ == '__main__':
  149. engine_path = sys.argv[1]
  150. image_path = sys.argv[2]
  151. if len(sys.argv) < 4:
  152. image_size = (416, 416)
  153. elif len(sys.argv) < 5:
  154. image_size = (int(sys.argv[3]), int(sys.argv[3]))
  155. else:
  156. image_size = (int(sys.argv[3]), int(sys.argv[4]))
  157. main(engine_path, image_path, image_size)