refactor(langchain): refactor optional imports logic (#32813)

* Use `importlib` to load dynamically the classes
* Removes missing package warnings

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-12-27 08:02:32 +01:00
committed by GitHub
parent d46187201d
commit 85f1ba2351
5 changed files with 221 additions and 229 deletions

View File

@@ -2,9 +2,17 @@
from __future__ import annotations from __future__ import annotations
import functools
import importlib
import warnings import warnings
from importlib import util from typing import (
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload TYPE_CHECKING,
Any,
Literal,
TypeAlias,
cast,
overload,
)
from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AIMessage, AnyMessage from langchain_core.messages import AIMessage, AnyMessage
@@ -14,6 +22,7 @@ from typing_extensions import override
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator, Callable, Iterator, Sequence from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from types import ModuleType
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@@ -21,6 +30,127 @@ if TYPE_CHECKING:
from pydantic import BaseModel from pydantic import BaseModel
def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel:
# TODO: replace with operator.call when lower bounding to Python 3.11
return cls(**kwargs)
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = {
"anthropic": ("langchain_anthropic", "ChatAnthropic", _call),
"azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call),
"azure_openai": ("langchain_openai", "AzureChatOpenAI", _call),
"bedrock": ("langchain_aws", "ChatBedrock", _call),
"bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call),
"cohere": ("langchain_cohere", "ChatCohere", _call),
"deepseek": ("langchain_deepseek", "ChatDeepSeek", _call),
"fireworks": ("langchain_fireworks", "ChatFireworks", _call),
"google_anthropic_vertex": (
"langchain_google_vertexai.model_garden",
"ChatAnthropicVertex",
_call,
),
"google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call),
"google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call),
"groq": ("langchain_groq", "ChatGroq", _call),
"huggingface": (
"langchain_huggingface",
"ChatHuggingFace",
lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs),
),
"ibm": (
"langchain_ibm",
"ChatWatsonx",
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
),
"mistralai": ("langchain_mistralai", "ChatMistralAI", _call),
"nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call),
"ollama": ("langchain_ollama", "ChatOllama", _call),
"openai": ("langchain_openai", "ChatOpenAI", _call),
"perplexity": ("langchain_perplexity", "ChatPerplexity", _call),
"together": ("langchain_together", "ChatTogether", _call),
"upstage": ("langchain_upstage", "ChatUpstage", _call),
"xai": ("langchain_xai", "ChatXAI", _call),
}
"""Registry mapping provider names to their import configuration.
Each entry maps a provider key to a tuple of:
- `module_path`: The Python module path containing the chat model class.
This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is
not exported from the package root.
- `class_name`: The name of the chat model class to import.
- `creator_func`: A callable that instantiates the class with provided kwargs.
"""
def _import_module(module: str) -> ModuleType:
"""Import a module by name.
Args:
module: The fully qualified module name to import (e.g., `'langchain_openai'`).
Returns:
The imported module.
Raises:
ImportError: If the module cannot be imported, with a message suggesting
the pip package to install.
"""
try:
return importlib.import_module(module)
except ImportError as e:
# Extract package name from module path (e.g., "langchain_azure_ai.chat_models"
# becomes "langchain-azure-ai")
pkg = module.split(".", maxsplit=1)[0].replace("_", "-")
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg) from e
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _get_chat_model_creator(
provider: str,
) -> Callable[..., BaseChatModel]:
"""Return a factory function that creates a chat model for the given provider.
This function is cached to avoid repeated module imports.
Args:
provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`).
Must be a key in `_SUPPORTED_PROVIDERS`.
Returns:
A callable that accepts model kwargs and returns a `BaseChatModel` instance for
the specified provider.
Raises:
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
ImportError: If the provider's integration package is not installed.
"""
if provider not in _SUPPORTED_PROVIDERS:
supported = ", ".join(_SUPPORTED_PROVIDERS.keys())
msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}"
raise ValueError(msg)
pkg, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
try:
module = _import_module(pkg)
except ImportError as e:
if provider != "ollama":
raise
# For backwards compatibility
try:
module = _import_module("langchain_community.chat_models")
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
raise e from None
cls = getattr(module, class_name)
return functools.partial(creator_func, cls=cls)
@overload @overload
def init_chat_model( def init_chat_model(
model: str, model: str,
@@ -334,154 +464,8 @@ def _init_chat_model_helper(
**kwargs: Any, **kwargs: Any,
) -> BaseChatModel: ) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider) model, model_provider = _parse_model(model, model_provider)
if model_provider == "openai": creator_func = _get_chat_model_creator(model_provider)
_check_pkg("langchain_openai") return creator_func(model=model, **kwargs)
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
if model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
if model_provider == "azure_ai":
_check_pkg("langchain_azure_ai")
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
return AzureAIChatCompletionsModel(model=model, **kwargs)
if model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
if model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
if model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
if model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
if model_provider == "ollama":
try:
_check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama
except ImportError:
# For backwards compatibility
try:
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
_check_pkg("langchain_ollama")
return ChatOllama(model=model, **kwargs)
if model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
if model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
if model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
return ChatBedrock(model_id=model, **kwargs)
if model_provider == "bedrock_converse":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs)
if model_provider == "google_anthropic_vertex":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs)
if model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs)
if model_provider == "nvidia":
_check_pkg("langchain_nvidia_ai_endpoints")
from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(model=model, **kwargs)
if model_provider == "ibm":
_check_pkg("langchain_ibm")
from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs)
if model_provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs)
if model_provider == "perplexity":
_check_pkg("langchain_perplexity")
from langchain_perplexity import ChatPerplexity
return ChatPerplexity(model=model, **kwargs)
if model_provider == "upstage":
_check_pkg("langchain_upstage")
from langchain_upstage import ChatUpstage
return ChatUpstage(model=model, **kwargs)
supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
raise ValueError(msg)
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"azure_ai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
"bedrock_converse",
"google_anthropic_vertex",
"deepseek",
"ibm",
"xai",
"perplexity",
"upstage",
}
def _attempt_infer_model_provider(model_name: str) -> str | None: def _attempt_infer_model_provider(model_name: str) -> str | None:
@@ -582,13 +566,6 @@ def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
return model, model_provider return model, model_provider
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
raise ImportError(msg)
def _remove_prefix(s: str, prefix: str) -> str: def _remove_prefix(s: str, prefix: str) -> str:
return s.removeprefix(prefix) return s.removeprefix(prefix)

