- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.3k
Draft implementation of support for embeddings APIs #3252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| from pydantic_ai.models.instrumented import InstrumentationSettings | ||
| from pydantic_ai.providers import infer_provider | ||
|  | ||
| KnownEmbeddingModelName = TypeAliasType( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test like this one to verify this is up to date:
| def test_known_model_names(): # pragma: lax no cover | 
| if model_kind.startswith('gateway/'): | ||
| model_kind = provider_name.removeprefix('gateway/') | ||
|  | ||
| # TODO: extend the following list for other providers as appropriate | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll have to check which of the OpenAI-compatible APIs also support embeddings
|  | ||
| return CohereEmbeddingModel(model_name, provider=provider) | ||
| else: | ||
| raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/ggozad/haiku.rag/tree/main/src/haiku/rag/embeddings has Ollama, vLLM and VoyageAI, which would be worth adding as well
| raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover | ||
|  | ||
|  | ||
| @dataclass | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @dataclass | |
| @dataclass(init=False) | 
|  | ||
| Args: | ||
| model_name: The name of the Cohere model to use. List of model names | ||
| available [here](https://docs.cohere.com/docs/models#command). | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to move this to __init__
|  | ||
| Args: | ||
| model_name: The name of the OpenAI model to use. List of model names | ||
| available [here](https://docs.OpenAI.com/docs/models#command). | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| provider: The provider to use for authentication and API access. Can be either the string | ||
| 'OpenAI' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be | ||
| created using the other parameters. | ||
| profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop
| input_is_string = isinstance(documents, str) | ||
| if input_is_string: | ||
| documents = [documents] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how I feel about every model implementation needing to repeat this
|  | ||
| Supported by: | ||
|  | ||
| * Cohere (See `cohere.EmbedInputType`) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the pattern in ModelSettings, we should move any options only supported by one model to {Cohere}EmbeddingSettings with a {cohere}_ prefix
| Docs Preview
 | 
| Thanks for starting this and please do let me know if you need help :) One thing you might want to support from the start is having as part of the  Embedding models have a limit of how many tokens of input they can handle. Most providers will raise ( All this is well explained here I would not necessarily truncate like in the cookbook and still just raise, but I would be grateful to have available from the model side the  The only difficulty I see with this is that not all providers expose the tokenizers, for example Ollama does not. But still, would be nice to have it for the providers that do support it, as it's a crucial step when you are trying to chunk a document for embedding. In  Edit: I am not suggesting that calling  | 
Started this in collaboration with @DouweM, I'd like to ensure consensus on the API design before adding the remaining-providers/logfire-instrumentation/docs/tests.
This is inspired by the approach in haiku.rag, though we adapted it to be a bit closer to the
AgentAPIs are used (and how you can override model, settings, etc.).Closes #58