Torch 모델을 TensorRT 로 변환하는 코드를 정리한 글입니다.

Requirements

  • pycuda
  • cuda-python
  • tensorrt
  • onnx
  • torch

1. Convert Torch to Onnx

import onnx
import torch

DEVICE = "cuda:0"
MODEL_NAME = "TEST"

MODEL_PATH = f"{MODEL_NAME}.pt"
ONNX_PATH = f"{MODEL_NAME}.onnx"

input_shapes = [1, 3, 256, 144] # Use Custom Input
dummy_input = torch.ones(input_shapes, dtype=torch.float32).to(DEVICE)

model = TorchModel() # Use Custom Model
model.eval()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
print("Load Torch Model Weight", MODEL_PATH)

with torch.no_grad():
    torch.onnx.export(model,  # 실행될 모델
                      dummy_input,  # 모델 입력값 (튜플 또는 여러 입력값들도 가능)
                      ONNX_PATH,  # 모델 저장 경로 (파일 또는 파일과 유사한 객체 모두 가능)
                      export_params=True,  # 모델 파일 안에 학습된 모델 가중치를 저장할지의 여부
                      opset_version=17,  # 모델을 변환할 때 사용할 ONNX 버전
                      do_constant_folding=True,  # 최적화시 상수폴딩을 사용할지의 여부
                      input_names=['input'],  # 모델의 입력값을 가리키는 이름
                      output_names=['output'],  # 모델의 출력값을 가리키는 이름
                      dynamic_axes={'input': {0: 'batch_size'},  # 가변적인 길이를 가진 차원
                                    'output': {0: 'batch_size'}})

onnx_model = onnx.load(ONNX_PATH)
onnx.checker.check_model(onnx_model)
print("Convert Torch to Onnx", ONNX_PATH)

2. Onnx to TensorRT Engine

import tensorrt as trt
import pycuda.driver as cuda

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
# Dynamic Input Batch Flag
EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
# TRT가 네트워크를 최적화 하는 동안 사용할 메모리 
MAX_GB = 1
MAX_WORKSPACE_SIZE = MAX_GB * (1 << 30) 
MAX_BATCH = 512
FP16_MODE = False

def build_engine(onnx_file_path, trt_file_path, fp16_mode, max_workspace_size):
    
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
        # BuilderConfig 객체 생성
        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace_size)
        if fp16_mode:
            config.set_flag(trt.BuilderFlag.FP16)

        # ONNX 모델 읽기
        with open(onnx_file_path, 'rb') as model:
            if not parser.parse(model.read()):
                print('Failed to parse the ONNX file.')
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return False

        inputs = [network.get_input(i) for i in range(network.num_inputs)]
        outputs = [network.get_output(i) for i in range(network.num_outputs)]

        for input in inputs:
            print(f"Model {input.name} shape: {input.shape} {input.dtype}")
        for output in outputs:
            print(f"Model {output.name} shape: {output.shape} {output.dtype}")

        profile = builder.create_optimization_profile()
        min_shape = [1] + input_shapes[1:]
        opt_shape = [int(MAX_BATCH // 2)] + input_shapes[1:]
        max_shape = [MAX_BATCH] + input_shapes[1:]
        for input in inputs:
            profile.set_shape(input.name, min_shape, opt_shape, max_shape)
        config.add_optimization_profile(profile)

        # TensorRT 엔진 빌드
        engine = builder.build_engine(network, config)
        if engine is None:
            print('Failed to build the TensorRT engine.')
            return False
        with open(trt_file_path, 'wb') as f:
            f.write(engine.serialize())

        # engine_bytes = builder.build_serialized_network(network, config)
        # with open(trt_file_path, "wb") as f:
        #     f.write(engine_bytes)

        return True

3. TRT Inference Code

import tensorrt as trt
from cuda import cudart, cuda
import numpy as np
import ctypes
from typing import Optional, List, Union
import locale
locale.getpreferredencoding = lambda: "UTF-8"


def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    if isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))
def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res


class HostDeviceMem:
    """Pair of host and device memory, where the host memory is wrapped in a numpy array"""
    def __init__(self, size: int, dtype: Optional[np.dtype] = None):
        dtype = dtype or np.dtype(np.uint8)
        nbytes = size * dtype.itemsize
        host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
        pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))

        self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
        self._device = cuda_call(cudart.cudaMalloc(nbytes))
        self._nbytes = nbytes

    @property
    def host(self) -> np.ndarray:
        return self._host

    @host.setter
    def host(self, data: Union[np.ndarray, bytes]):
        if isinstance(data, np.ndarray):
            if data.size > self.host.size:
                raise ValueError(
                    f"Tried to fit an array of size {data.size} into host memory of size {self.host.size}"
                )
            np.copyto(self.host[:data.size], data.flat, casting='safe')
        else:
            assert self.host.dtype == np.uint8
            self.host[:self.nbytes] = np.frombuffer(data, dtype=np.uint8)

    @property
    def device(self) -> int:
        return self._device

    @property
    def nbytes(self) -> int:
        return self._nbytes

    def __str__(self):
        return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"

    def __repr__(self):
        return self.__str__()

    def free(self):
        cuda_call(cudart.cudaFree(self.device))
        cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))


