refactor(langchain): refactor optional imports logic (#32813)

* Use `importlib` to load dynamically the classes
* Removes missing package warnings

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-12-27 08:02:32 +01:00
committed by GitHub
parent d46187201d
commit 85f1ba2351
5 changed files with 221 additions and 229 deletions

View File

@@ -2,9 +2,17 @@
from __future__ import annotations
import functools
import importlib
import warnings
from importlib import util
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeAlias,
cast,
overload,
)
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AIMessage, AnyMessage
@@ -14,6 +22,7 @@ from typing_extensions import override
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from types import ModuleType
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
@@ -21,6 +30,127 @@ if TYPE_CHECKING:
from pydantic import BaseModel
def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel:
# TODO: replace with operator.call when lower bounding to Python 3.11
return cls(**kwargs)
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = {
"anthropic": ("langchain_anthropic", "ChatAnthropic", _call),
"azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call),
"azure_openai": ("langchain_openai", "AzureChatOpenAI", _call),
"bedrock": ("langchain_aws", "ChatBedrock", _call),
"bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call),
"cohere": ("langchain_cohere", "ChatCohere", _call),
"deepseek": ("langchain_deepseek", "ChatDeepSeek", _call),
"fireworks": ("langchain_fireworks", "ChatFireworks", _call),
"google_anthropic_vertex": (
"langchain_google_vertexai.model_garden",
"ChatAnthropicVertex",
_call,
),
"google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call),
"google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call),
"groq": ("langchain_groq", "ChatGroq", _call),
"huggingface": (
"langchain_huggingface",
"ChatHuggingFace",
lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs),
),
"ibm": (
"langchain_ibm",
"ChatWatsonx",
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
),
"mistralai": ("langchain_mistralai", "ChatMistralAI", _call),
"nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call),
"ollama": ("langchain_ollama", "ChatOllama", _call),
"openai": ("langchain_openai", "ChatOpenAI", _call),
"perplexity": ("langchain_perplexity", "ChatPerplexity", _call),
"together": ("langchain_together", "ChatTogether", _call),
"upstage": ("langchain_upstage", "ChatUpstage", _call),
"xai": ("langchain_xai", "ChatXAI", _call),
}
"""Registry mapping provider names to their import configuration.
Each entry maps a provider key to a tuple of:
- `module_path`: The Python module path containing the chat model class.
This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is
not exported from the package root.
- `class_name`: The name of the chat model class to import.
- `creator_func`: A callable that instantiates the class with provided kwargs.
"""
def _import_module(module: str) -> ModuleType:
"""Import a module by name.
Args:
module: The fully qualified module name to import (e.g., `'langchain_openai'`).
Returns:
The imported module.
Raises:
ImportError: If the module cannot be imported, with a message suggesting
the pip package to install.
"""
try:
return importlib.import_module(module)
except ImportError as e:
# Extract package name from module path (e.g., "langchain_azure_ai.chat_models"
# becomes "langchain-azure-ai")
pkg = module.split(".", maxsplit=1)[0].replace("_", "-")
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg) from e
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _get_chat_model_creator(
provider: str,
) -> Callable[..., BaseChatModel]:
"""Return a factory function that creates a chat model for the given provider.
This function is cached to avoid repeated module imports.
Args:
provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`).
Must be a key in `_SUPPORTED_PROVIDERS`.
Returns:
A callable that accepts model kwargs and returns a `BaseChatModel` instance for
the specified provider.
Raises:
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
ImportError: If the provider's integration package is not installed.
"""
if provider not in _SUPPORTED_PROVIDERS:
supported = ", ".join(_SUPPORTED_PROVIDERS.keys())
msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}"
raise ValueError(msg)
pkg, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
try:
module = _import_module(pkg)
except ImportError as e:
if provider != "ollama":
raise
# For backwards compatibility
try:
module = _import_module("langchain_community.chat_models")
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
raise e from None
cls = getattr(module, class_name)
return functools.partial(creator_func, cls=cls)
@overload
def init_chat_model(
model: str,
@@ -334,154 +464,8 @@ def _init_chat_model_helper(
**kwargs: Any,
) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider)
if model_provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
if model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
if model_provider == "azure_ai":
_check_pkg("langchain_azure_ai")
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
return AzureAIChatCompletionsModel(model=model, **kwargs)
if model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
if model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
if model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
if model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
if model_provider == "ollama":
try:
_check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama
except ImportError:
# For backwards compatibility
try:
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
_check_pkg("langchain_ollama")
return ChatOllama(model=model, **kwargs)
if model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
if model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
if model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
return ChatBedrock(model_id=model, **kwargs)
if model_provider == "bedrock_converse":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs)
if model_provider == "google_anthropic_vertex":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs)
if model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs)
if model_provider == "nvidia":
_check_pkg("langchain_nvidia_ai_endpoints")
from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(model=model, **kwargs)
if model_provider == "ibm":
_check_pkg("langchain_ibm")
from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs)
if model_provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs)
if model_provider == "perplexity":
_check_pkg("langchain_perplexity")
from langchain_perplexity import ChatPerplexity
return ChatPerplexity(model=model, **kwargs)
if model_provider == "upstage":
_check_pkg("langchain_upstage")
from langchain_upstage import ChatUpstage
return ChatUpstage(model=model, **kwargs)
supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
raise ValueError(msg)
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"azure_ai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
"bedrock_converse",
"google_anthropic_vertex",
"deepseek",
"ibm",
"xai",
"perplexity",
"upstage",
}
creator_func = _get_chat_model_creator(model_provider)
return creator_func(model=model, **kwargs)
def _attempt_infer_model_provider(model_name: str) -> str | None:
@@ -582,13 +566,6 @@ def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
return model, model_provider
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
raise ImportError(msg)
def _remove_prefix(s: str, prefix: str) -> str:
return s.removeprefix(prefix)

