diff --git a/src/default-providers/ChromeAI/chat.ts b/src/default-providers/ChromeAI/chat.ts new file mode 100644 index 0000000..77e4775 --- /dev/null +++ b/src/default-providers/ChromeAI/chat.ts @@ -0,0 +1,48 @@ +import { + BaseChatModel, + BaseChatModelCallOptions +} from '@langchain/core/language_models/chat_models'; +import { + AIMessageChunk, + BaseMessage, + AIMessage +} from '@langchain/core/messages'; +import { ChromeAI as ChromeLLM } from '@langchain/community/experimental/llms/chrome_ai'; +import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; +import { ChatResult, ChatGeneration } from '@langchain/core/outputs'; + +export interface IChromeChatCallOptions extends BaseChatModelCallOptions {} + +export class ChromeChatModel extends BaseChatModel< + IChromeChatCallOptions, + AIMessageChunk +> { + private llm: ChromeLLM; + + constructor(fields?: ConstructorParameters[0]) { + super(fields ?? {}); + this.llm = new ChromeLLM(fields ?? {}); + } + + _llmType() { + return 'chrome-chat'; + } + + async _generate( + messages: BaseMessage[], + options: IChromeChatCallOptions, + runManager?: CallbackManagerForLLMRun + ): Promise { + const text = messages.map(m => m.content).join('\n'); + const completion = await this.llm.invoke(text, options); + + const generations: ChatGeneration[] = [ + { + text: completion, + message: new AIMessage(completion) + } + ]; + + return { generations }; + } +} diff --git a/src/default-providers/index.ts b/src/default-providers/index.ts index b953648..3561970 100644 --- a/src/default-providers/index.ts +++ b/src/default-providers/index.ts @@ -6,7 +6,7 @@ import { Notification } from '@jupyterlab/apputils'; import { ChatAnthropic } from '@langchain/anthropic'; import { ChatWebLLM } from '@langchain/community/chat_models/webllm'; -import { ChromeAI } from '@langchain/community/experimental/llms/chrome_ai'; +import { ChromeChatModel } from '../default-providers/ChromeAI/chat'; import { ChatGoogleGenerativeAI } from '@langchain/google-genai'; import { ChatMistralAI } from '@langchain/mistralai'; import { ChatOllama } from '@langchain/ollama'; @@ -56,9 +56,7 @@ const AIProviders: IAIProvider[] = [ }, { name: 'ChromeAI', - // TODO: fix - // @ts-expect-error: missing properties - chat: ChromeAI, + chat: ChromeChatModel, completer: ChromeCompleter, instructions: ChromeAIInstructions, settingsSchema: ChromeAISettings,