def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice))

def free_buffers(inputs: List[HostDeviceMem], outputs: List[HostDeviceMem], stream: cudart.cudaStream_t):
    for mem in inputs + outputs:
        mem.free()
    cuda_call(cudart.cudaStreamDestroy(stream))

def _do_inference_base(inputs, outputs, stream, execute_async_func, dim=128):
    # Transfer input data to the GPU.
    kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
    [cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs]
    # Run inference.
    execute_async_func()
    # Transfer predictions back from the GPU.
    kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
    [cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs]
    # Synchronize the stream
    cuda_call(cudart.cudaStreamSynchronize(stream))
    # Return only the host outputs.
    return [np.array(out.host).reshape(-1, dim) for out in outputs]


# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, engine, bindings, inputs, outputs, stream, dim):
    def execute_async_func():
        context.execute_async_v3(stream_handle=stream)
    # Setup context tensor address.
    num_io = engine.num_io_tensors
    for i in range(num_io):
        context.set_tensor_address(engine.get_tensor_name(i), bindings[i])

    return _do_inference_base(inputs, outputs, stream, execute_async_func, dim)


class TRTInference:
    def __init__(self, engine_path, dim=128):
        # TensorRT 로거 설정
        self._TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
        self._engine = self._load_engine(engine_path)
        self._context = self._engine.create_execution_context()
        self._dim = dim


    def _load_engine(self, engine_path: str) -> trt.ICudaEngine:
        runtime = trt.Runtime(self._TRT_LOGGER)
        with open(engine_path, "rb") as plan:
            engine = runtime.deserialize_cuda_engine(plan.read())
        print("Load TRT Engine Successfully", engine_path)
        return engine

    def allocate_buffer(self, batch_size: int):

        stream = cuda_call(cudart.cudaStreamCreate())

        inputs = []
        outputs = []
        bindings = []
        for i in range(self._engine.num_io_tensors):
            tensor_name = self._engine.get_tensor_name(i)

            if self._engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
                shape = self._engine.get_tensor_profile_shape(tensor_name, 0)[-1]

            else:
                shape = self._engine.get_tensor_shape(tensor_name)
                # Replace dynamic batch dim max input batch
            shape[0] = batch_size
            # Size in bytes
            size = trt.volume(shape)
            trt_type = self._engine.get_tensor_dtype(tensor_name)

            # Allocate host and device buffers
            if trt.nptype(trt_type):
                dtype = np.dtype(trt.nptype(trt_type))
                bindingMemory = HostDeviceMem(size, dtype)
            else:  # no numpy support: create a byte array instead (BF16, FP8, INT4)
                size = int(size * trt_type.itemsize)
                bindingMemory = HostDeviceMem(size)

            # Append the device buffer to device bindings
            bindings.append(int(bindingMemory.device))

            # Append to the appropriate list
            if self._engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
                inputs.append(bindingMemory)
            else:
                outputs.append(bindingMemory)

        return stream, inputs, outputs, bindings

    def infer(self, input_data: np.ndarray):

        shape = input_data.shape
        stream, inputs, outputs, bindings = self.allocate_buffer(shape[0])

        self._context.set_input_shape("input", shape)
        np.copyto(inputs[0].host, input_data.ravel())
        memcpy_host_to_device(inputs[0].device, input_data)

        results = do_inference(self._context, self._engine, bindings, inputs, outputs, stream, self._dim)

        free_buffers(inputs, outputs, stream)
        return results[0]

4. Check Accuracy & Time Difference


import torch
import time
import numpy as np
import time

trt_model = TRTInference("test.trt")
torch_model = TorchModel() # Use Custom Model
torch_model.eval()
torch_model.load_state_dict(torch.load("test.pth", map_location="cuda:0"))

STEP = 50
MAX_BATCH_SIZE = 512

for batch in [1, MAX_BATCH_SIZE//2, MAX_BATCH_SIZE]:
    diffs = []
    torch_time = 0
    trt_time = 0
    for i in range(STEP):
        
        shape = [batch, 3, 256, 144]
        input_data = np.random.rand(*shape).astype("float32")
        torch_input = torch.from_numpy(input_data).to(TORCH_DEVICE)
    
        start = time.time()
        torch_out = torch_model(torch_input)
        torch_time += time.time() - start
    
        start = time.time()
        trt_out = trt_model.infer(input_data)
        trt_time += time.time() - start

        max_diff = np.max(np.abs(trt_out-torch_out.cpu().detach().numpy()))
        diffs.append(max_diff)

    print(f"BATCH {batch} | Torch: {torch_time/STEP:.5f} sec | Trt: {trt_time/STEP:.5f} sec | Max Diff: {np.max(np.array(diffs))}")