Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,746 changes: 1,558 additions & 188 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ serde_json = "1"
toml = "0"
env_logger = "0"
reqwest = { version = "0", default-features = false, features = ["http2", "json", "blocking", "multipart", "rustls-tls"] }
aws-config = "1.8.3"
aws-sdk-bedrockruntime = "1.99.0"
tokio = "1.47.0"

[dev-dependencies]
tempfile = "3"
Expand Down
113 changes: 113 additions & 0 deletions src/client/aws.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use log::debug;

use crate::client::request_schemas::AnthropicPrompt;
use crate::config::api::{ApiClient, ApiConfig, ApiError};
use crate::config::prompt::{Message, Prompt};
use crate::Api;

use aws_config::BehaviorVersion;
use aws_sdk_bedrockruntime::{operation::converse::ConverseOutput, Client as BedrockClient};
use tokio::runtime::Runtime;

pub struct AwsClient {
api_config: ApiConfig,
client: BedrockClient,
prompt: Prompt,
runtime: Runtime,
}

impl AwsClient {
pub fn new(api_config: ApiConfig, prompt: Prompt) -> Self {
let runtime = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Err(e) => panic!("AwsClient failed to initialize tokio runtime: {e}"),
Ok(v) => v,
};
let config = runtime
.block_on(async { aws_config::load_defaults(BehaviorVersion::v2025_01_17()).await });
let client = BedrockClient::new(&config);

AwsClient {
api_config,
client,
prompt,
runtime,
}
}

fn get_converse_output_text(&self, output: ConverseOutput) -> Result<String, ApiError> {
let text = output
.output()
.ok_or(ApiError::new(
self.prompt.model.clone(),
"no output".to_string(),
))?
.as_message()
.map_err(|_| {
ApiError::new(
self.prompt.model.clone(),
"output not a message".to_string(),
)
})?
.content()
.first()
.ok_or(ApiError::new(
self.prompt.model.clone(),
"no content in message".to_string(),
))?
.as_text()
.map_err(|_| {
ApiError::new(self.prompt.model.clone(), "content is not text".to_string())
})?
.to_string();
Ok(text)
}
}

impl ApiClient for AwsClient {
fn do_request(&self) -> Result<Message, ApiError> {
let prompt_format = match self.prompt.api {
Api::AWSBedrock => AnthropicPrompt::from(self.prompt.clone()),
Api::AnotherApiForTests => panic!("This api is not made for actual use."),
_ => unreachable!(),
};

let result = self.runtime.block_on(async {
let response = self
.client
.converse()
.model_id(self.prompt.model.as_ref().unwrap())
.set_messages(Some(prompt_format.into()))
.send()
.await;

match response {
Ok(output) => {
let text = self.get_converse_output_text(output)?;
Ok(text)
}
Err(e) => {
use aws_sdk_bedrockruntime::error::DisplayErrorContext;
debug!("error: {}", DisplayErrorContext(&e));

Err(e
.as_service_error()
.map(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))
.unwrap_or_else(|| {
ApiError::new(
self.prompt.model.clone(),
"Unknown service error".to_string(),
)
}))
}
}
});

match result {
Ok(response) => Ok(Message::assistant(response.as_str())),
Err(e) => Err(e),
}
}
}
4 changes: 4 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod aws;
mod request_schemas;
pub mod reqwest;
mod response_schemas;
14 changes: 14 additions & 0 deletions src/text/request_schemas.rs → src/client/request_schemas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,17 @@ impl From<Prompt> for AnthropicPrompt {
}
}
}

#[derive(Serialize, Deserialize)]
#[serde(untagged)]
pub enum PromptFormat {
OpenAi(OpenAiPrompt),
Anthropic(AnthropicPrompt),
AWSBedrock(AnthropicPrompt),
}

impl Into<Vec<aws_sdk_bedrockruntime::types::Message>> for AnthropicPrompt {
fn into(self) -> Vec<aws_sdk_bedrockruntime::types::Message> {
self.messages.iter().cloned().map(|m| m.into()).collect()
}
}
102 changes: 102 additions & 0 deletions src/client/reqwest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use std::time::Duration;

use super::request_schemas::{AnthropicPrompt, OpenAiPrompt, PromptFormat};
use super::response_schemas::{AnthropicResponse, OllamaResponse, OpenAiResponse};
use crate::config::api::{ApiClient, ApiConfig, ApiError};
use crate::config::prompt::{Message, Prompt};
use crate::utils::handle_api_response;
use crate::Api;

pub struct ReqwestClient {
api_config: ApiConfig,
client: reqwest::blocking::Client,
prompt: Prompt,
}

impl ReqwestClient {
pub fn new(api_config: ApiConfig, prompt: Prompt) -> Self {
let client = reqwest::blocking::Client::builder()
.timeout(
api_config
.timeout_seconds
.map(|t| Duration::from_secs(t.into())),
)
.build()
.expect("Unable to initialize reqwest HTTP client");

ReqwestClient {
api_config,
client,
prompt,
}
}
}

