diff --git a/tensorrt_convert.py b/tensorrt_convert.py index e7e3e92..2502c01 100644 --- a/tensorrt_convert.py +++ b/tensorrt_convert.py @@ -5,6 +5,7 @@ import comfy.model_management import tensorrt as trt +from .tensorrt_error_recorder import TrTErrorRecorder import folder_paths from tqdm import tqdm @@ -284,6 +285,7 @@ def forward(self, x, timesteps, context, y=None): # TRT conversion starts here logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) + builder.error_recorder = TrTErrorRecorder() network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) diff --git a/tensorrt_error_recorder.py b/tensorrt_error_recorder.py new file mode 100644 index 0000000..86c2c81 --- /dev/null +++ b/tensorrt_error_recorder.py @@ -0,0 +1,47 @@ +import logging +import threading + +import tensorrt as trt + +# TensorRT Python API Docs: https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/ErrorRecorder.html +# +# NOTE: TensorRT does not attempt to marshall errors across threads, so IErrorRecorder implementations must be thread-safe to allow for the possibility of errors occuring on multiple threads. +class TrTErrorRecorder(trt.IErrorRecorder): + def __init__(self): + self.lock = threading.Lock() + self.error_list = [] + super().__init__() + + def clear(self): + with self.lock: + self.error_list = [] + + def get_error_code(self, error_index): + with self.lock: + if error_index >= len(self.error_list) or error_index < 0: + raise IndexError(f'Invalid error index "{error_index}"') + + return self.error_list[error_index]['error_code'] + + def get_error_desc(self, error_index): + with self.lock: + if error_index >= len(self.error_list) or error_index < 0: + raise IndexError(f'Invalid error index "{error_index}"') + + return self.error_list[error_index]['error_desc'] + + def has_overflowed(self): + return False + + def num_errors(self): + with self.lock: + return len(self.error_list) + + def report_error(self, error_code, error_desc): + logging.error(f"TensorRT has encountered an error. ErrorCode: {error_code}. ErrorDesc: {error_desc}") + + with self.lock: + self.error_list.append({'error_code': error_code, 'error_desc': error_desc}) + + # TODO Future: return True for errors we consider 'fatal', which hints to TensorRT to stop execution. + return False diff --git a/tensorrt_loader.py b/tensorrt_loader.py index 8ceff29..d03b489 100644 --- a/tensorrt_loader.py +++ b/tensorrt_loader.py @@ -18,11 +18,13 @@ [os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"}) import tensorrt as trt +from .tensorrt_error_recorder import TrTErrorRecorder trt.init_libnvinfer_plugins(None, "") logger = trt.Logger(trt.Logger.INFO) runtime = trt.Runtime(logger) +runtime.error_recorder = TrTErrorRecorder() # Is there a function that already exists for this? def trt_datatype_to_torch(datatype): @@ -35,11 +37,28 @@ def trt_datatype_to_torch(datatype): elif datatype == trt.bfloat16: return torch.bfloat16 +def check_for_trt_errors(runtime): + num_deserialize_errors = runtime.error_recorder.num_errors() + if num_deserialize_errors == 0: + return + + error_string = '' + for error_index in range(num_deserialize_errors): + if error_index > 0: + error_string += "\n" + error_string += runtime.error_recorder.get_error_desc(error_index) + runtime.error_recorder.clear() + raise RuntimeError(f'Failed to load TensorRT engine: {error_string}') + class TrTUnet: def __init__(self, engine_path): with open(engine_path, "rb") as f: self.engine = runtime.deserialize_cuda_engine(f.read()) + check_for_trt_errors(runtime) + self.context = self.engine.create_execution_context() + check_for_trt_errors(runtime) + self.dtype = torch.float16 def set_bindings_shape(self, inputs, split_batch):