mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
refactor(langchain): refactor optional imports logic (#32813)
* Use `importlib` to load dynamically the classes * Removes missing package warnings --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
d46187201d
commit
85f1ba2351
@@ -2,9 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import warnings
|
||||
from importlib import util
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||
from langchain_core.messages import AIMessage, AnyMessage
|
||||
@@ -14,6 +22,7 @@ from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from types import ModuleType
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -21,6 +30,127 @@ if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel:
|
||||
# TODO: replace with operator.call when lower bounding to Python 3.11
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = {
|
||||
"anthropic": ("langchain_anthropic", "ChatAnthropic", _call),
|
||||
"azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call),
|
||||
"azure_openai": ("langchain_openai", "AzureChatOpenAI", _call),
|
||||
"bedrock": ("langchain_aws", "ChatBedrock", _call),
|
||||
"bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call),
|
||||
"cohere": ("langchain_cohere", "ChatCohere", _call),
|
||||
"deepseek": ("langchain_deepseek", "ChatDeepSeek", _call),
|
||||
"fireworks": ("langchain_fireworks", "ChatFireworks", _call),
|
||||
"google_anthropic_vertex": (
|
||||
"langchain_google_vertexai.model_garden",
|
||||
"ChatAnthropicVertex",
|
||||
_call,
|
||||
),
|
||||
"google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call),
|
||||
"google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call),
|
||||
"groq": ("langchain_groq", "ChatGroq", _call),
|
||||
"huggingface": (
|
||||
"langchain_huggingface",
|
||||
"ChatHuggingFace",
|
||||
lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs),
|
||||
),
|
||||
"ibm": (
|
||||
"langchain_ibm",
|
||||
"ChatWatsonx",
|
||||
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
|
||||
),
|
||||
"mistralai": ("langchain_mistralai", "ChatMistralAI", _call),
|
||||
"nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call),
|
||||
"ollama": ("langchain_ollama", "ChatOllama", _call),
|
||||
"openai": ("langchain_openai", "ChatOpenAI", _call),
|
||||
"perplexity": ("langchain_perplexity", "ChatPerplexity", _call),
|
||||
"together": ("langchain_together", "ChatTogether", _call),
|
||||
"upstage": ("langchain_upstage", "ChatUpstage", _call),
|
||||
"xai": ("langchain_xai", "ChatXAI", _call),
|
||||
}
|
||||
"""Registry mapping provider names to their import configuration.
|
||||
|
||||
Each entry maps a provider key to a tuple of:
|
||||
|
||||
- `module_path`: The Python module path containing the chat model class.
|
||||
|
||||
This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is
|
||||
not exported from the package root.
|
||||
- `class_name`: The name of the chat model class to import.
|
||||
- `creator_func`: A callable that instantiates the class with provided kwargs.
|
||||
"""
|
||||
|
||||
|
||||
def _import_module(module: str) -> ModuleType:
|
||||
"""Import a module by name.
|
||||
|
||||
Args:
|
||||
module: The fully qualified module name to import (e.g., `'langchain_openai'`).
|
||||
|
||||
Returns:
|
||||
The imported module.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module cannot be imported, with a message suggesting
|
||||
the pip package to install.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError as e:
|
||||
# Extract package name from module path (e.g., "langchain_azure_ai.chat_models"
|
||||
# becomes "langchain-azure-ai")
|
||||
pkg = module.split(".", maxsplit=1)[0].replace("_", "-")
|
||||
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
||||
def _get_chat_model_creator(
|
||||
provider: str,
|
||||
) -> Callable[..., BaseChatModel]:
|
||||
"""Return a factory function that creates a chat model for the given provider.
|
||||
|
||||
This function is cached to avoid repeated module imports.
|
||||
|
||||
Args:
|
||||
provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`).
|
||||
|
||||
Must be a key in `_SUPPORTED_PROVIDERS`.
|
||||
|
||||
Returns:
|
||||
A callable that accepts model kwargs and returns a `BaseChatModel` instance for
|
||||
the specified provider.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
|
||||
ImportError: If the provider's integration package is not installed.
|
||||
"""
|
||||
if provider not in _SUPPORTED_PROVIDERS:
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS.keys())
|
||||
msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}"
|
||||
raise ValueError(msg)
|
||||
|
||||
pkg, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
|
||||
try:
|
||||
module = _import_module(pkg)
|
||||
except ImportError as e:
|
||||
if provider != "ollama":
|
||||
raise
|
||||
# For backwards compatibility
|
||||
try:
|
||||
module = _import_module("langchain_community.chat_models")
|
||||
except ImportError:
|
||||
# If both langchain-ollama and langchain-community aren't available,
|
||||
# raise an error related to langchain-ollama
|
||||
raise e from None
|
||||
|
||||
cls = getattr(module, class_name)
|
||||
return functools.partial(creator_func, cls=cls)
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: str,
|
||||
@@ -334,154 +464,8 @@ def _init_chat_model_helper(
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel:
|
||||
model, model_provider = _parse_model(model, model_provider)
|
||||
if model_provider == "openai":
|
||||
_check_pkg("langchain_openai")
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(model=model, **kwargs)
|
||||
if model_provider == "anthropic":
|
||||
_check_pkg("langchain_anthropic")
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||
if model_provider == "azure_openai":
|
||||
_check_pkg("langchain_openai")
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI(model=model, **kwargs)
|
||||
if model_provider == "azure_ai":
|
||||
_check_pkg("langchain_azure_ai")
|
||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
||||
|
||||
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
||||
if model_provider == "cohere":
|
||||
_check_pkg("langchain_cohere")
|
||||
from langchain_cohere import ChatCohere
|
||||
|
||||
return ChatCohere(model=model, **kwargs)
|
||||
if model_provider == "google_vertexai":
|
||||
_check_pkg("langchain_google_vertexai")
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
return ChatVertexAI(model=model, **kwargs)
|
||||
if model_provider == "google_genai":
|
||||
_check_pkg("langchain_google_genai")
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
||||
if model_provider == "fireworks":
|
||||
_check_pkg("langchain_fireworks")
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
return ChatFireworks(model=model, **kwargs)
|
||||
if model_provider == "ollama":
|
||||
try:
|
||||
_check_pkg("langchain_ollama")
|
||||
from langchain_ollama import ChatOllama
|
||||
except ImportError:
|
||||
# For backwards compatibility
|
||||
try:
|
||||
_check_pkg("langchain_community")
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
except ImportError:
|
||||
# If both langchain-ollama and langchain-community aren't available,
|
||||
# raise an error related to langchain-ollama
|
||||
_check_pkg("langchain_ollama")
|
||||
|
||||
return ChatOllama(model=model, **kwargs)
|
||||
if model_provider == "together":
|
||||
_check_pkg("langchain_together")
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
return ChatTogether(model=model, **kwargs)
|
||||
if model_provider == "mistralai":
|
||||
_check_pkg("langchain_mistralai")
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
|
||||
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||
if model_provider == "huggingface":
|
||||
_check_pkg("langchain_huggingface")
|
||||
from langchain_huggingface import ChatHuggingFace
|
||||
|
||||
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
|
||||
if model_provider == "groq":
|
||||
_check_pkg("langchain_groq")
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
return ChatGroq(model=model, **kwargs)
|
||||
if model_provider == "bedrock":
|
||||
_check_pkg("langchain_aws")
|
||||
from langchain_aws import ChatBedrock
|
||||
|
||||
return ChatBedrock(model_id=model, **kwargs)
|
||||
if model_provider == "bedrock_converse":
|
||||
_check_pkg("langchain_aws")
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
|
||||
return ChatBedrockConverse(model=model, **kwargs)
|
||||
if model_provider == "google_anthropic_vertex":
|
||||
_check_pkg("langchain_google_vertexai")
|
||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||
|
||||
return ChatAnthropicVertex(model=model, **kwargs)
|
||||
if model_provider == "deepseek":
|
||||
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
return ChatDeepSeek(model=model, **kwargs)
|
||||
if model_provider == "nvidia":
|
||||
_check_pkg("langchain_nvidia_ai_endpoints")
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
|
||||
return ChatNVIDIA(model=model, **kwargs)
|
||||
if model_provider == "ibm":
|
||||
_check_pkg("langchain_ibm")
|
||||
from langchain_ibm import ChatWatsonx
|
||||
|
||||
return ChatWatsonx(model_id=model, **kwargs)
|
||||
if model_provider == "xai":
|
||||
_check_pkg("langchain_xai")
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
return ChatXAI(model=model, **kwargs)
|
||||
if model_provider == "perplexity":
|
||||
_check_pkg("langchain_perplexity")
|
||||
from langchain_perplexity import ChatPerplexity
|
||||
|
||||
return ChatPerplexity(model=model, **kwargs)
|
||||
if model_provider == "upstage":
|
||||
_check_pkg("langchain_upstage")
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
return ChatUpstage(model=model, **kwargs)
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
_SUPPORTED_PROVIDERS = {
|
||||
"openai",
|
||||
"anthropic",
|
||||
"azure_openai",
|
||||
"azure_ai",
|
||||
"cohere",
|
||||
"google_vertexai",
|
||||
"google_genai",
|
||||
"fireworks",
|
||||
"ollama",
|
||||
"together",
|
||||
"mistralai",
|
||||
"huggingface",
|
||||
"groq",
|
||||
"bedrock",
|
||||
"bedrock_converse",
|
||||
"google_anthropic_vertex",
|
||||
"deepseek",
|
||||
"ibm",
|
||||
"xai",
|
||||
"perplexity",
|
||||
"upstage",
|
||||
}
|
||||
creator_func = _get_chat_model_creator(model_provider)
|
||||
return creator_func(model=model, **kwargs)
|
||||
|
||||
|
||||
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
||||
@@ -582,13 +566,6 @@ def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
||||
return model, model_provider
|
||||
|
||||
|
||||
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
|
||||
if not util.find_spec(pkg):
|
||||
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
||||
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
def _remove_prefix(s: str, prefix: str) -> str:
|
||||
return s.removeprefix(prefix)
|
||||
|
||||
|
||||
@@ -1,27 +1,90 @@
|
||||
"""Factory functions for embeddings."""
|
||||
|
||||
import functools
|
||||
from importlib import util
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
_SUPPORTED_PROVIDERS = {
|
||||
"azure_openai": "langchain_openai",
|
||||
"bedrock": "langchain_aws",
|
||||
"cohere": "langchain_cohere",
|
||||
"google_genai": "langchain_google_genai",
|
||||
"google_vertexai": "langchain_google_vertexai",
|
||||
"huggingface": "langchain_huggingface",
|
||||
"mistralai": "langchain_mistralai",
|
||||
"ollama": "langchain_ollama",
|
||||
"openai": "langchain_openai",
|
||||
|
||||
def _call(cls: type[Embeddings], **kwargs: Any) -> Embeddings:
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., Embeddings]]] = {
|
||||
"azure_openai": ("langchain_openai", "OpenAIEmbeddings", _call),
|
||||
"bedrock": (
|
||||
"langchain_aws",
|
||||
"BedrockEmbeddings",
|
||||
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
|
||||
),
|
||||
"cohere": ("langchain_cohere", "CohereEmbeddings", _call),
|
||||
"google_genai": ("langchain_google_genai", "GoogleGenerativeAIEmbeddings", _call),
|
||||
"google_vertexai": ("langchain_google_vertexai", "VertexAIEmbeddings", _call),
|
||||
"huggingface": (
|
||||
"langchain_huggingface",
|
||||
"HuggingFaceEmbeddings",
|
||||
lambda cls, model, **kwargs: cls(model_name=model, **kwargs),
|
||||
),
|
||||
"mistralai": ("langchain_mistralai", "MistralAIEmbeddings", _call),
|
||||
"ollama": ("langchain_ollama", "OllamaEmbeddings", _call),
|
||||
"openai": ("langchain_openai", "OpenAIEmbeddings", _call),
|
||||
}
|
||||
"""Registry mapping provider names to their import configuration.
|
||||
|
||||
Each entry maps a provider key to a tuple of:
|
||||
|
||||
- `module_path`: The Python module path containing the embeddings class.
|
||||
- `class_name`: The name of the embeddings class to import.
|
||||
- `creator_func`: A callable that instantiates the class with provided kwargs.
|
||||
"""
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
||||
def _get_embeddings_class_creator(provider: str) -> Callable[..., Embeddings]:
|
||||
"""Return a factory function that creates an embeddings model for the given provider.
|
||||
|
||||
This function is cached to avoid repeated module imports.
|
||||
|
||||
Args:
|
||||
provider: The name of the model provider (e.g., `'openai'`, `'cohere'`).
|
||||
|
||||
Must be a key in `_SUPPORTED_PROVIDERS`.
|
||||
|
||||
Returns:
|
||||
A callable that accepts model kwargs and returns an `Embeddings` instance for
|
||||
the specified provider.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
|
||||
ImportError: If the provider's integration package is not installed.
|
||||
"""
|
||||
if provider not in _SUPPORTED_PROVIDERS:
|
||||
msg = (
|
||||
f"Provider '{provider}' is not supported.\n"
|
||||
f"Supported providers and their required packages:\n"
|
||||
f"{_get_provider_list()}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
module_name, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
pkg = module_name.replace("_", "-")
|
||||
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||
raise ImportError(msg) from e
|
||||
|
||||
cls = getattr(module, class_name)
|
||||
return functools.partial(creator_func, cls=cls)
|
||||
|
||||
|
||||
def _get_provider_list() -> str:
|
||||
"""Get formatted list of providers and their packages."""
|
||||
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
|
||||
return "\n".join(
|
||||
f" - {p}: {pkg[0].replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
|
||||
)
|
||||
|
||||
|
||||
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
@@ -50,7 +113,6 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
|
||||
"""
|
||||
if ":" not in model_name:
|
||||
providers = _SUPPORTED_PROVIDERS
|
||||
msg = (
|
||||
f"Invalid model format '{model_name}'.\n"
|
||||
f"Model name must be in format 'provider:model-name'\n"
|
||||
@@ -58,7 +120,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
f" - openai:text-embedding-3-small\n"
|
||||
f" - bedrock:amazon.titan-embed-text-v1\n"
|
||||
f" - cohere:embed-english-v3.0\n"
|
||||
f"Supported providers: {providers}"
|
||||
f"Supported providers: {_SUPPORTED_PROVIDERS.keys()}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -93,13 +155,12 @@ def _infer_model_and_provider(
|
||||
model_name = model
|
||||
|
||||
if not provider:
|
||||
providers = _SUPPORTED_PROVIDERS
|
||||
msg = (
|
||||
"Must specify either:\n"
|
||||
"1. A model string in format 'provider:model-name'\n"
|
||||
" Example: 'openai:text-embedding-3-small'\n"
|
||||
"2. Or explicitly set provider from: "
|
||||
f"{providers}"
|
||||
f"{_SUPPORTED_PROVIDERS.keys()}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -113,14 +174,6 @@ def _infer_model_and_provider(
|
||||
return provider, model_name
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
||||
def _check_pkg(pkg: str) -> None:
|
||||
"""Check if a package is installed."""
|
||||
if not util.find_spec(pkg):
|
||||
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
def init_embeddings(
|
||||
model: str,
|
||||
*,
|
||||
@@ -197,51 +250,7 @@ def init_embeddings(
|
||||
raise ValueError(msg)
|
||||
|
||||
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
||||
pkg = _SUPPORTED_PROVIDERS[provider]
|
||||
_check_pkg(pkg)
|
||||
|
||||
if provider == "openai":
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "azure_openai":
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "google_genai":
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "google_vertexai":
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
|
||||
return VertexAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "bedrock":
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
|
||||
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
||||
if provider == "cohere":
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
|
||||
return CohereEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "mistralai":
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
|
||||
return MistralAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "huggingface":
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
||||
if provider == "ollama":
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
|
||||
return OllamaEmbeddings(model=model_name, **kwargs)
|
||||
msg = (
|
||||
f"Provider '{provider}' is not supported.\n"
|
||||
f"Supported providers and their required packages:\n"
|
||||
f"{_get_provider_list()}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return _get_embeddings_class_creator(provider)(model=model_name, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -18,7 +18,7 @@ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
|
||||
],
|
||||
)
|
||||
async def test_init_embedding_model(provider: str, model: str) -> None:
|
||||
package = _SUPPORTED_PROVIDERS[provider]
|
||||
package = _SUPPORTED_PROVIDERS[provider][0]
|
||||
try:
|
||||
importlib.import_module(package)
|
||||
except ImportError:
|
||||
|
||||
@@ -8,6 +8,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSequence
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain.chat_models import __all__, init_chat_model
|
||||
from langchain.chat_models.base import _SUPPORTED_PROVIDERS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@@ -57,10 +58,15 @@ def test_init_missing_dep() -> None:
|
||||
|
||||
|
||||
def test_init_unknown_provider() -> None:
|
||||
with pytest.raises(ValueError, match="Unsupported model_provider='bar'"):
|
||||
with pytest.raises(ValueError, match="Unsupported provider='bar'"):
|
||||
init_chat_model("foo", model_provider="bar")
|
||||
|
||||
|
||||
def test_supported_providers_is_sorted() -> None:
|
||||
"""Test that supported providers are sorted alphabetically."""
|
||||
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
|
||||
@@ -102,7 +102,7 @@ def test_infer_model_and_provider_errors() -> None:
|
||||
)
|
||||
def test_supported_providers_package_names(provider: str) -> None:
|
||||
"""Test that all supported providers have valid package names."""
|
||||
package = _SUPPORTED_PROVIDERS[provider]
|
||||
package = _SUPPORTED_PROVIDERS[provider][0]
|
||||
assert "-" not in package
|
||||
assert package.startswith("langchain_")
|
||||
assert package.islower()
|
||||
|
||||
Reference in New Issue
Block a user