impl ApiClient for ReqwestClient {
fn do_request(&self) -> Result<Message, ApiError> {
let prompt_format = match self.prompt.api {
Api::Ollama
| Api::Openai
| Api::AzureOpenai
| Api::Mistral
| Api::Groq
| Api::Cerebras => PromptFormat::OpenAi(OpenAiPrompt::from(self.prompt.clone())),
Api::Anthropic => PromptFormat::Anthropic(AnthropicPrompt::from(self.prompt.clone())),
Api::AWSBedrock => PromptFormat::AWSBedrock(AnthropicPrompt::from(self.prompt.clone())),
Api::AnotherApiForTests => panic!("This api is not made for actual use."),
};

let request = self
.client
.post(&self.api_config.url)
.header("Content-Type", "application/json")
.json(&prompt_format);

// https://stackoverflow.com/questions/77862683/rust-reqwest-cant-make-a-request
let request = match self.prompt.api {
Api::Cerebras => request.header("User-Agent", "CUSTOM_NAME/1.0"),
_ => request,
};

// Add auth if necessary
let request = match self.prompt.api {
Api::Openai | Api::Mistral | Api::Groq | Api::Cerebras => request.header(
"Authorization",
&format!("Bearer {}", &self.api_config.get_api_key()),
),
Api::AzureOpenai => request.header("api-key", &self.api_config.get_api_key()),
Api::Anthropic => request
.header("x-api-key", &self.api_config.get_api_key())
.header(
"anthropic-version",
self.api_config.version.as_ref().expect(
"version required for Anthropic, please add version key to your api config",
),
),
_ => request,
};

let response_text: String = match self.prompt.api {
Api::Ollama => handle_api_response::<OllamaResponse>(
request
.send()
.map_err(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))?,
),
Api::Openai | Api::AzureOpenai | Api::Mistral | Api::Groq | Api::Cerebras => {
handle_api_response::<OpenAiResponse>(
request
.send()
.map_err(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))?,
)
}
Api::Anthropic => handle_api_response::<AnthropicResponse>(
request
.send()
.map_err(|e| ApiError::new(self.prompt.model.clone(), e.to_string()))?,
),
Api::AWSBedrock | Api::AnotherApiForTests => unreachable!(),
};

Ok(Message::assistant(&response_text))
}
}
File renamed without changes.
40 changes: 39 additions & 1 deletion src/config/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;

use super::{prompt::Prompt, resolve_config_path};
use super::{prompt::Message, prompt::Prompt, resolve_config_path};

const API_KEYS_FILE: &str = ".api_configs.toml";

Expand All @@ -20,6 +20,7 @@ pub enum Api {
Groq,
Mistral,
Openai,
AWSBedrock,
AzureOpenai,
Cerebras,
}
Expand All @@ -46,6 +47,7 @@ impl ToString for Api {
match self {
Api::Ollama => "ollama".to_string(),
Api::Openai => "openai".to_string(),
Api::AWSBedrock => "awsbedrock".to_string(),
Api::AzureOpenai => "azureopenai".to_string(),
Api::Mistral => "mistral".to_string(),
Api::Groq => "groq".to_string(),
Expand Down Expand Up @@ -138,6 +140,17 @@ impl ApiConfig {
}
}

pub(super) fn awsbedrock() -> Self {
ApiConfig {
api_key_command: None,
api_key: None,
url: String::from(""),
default_model: Some(String::from("us.anthropic.claude-3-7-sonnet-20250219-v1:0")),
version: None,
timeout_seconds: None,
}
}

pub(super) fn azureopenai() -> Self {
ApiConfig {
api_key_command: None,
Expand Down Expand Up @@ -202,6 +215,7 @@ pub(super) fn generate_api_keys_file() -> std::io::Result<()> {
let mut api_config = HashMap::new();
api_config.insert(Api::Ollama.to_string(), ApiConfig::ollama());
api_config.insert(Api::Openai.to_string(), ApiConfig::openai());
api_config.insert(Api::AWSBedrock.to_string(), ApiConfig::awsbedrock());
api_config.insert(Api::AzureOpenai.to_string(), ApiConfig::azureopenai());
api_config.insert(Api::Mistral.to_string(), ApiConfig::mistral());
api_config.insert(Api::Groq.to_string(), ApiConfig::groq());
Expand Down Expand Up @@ -244,3 +258,27 @@ pub fn get_api_config(api: &str) -> ApiConfig {
)
})
}

pub trait ApiClient {
fn do_request(&self) -> Result<Message, ApiError>;
}

#[derive(Debug)]
pub struct ApiError {
pub model: Option<String>,
pub error: String,
}

impl std::fmt::Display for ApiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Can't invoke '{:?}'. Reason: {}", self.model, self.error)
}
}

impl std::error::Error for ApiError {}

impl ApiError {
pub fn new(model: Option<String>, error: String) -> Self {
ApiError { model, error }
}
}
18 changes: 18 additions & 0 deletions src/config/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::io::Write;
use std::path::PathBuf;

use crate::config::{api::Api, resolve_config_path};
use aws_sdk_bedrockruntime::types::ConversationRole;

const PROMPT_FILE: &str = "prompts.toml";
const CONVERSATION_FILE: &str = "conversation.toml";
Expand Down Expand Up @@ -92,6 +93,23 @@ impl Message {
}
}

impl Into<aws_sdk_bedrockruntime::types::Message> for Message {
fn into(self) -> aws_sdk_bedrockruntime::types::Message {
let role = match self.role.as_str() {
"assistant" => ConversationRole::Assistant,
"user" => ConversationRole::User,
_ => panic!("system role not supported for bedrock messages"),
};
aws_sdk_bedrockruntime::types::Message::builder()
.role(role)
.content(aws_sdk_bedrockruntime::types::ContentBlock::Text(
self.content,
))
.build()
.unwrap()
}
}

pub(super) fn prompts_path() -> PathBuf {
resolve_config_path().join(PROMPT_FILE)
}
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod client;
mod config;
mod prompt_customization;
mod text;
Expand Down
Loading