Skip to content

add support for token provider #1587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 1 addition & 50 deletions src/azure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import * as Errors from './error';
import { FinalRequestOptions } from './internal/request-options';
import { isObj, readEnv } from './internal/utils';
import { ClientOptions, OpenAI } from './client';
import { buildHeaders, NullableHeaders } from './internal/headers';

/** API Client for interfacing with the Azure OpenAI API. */
export interface AzureClientOptions extends ClientOptions {
Expand Down Expand Up @@ -37,7 +36,6 @@ export interface AzureClientOptions extends ClientOptions {

/** API Client for interfacing with the Azure OpenAI API. */
export class AzureOpenAI extends OpenAI {
private _azureADTokenProvider: (() => Promise<string>) | undefined;
deploymentName: string | undefined;
apiVersion: string = '';

Expand Down Expand Up @@ -90,9 +88,6 @@ export class AzureOpenAI extends OpenAI {
);
}

// define a sentinel value to avoid any typing issues
apiKey ??= API_KEY_SENTINEL;

opts.defaultQuery = { ...opts.defaultQuery, 'api-version': apiVersion };

if (!baseURL) {
Expand All @@ -114,13 +109,12 @@ export class AzureOpenAI extends OpenAI {
}

super({
apiKey,
apiKey: azureADTokenProvider ?? apiKey,
baseURL,
...opts,
...(dangerouslyAllowBrowser !== undefined ? { dangerouslyAllowBrowser } : {}),
});

this._azureADTokenProvider = azureADTokenProvider;
this.apiVersion = apiVersion;
this.deploymentName = deployment;
}
Expand All @@ -140,47 +134,6 @@ export class AzureOpenAI extends OpenAI {
}
return super.buildRequest(options, props);
}

async _getAzureADToken(): Promise<string | undefined> {
if (typeof this._azureADTokenProvider === 'function') {
const token = await this._azureADTokenProvider();
if (!token || typeof token !== 'string') {
throw new Errors.OpenAIError(
`Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`,
);
}
return token;
}
return undefined;
}

protected override async authHeaders(opts: FinalRequestOptions): Promise<NullableHeaders | undefined> {
return;
}

protected override async prepareOptions(opts: FinalRequestOptions): Promise<void> {
opts.headers = buildHeaders([opts.headers]);

/**
* The user should provide a bearer token provider if they want
* to use Azure AD authentication. The user shouldn't set the
* Authorization header manually because the header is overwritten
* with the Azure AD token if a bearer token provider is provided.
*/
if (opts.headers.values.get('Authorization') || opts.headers.values.get('api-key')) {
return super.prepareOptions(opts);
}

const token = await this._getAzureADToken();
if (token) {
opts.headers.values.set('Authorization', `Bearer ${token}`);
} else if (this.apiKey !== API_KEY_SENTINEL) {
opts.headers.values.set('api-key', this.apiKey);
} else {
throw new Errors.OpenAIError('Unable to handle auth');
}
return super.prepareOptions(opts);
}
}

