-
Couldn't load subscription status.
- Fork 1.3k
Infer provider and model type #3185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| if isinstance(model, Model): | ||
| return model | ||
| elif model == 'test': | ||
| def infer_provider_model_class(model: KnownModelName | str) -> tuple[type[Model], type[Provider]]: # noqa: C901 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tuple value order should match the method name. I think Provider, Model makes the most sense because it matches the hierarchy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note also that we need tests for this method.
| if provider == 'gateway': | ||
| from ..providers.gateway import infer_model as infer_model_from_gateway | ||
|
|
||
| return infer_model_from_gateway(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need these special lines for 'gateway', if infer_provider_model_class already checks it as well?
Edit: It's possible that we do need it because it sets a bunch of values on the provider.
|
|
||
| def infer_model(model_name: str) -> Model: | ||
| """Infer the model class that will be used to make requests to the gateway. | ||
| def infer_provider_model_class(model_name: str) -> tuple[type[Model], type[Provider[Any]]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work as expected, as the provider needs arguments to work. If/since we can't make infer_provider_model_class work properly for gateway, I think we should just not support it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so just raise error for gateway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slkoo-cc Yeah let's do that, because in that case the class alone is not useful
| from .openai import OpenAIResponsesModel | ||
|
|
||
| return OpenAIResponsesModel(model_name, provider='openai') | ||
| return OpenAIResponsesModel, infer_provider_class('openai') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use provider here as well, right, if we add support for openai-responses in infer_provider_class?
| from .huggingface import HuggingFaceModel | ||
|
|
||
| return HuggingFaceModel(model_name, provider=provider) | ||
| return HuggingFaceModel, infer_provider_class(provider) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the second tuple argument is always infer_provider_class(provider), can we not repeat it but just have it once?
|
Does it need to be a tuple? What if at some point we make this infer the API as well? Can we make this a dataclass/typeddict? |
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
|
Big updated from @DouweM included in pr. |
|
@slkoo-cc Thanks, can you have a look at the failing linting please? |
|
The lint type error doesn't seem to be fixable.... |
Fix #3163