mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
langchain[patch]: init_chat_model provider in model string (#28367)
```python llm = init_chat_model("openai:gpt-4o") ```
This commit is contained in:
parent
8adc4a5bcc
commit
ffe7bd4832
@ -328,13 +328,7 @@ def init_chat_model(
|
|||||||
def _init_chat_model_helper(
|
def _init_chat_model_helper(
|
||||||
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
|
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
|
||||||
) -> BaseChatModel:
|
) -> BaseChatModel:
|
||||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
model, model_provider = _parse_model(model, model_provider)
|
||||||
if not model_provider:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unable to infer model provider for {model=}, please specify "
|
|
||||||
f"model_provider directly."
|
|
||||||
)
|
|
||||||
model_provider = model_provider.replace("-", "_").lower()
|
|
||||||
if model_provider == "openai":
|
if model_provider == "openai":
|
||||||
_check_pkg("langchain_openai")
|
_check_pkg("langchain_openai")
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
@ -461,6 +455,24 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]:
|
||||||
|
if (
|
||||||
|
not model_provider
|
||||||
|
and ":" in model
|
||||||
|
and model.split(":")[0] in _SUPPORTED_PROVIDERS
|
||||||
|
):
|
||||||
|
model_provider = model.split(":")[0]
|
||||||
|
model = ":".join(model.split(":")[1:])
|
||||||
|
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||||
|
if not model_provider:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to infer model provider for {model=}, please specify "
|
||||||
|
f"model_provider directly."
|
||||||
|
)
|
||||||
|
model_provider = model_provider.replace("-", "_").lower()
|
||||||
|
return model, model_provider
|
||||||
|
|
||||||
|
|
||||||
def _check_pkg(pkg: str) -> None:
|
def _check_pkg(pkg: str) -> None:
|
||||||
if not util.find_spec(pkg):
|
if not util.find_spec(pkg):
|
||||||
pkg_kebab = pkg.replace("_", "-")
|
pkg_kebab = pkg.replace("_", "-")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -26,7 +27,6 @@ def test_all_imports() -> None:
|
|||||||
"langchain_openai",
|
"langchain_openai",
|
||||||
"langchain_anthropic",
|
"langchain_anthropic",
|
||||||
"langchain_fireworks",
|
"langchain_fireworks",
|
||||||
"langchain_mistralai",
|
|
||||||
"langchain_groq",
|
"langchain_groq",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -38,10 +38,14 @@ def test_all_imports() -> None:
|
|||||||
("mixtral-8x7b-32768", "groq"),
|
("mixtral-8x7b-32768", "groq"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_init_chat_model(model_name: str, model_provider: str) -> None:
|
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
|
||||||
_: BaseChatModel = init_chat_model(
|
llm1: BaseChatModel = init_chat_model(
|
||||||
model_name, model_provider=model_provider, api_key="foo"
|
model_name, model_provider=model_provider, api_key="foo"
|
||||||
)
|
)
|
||||||
|
llm2: BaseChatModel = init_chat_model(
|
||||||
|
f"{model_provider}:{model_name}", api_key="foo"
|
||||||
|
)
|
||||||
|
assert llm1.dict() == llm2.dict()
|
||||||
|
|
||||||
|
|
||||||
def test_init_missing_dep() -> None:
|
def test_init_missing_dep() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user