diff --git a/axengine/_providers.py b/axengine/_providers.py index dfab02e..1e4f904 100644 --- a/axengine/_providers.py +++ b/axengine/_providers.py @@ -6,6 +6,7 @@ # import ctypes.util as cutil +import os providers = [] axengine_provider_name = 'AxEngineExecutionProvider' @@ -14,13 +15,22 @@ _axengine_lib_name = 'ax_engine' _axclrt_lib_name = 'axcl_rt' -# check if axcl_rt is installed, so if available, it's the default provider -if cutil.find_library(_axclrt_lib_name) is not None: - providers.append(axclrt_provider_name) +_AXCL_HOST_DEVICE = '/dev/axcl_host' +# check if axcl_rt is installed, so if available, it's the default provider +_has_axclrt_lib = cutil.find_library(_axclrt_lib_name) is not None # check if ax_engine is installed -if cutil.find_library(_axengine_lib_name) is not None: - providers.append(axengine_provider_name) +_has_axengine_lib = cutil.find_library(_axengine_lib_name) is not None + +if _has_axclrt_lib and _has_axengine_lib: + if os.path.exists(_AXCL_HOST_DEVICE): + providers = [axclrt_provider_name, axengine_provider_name] + else: + providers = [axengine_provider_name, axclrt_provider_name] +elif _has_axclrt_lib: + providers = [axclrt_provider_name] +elif _has_axengine_lib: + providers = [axengine_provider_name] def get_all_providers():