const _deployments_endpoints = new Set([
Expand All @@ -194,5 +147,3 @@ const _deployments_endpoints = new Set([
'/batches',
'/images/edits',
]);

const API_KEY_SENTINEL = '<Missing Key>';
31 changes: 20 additions & 11 deletions src/beta/realtime/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
const dangerouslyAllowBrowser =
props.dangerouslyAllowBrowser ??
(client as any)?._options?.dangerouslyAllowBrowser ??
(client?.apiKey.startsWith('ek_') ? true : null);

(typeof (client as any)?._options?.apiKey === 'string' && client?.apiKey?.startsWith('ek_') ?
true
: null);
if (!dangerouslyAllowBrowser && isRunningInBrowser()) {
throw new OpenAIError(
"It looks like you're running in a browser-like environment.\n\nThis is disabled by default, as it risks exposing your secret API credentials to attackers.\n\nYou can avoid this error by creating an ephemeral session token:\nhttps://platform.openai.com/docs/api-reference/realtime-sessions\n",
Expand All @@ -49,6 +50,10 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {

client ??= new OpenAI({ dangerouslyAllowBrowser });

if (typeof (client as any)?._options?.apiKey !== 'string') {
throw new Error('Call the create method instead to construct the client');
}

this.url = buildRealtimeURL(client, props.model);
props.onURL?.(this.url);

Expand Down Expand Up @@ -94,20 +99,24 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
}
}

static async create(
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setApiKey'>,
props: { model: string; dangerouslyAllowBrowser?: boolean },
): Promise<OpenAIRealtimeWebSocket> {
await client._setApiKey();
return new OpenAIRealtimeWebSocket(props, client);
}

static async azure(
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
client: Pick<AzureOpenAI, '_setApiKey' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {},
): Promise<OpenAIRealtimeWebSocket> {
const token = await client._getAzureADToken();
const isToken = await client._setApiKey();
function onURL(url: URL) {
if (client.apiKey !== '<Missing Key>') {
url.searchParams.set('api-key', client.apiKey);
if (isToken) {
url.searchParams.set('Authorization', `Bearer ${client.apiKey}`);
} else {
if (token) {
url.searchParams.set('Authorization', `Bearer ${token}`);
} else {
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
}
url.searchParams.set('api-key', client.apiKey);
}
}
const deploymentName = options.deploymentName ?? client.deploymentName;
Expand Down
28 changes: 17 additions & 11 deletions src/beta/realtime/ws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
) {
super();
client ??= new OpenAI();

if (typeof (client as any)._options.apiKey !== 'string') {
throw new Error('Call the create method instead to construct the client');
}
this.url = buildRealtimeURL(client, props.model);
this.socket = new WS.WebSocket(this.url, {
...props.options,
Expand Down Expand Up @@ -51,8 +53,16 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
});
}

static async create(
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setApiKey'>,
props: { model: string; options?: WS.ClientOptions | undefined },
): Promise<OpenAIRealtimeWS> {
await client._setApiKey();
return new OpenAIRealtimeWS(props, client);
}

static async azure(
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
client: Pick<AzureOpenAI, '_setApiKey' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {},
): Promise<OpenAIRealtimeWS> {
const deploymentName = options.deploymentName ?? client.deploymentName;
Expand Down Expand Up @@ -82,15 +92,11 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
}
}

async function getAzureHeaders(client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiKey'>) {
if (client.apiKey !== '<Missing Key>') {
return { 'api-key': client.apiKey };
async function getAzureHeaders(client: Pick<AzureOpenAI, '_setApiKey' | 'apiKey'>) {
const isToken = await client._setApiKey();
if (isToken) {
return { Authorization: `Bearer ${isToken}` };
} else {
const token = await client._getAzureADToken();
if (token) {
return { Authorization: `Bearer ${token}` };
} else {
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
}
return { 'api-key': client.apiKey };
}
}
39 changes: 34 additions & 5 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,13 @@ import {
} from './internal/utils/log';
import { isEmptyObj } from './internal/utils/values';

export type ApiKeySetter = () => Promise<string>;

export interface ClientOptions {
/**
* Defaults to process.env['OPENAI_API_KEY'].
*/
apiKey?: string | undefined;

apiKey?: string | ApiKeySetter | undefined;
/**
* Defaults to process.env['OPENAI_ORG_ID'].
*/
Expand Down Expand Up @@ -349,7 +350,7 @@ export class OpenAI {
}: ClientOptions = {}) {
if (apiKey === undefined) {
throw new Errors.OpenAIError(
"The OPENAI_API_KEY environment variable is missing or empty; either provide it, or instantiate the OpenAI client with an apiKey option, like new OpenAI({ apiKey: 'My API Key' }).",
'Missing credentials. Please pass an `apiKey`, or set the `OPENAI_API_KEY` environment variable.',
);
}

Expand Down Expand Up @@ -385,7 +386,7 @@ export class OpenAI {

this._options = options;

this.apiKey = apiKey;
this.apiKey = typeof apiKey === 'string' ? apiKey : 'Missing Key';
this.organization = organization;
this.project = project;
this.webhookSecret = webhookSecret;
Expand Down Expand Up @@ -453,6 +454,32 @@ export class OpenAI {
return Errors.APIError.generate(status, error, message, headers);
}

async _setApiKey(): Promise<boolean> {
const apiKey = this._options.apiKey;
if (typeof apiKey === 'function') {
try {
const token = await apiKey();
if (!token || typeof token !== 'string') {
throw new Errors.OpenAIError(
`Expected 'apiKey' function argument to return a string but it returned ${token}`,
);
}
this.apiKey = token;
return true;
} catch (err: any) {
if (err instanceof Errors.OpenAIError) {
throw err;
}
throw new Errors.OpenAIError(
`Failed to get token from 'apiKey' function: ${err.message}`,
// @ts-ignore
{ cause: err },
);
}
}
return false;
}

buildURL(
path: string,
query: Record<string, unknown> | null | undefined,
Expand All @@ -479,7 +506,9 @@ export class OpenAI {
/**
* Used as a callback for mutating the given `FinalRequestOptions` object.
*/
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {}
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {
await this._setApiKey();
}

/**
* Used as a callback for mutating the given `RequestInit` object.
Expand Down
77 changes: 77 additions & 0 deletions tests/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -719,4 +719,81 @@ describe('retries', () => {
).toEqual(JSON.stringify({ a: 1 }));
expect(count).toEqual(3);
});

describe('auth', () => {
test('apiKey', async () => {
const client = new OpenAI({
baseURL: 'http://localhost:5000/',
apiKey: 'My API Key',
});
const { req } = await client.buildRequest({ path: '/foo', method: 'get' });
expect(req.headers.get('authorization')).toEqual('Bearer My API Key');
});

test('token', async () => {
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
return new Response(JSON.stringify({}), { headers: headers ?? [] });
};
const client = new OpenAI({
baseURL: 'http://localhost:5000/',
apiKey: async () => 'my token',
fetch: testFetch,
});
expect(
(await client.request({ method: 'post', path: 'https://example.com' }).asResponse()).headers.get(
'authorization',
),
).toEqual('Bearer my token');
});

test('token is refreshed', async () => {
let fail = true;
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
if (fail) {
fail = false;
return new Response(undefined, {
status: 429,
headers: {
'Retry-After': '0.1',
},
});
}
return new Response(JSON.stringify({}), {
headers: headers ?? [],
});
};
let counter = 0;
async function apiKey() {
return `token-${counter++}`;
}
const client = new OpenAI({
baseURL: 'http://localhost:5000/',
apiKey,
fetch: testFetch,
});
expect(
(
await client.chat.completions
.create({
model: '',
messages: [{ role: 'system', content: 'Hello' }],
})
.asResponse()
).headers.get('authorization'),
).toEqual('Bearer token-1');
});

test('at least one', () => {
try {
new OpenAI({
baseURL: 'http://localhost:5000/',
});
} catch (error: any) {
expect(error).toBeInstanceOf(Error);
expect(error.message).toEqual(
'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.',
);
}
});
});
});
Loading
Loading