diff --git a/onnxruntime/core/providers/openvino/exported_symbols.lst b/onnxruntime/core/providers/openvino/exported_symbols.lst index f4c41412594af..6dc5905ae4550 100644 --- a/onnxruntime/core/providers/openvino/exported_symbols.lst +++ b/onnxruntime/core/providers/openvino/exported_symbols.lst @@ -1 +1,3 @@ _GetProvider +_CreateEpFactories +_ReleaseEpFactory diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index bad1d416eeda2..fda7ef6534197 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -403,6 +403,18 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { return ov_ep; } + // This is called during session creation when AppendExecutionProvider_V2 is used. + // This one is called because ParseProviderInfo / ParseConfigOptions, etc. are already + // performed in CreateIExecutionProvider, and so provider_info_ has already been populated. + std::unique_ptr CreateProvider_V2(const OrtSessionOptions& /*session_options*/, + const OrtLogger& session_logger) { + ProviderInfo provider_info = provider_info_; + auto ov_ep = std::make_unique(provider_info, shared_context_); + ov_ep->SetLogger(reinterpret_cast(&session_logger)); + return ov_ep; + } + + private: ProviderInfo provider_info_; std::shared_ptr shared_context_; @@ -433,6 +445,116 @@ struct OpenVINO_Provider : Provider { return std::make_shared(pi, SharedContext::Get()); } + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + // Check if no devices are provided + if (num_devices == 0) { + return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "No devices provided to CreateIExecutionProvider"); + } + + // For provider options that we don't support directly but are still supported through load_config, + // give some specific guidance & example about how to make use of the option through load_config. + const std::vector> block_and_advise_entries = { + {"cache_dir", "\"CACHE_DIR\": \"\""}, + {"precision", "\"INFERENCE_PRECISION_HINT\": \"F32\""}, + {"num_of_threads", "\"INFERENCE_NUM_THREADS\": \"1\""}, + {"num_streams", "\"NUM_STREAMS\": \"1\""}, + {"model_priority", "\"MODEL_PRIORITY\": \"LOW\""}, + {"enable_opencl_throttling", "\"GPU\": {\"PLUGIN_THROTTLE\": \"1\"}"}, + {"enable_qdq_optimizer", "\"NPU\": {\"NPU_QDQ_OPTIMIZATION\": \"YES\"}"} + }; + + for (auto& block_and_advise_entry : block_and_advise_entries) { + if (provider_options.find(block_and_advise_entry.first) != provider_options.end()) { + std::string message = "OpenVINO EP: Option '" + block_and_advise_entry.first + + "' cannot be set when using AppendExecutionProvider_V2. " + + "It can instead be enabled by a load_config key / value pair. For example: " + + block_and_advise_entry.second; + return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, message); + } + } + + // For the rest of the disallowed provider options, give a generic error message. + const std::vector blocked_provider_keys = { + "device_type", "device_id", "device_luid", "context", "disable_dynamic_shapes"}; + + for (const auto& key : blocked_provider_keys) { + if (provider_options.find(key) != provider_options.end()) { + return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, + "OpenVINO EP: Option '" + key + "' cannot be set when using AppendExecutionProvider_V2."); + } + } + + const char* ov_device_key = "ov_device"; + const char* ov_meta_device_key = "ov_meta_device"; + + // Create a unique list of ov_devices that were passed in. + std::unordered_set unique_ov_devices; + std::vector ordered_unique_ov_devices; + for (size_t i = 0; i < num_devices; ++i) { + const auto& device_meta_data = ep_metadata[i]; + auto ov_device_it = device_meta_data->Entries().find(ov_device_key); + if (ov_device_it == device_meta_data->Entries().end()) { + return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, "OpenVINO EP device metadata not found."); + } + auto &ov_device = ov_device_it->second; + + // Add to ordered_unique only if not already present + if (unique_ov_devices.insert(ov_device).second) { + ordered_unique_ov_devices.push_back(ov_device); + } + } + + std::string ov_meta_device_type = "NONE"; + { + auto ov_meta_device_it = ep_metadata[0]->Entries().find(ov_meta_device_key); + if (ov_meta_device_it != ep_metadata[0]->Entries().end()) { + ov_meta_device_type = ov_meta_device_it->second; + } + } + + bool is_meta_device_factory = (ov_meta_device_type != "NONE"); + + if (ordered_unique_ov_devices.size() > 1 && !is_meta_device_factory) { + LOGS_DEFAULT(WARNING) << "[OpenVINO EP] Multiple devices were specified that are not OpenVINO meta devices. Using first ov_device only: " << ordered_unique_ov_devices.at(0); + ordered_unique_ov_devices.resize(1); // Use only the first device if not a meta device factory + } + + std::string ov_device_string; + if (is_meta_device_factory) { + // Build up a meta device string based on the devices that are passed in. E.g. AUTO:NPU,GPU.0,CPU + ov_device_string = ov_meta_device_type; + ov_device_string += ":"; + } + + bool prepend_comma = false; + for (const auto& ov_device : ordered_unique_ov_devices) { + if (prepend_comma) { + ov_device_string += ","; + } + ov_device_string += ov_device; + prepend_comma = true; + } + + provider_options["device_type"] = ov_device_string; + + // Parse provider info with the device type + ProviderInfo pi; + const auto& config_options = session_options.GetConfigOptions(); + ParseProviderInfo(provider_options, &config_options, pi); + ParseConfigOptions(pi); + + // Create and return the execution provider + auto factory = std::make_unique(pi, SharedContext::Get()); + ep = factory->CreateProvider_V2(session_options, logger); + return Status::OK(); + } + void Initialize() override { } diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc new file mode 100644 index 0000000000000..e347bcf1b1aef --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -0,0 +1,182 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include +#include +#include +#include +#include +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include "onnxruntime_c_api.h" +#include "ov_factory.h" +#include "openvino/openvino.hpp" +#include "ov_interface.h" + +using namespace onnxruntime::openvino_ep; +using ov_core_singleton = onnxruntime::openvino_ep::WeakSingleton; + +static void InitCxxApi(const OrtApiBase& ort_api_base) { + static std::once_flag init_api; + std::call_once(init_api, [&]() { + const OrtApi* ort_api = ort_api_base.GetApi(ORT_API_VERSION); + Ort::InitApi(ort_api); + }); +} + +OpenVINOEpPluginFactory::OpenVINOEpPluginFactory(ApiPtrs apis, const std::string& ov_metadevice_name, std::shared_ptr core) + : ApiPtrs{apis}, + ep_name_(ov_metadevice_name.empty() ? provider_name_ : std::string(provider_name_) + "." + ov_metadevice_name), + device_type_(ov_metadevice_name), + ov_core_(std::move(core)) { + OrtEpFactory::GetName = GetNameImpl; + OrtEpFactory::GetVendor = GetVendorImpl; + OrtEpFactory::GetVendorId = GetVendorIdImpl; + OrtEpFactory::GetSupportedDevices = GetSupportedDevicesImpl; + OrtEpFactory::GetVersion = GetVersionImpl; + OrtEpFactory::CreateDataTransfer = CreateDataTransferImpl; + + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. +} + +const std::vector& OpenVINOEpPluginFactory::GetOvDevices() { + static std::vector devices = ov_core_singleton::Get()->get_available_devices(); + return devices; +} + +const std::vector& OpenVINOEpPluginFactory::GetOvMetaDevices() { + static std::vector virtual_devices = [ov_core = ov_core_singleton::Get()] { + std::vector supported_virtual_devices{}; + for (const auto& meta_device : known_meta_devices_) { + try { + ov_core->get_property(meta_device, ov::supported_properties); + supported_virtual_devices.push_back(meta_device); + } catch (ov::Exception&) { + // meta device isn't supported. + } + } + return supported_virtual_devices; + }(); + + return virtual_devices; +} + +OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + + // Create a map for device type mapping + static const std::map ort_to_ov_device_name = { + {OrtHardwareDeviceType::OrtHardwareDeviceType_CPU, "CPU"}, + {OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, "GPU"}, + {OrtHardwareDeviceType::OrtHardwareDeviceType_NPU, "NPU"}, + }; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (ort_api.HardwareDevice_VendorId(&device) != vendor_id_) { + // Not an Intel Device. + continue; + } + + auto device_type = ort_api.HardwareDevice_Type(&device); + auto device_it = ort_to_ov_device_name.find(device_type); + if (device_it == ort_to_ov_device_name.end()) { + // We don't know about this device type + continue; + } + + const auto& ov_device_type = device_it->second; + std::string ov_device_name; + auto get_pci_device_id = [&](const std::string& ov_device) { + try { + ov::device::PCIInfo pci_info = ov_core_->get_property(ov_device, ov::device::pci_info); + return pci_info.device; + } catch (ov::Exception&) { + return 0u; // If we can't get the PCI info, we won't have a device ID. + } + }; + + auto filtered_devices = GetOvDevices(ov_device_type); + auto matched_device = filtered_devices.begin(); + if (filtered_devices.size() > 1 && device_type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // If there are multiple devices of the same type, we need to match by device ID. + matched_device = std::find_if(filtered_devices.begin(), filtered_devices.end(), [&](const std::string& ov_device) { + uint32_t ort_device_id = ort_api.HardwareDevice_DeviceId(&device); + return ort_device_id == get_pci_device_id(ov_device); + }); + } + + if (matched_device == filtered_devices.end()) { + // We didn't find a matching OpenVINO device for the OrtHardwareDevice. + continue; + } + + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + ort_api.CreateKeyValuePairs(&ep_metadata); + ort_api.AddKeyValuePair(ep_metadata, ov_device_key_, matched_device->c_str()); + + if (IsMetaDeviceFactory()) { + ort_api.AddKeyValuePair(ep_metadata, ov_meta_device_key_, device_type_.c_str()); + } + + // Create EP device + auto* status = ort_api.GetEpApi()->CreateEpDevice(this, &device, ep_metadata, ep_options, + &ep_devices[num_ep_devices++]); + + ort_api.ReleaseKeyValuePairs(ep_metadata); + ort_api.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + } + + return nullptr; +} + +extern "C" { +// +// Public symbols +// +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + InitCxxApi(*ort_api_base); + const ApiPtrs api_ptrs{Ort::GetApi(), Ort::GetEpApi(), Ort::GetModelEditorApi()}; + + // Get available devices from OpenVINO + auto ov_core = ov_core_singleton::Get(); + std::vector supported_factories = {""}; + const auto& meta_devices = OpenVINOEpPluginFactory::GetOvMetaDevices(); + supported_factories.insert(supported_factories.end(), meta_devices.begin(), meta_devices.end()); + + const size_t required_factories = supported_factories.size(); + if (max_factories < required_factories) { + return Ort::Status(std::format("Not enough space to return EP factories. Need at least {} factories.", required_factories).c_str(), ORT_INVALID_ARGUMENT); + } + + size_t factory_index = 0; + for (const auto& device_name : supported_factories) { + // Create a factory for this specific device + factories[factory_index++] = new OpenVINOEpPluginFactory(api_ptrs, device_name, ov_core); + } + + *num_factories = factory_index; + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} +} diff --git a/onnxruntime/core/providers/openvino/ov_factory.h b/onnxruntime/core/providers/openvino/ov_factory.h new file mode 100644 index 0000000000000..37739f67323c1 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_factory.h @@ -0,0 +1,156 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include + +#include "core/providers/shared_library/provider_api.h" +#include "openvino/openvino.hpp" + +namespace onnxruntime { +namespace openvino_ep { + +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; +}; + +#define OVEP_DISABLE_MOVE(class_name) \ + class_name(class_name&&) = delete; \ + class_name& operator=(class_name&&) = delete; + +#define OVEP_DISABLE_COPY(class_name) \ + class_name(const class_name&) = delete; \ + class_name& operator=(const class_name&) = delete; + +#define OVEP_DISABLE_COPY_AND_MOVE(class_name) \ + OVEP_DISABLE_COPY(class_name) \ + OVEP_DISABLE_MOVE(class_name) + +template +static auto ApiEntry(Func&& func, std::optional> logger = std::nullopt) { + try { + return func(); + } catch (const Ort::Exception& ex) { + if (logger) { + ORT_CXX_LOG_NOEXCEPT(logger->get(), ORT_LOGGING_LEVEL_ERROR, ex.what()); + } + if constexpr (std::is_same_v) { + return Ort::Status(ex.what(), ex.GetOrtErrorCode()).release(); + } + } catch (const std::exception& ex) { + if (logger) { + ORT_CXX_LOG_NOEXCEPT(logger->get(), ORT_LOGGING_LEVEL_ERROR, ex.what()); + } + if constexpr (std::is_same_v) { + return Ort::Status(ex.what(), ORT_RUNTIME_EXCEPTION).release(); + } + } catch (...) { + if (logger) { + ORT_CXX_LOG_NOEXCEPT(logger->get(), ORT_LOGGING_LEVEL_ERROR, "Unknown exception occurred."); + } + if constexpr (std::is_same_v) { + return Ort::Status("Unknown exception occurred.", ORT_RUNTIME_EXCEPTION).release(); + } + } +} + +class OpenVINOEpPluginFactory : public OrtEpFactory, public ApiPtrs { + public: + OpenVINOEpPluginFactory(ApiPtrs apis, const std::string& ov_device, std::shared_ptr ov_core); + ~OpenVINOEpPluginFactory() = default; + + OVEP_DISABLE_COPY_AND_MOVE(OpenVINOEpPluginFactory) + + static const std::vector& GetOvDevices(); + + std::vector GetOvDevices(const std::string& device_type) { + std::vector filtered_devices; + const auto& devices = GetOvDevices(); + std::copy_if(devices.begin(), devices.end(), std::back_inserter(filtered_devices), + [&device_type](const std::string& device) { + return device.find(device_type) != std::string::npos; + }); + return filtered_devices; + } + + static const std::vector& GetOvMetaDevices(); + + // Member functions + const char* GetName() const { + return ep_name_.c_str(); + } + + const char* GetVendor() const { + return vendor_; + } + + OrtStatus* GetSupportedDevices(const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices); + + bool IsMetaDeviceFactory() const { + return known_meta_devices_.find(device_type_) != known_meta_devices_.end(); + } + + // Constants + static constexpr const char* vendor_ = "Intel"; + static constexpr uint32_t vendor_id_{0x8086}; // Intel's PCI vendor ID + static constexpr const char* ov_device_key_ = "ov_device"; + static constexpr const char* ov_meta_device_key_ = "ov_meta_device"; + static constexpr const char* provider_name_ = "OpenVINOExecutionProvider"; + + private: + std::string ep_name_; + std::string device_type_; + std::vector ov_devices_; + std::shared_ptr ov_core_; + inline static const std::set known_meta_devices_ = { + "AUTO"}; + + public: + // Static callback methods for the OrtEpFactory interface + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->GetName(); + } + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->GetVendor(); + } + + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return OpenVINOEpPluginFactory::vendor_id_; + } + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + auto* factory = static_cast(this_ptr); + return ApiEntry([&]() { return factory->GetSupportedDevices(devices, num_devices, ep_devices, max_ep_devices, p_num_ep_devices); }); + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // return nullptr to indicate that this EP does not support data transfer. + return nullptr; + } + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory*) noexcept { + return ORT_VERSION; + } +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index ee35a3ebef7cb..6d1db4366410b 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -159,7 +159,7 @@ class OVInferRequest { ov::InferRequest& GetNewObj() { return ovInfReq; } - virtual void RewindKVCache(size_t index) {} + virtual void RewindKVCache([[maybe_unused]] size_t index) {} }; class StatefulOVInferRequest : public OVInferRequest { diff --git a/onnxruntime/core/providers/openvino/symbols.def b/onnxruntime/core/providers/openvino/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/openvino/symbols.def +++ b/onnxruntime/core/providers/openvino/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/openvino/version_script.lds b/onnxruntime/core/providers/openvino/version_script.lds index 094abb3329781..3600a4f8f4b51 100644 --- a/onnxruntime/core/providers/openvino/version_script.lds +++ b/onnxruntime/core/providers/openvino/version_script.lds @@ -1,7 +1,9 @@ #_init and _fini should be local VERS_1.0 { global: - GetProvider; + GetProvider; + CreateEpFactories; + ReleaseEpFactory; # Hide everything else. local: diff --git a/onnxruntime/test/providers/openvino/openvino_plugin.cc b/onnxruntime/test/providers/openvino/openvino_plugin.cc new file mode 100644 index 0000000000000..5abca55820a24 --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_plugin.cc @@ -0,0 +1,302 @@ +#include +#include + +#include "gtest/gtest.h" +#include "core/common/common.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "onnxruntime_cxx_api.h" +#include "api_asserts.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +extern std::unique_ptr ort_env; + +struct OrtEpLibraryOv : public ::testing::Test { + static const inline std::filesystem::path library_path = +#if _WIN32 + "onnxruntime_providers_openvino.dll"; +#else + "libonnxruntime_providers_openvino.so"; +#endif + static const inline std::string registration_name = "OpenVINOExecutionProvider"; + + void SetUp() override { +#ifndef _WIN32 + GTEST_SKIP() << "Skipping OpenVINO EP tests as the OpenVINO plugin is not built."; +#endif + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + } + + void TearDown() override { +#ifndef _WIN32 + GTEST_SKIP() << "Skipping OpenVINO EP tests as the OpenVINO plugin is not built."; +#endif + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); + } + + void RunModelWithSession(Ort::Session& session) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data(6, 2.0f); + std::vector ort_inputs; + std::vector ort_input_names; + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); + } + + void RunModelWithPluginEp(Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + RunModelWithSession(session); + } + + void GenerateEpContextOnLegacyPath(std::filesystem::path epctx, bool embed_mode) { + Ort::SessionOptions session_options{}; + std::filesystem::remove(epctx); + // Add config option to enable EP context + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, epctx.string().c_str()); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, embed_mode ? "1" : "0"); + session_options.AppendExecutionProvider_OpenVINO_V2({{"device_type", "CPU"}}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + RunModelWithSession(session); + } + + void GenerateEpContextOnPluginPath(std::filesystem::path epctx, bool embed_mode) { + Ort::SessionOptions session_options{}; + std::filesystem::remove(epctx); + // Add config option to enable EP context + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, epctx.string().c_str()); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, embed_mode ? "1" : "0"); + Ort::ConstEpDevice plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + RunModelWithSession(session); + } + + Ort::ConstEpDevice GetOvCpuEpDevice(std::string device_type = "CPU") { + auto ep_devices = ort_env->GetEpDevices(); + Ort::ConstEpDevice plugin_ep_device{}; + + for (Ort::ConstEpDevice& device : ep_devices) { + if (device.Device().Type() == OrtHardwareDeviceType_CPU && + std::string_view(device.EpName()).find(registration_name) != std::string::npos) { + const auto& meta_kv = device.EpMetadata().GetKeyValuePairs(); + auto device_type_it = meta_kv.find("ov_device"); + if (device_type_it != meta_kv.end()) { + if (device_type_it->second == device_type) { + plugin_ep_device = device; + break; + } + } + } + } + + return plugin_ep_device; + } +}; + +TEST_F(OrtEpLibraryOv, LoadUnloadPluginLibrary) { + auto ep_devices = ort_env->GetEpDevices(); + auto test_cpu_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(test_cpu_ep_device, nullptr); + ASSERT_STREQ(test_cpu_ep_device.EpVendor(), "Intel"); + Ort::ConstHardwareDevice device = test_cpu_ep_device.Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + std::unordered_map ep_metadata_entries = test_cpu_ep_device.EpMetadata().GetKeyValuePairs(); + ASSERT_GT(ep_metadata_entries.size(), 0); + ASSERT_GT(ep_metadata_entries.count("ov_device"), 0); +} + +TEST_F(OrtEpLibraryOv, MetaDevicesAvailable) { + auto ep_devices = ort_env->GetEpDevices(); + auto expected_meta_devices = {"AUTO"}; + + for (auto& expected_meta_device : expected_meta_devices) { + std::string expected_ep_name = registration_name + "." + expected_meta_device; + auto it = std::find_if(ep_devices.begin(), ep_devices.end(), + [&](Ort::ConstEpDevice& device) { + return std::string_view(device.EpName()).find(expected_ep_name) != std::string::npos; + }); + bool meta_device_found = it != ep_devices.end(); + ASSERT_TRUE(meta_device_found) << "Expected to find " << expected_ep_name; + } +} + +TEST_F(OrtEpLibraryOv, RunSessionWithAllAUTODevices) { + auto ep_devices = ort_env->GetEpDevices(); + std::vector matching_devices; + + for (const auto& device : ep_devices) { + std::string ep_name = device.EpName(); + if (ep_name.find(registration_name) != std::string::npos && + (ep_name == registration_name + ".AUTO")) { + matching_devices.push_back(device); + } + } + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, matching_devices, std::unordered_map{}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_MulInference) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + RunModelWithPluginEp(session_options); +} + +TEST_F(OrtEpLibraryOv, PluginEp_PreferCpu_MulInference) { + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunModelWithPluginEp(session_options); +} + +struct EpCtxTestCases { + const ORTCHAR_T* ctx_filename; + bool embed_mode; +}; + +static const std::vector ep_context_cases = { + {ORT_TSTR("mul_1_ctx_cpu_embed1.onnx"), true}, + {ORT_TSTR("mul_1_ctx_cpu_embed0.onnx"), false}, + {ORT_TSTR("testdata/mul_1_ctx_cpu_embed0.onnx"), false}}; + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_variants) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + for (const auto& test_case : ep_context_cases) { + GenerateEpContextOnLegacyPath(test_case.ctx_filename, test_case.embed_mode); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, test_case.ctx_filename, session_options); + RunModelWithSession(session); + } +} + +TEST_F(OrtEpLibraryOv, PluginEp_CheckV2DisallowedProviderOptions) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + std::vector> disallowed_provider_option_examples = { + {{"device_type", "CPU"}}, + {{"device_id", "CPU"}}, + {{"device_luid", "1234"}}, + {{"cache_dir", "cache"}}, + {{"precision", "F32"}}, + {{"context", "4"}}, + {{"num_of_threads", "1"}}, + {{"model_priority", "DEFAULT"}}, + {{"num_streams", "1"}}, + {{"enable_opencl_throttling", "true"}}, + {{"enable_qdq_optimizer", "true"}}, + {{"disable_dynamic_shapes", "true"}}, + }; + for (auto& example : disallowed_provider_option_examples) { + EXPECT_THROW({ + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, example); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); }, Ort::Exception); + } +} + +TEST_F(OrtEpLibraryOv, GenerateEpContextEmbedded) { + GenerateEpContextOnPluginPath(ORT_TSTR("mul_1_ctx_cpu_embed1.onnx"), true); +} + +TEST_F(OrtEpLibraryOv, GenerateEpContext) { + GenerateEpContextOnPluginPath(ORT_TSTR("mul_1_ctx_cpu_embed0.onnx"), false); +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_plugin_roundtrip_variants) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + for (const auto& test_case : ep_context_cases) { + if (test_case.embed_mode) { + // TODO(ericcraw) Re-enable. + // Skip the embed mode until upstream fix. + continue; + } + + GenerateEpContextOnPluginPath(test_case.ctx_filename, test_case.embed_mode); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, test_case.ctx_filename, session_options); + RunModelWithSession(session); + } +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_plugin_roundtrip_variants_absolute) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + for (const auto& test_case : ep_context_cases) { + if (test_case.embed_mode) { + // TODO(ericcraw) Re-enable. + // Skip the embed mode until upstream fix. + continue; + } + + auto absolute_path = std::filesystem::absolute(test_case.ctx_filename).native(); + GenerateEpContextOnPluginPath(absolute_path.c_str(), test_case.embed_mode); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, absolute_path.c_str(), session_options); + RunModelWithSession(session); + } +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_multiple_devices) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + std::vector multi_device_list(2, plugin_ep_device); // 2 copies of cpu device. + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, multi_device_list, std::unordered_map{}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_mixed_factory_devices_throw_exception) { + auto ep_devices = ort_env->GetEpDevices(); + std::vector matching_devices; + + for (const auto& device : ep_devices) { + std::string ep_name = device.EpName(); + if (ep_name.find(registration_name) != std::string::npos && + (ep_name == registration_name || ep_name == registration_name + ".AUTO")) { + matching_devices.push_back(device); + } + } + + ASSERT_GT(matching_devices.size(), 1) << "Expected more than one matching EP device"; + + EXPECT_THROW({ + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, matching_devices, std::unordered_map{}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); }, Ort::Exception); +}