From 6518dcc8704c4b1a90848853eea65e696e91bc6d Mon Sep 17 00:00:00 2001 From: Martin Emde Date: Sat, 16 Aug 2025 09:45:41 -0700 Subject: [PATCH] Provider registry enables use of other providers This refactoring makes it possible to support more providers than the current two (openrouter and openai) without modifying the existing interface, fallback behavior, or expectations for consumers. --- Gemfile.lock | 2 +- README.md | 9 +- lib/raix/chat_completion.rb | 44 +-- lib/raix/configuration.rb | 52 +++- lib/raix/providers/open_router_provider.rb | 40 +++ lib/raix/providers/openai_provider.rb | 31 ++ spec/raix/chat_completion_spec.rb | 64 +++++ spec/raix/configuration_spec.rb | 270 ++++++++++++++++++ .../providers/open_router_provider_spec.rb | 64 +++++ spec/raix/providers/openai_provider_spec.rb | 67 +++++ 10 files changed, 607 insertions(+), 36 deletions(-) create mode 100644 lib/raix/providers/open_router_provider.rb create mode 100644 lib/raix/providers/openai_provider.rb create mode 100644 spec/raix/providers/open_router_provider_spec.rb create mode 100644 spec/raix/providers/openai_provider_spec.rb diff --git a/Gemfile.lock b/Gemfile.lock index 2591c43..060029c 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - raix (1.0.2) + raix (1.0.3) activesupport (>= 6.0) faraday-retry (~> 2.0) open_router (~> 0.2) diff --git a/README.md b/README.md index 2cedff9..7f0f928 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ The second (optional) module that you can add to your Ruby classes after `ChatCo When the AI responds with tool function calls instead of a text message, Raix automatically: 1. Executes the requested tool functions -2. Adds the function results to the conversation transcript +2. Adds the function results to the conversation transcript 3. Sends the updated transcript back to the AI for another completion 4. Repeats this process until the AI responds with a regular text message @@ -739,6 +739,13 @@ You can add an initializer to your application's `config/initializers` directory You will also need to configure the OpenRouter API access token as per the instructions here: https://github.com/OlympiaAI/open_router?tab=readme-ov-file#quickstart +### Custom Providers + +You may register custom providers instead of either of the above by registering an object that responds +to `request(params:, model:, messages:)` and returns a OpenAI compatible chat completion response. + +See the [OpenAIProvider implementation](/OlympiaAI/raix/blob/main/lib/raix/providers/openai_provider.rb) for the OpenAI compatible provider implementation. + ### Global vs class level configuration You can either configure Raix globally or at the class level. The global configuration is set in the initializer as shown above. You can however also override all configuration options of the `Configuration` class on the class level with the diff --git a/lib/raix/chat_completion.rb b/lib/raix/chat_completion.rb index dd26ad4..e7dffa5 100644 --- a/lib/raix/chat_completion.rb +++ b/lib/raix/chat_completion.rb @@ -146,7 +146,7 @@ def chat_completion(params: {}, loop: false, json: false, raw: false, openai: ni response = if openai openai_request(params:, model: openai, messages:) else - openrouter_request(params:, model:, messages:) + provider_request(params:, model:, messages:) end retry_count = 0 content = nil @@ -174,7 +174,7 @@ def chat_completion(params: {}, loop: false, json: false, raw: false, openai: ni response = if openai openai_request(params:, model: openai, messages:) else - openrouter_request(params:, model:, messages:) + provider_request(params:, model:, messages:) end # Process the final response @@ -220,7 +220,7 @@ def chat_completion(params: {}, loop: false, json: false, raw: false, openai: ni response = if openai openai_request(params:, model: openai, messages:) else - openrouter_request(params:, model:, messages:) + provider_request(params:, model:, messages:) end content = response.dig("choices", 0, "message", "content") @@ -308,41 +308,23 @@ def filtered_tools(tool_names) end def openai_request(params:, model:, messages:) - if params[:prediction] - params.delete(:max_completion_tokens) - else - params[:max_completion_tokens] ||= params[:max_tokens] - params.delete(:max_tokens) - end - params[:stream] ||= stream.presence - params[:stream_options] = { include_usage: true } if params[:stream] - - params.delete(:temperature) if model.start_with?("o") || model.include?("gpt-5") + provider = configuration.provider(:openai) + raise "OpenAI provider not configured. Use configuration.openai_client = OpenAI::Client.new" unless provider - configuration.openai_client.chat(parameters: params.compact.merge(model:, messages:)) + provider.request(params:, model:, messages:) end - def openrouter_request(params:, model:, messages:) - # max_completion_tokens is not supported by OpenRouter - params.delete(:max_completion_tokens) + def provider_request(params:, model:, messages:) + params[:stream] ||= stream.presence - retry_count = 0 + # If openrouter is set, use it and pass provider as a parameter + # Otherwise, use provider to select the provider from the registry + provider = configuration.provider(:openrouter) || configuration.provider(params.delete(:provider)) - params.delete(:temperature) if model.start_with?("openai/o") || model.include?("gpt-5") + raise "No provider configured." unless provider - begin - configuration.openrouter_client.complete(messages, model:, extras: params.compact, stream:) - rescue OpenRouter::ServerError => e - if e.message.include?("retry") - warn "Retrying OpenRouter request... (#{retry_count} attempts) #{e.message}" - retry_count += 1 - sleep 1 * retry_count # backoff - retry if retry_count < 5 - end - - raise e - end + provider.request(params:, model:, messages:) end end end diff --git a/lib/raix/configuration.rb b/lib/raix/configuration.rb index 599cec5..fe85c83 100644 --- a/lib/raix/configuration.rb +++ b/lib/raix/configuration.rb @@ -1,5 +1,8 @@ # frozen_string_literal: true +require_relative "providers/open_router_provider" +require_relative "providers/openai_provider" + module Raix # The Configuration class holds the configuration options for the Raix gem. class Configuration @@ -30,11 +33,25 @@ def self.attr_accessor_with_fallback(method_name) # is normally set in each class that includes the ChatCompletion module. attr_accessor_with_fallback :model + attr_writer :openrouter_client, :openai_client + # The openrouter_client option determines the default client to use for communication. - attr_accessor_with_fallback :openrouter_client + def openrouter_client + value = @openrouter_client + return value if value + return unless fallback + + fallback.openrouter_client + end # The openai_client option determines the OpenAI client to use for communication. - attr_accessor_with_fallback :openai_client + def openai_client + value = @openai_client + return value if value + return unless fallback + + fallback.openai_client + end # The max_tool_calls option determines the maximum number of tool calls # before forcing a text response to prevent excessive function invocations. @@ -54,12 +71,41 @@ def initialize(fallback: nil) self.model = DEFAULT_MODEL self.max_tool_calls = DEFAULT_MAX_TOOL_CALLS self.fallback = fallback + @providers = {} end def client? - !!(openrouter_client || openai_client) + !!(openrouter_client || openai_client || @providers.any?) + end + + def register_provider(name, client) + @providers[name] = client end + # Find the provider to use based on the name, if given. + # Fall back to the next registered provider if no name is provided. + # We must use the openai_client and openrouter_client methods so that the + # previous fallback behavior is preserved. + def provider(name = nil) + # Prioritize use of registered providers before using openai_client or openrouter_client. + return @providers[name] if name && @providers.key?(name) + + # if openai is specified explicitly, use openai_client. + # if openrouter_client is set, use it for backwards compatibility. + # finally, use the named or first registered provider. + if name == :openai + openai_client ? Providers::OpenAIProvider.new(openai_client) : nil + elsif name == :openrouter || openrouter_client + openrouter_client ? Providers::OpenRouterProvider.new(openrouter_client) : nil + elsif @providers.any? + @providers.values.first + elsif fallback + fallback.provider(name) + end + end + + attr_reader :providers + private attr_accessor :fallback diff --git a/lib/raix/providers/open_router_provider.rb b/lib/raix/providers/open_router_provider.rb new file mode 100644 index 0000000..e18a21c --- /dev/null +++ b/lib/raix/providers/open_router_provider.rb @@ -0,0 +1,40 @@ +# frozen_string_literal: true + +module Raix + module Providers + # A wrapper around the OpenRouter client interface to make it compatible with the provider interface. + class OpenRouterProvider + attr_reader :client + + def initialize(client) + @client = client + end + + def request(params:, model:, messages:) + params = params.dup + + # max_completion_tokens is not supported by OpenRouter + params.delete(:max_completion_tokens) + + retry_count = 0 + + params.delete(:temperature) if model.start_with?("openai/o") || model.include?("gpt-5") + + stream = params.delete(:stream) + + begin + client.complete(messages, model:, extras: params.compact, stream:) + rescue ::OpenRouter::ServerError => e + if e.message.include?("retry") + warn "Retrying OpenRouter request... (#{retry_count} attempts) #{e.message}" + retry_count += 1 + sleep 1 * retry_count # backoff + retry if retry_count < 5 + end + + raise e + end + end + end + end +end diff --git a/lib/raix/providers/openai_provider.rb b/lib/raix/providers/openai_provider.rb new file mode 100644 index 0000000..c6ab9de --- /dev/null +++ b/lib/raix/providers/openai_provider.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +module Raix + module Providers + # A wrapper around the OpenAI client to make it compatible with the provider interface. + class OpenAIProvider + attr_reader :client + + def initialize(client) + @client = client + end + + def request(params:, model:, messages:) + params = params.dup + + if params[:prediction] + params.delete(:max_completion_tokens) + else + params[:max_completion_tokens] ||= params[:max_tokens] + params.delete(:max_tokens) + end + + params[:stream_options] = { include_usage: true } if params[:stream] + + params.delete(:temperature) if model.start_with?("o") || model.include?("gpt-5") + + client.chat(parameters: params.compact.merge(model:, messages:)) + end + end + end +end diff --git a/spec/raix/chat_completion_spec.rb b/spec/raix/chat_completion_spec.rb index 947ef07..0a428a3 100644 --- a/spec/raix/chat_completion_spec.rb +++ b/spec/raix/chat_completion_spec.rb @@ -22,6 +22,18 @@ def initialize end end +class TestOveriddenConfiguration + include Raix::ChatCompletion + + # Override the configuration accessor to make testing non-global + attr_accessor :configuration + + def initialize + self.model = "test-model" + transcript << { user: "What is the meaning of life?" } + end +end + RSpec.describe MeaningOfLife, :vcr do subject { described_class.new } @@ -67,3 +79,55 @@ def initialize subject.chat_completion end end + +RSpec.describe "Provider parameter behavior" do + context "when openrouter_client is set" do + it "passes provider as a parameter to openrouter" do + mock_openrouter = instance_double("OpenRouter::Client") + expect(mock_openrouter).to receive(:complete).with( + anything, + model: "test-model", + extras: hash_including(provider: "anthropic"), + stream: anything + ).and_return("choices" => [{ "message" => { "content" => "42" } }]) + + chat_client = TestOveriddenConfiguration.new + chat_client.configuration = Raix::Configuration.new + chat_client.configuration.openrouter_client = mock_openrouter + chat_client.provider = "anthropic" + + expect(chat_client.chat_completion).to eq("42") + end + end + + context "when openrouter_client is not set" do + it "uses provider parameter to select the registered provider" do + mock_provider = instance_double("CustomProvider") + expect(mock_provider).to receive(:request).with( + params: hash_not_including(:provider), + model: "test-model", + messages: anything + ).and_return("choices" => [{ "message" => { "content" => "42" } }]) + + chat_client = TestOveriddenConfiguration.new + chat_client.configuration = Raix::Configuration.new + chat_client.configuration.register_provider(:custom, mock_provider) + chat_client.provider = :custom + + expect(chat_client.chat_completion).to eq("42") + end + end + + context "when openrouter_client is not set and provider is not found" do + it "raises error" do + chat_client = TestOveriddenConfiguration.new + chat_client.configuration = Raix::Configuration.new + chat_client.provider = :nonexistent + + expect { chat_client.chat_completion }.to raise_error( + RuntimeError, + "No provider configured." + ) + end + end +end diff --git a/spec/raix/configuration_spec.rb b/spec/raix/configuration_spec.rb index 8663ac0..30a3c19 100644 --- a/spec/raix/configuration_spec.rb +++ b/spec/raix/configuration_spec.rb @@ -22,5 +22,275 @@ expect(configuration.client?).to eq false end end + + context "with a fallback with openai_client" do + it "returns true" do + configuration = Raix::Configuration.new + configuration.openrouter_client = OpenRouter::Client.new + child_config = Raix::Configuration.new(fallback: configuration) + + expect(child_config.client?).to be true + end + end + + context "with a fallback with openrouter_client" do + it "returns true" do + configuration = Raix::Configuration.new + configuration.openai_client = OpenAI::Client.new + child_config = Raix::Configuration.new(fallback: configuration) + + expect(child_config.client?).to be true + end + end + + context "with a fallback with neither client" do + it "returns false" do + configuration = Raix::Configuration.new + child_config = Raix::Configuration.new(fallback: configuration) + + expect(child_config.client?).to be false + end + end + end + + describe "registering providers" do + it "allows registering a provider with a name and client" do + configuration = Raix::Configuration.new + mock_client = double("MockClient") + + configuration.register_provider(:test_provider, mock_client) + + expect(configuration.providers[:test_provider]).to eq(mock_client) + end + + it "allows retrieving a registered provider" do + configuration = Raix::Configuration.new + mock_client = double("MockClient") + + configuration.register_provider(:test_provider, mock_client) + + expect(configuration.provider(:test_provider)).to eq(mock_client) + end + + it "returns nil for unregistered providers" do + configuration = Raix::Configuration.new + + expect(configuration.provider(:unknown)).to be_nil + end + + it "considers registered providers in client? check" do + configuration = Raix::Configuration.new + mock_client = double("MockClient") + + expect(configuration.client?).to be false + + configuration.register_provider(:test_provider, mock_client) + + expect(configuration.client?).to be true + end + end + + describe "internal provider storage" do + it "stores openai_client in the provider registry" do + configuration = Raix::Configuration.new + openai_client = double("OpenAI Client") + + configuration.openai_client = openai_client + + expect(configuration.provider(:openai)).to be_a(Raix::Providers::OpenAIProvider) + expect(configuration.provider(:openai).client).to eq(openai_client) + expect(configuration.openai_client).to eq(openai_client) + end + + it "stores openrouter_client in the provider registry" do + configuration = Raix::Configuration.new + openrouter_client = double("OpenRouter Client") + + configuration.openrouter_client = openrouter_client + + expect(configuration.provider(:openrouter)).to be_a(Raix::Providers::OpenRouterProvider) + expect(configuration.provider(:openrouter).client).to eq(openrouter_client) + expect(configuration.openrouter_client).to eq(openrouter_client) + end + + it "returns nil for provider when client is not set" do + configuration = Raix::Configuration.new + + expect(configuration.provider(:openai)).to be_nil + expect(configuration.provider(:openrouter)).to be_nil + end + end + + describe "#provider" do + let(:configuration) { Raix::Configuration.new } + + context "when a specific provider name is requested" do + context "and the provider is registered" do + it "returns the registered provider" do + mock_provider = double("MockProvider") + configuration.register_provider(:custom, mock_provider) + + expect(configuration.provider(:custom)).to eq(mock_provider) + end + end + + context "and the provider is not registered" do + context "when requesting :openai" do + it "returns OpenAIProvider if openai_client is set" do + openai_client = double("OpenAI Client") + configuration.openai_client = openai_client + + provider = configuration.provider(:openai) + expect(provider).to be_a(Raix::Providers::OpenAIProvider) + expect(provider.client).to eq(openai_client) + end + + it "returns nil if openai_client is not set" do + expect(configuration.provider(:openai)).to be_nil + end + end + + context "when requesting :openrouter" do + it "returns OpenRouterProvider if openrouter_client is set" do + openrouter_client = double("OpenRouter Client") + configuration.openrouter_client = openrouter_client + + provider = configuration.provider(:openrouter) + expect(provider).to be_a(Raix::Providers::OpenRouterProvider) + expect(provider.client).to eq(openrouter_client) + end + + it "returns nil if openrouter_client is not set" do + expect(configuration.provider(:openrouter)).to be_nil + end + end + + context "when requesting an unknown provider" do + it "returns nil" do + expect(configuration.provider(:unknown)).to be_nil + end + end + end + end + + context "when no provider name is specified" do + context "and openrouter_client is set" do + it "returns OpenRouterProvider for backwards compatibility" do + openrouter_client = double("OpenRouter Client") + configuration.openrouter_client = openrouter_client + + provider = configuration.provider + expect(provider).to be_a(Raix::Providers::OpenRouterProvider) + expect(provider.client).to eq(openrouter_client) + end + + it "prioritizes openrouter_client even if providers are registered" do + openrouter_client = double("OpenRouter Client") + configuration.openrouter_client = openrouter_client + configuration.register_provider(:custom, double("CustomProvider")) + + provider = configuration.provider + expect(provider).to be_a(Raix::Providers::OpenRouterProvider) + end + end + + context "and openrouter_client is not set but providers are registered" do + it "returns the first registered provider" do + first_provider = double("FirstProvider") + second_provider = double("SecondProvider") + + configuration.register_provider(:first, first_provider) + configuration.register_provider(:second, second_provider) + + expect(configuration.provider).to eq(first_provider) + end + end + + context "and no clients or providers are set" do + context "with a fallback configuration" do + let(:fallback) { Raix::Configuration.new } + let(:configuration) { Raix::Configuration.new(fallback:) } + + it "delegates to fallback's provider method" do + mock_provider = double("MockProvider") + fallback.register_provider(:fallback_provider, mock_provider) + + expect(configuration.provider).to eq(mock_provider) + end + + it "passes the name parameter to fallback" do + mock_provider = double("MockProvider") + fallback.register_provider(:custom, mock_provider) + + expect(configuration.provider(:custom)).to eq(mock_provider) + end + end + + context "without a fallback configuration" do + it "returns nil" do + expect(configuration.provider).to be_nil + end + end + end + end + + context "priority order verification" do + it "prioritizes registered providers over openai_client" do + mock_provider = double("MockProvider") + openai_client = double("OpenAI Client") + + configuration.openai_client = openai_client + configuration.register_provider(:openai, mock_provider) + + expect(configuration.provider(:openai)).to eq(mock_provider) + end + + it "prioritizes registered providers over openrouter_client" do + mock_provider = double("MockProvider") + openrouter_client = double("OpenRouter Client") + + configuration.openrouter_client = openrouter_client + configuration.register_provider(:openrouter, mock_provider) + + expect(configuration.provider(:openrouter)).to eq(mock_provider) + end + end + end + + describe "fallback behavior with provider registry" do + it "falls back to parent openai_client when not set" do + fallback = Raix::Configuration.new + openai_client = double("OpenAI Client") + fallback.openai_client = openai_client + + configuration = Raix::Configuration.new(fallback:) + + expect(configuration.openai_client).to eq(openai_client) + expect(configuration.provider(:openai).client).to eq(openai_client) + end + + it "falls back to parent openrouter_client when not set" do + fallback = Raix::Configuration.new + openrouter_client = double("OpenRouter Client") + fallback.openrouter_client = openrouter_client + + configuration = Raix::Configuration.new(fallback:) + + expect(configuration.openrouter_client).to eq(openrouter_client) + expect(configuration.provider(:openrouter).client).to eq(openrouter_client) + end + + it "uses local openai_client when set, ignoring parent" do + fallback = Raix::Configuration.new + openai_client = double("OpenAI Client") + fallback.openai_client = openai_client + + configuration = Raix::Configuration.new(fallback:) + local_openai_client = double("Local OpenAI Client") + configuration.openai_client = local_openai_client + + expect(configuration.openai_client).to eq(local_openai_client) + expect(configuration.provider(:openai).client).to eq(local_openai_client) + end end end diff --git a/spec/raix/providers/open_router_provider_spec.rb b/spec/raix/providers/open_router_provider_spec.rb new file mode 100644 index 0000000..1206485 --- /dev/null +++ b/spec/raix/providers/open_router_provider_spec.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +require "spec_helper" + +RSpec.describe Raix::Providers::OpenRouterProvider do + let(:openrouter_client) { double("OpenRouter Client") } + let(:messages) { [{ role: "user", content: "Hello" }] } + let(:provider) { described_class.new(openrouter_client) } + + it "wraps the OpenRouter client" do + expect(provider.client).to eq(openrouter_client) + end + + it "implements request method that calls the client's complete method" do + expect(openrouter_client).to receive(:complete).with( + messages, + model: "claude-3-opus", + extras: { temperature: 0.7 }, + stream: false + ).and_return({ "choices" => [{ "message" => { "content" => "Response" } }] }) + + result = provider.request( + params: { temperature: 0.7, max_completion_tokens: 1000, stream: false }, + model: "claude-3-opus", + messages: + ) + + expect(result).to eq({ "choices" => [{ "message" => { "content" => "Response" } }] }) + end + + it "removes max_completion_tokens parameter" do + expect(openrouter_client).to receive(:complete).with( + messages, + model: "claude-3-opus", + extras: { temperature: 0.7 }, + stream: false + ).and_return({ "choices" => [] }) + + provider.request( + params: { temperature: 0.7, max_completion_tokens: 1000, stream: false }, + model: "claude-3-opus", + messages: + ) + end + + it "handles retry logic for server errors" do + call_count = 0 + allow(openrouter_client).to receive(:complete) do + call_count += 1 + raise OpenRouter::ServerError, "Please retry" if call_count < 3 + + { "choices" => [{ "message" => { "content" => "Success" } }] } + end + + result = provider.request( + params: { stream: false }, + model: "claude-3-opus", + messages: + ) + + expect(result).to eq({ "choices" => [{ "message" => { "content" => "Success" } }] }) + expect(call_count).to eq(3) + end +end diff --git a/spec/raix/providers/openai_provider_spec.rb b/spec/raix/providers/openai_provider_spec.rb new file mode 100644 index 0000000..61ef75d --- /dev/null +++ b/spec/raix/providers/openai_provider_spec.rb @@ -0,0 +1,67 @@ +# frozen_string_literal: true + +require "spec_helper" + +RSpec.describe Raix::Providers::OpenAIProvider do + let(:openai_client) { double("OpenAI Client") } + let(:messages) { [{ role: "user", content: "Hello" }] } + let(:provider) { described_class.new(openai_client) } + + it "wraps the OpenAI client" do + expect(provider.client).to eq(openai_client) + end + + it "implements request method that calls the client's chat method" do + expected_params = { + temperature: 0.7, + max_completion_tokens: 1000, + stream: true, + stream_options: { include_usage: true }, + model: "gpt-4", + messages: + } + + expect(openai_client).to receive(:chat).with( + parameters: expected_params + ).and_return({ "choices" => [{ "message" => { "content" => "Response" } }] }) + + result = provider.request( + params: { temperature: 0.7, max_completion_tokens: 1000, stream: true }, + model: "gpt-4", + messages: + ) + + expect(result).to eq({ "choices" => [{ "message" => { "content" => "Response" } }] }) + end + + it "removes temperature for o-models" do + expect(openai_client).to receive(:chat).with( + parameters: { max_completion_tokens: 1000, model: "o1-preview", messages: } + ).and_return({ "choices" => [] }) + + provider.request( + params: { temperature: 0.7, max_completion_tokens: 1000 }, + model: "o1-preview", + messages: + ) + end + + it "handles prediction parameters correctly" do + prediction_params = { + prediction: { type: "content", content: "predicted text" }, + stream: false, + model: "gpt-4", + messages: + } + + expect(openai_client).to receive(:chat).with( + parameters: prediction_params + ).and_return({ "choices" => [] }) + + provider.request( + params: prediction_params.merge(max_completion_tokens: 1000), + model: "gpt-4", + messages: + ) + end +end