@@ -33,11 +33,12 @@ var (
3333)
3434
3535type Client struct {
36- url string
37- key string
3836 defaultModel string
3937 c * openai.Client
4038 cache * cache.Client
39+ invalidAuth bool
40+ cacheKeyBase string
41+ setSeed bool
4142}
4243
4344type Options struct {
@@ -47,6 +48,8 @@ type Options struct {
4748 APIType openai.APIType `usage:"OpenAI API Type (valid: OPEN_AI, AZURE, AZURE_AD)" name:"openai-api-type" env:"OPENAI_API_TYPE"`
4849 OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"`
4950 DefaultModel string `usage:"Default LLM model to use" default:"gpt-4-turbo-preview"`
51+ SetSeed bool `usage:"-"`
52+ CacheKey string `usage:"-"`
5053 Cache * cache.Client
5154}
5255
@@ -59,6 +62,8 @@ func complete(opts ...Options) (result Options, err error) {
5962 result .APIVersion = types .FirstSet (opt .APIVersion , result .APIVersion )
6063 result .APIType = types .FirstSet (opt .APIType , result .APIType )
6164 result .DefaultModel = types .FirstSet (opt .DefaultModel , result .DefaultModel )
65+ result .SetSeed = types .FirstSet (opt .SetSeed , result .SetSeed )
66+ result .CacheKey = types .FirstSet (opt .CacheKey , result .CacheKey )
6267 }
6368
6469 if result .Cache == nil {
@@ -75,10 +80,6 @@ func complete(opts ...Options) (result Options, err error) {
7580 result .APIKey = key
7681 }
7782
78- if result .APIKey == "" && result .BaseURL == "" {
79- return result , fmt .Errorf ("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable" )
80- }
81-
8283 return result , err
8384}
8485
@@ -112,13 +113,28 @@ func NewClient(opts ...Options) (*Client, error) {
112113 cfg .APIVersion = types .FirstSet (opt .APIVersion , cfg .APIVersion )
113114 cfg .APIType = types .FirstSet (opt .APIType , cfg .APIType )
114115
116+ cacheKeyBase := opt .CacheKey
117+ if cacheKeyBase == "" {
118+ cacheKeyBase = hash .ID (opt .APIKey , opt .BaseURL )
119+ }
120+
115121 return & Client {
116122 c : openai .NewClientWithConfig (cfg ),
117123 cache : opt .Cache ,
118124 defaultModel : opt .DefaultModel ,
125+ cacheKeyBase : cacheKeyBase ,
126+ invalidAuth : opt .APIKey == "" && opt .BaseURL == "" ,
127+ setSeed : opt .SetSeed ,
119128 }, nil
120129}
121130
131+ func (c * Client ) ValidAuth () error {
132+ if c .invalidAuth {
133+ return fmt .Errorf ("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable" )
134+ }
135+ return nil
136+ }
137+
122138func (c * Client ) Supports (ctx context.Context , modelName string ) (bool , error ) {
123139 models , err := c .ListModels (ctx )
124140 if err != nil {
@@ -133,6 +149,10 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
133149 return nil , nil
134150 }
135151
152+ if err := c .ValidAuth (); err != nil {
153+ return nil , err
154+ }
155+
136156 models , err := c .c .ListModels (ctx )
137157 if err != nil {
138158 return nil , err
@@ -146,8 +166,7 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
146166
147167func (c * Client ) cacheKey (request openai.ChatCompletionRequest ) string {
148168 return hash .Encode (map [string ]any {
149- "url" : c .url ,
150- "key" : c .key ,
169+ "base" : c .cacheKeyBase ,
151170 "request" : request ,
152171 })
153172}
@@ -277,6 +296,10 @@ func toMessages(request types.CompletionRequest) (result []openai.ChatCompletion
277296}
278297
279298func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
299+ if err := c .ValidAuth (); err != nil {
300+ return nil , err
301+ }
302+
280303 if messageRequest .Model == "" {
281304 messageRequest .Model = c .defaultModel
282305 }
@@ -296,10 +319,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
296319 }
297320
298321 if messageRequest .Temperature == nil {
299- // this is a hack because the field is marked as omitempty, so we need it to be set to a non-zero value but arbitrarily small
300- request .Temperature = 1e-08
322+ request .Temperature = new (float32 )
301323 } else {
302- request .Temperature = * messageRequest .Temperature
324+ request .Temperature = messageRequest .Temperature
303325 }
304326
305327 if messageRequest .JSONResponse {
@@ -330,7 +352,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
330352 }
331353
332354 var cacheResponse bool
333- request .Seed = ptr (c .seed (request ))
355+ if c .setSeed {
356+ request .Seed = ptr (c .seed (request ))
357+ }
334358 response , ok , err := c .fromCache (ctx , messageRequest , request )
335359 if err != nil {
336360 return nil , err
0 commit comments