Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/apis/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@ use std::collections::HashMap;

use crate::requests::Requests;
use crate::*;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize, Serializer};

fn serialize_f32_two_decimals<S>(value: &Option<f32>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(v) => serializer.serialize_f64((*v as f64 * 100.0).round() / 100.0),
None => serializer.serialize_none(),
}
}


use super::{completions::Completion, CHAT_COMPLETION_CREATE};

Expand All @@ -23,14 +34,14 @@ pub struct ChatBody {
/// while lower values like 0.2 will make it more focused and deterministic.
/// We generally recommend altering this or top_p but not both.
/// Defaults to 1
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling,
/// where the model considers the results of the tokens with top_p probability mass.
/// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
/// We generally recommend altering this or temperature but not both.
/// Defaults to 1
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
pub top_p: Option<f32>,
/// How many chat completion choices to generate for each input message.
/// Defaults to 1
Expand All @@ -55,13 +66,13 @@ pub struct ChatBody {
/// Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics.
/// Defaults to 0
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0.
/// Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
/// Defaults to 0
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
pub frequency_penalty: Option<f32>,
/// Modify the likelihood of specified tokens appearing in the completion.
/// Accepts a json object that maps tokens (specified by their token ID in the tokenizer)
Expand Down
7 changes: 6 additions & 1 deletion src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ fn deal_response(response: Result<ureq::Response, ureq::Error>, sub_url: &str) -
},
Err(err) => match err {
ureq::Error::Status(status, response) => {
let error_msg = response.into_json::<Json>().unwrap();
let mut error_msg = response
.into_json::<Json>()
.unwrap_or_else(|x| serde_json::Value::String(x.to_string()));
if let serde_json::Value::String(ref mut s) = error_msg {
*s = format!("status: {}, msg: {}", status, s);
}
error!("<== ❌\n\tError api: {sub_url}, status: {status}, error: {error_msg}");
return Err(Error::ApiError(format!("{error_msg}")));
},
Expand Down