Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions axengine/_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#

import ctypes.util as cutil
import os

providers = []
axengine_provider_name = 'AxEngineExecutionProvider'
Expand All @@ -14,13 +15,22 @@
_axengine_lib_name = 'ax_engine'
_axclrt_lib_name = 'axcl_rt'

# check if axcl_rt is installed, so if available, it's the default provider
if cutil.find_library(_axclrt_lib_name) is not None:
providers.append(axclrt_provider_name)
_AXCL_HOST_DEVICE = '/dev/axcl_host'

# check if axcl_rt is installed, so if available, it's the default provider
_has_axclrt_lib = cutil.find_library(_axclrt_lib_name) is not None
# check if ax_engine is installed
if cutil.find_library(_axengine_lib_name) is not None:
providers.append(axengine_provider_name)
_has_axengine_lib = cutil.find_library(_axengine_lib_name) is not None

if _has_axclrt_lib and _has_axengine_lib:
if os.path.exists(_AXCL_HOST_DEVICE):
providers = [axclrt_provider_name, axengine_provider_name]
else:
providers = [axengine_provider_name, axclrt_provider_name]
elif _has_axclrt_lib:
providers = [axclrt_provider_name]
elif _has_axengine_lib:
providers = [axengine_provider_name]


def get_all_providers():
Expand Down
Loading