View File

@@ -1,27 +1,90 @@
"""Factory functions for embeddings.""" """Factory functions for embeddings."""
import functools import functools
from importlib import util import importlib
from collections.abc import Callable
from typing import Any from typing import Any
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
_SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai", def _call(cls: type[Embeddings], **kwargs: Any) -> Embeddings:
"bedrock": "langchain_aws", return cls(**kwargs)
"cohere": "langchain_cohere",
"google_genai": "langchain_google_genai",
"google_vertexai": "langchain_google_vertexai", _SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., Embeddings]]] = {
"huggingface": "langchain_huggingface", "azure_openai": ("langchain_openai", "OpenAIEmbeddings", _call),
"mistralai": "langchain_mistralai", "bedrock": (
"ollama": "langchain_ollama", "langchain_aws",
"openai": "langchain_openai", "BedrockEmbeddings",
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
),
"cohere": ("langchain_cohere", "CohereEmbeddings", _call),
"google_genai": ("langchain_google_genai", "GoogleGenerativeAIEmbeddings", _call),
"google_vertexai": ("langchain_google_vertexai", "VertexAIEmbeddings", _call),
"huggingface": (
"langchain_huggingface",
"HuggingFaceEmbeddings",
lambda cls, model, **kwargs: cls(model_name=model, **kwargs),
),
"mistralai": ("langchain_mistralai", "MistralAIEmbeddings", _call),
"ollama": ("langchain_ollama", "OllamaEmbeddings", _call),
"openai": ("langchain_openai", "OpenAIEmbeddings", _call),
} }
"""Registry mapping provider names to their import configuration.
Each entry maps a provider key to a tuple of:
- `module_path`: The Python module path containing the embeddings class.
- `class_name`: The name of the embeddings class to import.
- `creator_func`: A callable that instantiates the class with provided kwargs.
"""
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _get_embeddings_class_creator(provider: str) -> Callable[..., Embeddings]:
"""Return a factory function that creates an embeddings model for the given provider.
This function is cached to avoid repeated module imports.
Args:
provider: The name of the model provider (e.g., `'openai'`, `'cohere'`).
Must be a key in `_SUPPORTED_PROVIDERS`.
Returns:
A callable that accepts model kwargs and returns an `Embeddings` instance for
the specified provider.
Raises:
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
ImportError: If the provider's integration package is not installed.
"""
if provider not in _SUPPORTED_PROVIDERS:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
module_name, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
try:
module = importlib.import_module(module_name)
except ImportError as e:
pkg = module_name.replace("_", "-")
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg) from e
cls = getattr(module, class_name)
return functools.partial(creator_func, cls=cls)
def _get_provider_list() -> str: def _get_provider_list() -> str:
"""Get formatted list of providers and their packages.""" """Get formatted list of providers and their packages."""
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()) return "\n".join(
f" - {p}: {pkg[0].replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
)
def _parse_model_string(model_name: str) -> tuple[str, str]: def _parse_model_string(model_name: str) -> tuple[str, str]:
@@ -50,7 +113,6 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
""" """
if ":" not in model_name: if ":" not in model_name:
providers = _SUPPORTED_PROVIDERS
msg = ( msg = (
f"Invalid model format '{model_name}'.\n" f"Invalid model format '{model_name}'.\n"
f"Model name must be in format 'provider:model-name'\n" f"Model name must be in format 'provider:model-name'\n"
@@ -58,7 +120,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
f" - openai:text-embedding-3-small\n" f" - openai:text-embedding-3-small\n"
f" - bedrock:amazon.titan-embed-text-v1\n" f" - bedrock:amazon.titan-embed-text-v1\n"
f" - cohere:embed-english-v3.0\n" f" - cohere:embed-english-v3.0\n"
f"Supported providers: {providers}" f"Supported providers: {_SUPPORTED_PROVIDERS.keys()}"
) )
raise ValueError(msg) raise ValueError(msg)
@@ -93,13 +155,12 @@ def _infer_model_and_provider(
model_name = model model_name = model
if not provider: if not provider:
providers = _SUPPORTED_PROVIDERS
msg = ( msg = (
"Must specify either:\n" "Must specify either:\n"
"1. A model string in format 'provider:model-name'\n" "1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n" " Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set provider from: " "2. Or explicitly set provider from: "
f"{providers}" f"{_SUPPORTED_PROVIDERS.keys()}"
) )
raise ValueError(msg) raise ValueError(msg)
@@ -113,14 +174,6 @@ def _infer_model_and_provider(
return provider, model_name return provider, model_name
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg)
def init_embeddings( def init_embeddings(
model: str, model: str,
*, *,
@@ -197,51 +250,7 @@ def init_embeddings(
raise ValueError(msg) raise ValueError(msg)
provider, model_name = _infer_model_and_provider(model, provider=provider) provider, model_name = _infer_model_and_provider(model, provider=provider)
pkg = _SUPPORTED_PROVIDERS[provider] return _get_embeddings_class_creator(provider)(model=model_name, **kwargs)
_check_pkg(pkg)
if provider == "openai":
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model_name, **kwargs)
if provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings
return VertexAIEmbeddings(model=model_name, **kwargs)
if provider == "bedrock":
from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(model_id=model_name, **kwargs)
if provider == "cohere":
from langchain_cohere import CohereEmbeddings
return CohereEmbeddings(model=model_name, **kwargs)
if provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings
return MistralAIEmbeddings(model=model_name, **kwargs)
if provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
if provider == "ollama":
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings(model=model_name, **kwargs)
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
__all__ = [ __all__ = [

View File

@@ -18,7 +18,7 @@ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
], ],
) )
async def test_init_embedding_model(provider: str, model: str) -> None: async def test_init_embedding_model(provider: str, model: str) -> None:
package = _SUPPORTED_PROVIDERS[provider] package = _SUPPORTED_PROVIDERS[provider][0]
try: try:
importlib.import_module(package) importlib.import_module(package)
except ImportError: except ImportError:

View File

@@ -8,6 +8,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSequence
from pydantic import SecretStr from pydantic import SecretStr
from langchain.chat_models import __all__, init_chat_model from langchain.chat_models import __all__, init_chat_model
from langchain.chat_models.base import _SUPPORTED_PROVIDERS
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@@ -57,10 +58,15 @@ def test_init_missing_dep() -> None:
def test_init_unknown_provider() -> None: def test_init_unknown_provider() -> None:
with pytest.raises(ValueError, match="Unsupported model_provider='bar'"): with pytest.raises(ValueError, match="Unsupported provider='bar'"):
init_chat_model("foo", model_provider="bar") init_chat_model("foo", model_provider="bar")
def test_supported_providers_is_sorted() -> None:
"""Test that supported providers are sorted alphabetically."""
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
@pytest.mark.requires("langchain_openai") @pytest.mark.requires("langchain_openai")
@mock.patch.dict( @mock.patch.dict(
os.environ, os.environ,

View File

@@ -102,7 +102,7 @@ def test_infer_model_and_provider_errors() -> None:
) )
def test_supported_providers_package_names(provider: str) -> None: def test_supported_providers_package_names(provider: str) -> None:
"""Test that all supported providers have valid package names.""" """Test that all supported providers have valid package names."""
package = _SUPPORTED_PROVIDERS[provider] package = _SUPPORTED_PROVIDERS[provider][0]
assert "-" not in package assert "-" not in package
assert package.startswith("langchain_") assert package.startswith("langchain_")
assert package.islower() assert package.islower()