langchain[patch]: init_chat_model provider in model string (#28367)

```python
llm = init_chat_model("openai:gpt-4o")
```
This commit is contained in:
Bagatur 2024-11-27 00:20:25 -08:00 committed by GitHub
parent 8adc4a5bcc
commit ffe7bd4832
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 10 deletions

View File

@ -328,13 +328,7 @@ def init_chat_model(
def _init_chat_model_helper( def _init_chat_model_helper(
model: str, *, model_provider: Optional[str] = None, **kwargs: Any model: str, *, model_provider: Optional[str] = None, **kwargs: Any
) -> BaseChatModel: ) -> BaseChatModel:
model_provider = model_provider or _attempt_infer_model_provider(model) model, model_provider = _parse_model(model, model_provider)
if not model_provider:
raise ValueError(
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
)
model_provider = model_provider.replace("-", "_").lower()
if model_provider == "openai": if model_provider == "openai":
_check_pkg("langchain_openai") _check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@ -461,6 +455,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return None return None
def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]:
if (
not model_provider
and ":" in model
and model.split(":")[0] in _SUPPORTED_PROVIDERS
):
model_provider = model.split(":")[0]
model = ":".join(model.split(":")[1:])
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
raise ValueError(
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
)
model_provider = model_provider.replace("-", "_").lower()
return model, model_provider
def _check_pkg(pkg: str) -> None: def _check_pkg(pkg: str) -> None:
if not util.find_spec(pkg): if not util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-") pkg_kebab = pkg.replace("_", "-")

View File

@ -1,4 +1,5 @@
import os import os
from typing import Optional
from unittest import mock from unittest import mock
import pytest import pytest
@ -26,7 +27,6 @@ def test_all_imports() -> None:
"langchain_openai", "langchain_openai",
"langchain_anthropic", "langchain_anthropic",
"langchain_fireworks", "langchain_fireworks",
"langchain_mistralai",
"langchain_groq", "langchain_groq",
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -38,10 +38,14 @@ def test_all_imports() -> None:
("mixtral-8x7b-32768", "groq"), ("mixtral-8x7b-32768", "groq"),
], ],
) )
def test_init_chat_model(model_name: str, model_provider: str) -> None: def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
_: BaseChatModel = init_chat_model( llm1: BaseChatModel = init_chat_model(
model_name, model_provider=model_provider, api_key="foo" model_name, model_provider=model_provider, api_key="foo"
) )
llm2: BaseChatModel = init_chat_model(
f"{model_provider}:{model_name}", api_key="foo"
)
assert llm1.dict() == llm2.dict()
def test_init_missing_dep() -> None: def test_init_missing_dep() -> None: