From 85f1ba2351b7c75e5ce3a55dbed4bd061c3c1dd5 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sat, 27 Dec 2025 08:02:32 +0100 Subject: [PATCH] refactor(langchain): refactor optional imports logic (#32813) * Use `importlib` to load dynamically the classes * Removes missing package warnings --------- Co-authored-by: Mason Daugherty Co-authored-by: Mason Daugherty --- .../langchain/chat_models/base.py | 291 ++++++++---------- .../langchain_v1/langchain/embeddings/base.py | 147 ++++----- .../integration_tests/embeddings/test_base.py | 2 +- .../chat_models/test_chat_models.py | 8 +- .../tests/unit_tests/embeddings/test_base.py | 2 +- 5 files changed, 221 insertions(+), 229 deletions(-) diff --git a/libs/langchain_v1/langchain/chat_models/base.py b/libs/langchain_v1/langchain/chat_models/base.py index fb653de3b55..7f47772321e 100644 --- a/libs/langchain_v1/langchain/chat_models/base.py +++ b/libs/langchain_v1/langchain/chat_models/base.py @@ -2,9 +2,17 @@ from __future__ import annotations +import functools +import importlib import warnings -from importlib import util -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Literal, + TypeAlias, + cast, + overload, +) from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.messages import AIMessage, AnyMessage @@ -14,6 +22,7 @@ from typing_extensions import override if TYPE_CHECKING: from collections.abc import AsyncIterator, Callable, Iterator, Sequence + from types import ModuleType from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import BaseTool @@ -21,6 +30,127 @@ if TYPE_CHECKING: from pydantic import BaseModel +def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel: + # TODO: replace with operator.call when lower bounding to Python 3.11 + return cls(**kwargs) + + +_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = { + "anthropic": ("langchain_anthropic", "ChatAnthropic", _call), + "azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call), + "azure_openai": ("langchain_openai", "AzureChatOpenAI", _call), + "bedrock": ("langchain_aws", "ChatBedrock", _call), + "bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call), + "cohere": ("langchain_cohere", "ChatCohere", _call), + "deepseek": ("langchain_deepseek", "ChatDeepSeek", _call), + "fireworks": ("langchain_fireworks", "ChatFireworks", _call), + "google_anthropic_vertex": ( + "langchain_google_vertexai.model_garden", + "ChatAnthropicVertex", + _call, + ), + "google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call), + "google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call), + "groq": ("langchain_groq", "ChatGroq", _call), + "huggingface": ( + "langchain_huggingface", + "ChatHuggingFace", + lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs), + ), + "ibm": ( + "langchain_ibm", + "ChatWatsonx", + lambda cls, model, **kwargs: cls(model_id=model, **kwargs), + ), + "mistralai": ("langchain_mistralai", "ChatMistralAI", _call), + "nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call), + "ollama": ("langchain_ollama", "ChatOllama", _call), + "openai": ("langchain_openai", "ChatOpenAI", _call), + "perplexity": ("langchain_perplexity", "ChatPerplexity", _call), + "together": ("langchain_together", "ChatTogether", _call), + "upstage": ("langchain_upstage", "ChatUpstage", _call), + "xai": ("langchain_xai", "ChatXAI", _call), +} +"""Registry mapping provider names to their import configuration. + +Each entry maps a provider key to a tuple of: + +- `module_path`: The Python module path containing the chat model class. + + This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is + not exported from the package root. +- `class_name`: The name of the chat model class to import. +- `creator_func`: A callable that instantiates the class with provided kwargs. +""" + + +def _import_module(module: str) -> ModuleType: + """Import a module by name. + + Args: + module: The fully qualified module name to import (e.g., `'langchain_openai'`). + + Returns: + The imported module. + + Raises: + ImportError: If the module cannot be imported, with a message suggesting + the pip package to install. + """ + try: + return importlib.import_module(module) + except ImportError as e: + # Extract package name from module path (e.g., "langchain_azure_ai.chat_models" + # becomes "langchain-azure-ai") + pkg = module.split(".", maxsplit=1)[0].replace("_", "-") + msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`" + raise ImportError(msg) from e + + +@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS)) +def _get_chat_model_creator( + provider: str, +) -> Callable[..., BaseChatModel]: + """Return a factory function that creates a chat model for the given provider. + + This function is cached to avoid repeated module imports. + + Args: + provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`). + + Must be a key in `_SUPPORTED_PROVIDERS`. + + Returns: + A callable that accepts model kwargs and returns a `BaseChatModel` instance for + the specified provider. + + Raises: + ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`. + ImportError: If the provider's integration package is not installed. + """ + if provider not in _SUPPORTED_PROVIDERS: + supported = ", ".join(_SUPPORTED_PROVIDERS.keys()) + msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}" + raise ValueError(msg) + + pkg, class_name, creator_func = _SUPPORTED_PROVIDERS[provider] + try: + module = _import_module(pkg) + except ImportError as e: + if provider != "ollama": + raise + # For backwards compatibility + try: + module = _import_module("langchain_community.chat_models") + except ImportError: + # If both langchain-ollama and langchain-community aren't available, + # raise an error related to langchain-ollama + raise e from None + + cls = getattr(module, class_name) + return functools.partial(creator_func, cls=cls) + + @overload def init_chat_model( model: str, @@ -334,154 +464,8 @@ def _init_chat_model_helper( **kwargs: Any, ) -> BaseChatModel: model, model_provider = _parse_model(model, model_provider) - if model_provider == "openai": - _check_pkg("langchain_openai") - from langchain_openai import ChatOpenAI - - return ChatOpenAI(model=model, **kwargs) - if model_provider == "anthropic": - _check_pkg("langchain_anthropic") - from langchain_anthropic import ChatAnthropic - - return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] - if model_provider == "azure_openai": - _check_pkg("langchain_openai") - from langchain_openai import AzureChatOpenAI - - return AzureChatOpenAI(model=model, **kwargs) - if model_provider == "azure_ai": - _check_pkg("langchain_azure_ai") - from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel - - return AzureAIChatCompletionsModel(model=model, **kwargs) - if model_provider == "cohere": - _check_pkg("langchain_cohere") - from langchain_cohere import ChatCohere - - return ChatCohere(model=model, **kwargs) - if model_provider == "google_vertexai": - _check_pkg("langchain_google_vertexai") - from langchain_google_vertexai import ChatVertexAI - - return ChatVertexAI(model=model, **kwargs) - if model_provider == "google_genai": - _check_pkg("langchain_google_genai") - from langchain_google_genai import ChatGoogleGenerativeAI - - return ChatGoogleGenerativeAI(model=model, **kwargs) - if model_provider == "fireworks": - _check_pkg("langchain_fireworks") - from langchain_fireworks import ChatFireworks - - return ChatFireworks(model=model, **kwargs) - if model_provider == "ollama": - try: - _check_pkg("langchain_ollama") - from langchain_ollama import ChatOllama - except ImportError: - # For backwards compatibility - try: - _check_pkg("langchain_community") - from langchain_community.chat_models import ChatOllama - except ImportError: - # If both langchain-ollama and langchain-community aren't available, - # raise an error related to langchain-ollama - _check_pkg("langchain_ollama") - - return ChatOllama(model=model, **kwargs) - if model_provider == "together": - _check_pkg("langchain_together") - from langchain_together import ChatTogether - - return ChatTogether(model=model, **kwargs) - if model_provider == "mistralai": - _check_pkg("langchain_mistralai") - from langchain_mistralai import ChatMistralAI - - return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] - if model_provider == "huggingface": - _check_pkg("langchain_huggingface") - from langchain_huggingface import ChatHuggingFace - - return ChatHuggingFace.from_model_id(model_id=model, **kwargs) - if model_provider == "groq": - _check_pkg("langchain_groq") - from langchain_groq import ChatGroq - - return ChatGroq(model=model, **kwargs) - if model_provider == "bedrock": - _check_pkg("langchain_aws") - from langchain_aws import ChatBedrock - - return ChatBedrock(model_id=model, **kwargs) - if model_provider == "bedrock_converse": - _check_pkg("langchain_aws") - from langchain_aws import ChatBedrockConverse - - return ChatBedrockConverse(model=model, **kwargs) - if model_provider == "google_anthropic_vertex": - _check_pkg("langchain_google_vertexai") - from langchain_google_vertexai.model_garden import ChatAnthropicVertex - - return ChatAnthropicVertex(model=model, **kwargs) - if model_provider == "deepseek": - _check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek") - from langchain_deepseek import ChatDeepSeek - - return ChatDeepSeek(model=model, **kwargs) - if model_provider == "nvidia": - _check_pkg("langchain_nvidia_ai_endpoints") - from langchain_nvidia_ai_endpoints import ChatNVIDIA - - return ChatNVIDIA(model=model, **kwargs) - if model_provider == "ibm": - _check_pkg("langchain_ibm") - from langchain_ibm import ChatWatsonx - - return ChatWatsonx(model_id=model, **kwargs) - if model_provider == "xai": - _check_pkg("langchain_xai") - from langchain_xai import ChatXAI - - return ChatXAI(model=model, **kwargs) - if model_provider == "perplexity": - _check_pkg("langchain_perplexity") - from langchain_perplexity import ChatPerplexity - - return ChatPerplexity(model=model, **kwargs) - if model_provider == "upstage": - _check_pkg("langchain_upstage") - from langchain_upstage import ChatUpstage - - return ChatUpstage(model=model, **kwargs) - supported = ", ".join(_SUPPORTED_PROVIDERS) - msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}" - raise ValueError(msg) - - -_SUPPORTED_PROVIDERS = { - "openai", - "anthropic", - "azure_openai", - "azure_ai", - "cohere", - "google_vertexai", - "google_genai", - "fireworks", - "ollama", - "together", - "mistralai", - "huggingface", - "groq", - "bedrock", - "bedrock_converse", - "google_anthropic_vertex", - "deepseek", - "ibm", - "xai", - "perplexity", - "upstage", -} + creator_func = _get_chat_model_creator(model_provider) + return creator_func(model=model, **kwargs) def _attempt_infer_model_provider(model_name: str) -> str | None: @@ -582,13 +566,6 @@ def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]: return model, model_provider -def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None: - if not util.find_spec(pkg): - pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-") - msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`" - raise ImportError(msg) - - def _remove_prefix(s: str, prefix: str) -> str: return s.removeprefix(prefix) diff --git a/libs/langchain_v1/langchain/embeddings/base.py b/libs/langchain_v1/langchain/embeddings/base.py index 87a7c9f7ad0..bb012f2fabb 100644 --- a/libs/langchain_v1/langchain/embeddings/base.py +++ b/libs/langchain_v1/langchain/embeddings/base.py @@ -1,27 +1,90 @@ """Factory functions for embeddings.""" import functools -from importlib import util +import importlib +from collections.abc import Callable from typing import Any from langchain_core.embeddings import Embeddings -_SUPPORTED_PROVIDERS = { - "azure_openai": "langchain_openai", - "bedrock": "langchain_aws", - "cohere": "langchain_cohere", - "google_genai": "langchain_google_genai", - "google_vertexai": "langchain_google_vertexai", - "huggingface": "langchain_huggingface", - "mistralai": "langchain_mistralai", - "ollama": "langchain_ollama", - "openai": "langchain_openai", + +def _call(cls: type[Embeddings], **kwargs: Any) -> Embeddings: + return cls(**kwargs) + + +_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., Embeddings]]] = { + "azure_openai": ("langchain_openai", "OpenAIEmbeddings", _call), + "bedrock": ( + "langchain_aws", + "BedrockEmbeddings", + lambda cls, model, **kwargs: cls(model_id=model, **kwargs), + ), + "cohere": ("langchain_cohere", "CohereEmbeddings", _call), + "google_genai": ("langchain_google_genai", "GoogleGenerativeAIEmbeddings", _call), + "google_vertexai": ("langchain_google_vertexai", "VertexAIEmbeddings", _call), + "huggingface": ( + "langchain_huggingface", + "HuggingFaceEmbeddings", + lambda cls, model, **kwargs: cls(model_name=model, **kwargs), + ), + "mistralai": ("langchain_mistralai", "MistralAIEmbeddings", _call), + "ollama": ("langchain_ollama", "OllamaEmbeddings", _call), + "openai": ("langchain_openai", "OpenAIEmbeddings", _call), } +"""Registry mapping provider names to their import configuration. + +Each entry maps a provider key to a tuple of: + +- `module_path`: The Python module path containing the embeddings class. +- `class_name`: The name of the embeddings class to import. +- `creator_func`: A callable that instantiates the class with provided kwargs. +""" + + +@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS)) +def _get_embeddings_class_creator(provider: str) -> Callable[..., Embeddings]: + """Return a factory function that creates an embeddings model for the given provider. + + This function is cached to avoid repeated module imports. + + Args: + provider: The name of the model provider (e.g., `'openai'`, `'cohere'`). + + Must be a key in `_SUPPORTED_PROVIDERS`. + + Returns: + A callable that accepts model kwargs and returns an `Embeddings` instance for + the specified provider. + + Raises: + ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`. + ImportError: If the provider's integration package is not installed. + """ + if provider not in _SUPPORTED_PROVIDERS: + msg = ( + f"Provider '{provider}' is not supported.\n" + f"Supported providers and their required packages:\n" + f"{_get_provider_list()}" + ) + raise ValueError(msg) + + module_name, class_name, creator_func = _SUPPORTED_PROVIDERS[provider] + try: + module = importlib.import_module(module_name) + except ImportError as e: + pkg = module_name.replace("_", "-") + msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`" + raise ImportError(msg) from e + + cls = getattr(module, class_name) + return functools.partial(creator_func, cls=cls) def _get_provider_list() -> str: """Get formatted list of providers and their packages.""" - return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()) + return "\n".join( + f" - {p}: {pkg[0].replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items() + ) def _parse_model_string(model_name: str) -> tuple[str, str]: @@ -50,7 +113,6 @@ def _parse_model_string(model_name: str) -> tuple[str, str]: """ if ":" not in model_name: - providers = _SUPPORTED_PROVIDERS msg = ( f"Invalid model format '{model_name}'.\n" f"Model name must be in format 'provider:model-name'\n" @@ -58,7 +120,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]: f" - openai:text-embedding-3-small\n" f" - bedrock:amazon.titan-embed-text-v1\n" f" - cohere:embed-english-v3.0\n" - f"Supported providers: {providers}" + f"Supported providers: {_SUPPORTED_PROVIDERS.keys()}" ) raise ValueError(msg) @@ -93,13 +155,12 @@ def _infer_model_and_provider( model_name = model if not provider: - providers = _SUPPORTED_PROVIDERS msg = ( "Must specify either:\n" "1. A model string in format 'provider:model-name'\n" " Example: 'openai:text-embedding-3-small'\n" "2. Or explicitly set provider from: " - f"{providers}" + f"{_SUPPORTED_PROVIDERS.keys()}" ) raise ValueError(msg) @@ -113,14 +174,6 @@ def _infer_model_and_provider( return provider, model_name -@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS)) -def _check_pkg(pkg: str) -> None: - """Check if a package is installed.""" - if not util.find_spec(pkg): - msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`" - raise ImportError(msg) - - def init_embeddings( model: str, *, @@ -197,51 +250,7 @@ def init_embeddings( raise ValueError(msg) provider, model_name = _infer_model_and_provider(model, provider=provider) - pkg = _SUPPORTED_PROVIDERS[provider] - _check_pkg(pkg) - - if provider == "openai": - from langchain_openai import OpenAIEmbeddings - - return OpenAIEmbeddings(model=model_name, **kwargs) - if provider == "azure_openai": - from langchain_openai import AzureOpenAIEmbeddings - - return AzureOpenAIEmbeddings(model=model_name, **kwargs) - if provider == "google_genai": - from langchain_google_genai import GoogleGenerativeAIEmbeddings - - return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs) - if provider == "google_vertexai": - from langchain_google_vertexai import VertexAIEmbeddings - - return VertexAIEmbeddings(model=model_name, **kwargs) - if provider == "bedrock": - from langchain_aws import BedrockEmbeddings - - return BedrockEmbeddings(model_id=model_name, **kwargs) - if provider == "cohere": - from langchain_cohere import CohereEmbeddings - - return CohereEmbeddings(model=model_name, **kwargs) - if provider == "mistralai": - from langchain_mistralai import MistralAIEmbeddings - - return MistralAIEmbeddings(model=model_name, **kwargs) - if provider == "huggingface": - from langchain_huggingface import HuggingFaceEmbeddings - - return HuggingFaceEmbeddings(model_name=model_name, **kwargs) - if provider == "ollama": - from langchain_ollama import OllamaEmbeddings - - return OllamaEmbeddings(model=model_name, **kwargs) - msg = ( - f"Provider '{provider}' is not supported.\n" - f"Supported providers and their required packages:\n" - f"{_get_provider_list()}" - ) - raise ValueError(msg) + return _get_embeddings_class_creator(provider)(model=model_name, **kwargs) __all__ = [ diff --git a/libs/langchain_v1/tests/integration_tests/embeddings/test_base.py b/libs/langchain_v1/tests/integration_tests/embeddings/test_base.py index 4af0b24aa2c..d5f17af8b6c 100644 --- a/libs/langchain_v1/tests/integration_tests/embeddings/test_base.py +++ b/libs/langchain_v1/tests/integration_tests/embeddings/test_base.py @@ -18,7 +18,7 @@ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings ], ) async def test_init_embedding_model(provider: str, model: str) -> None: - package = _SUPPORTED_PROVIDERS[provider] + package = _SUPPORTED_PROVIDERS[provider][0] try: importlib.import_module(package) except ImportError: diff --git a/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py b/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py index aa2ea0e96e5..a331945dd29 100644 --- a/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py +++ b/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py @@ -8,6 +8,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSequence from pydantic import SecretStr from langchain.chat_models import __all__, init_chat_model +from langchain.chat_models.base import _SUPPORTED_PROVIDERS if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel @@ -57,10 +58,15 @@ def test_init_missing_dep() -> None: def test_init_unknown_provider() -> None: - with pytest.raises(ValueError, match="Unsupported model_provider='bar'"): + with pytest.raises(ValueError, match="Unsupported provider='bar'"): init_chat_model("foo", model_provider="bar") +def test_supported_providers_is_sorted() -> None: + """Test that supported providers are sorted alphabetically.""" + assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys()) + + @pytest.mark.requires("langchain_openai") @mock.patch.dict( os.environ, diff --git a/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py b/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py index aa726d98f60..57efc4b6e63 100644 --- a/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py +++ b/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py @@ -102,7 +102,7 @@ def test_infer_model_and_provider_errors() -> None: ) def test_supported_providers_package_names(provider: str) -> None: """Test that all supported providers have valid package names.""" - package = _SUPPORTED_PROVIDERS[provider] + package = _SUPPORTED_PROVIDERS[provider][0] assert "-" not in package assert package.startswith("langchain_") assert package.islower()