View File

@@ -1,27 +1,90 @@
"""Factory functions for embeddings."""
import functools
from importlib import util
import importlib
from collections.abc import Callable
from typing import Any
from langchain_core.embeddings import Embeddings
_SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"google_genai": "langchain_google_genai",
"google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai",
"ollama": "langchain_ollama",
"openai": "langchain_openai",
def _call(cls: type[Embeddings], **kwargs: Any) -> Embeddings:
return cls(**kwargs)
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., Embeddings]]] = {
"azure_openai": ("langchain_openai", "OpenAIEmbeddings", _call),
"bedrock": (
"langchain_aws",
"BedrockEmbeddings",
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
),
"cohere": ("langchain_cohere", "CohereEmbeddings", _call),
"google_genai": ("langchain_google_genai", "GoogleGenerativeAIEmbeddings", _call),
"google_vertexai": ("langchain_google_vertexai", "VertexAIEmbeddings", _call),
"huggingface": (
"langchain_huggingface",
"HuggingFaceEmbeddings",
lambda cls, model, **kwargs: cls(model_name=model, **kwargs),
),
"mistralai": ("langchain_mistralai", "MistralAIEmbeddings", _call),
"ollama": ("langchain_ollama", "OllamaEmbeddings", _call),
"openai": ("langchain_openai", "OpenAIEmbeddings", _call),
}
"""Registry mapping provider names to their import configuration.
Each entry maps a provider key to a tuple of:
- `module_path`: The Python module path containing the embeddings class.
- `class_name`: The name of the embeddings class to import.
- `creator_func`: A callable that instantiates the class with provided kwargs.
"""
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _get_embeddings_class_creator(provider: str) -> Callable[..., Embeddings]:
"""Return a factory function that creates an embeddings model for the given provider.
This function is cached to avoid repeated module imports.
Args:
provider: The name of the model provider (e.g., `'openai'`, `'cohere'`).
Must be a key in `_SUPPORTED_PROVIDERS`.
Returns:
A callable that accepts model kwargs and returns an `Embeddings` instance for
the specified provider.
Raises:
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
ImportError: If the provider's integration package is not installed.
"""
if provider not in _SUPPORTED_PROVIDERS:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
module_name, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
try:
module = importlib.import_module(module_name)
except ImportError as e:
pkg = module_name.replace("_", "-")
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg) from e
cls = getattr(module, class_name)
return functools.partial(creator_func, cls=cls)
def _get_provider_list() -> str:
"""Get formatted list of providers and their packages."""
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
return "\n".join(
f" - {p}: {pkg[0].replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
)
def _parse_model_string(model_name: str) -> tuple[str, str]:
@@ -50,7 +113,6 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
"""
if ":" not in model_name:
providers = _SUPPORTED_PROVIDERS
msg = (
f"Invalid model format '{model_name}'.\n"
f"Model name must be in format 'provider:model-name'\n"
@@ -58,7 +120,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
f" - openai:text-embedding-3-small\n"
f" - bedrock:amazon.titan-embed-text-v1\n"
f" - cohere:embed-english-v3.0\n"
f"Supported providers: {providers}"
f"Supported providers: {_SUPPORTED_PROVIDERS.keys()}"
)
raise ValueError(msg)
@@ -93,13 +155,12 @@ def _infer_model_and_provider(
model_name = model
if not provider:
providers = _SUPPORTED_PROVIDERS
msg = (
"Must specify either:\n"
"1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set provider from: "
f"{providers}"
f"{_SUPPORTED_PROVIDERS.keys()}"
)
raise ValueError(msg)
@@ -113,14 +174,6 @@ def _infer_model_and_provider(
return provider, model_name
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg)
def init_embeddings(
model: str,
*,
@@ -197,51 +250,7 @@ def init_embeddings(
raise ValueError(msg)
provider, model_name = _infer_model_and_provider(model, provider=provider)
pkg = _SUPPORTED_PROVIDERS[provider]
_check_pkg(pkg)
if provider == "openai":
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model_name, **kwargs)
if provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings
return VertexAIEmbeddings(model=model_name, **kwargs)
if provider == "bedrock":
from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(model_id=model_name, **kwargs)
if provider == "cohere":
from langchain_cohere import CohereEmbeddings
return CohereEmbeddings(model=model_name, **kwargs)
if provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings
return MistralAIEmbeddings(model=model_name, **kwargs)
if provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
if provider == "ollama":
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings(model=model_name, **kwargs)
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
return _get_embeddings_class_creator(provider)(model=model_name, **kwargs)
__all__ = [

View File

@@ -18,7 +18,7 @@ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
],
)
async def test_init_embedding_model(provider: str, model: str) -> None:
package = _SUPPORTED_PROVIDERS[provider]
package = _SUPPORTED_PROVIDERS[provider][0]
try:
importlib.import_module(package)
except ImportError:

View File

@@ -8,6 +8,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSequence
from pydantic import SecretStr
from langchain.chat_models import __all__, init_chat_model
from langchain.chat_models.base import _SUPPORTED_PROVIDERS
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
@@ -57,10 +58,15 @@ def test_init_missing_dep() -> None:
def test_init_unknown_provider() -> None:
with pytest.raises(ValueError, match="Unsupported model_provider='bar'"):
with pytest.raises(ValueError, match="Unsupported provider='bar'"):
init_chat_model("foo", model_provider="bar")
def test_supported_providers_is_sorted() -> None:
"""Test that supported providers are sorted alphabetically."""
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
@pytest.mark.requires("langchain_openai")
@mock.patch.dict(
os.environ,

View File

@@ -102,7 +102,7 @@ def test_infer_model_and_provider_errors() -> None:
)
def test_supported_providers_package_names(provider: str) -> None:
"""Test that all supported providers have valid package names."""
package = _SUPPORTED_PROVIDERS[provider]
package = _SUPPORTED_PROVIDERS[provider][0]
assert "-" not in package
assert package.startswith("langchain_")
assert package.islower()