mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
refactor(langchain): refactor optional imports logic (#32813)
* Use `importlib` to load dynamically the classes * Removes missing package warnings --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
d46187201d
commit
85f1ba2351
@@ -2,9 +2,17 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
import warnings
|
import warnings
|
||||||
from importlib import util
|
from typing import (
|
||||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Literal,
|
||||||
|
TypeAlias,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||||
from langchain_core.messages import AIMessage, AnyMessage
|
from langchain_core.messages import AIMessage, AnyMessage
|
||||||
@@ -14,6 +22,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@@ -21,6 +30,127 @@ if TYPE_CHECKING:
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel:
|
||||||
|
# TODO: replace with operator.call when lower bounding to Python 3.11
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = {
|
||||||
|
"anthropic": ("langchain_anthropic", "ChatAnthropic", _call),
|
||||||
|
"azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call),
|
||||||
|
"azure_openai": ("langchain_openai", "AzureChatOpenAI", _call),
|
||||||
|
"bedrock": ("langchain_aws", "ChatBedrock", _call),
|
||||||
|
"bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call),
|
||||||
|
"cohere": ("langchain_cohere", "ChatCohere", _call),
|
||||||
|
"deepseek": ("langchain_deepseek", "ChatDeepSeek", _call),
|
||||||
|
"fireworks": ("langchain_fireworks", "ChatFireworks", _call),
|
||||||
|
"google_anthropic_vertex": (
|
||||||
|
"langchain_google_vertexai.model_garden",
|
||||||
|
"ChatAnthropicVertex",
|
||||||
|
_call,
|
||||||
|
),
|
||||||
|
"google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call),
|
||||||
|
"google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call),
|
||||||
|
"groq": ("langchain_groq", "ChatGroq", _call),
|
||||||
|
"huggingface": (
|
||||||
|
"langchain_huggingface",
|
||||||
|
"ChatHuggingFace",
|
||||||
|
lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs),
|
||||||
|
),
|
||||||
|
"ibm": (
|
||||||
|
"langchain_ibm",
|
||||||
|
"ChatWatsonx",
|
||||||
|
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
|
||||||
|
),
|
||||||
|
"mistralai": ("langchain_mistralai", "ChatMistralAI", _call),
|
||||||
|
"nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call),
|
||||||
|
"ollama": ("langchain_ollama", "ChatOllama", _call),
|
||||||
|
"openai": ("langchain_openai", "ChatOpenAI", _call),
|
||||||
|
"perplexity": ("langchain_perplexity", "ChatPerplexity", _call),
|
||||||
|
"together": ("langchain_together", "ChatTogether", _call),
|
||||||
|
"upstage": ("langchain_upstage", "ChatUpstage", _call),
|
||||||
|
"xai": ("langchain_xai", "ChatXAI", _call),
|
||||||
|
}
|
||||||
|
"""Registry mapping provider names to their import configuration.
|
||||||
|
|
||||||
|
Each entry maps a provider key to a tuple of:
|
||||||
|
|
||||||
|
- `module_path`: The Python module path containing the chat model class.
|
||||||
|
|
||||||
|
This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is
|
||||||
|
not exported from the package root.
|
||||||
|
- `class_name`: The name of the chat model class to import.
|
||||||
|
- `creator_func`: A callable that instantiates the class with provided kwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _import_module(module: str) -> ModuleType:
|
||||||
|
"""Import a module by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The fully qualified module name to import (e.g., `'langchain_openai'`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The imported module.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the module cannot be imported, with a message suggesting
|
||||||
|
the pip package to install.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(module)
|
||||||
|
except ImportError as e:
|
||||||
|
# Extract package name from module path (e.g., "langchain_azure_ai.chat_models"
|
||||||
|
# becomes "langchain-azure-ai")
|
||||||
|
pkg = module.split(".", maxsplit=1)[0].replace("_", "-")
|
||||||
|
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||||
|
raise ImportError(msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
||||||
|
def _get_chat_model_creator(
|
||||||
|
provider: str,
|
||||||
|
) -> Callable[..., BaseChatModel]:
|
||||||
|
"""Return a factory function that creates a chat model for the given provider.
|
||||||
|
|
||||||
|
This function is cached to avoid repeated module imports.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`).
|
||||||
|
|
||||||
|
Must be a key in `_SUPPORTED_PROVIDERS`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A callable that accepts model kwargs and returns a `BaseChatModel` instance for
|
||||||
|
the specified provider.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
|
||||||
|
ImportError: If the provider's integration package is not installed.
|
||||||
|
"""
|
||||||
|
if provider not in _SUPPORTED_PROVIDERS:
|
||||||
|
supported = ", ".join(_SUPPORTED_PROVIDERS.keys())
|
||||||
|
msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
pkg, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
|
||||||
|
try:
|
||||||
|
module = _import_module(pkg)
|
||||||
|
except ImportError as e:
|
||||||
|
if provider != "ollama":
|
||||||
|
raise
|
||||||
|
# For backwards compatibility
|
||||||
|
try:
|
||||||
|
module = _import_module("langchain_community.chat_models")
|
||||||
|
except ImportError:
|
||||||
|
# If both langchain-ollama and langchain-community aren't available,
|
||||||
|
# raise an error related to langchain-ollama
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
return functools.partial(creator_func, cls=cls)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def init_chat_model(
|
def init_chat_model(
|
||||||
model: str,
|
model: str,
|
||||||
@@ -334,154 +464,8 @@ def _init_chat_model_helper(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseChatModel:
|
) -> BaseChatModel:
|
||||||
model, model_provider = _parse_model(model, model_provider)
|
model, model_provider = _parse_model(model, model_provider)
|
||||||
if model_provider == "openai":
|
creator_func = _get_chat_model_creator(model_provider)
|
||||||
_check_pkg("langchain_openai")
|
return creator_func(model=model, **kwargs)
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
return ChatOpenAI(model=model, **kwargs)
|
|
||||||
if model_provider == "anthropic":
|
|
||||||
_check_pkg("langchain_anthropic")
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
|
||||||
|
|
||||||
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
|
||||||
if model_provider == "azure_openai":
|
|
||||||
_check_pkg("langchain_openai")
|
|
||||||
from langchain_openai import AzureChatOpenAI
|
|
||||||
|
|
||||||
return AzureChatOpenAI(model=model, **kwargs)
|
|
||||||
if model_provider == "azure_ai":
|
|
||||||
_check_pkg("langchain_azure_ai")
|
|
||||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
|
||||||
|
|
||||||
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
|
||||||
if model_provider == "cohere":
|
|
||||||
_check_pkg("langchain_cohere")
|
|
||||||
from langchain_cohere import ChatCohere
|
|
||||||
|
|
||||||
return ChatCohere(model=model, **kwargs)
|
|
||||||
if model_provider == "google_vertexai":
|
|
||||||
_check_pkg("langchain_google_vertexai")
|
|
||||||
from langchain_google_vertexai import ChatVertexAI
|
|
||||||
|
|
||||||
return ChatVertexAI(model=model, **kwargs)
|
|
||||||
if model_provider == "google_genai":
|
|
||||||
_check_pkg("langchain_google_genai")
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
|
|
||||||
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
|
||||||
if model_provider == "fireworks":
|
|
||||||
_check_pkg("langchain_fireworks")
|
|
||||||
from langchain_fireworks import ChatFireworks
|
|
||||||
|
|
||||||
return ChatFireworks(model=model, **kwargs)
|
|
||||||
if model_provider == "ollama":
|
|
||||||
try:
|
|
||||||
_check_pkg("langchain_ollama")
|
|
||||||
from langchain_ollama import ChatOllama
|
|
||||||
except ImportError:
|
|
||||||
# For backwards compatibility
|
|
||||||
try:
|
|
||||||
_check_pkg("langchain_community")
|
|
||||||
from langchain_community.chat_models import ChatOllama
|
|
||||||
except ImportError:
|
|
||||||
# If both langchain-ollama and langchain-community aren't available,
|
|
||||||
# raise an error related to langchain-ollama
|
|
||||||
_check_pkg("langchain_ollama")
|
|
||||||
|
|
||||||
return ChatOllama(model=model, **kwargs)
|
|
||||||
if model_provider == "together":
|
|
||||||
_check_pkg("langchain_together")
|
|
||||||
from langchain_together import ChatTogether
|
|
||||||
|
|
||||||
return ChatTogether(model=model, **kwargs)
|
|
||||||
if model_provider == "mistralai":
|
|
||||||
_check_pkg("langchain_mistralai")
|
|
||||||
from langchain_mistralai import ChatMistralAI
|
|
||||||
|
|
||||||
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
|
||||||
if model_provider == "huggingface":
|
|
||||||
_check_pkg("langchain_huggingface")
|
|
||||||
from langchain_huggingface import ChatHuggingFace
|
|
||||||
|
|
||||||
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
|
|
||||||
if model_provider == "groq":
|
|
||||||
_check_pkg("langchain_groq")
|
|
||||||
from langchain_groq import ChatGroq
|
|
||||||
|
|
||||||
return ChatGroq(model=model, **kwargs)
|
|
||||||
if model_provider == "bedrock":
|
|
||||||
_check_pkg("langchain_aws")
|
|
||||||
from langchain_aws import ChatBedrock
|
|
||||||
|
|
||||||
return ChatBedrock(model_id=model, **kwargs)
|
|
||||||
if model_provider == "bedrock_converse":
|
|
||||||
_check_pkg("langchain_aws")
|
|
||||||
from langchain_aws import ChatBedrockConverse
|
|
||||||
|
|
||||||
return ChatBedrockConverse(model=model, **kwargs)
|
|
||||||
if model_provider == "google_anthropic_vertex":
|
|
||||||
_check_pkg("langchain_google_vertexai")
|
|
||||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
|
||||||
|
|
||||||
return ChatAnthropicVertex(model=model, **kwargs)
|
|
||||||
if model_provider == "deepseek":
|
|
||||||
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
|
||||||
from langchain_deepseek import ChatDeepSeek
|
|
||||||
|
|
||||||
return ChatDeepSeek(model=model, **kwargs)
|
|
||||||
if model_provider == "nvidia":
|
|
||||||
_check_pkg("langchain_nvidia_ai_endpoints")
|
|
||||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
|
||||||
|
|
||||||
return ChatNVIDIA(model=model, **kwargs)
|
|
||||||
if model_provider == "ibm":
|
|
||||||
_check_pkg("langchain_ibm")
|
|
||||||
from langchain_ibm import ChatWatsonx
|
|
||||||
|
|
||||||
return ChatWatsonx(model_id=model, **kwargs)
|
|
||||||
if model_provider == "xai":
|
|
||||||
_check_pkg("langchain_xai")
|
|
||||||
from langchain_xai import ChatXAI
|
|
||||||
|
|
||||||
return ChatXAI(model=model, **kwargs)
|
|
||||||
if model_provider == "perplexity":
|
|
||||||
_check_pkg("langchain_perplexity")
|
|
||||||
from langchain_perplexity import ChatPerplexity
|
|
||||||
|
|
||||||
return ChatPerplexity(model=model, **kwargs)
|
|
||||||
if model_provider == "upstage":
|
|
||||||
_check_pkg("langchain_upstage")
|
|
||||||
from langchain_upstage import ChatUpstage
|
|
||||||
|
|
||||||
return ChatUpstage(model=model, **kwargs)
|
|
||||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
|
||||||
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
_SUPPORTED_PROVIDERS = {
|
|
||||||
"openai",
|
|
||||||
"anthropic",
|
|
||||||
"azure_openai",
|
|
||||||
"azure_ai",
|
|
||||||
"cohere",
|
|
||||||
"google_vertexai",
|
|
||||||
"google_genai",
|
|
||||||
"fireworks",
|
|
||||||
"ollama",
|
|
||||||
"together",
|
|
||||||
"mistralai",
|
|
||||||
"huggingface",
|
|
||||||
"groq",
|
|
||||||
"bedrock",
|
|
||||||
"bedrock_converse",
|
|
||||||
"google_anthropic_vertex",
|
|
||||||
"deepseek",
|
|
||||||
"ibm",
|
|
||||||
"xai",
|
|
||||||
"perplexity",
|
|
||||||
"upstage",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
||||||
@@ -582,13 +566,6 @@ def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
|
|||||||
return model, model_provider
|
return model, model_provider
|
||||||
|
|
||||||
|
|
||||||
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
|
|
||||||
if not util.find_spec(pkg):
|
|
||||||
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
|
||||||
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
|
||||||
raise ImportError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_prefix(s: str, prefix: str) -> str:
|
def _remove_prefix(s: str, prefix: str) -> str:
|
||||||
return s.removeprefix(prefix)
|
return s.removeprefix(prefix)
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +1,90 @@
|
|||||||
"""Factory functions for embeddings."""
|
"""Factory functions for embeddings."""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from importlib import util
|
import importlib
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
_SUPPORTED_PROVIDERS = {
|
|
||||||
"azure_openai": "langchain_openai",
|
def _call(cls: type[Embeddings], **kwargs: Any) -> Embeddings:
|
||||||
"bedrock": "langchain_aws",
|
return cls(**kwargs)
|
||||||
"cohere": "langchain_cohere",
|
|
||||||
"google_genai": "langchain_google_genai",
|
|
||||||
"google_vertexai": "langchain_google_vertexai",
|
_SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., Embeddings]]] = {
|
||||||
"huggingface": "langchain_huggingface",
|
"azure_openai": ("langchain_openai", "OpenAIEmbeddings", _call),
|
||||||
"mistralai": "langchain_mistralai",
|
"bedrock": (
|
||||||
"ollama": "langchain_ollama",
|
"langchain_aws",
|
||||||
"openai": "langchain_openai",
|
"BedrockEmbeddings",
|
||||||
|
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
|
||||||
|
),
|
||||||
|
"cohere": ("langchain_cohere", "CohereEmbeddings", _call),
|
||||||
|
"google_genai": ("langchain_google_genai", "GoogleGenerativeAIEmbeddings", _call),
|
||||||
|
"google_vertexai": ("langchain_google_vertexai", "VertexAIEmbeddings", _call),
|
||||||
|
"huggingface": (
|
||||||
|
"langchain_huggingface",
|
||||||
|
"HuggingFaceEmbeddings",
|
||||||
|
lambda cls, model, **kwargs: cls(model_name=model, **kwargs),
|
||||||
|
),
|
||||||
|
"mistralai": ("langchain_mistralai", "MistralAIEmbeddings", _call),
|
||||||
|
"ollama": ("langchain_ollama", "OllamaEmbeddings", _call),
|
||||||
|
"openai": ("langchain_openai", "OpenAIEmbeddings", _call),
|
||||||
}
|
}
|
||||||
|
"""Registry mapping provider names to their import configuration.
|
||||||
|
|
||||||
|
Each entry maps a provider key to a tuple of:
|
||||||
|
|
||||||
|
- `module_path`: The Python module path containing the embeddings class.
|
||||||
|
- `class_name`: The name of the embeddings class to import.
|
||||||
|
- `creator_func`: A callable that instantiates the class with provided kwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
||||||
|
def _get_embeddings_class_creator(provider: str) -> Callable[..., Embeddings]:
|
||||||
|
"""Return a factory function that creates an embeddings model for the given provider.
|
||||||
|
|
||||||
|
This function is cached to avoid repeated module imports.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The name of the model provider (e.g., `'openai'`, `'cohere'`).
|
||||||
|
|
||||||
|
Must be a key in `_SUPPORTED_PROVIDERS`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A callable that accepts model kwargs and returns an `Embeddings` instance for
|
||||||
|
the specified provider.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provider is not in `_SUPPORTED_PROVIDERS`.
|
||||||
|
ImportError: If the provider's integration package is not installed.
|
||||||
|
"""
|
||||||
|
if provider not in _SUPPORTED_PROVIDERS:
|
||||||
|
msg = (
|
||||||
|
f"Provider '{provider}' is not supported.\n"
|
||||||
|
f"Supported providers and their required packages:\n"
|
||||||
|
f"{_get_provider_list()}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
module_name, class_name, creator_func = _SUPPORTED_PROVIDERS[provider]
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
except ImportError as e:
|
||||||
|
pkg = module_name.replace("_", "-")
|
||||||
|
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
||||||
|
raise ImportError(msg) from e
|
||||||
|
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
return functools.partial(creator_func, cls=cls)
|
||||||
|
|
||||||
|
|
||||||
def _get_provider_list() -> str:
|
def _get_provider_list() -> str:
|
||||||
"""Get formatted list of providers and their packages."""
|
"""Get formatted list of providers and their packages."""
|
||||||
return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
|
return "\n".join(
|
||||||
|
f" - {p}: {pkg[0].replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||||
@@ -50,7 +113,6 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if ":" not in model_name:
|
if ":" not in model_name:
|
||||||
providers = _SUPPORTED_PROVIDERS
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Invalid model format '{model_name}'.\n"
|
f"Invalid model format '{model_name}'.\n"
|
||||||
f"Model name must be in format 'provider:model-name'\n"
|
f"Model name must be in format 'provider:model-name'\n"
|
||||||
@@ -58,7 +120,7 @@ def _parse_model_string(model_name: str) -> tuple[str, str]:
|
|||||||
f" - openai:text-embedding-3-small\n"
|
f" - openai:text-embedding-3-small\n"
|
||||||
f" - bedrock:amazon.titan-embed-text-v1\n"
|
f" - bedrock:amazon.titan-embed-text-v1\n"
|
||||||
f" - cohere:embed-english-v3.0\n"
|
f" - cohere:embed-english-v3.0\n"
|
||||||
f"Supported providers: {providers}"
|
f"Supported providers: {_SUPPORTED_PROVIDERS.keys()}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@@ -93,13 +155,12 @@ def _infer_model_and_provider(
|
|||||||
model_name = model
|
model_name = model
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
providers = _SUPPORTED_PROVIDERS
|
|
||||||
msg = (
|
msg = (
|
||||||
"Must specify either:\n"
|
"Must specify either:\n"
|
||||||
"1. A model string in format 'provider:model-name'\n"
|
"1. A model string in format 'provider:model-name'\n"
|
||||||
" Example: 'openai:text-embedding-3-small'\n"
|
" Example: 'openai:text-embedding-3-small'\n"
|
||||||
"2. Or explicitly set provider from: "
|
"2. Or explicitly set provider from: "
|
||||||
f"{providers}"
|
f"{_SUPPORTED_PROVIDERS.keys()}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@@ -113,14 +174,6 @@ def _infer_model_and_provider(
|
|||||||
return provider, model_name
|
return provider, model_name
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
|
||||||
def _check_pkg(pkg: str) -> None:
|
|
||||||
"""Check if a package is installed."""
|
|
||||||
if not util.find_spec(pkg):
|
|
||||||
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
|
|
||||||
raise ImportError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def init_embeddings(
|
def init_embeddings(
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
@@ -197,51 +250,7 @@ def init_embeddings(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
||||||
pkg = _SUPPORTED_PROVIDERS[provider]
|
return _get_embeddings_class_creator(provider)(model=model_name, **kwargs)
|
||||||
_check_pkg(pkg)
|
|
||||||
|
|
||||||
if provider == "openai":
|
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
|
|
||||||
return OpenAIEmbeddings(model=model_name, **kwargs)
|
|
||||||
if provider == "azure_openai":
|
|
||||||
from langchain_openai import AzureOpenAIEmbeddings
|
|
||||||
|
|
||||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
|
||||||
if provider == "google_genai":
|
|
||||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
||||||
|
|
||||||
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
|
|
||||||
if provider == "google_vertexai":
|
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings
|
|
||||||
|
|
||||||
return VertexAIEmbeddings(model=model_name, **kwargs)
|
|
||||||
if provider == "bedrock":
|
|
||||||
from langchain_aws import BedrockEmbeddings
|
|
||||||
|
|
||||||
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
|
||||||
if provider == "cohere":
|
|
||||||
from langchain_cohere import CohereEmbeddings
|
|
||||||
|
|
||||||
return CohereEmbeddings(model=model_name, **kwargs)
|
|
||||||
if provider == "mistralai":
|
|
||||||
from langchain_mistralai import MistralAIEmbeddings
|
|
||||||
|
|
||||||
return MistralAIEmbeddings(model=model_name, **kwargs)
|
|
||||||
if provider == "huggingface":
|
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
|
||||||
|
|
||||||
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
|
||||||
if provider == "ollama":
|
|
||||||
from langchain_ollama import OllamaEmbeddings
|
|
||||||
|
|
||||||
return OllamaEmbeddings(model=model_name, **kwargs)
|
|
||||||
msg = (
|
|
||||||
f"Provider '{provider}' is not supported.\n"
|
|
||||||
f"Supported providers and their required packages:\n"
|
|
||||||
f"{_get_provider_list()}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_init_embedding_model(provider: str, model: str) -> None:
|
async def test_init_embedding_model(provider: str, model: str) -> None:
|
||||||
package = _SUPPORTED_PROVIDERS[provider]
|
package = _SUPPORTED_PROVIDERS[provider][0]
|
||||||
try:
|
try:
|
||||||
importlib.import_module(package)
|
importlib.import_module(package)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSequence
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from langchain.chat_models import __all__, init_chat_model
|
from langchain.chat_models import __all__, init_chat_model
|
||||||
|
from langchain.chat_models.base import _SUPPORTED_PROVIDERS
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
@@ -57,10 +58,15 @@ def test_init_missing_dep() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_init_unknown_provider() -> None:
|
def test_init_unknown_provider() -> None:
|
||||||
with pytest.raises(ValueError, match="Unsupported model_provider='bar'"):
|
with pytest.raises(ValueError, match="Unsupported provider='bar'"):
|
||||||
init_chat_model("foo", model_provider="bar")
|
init_chat_model("foo", model_provider="bar")
|
||||||
|
|
||||||
|
|
||||||
|
def test_supported_providers_is_sorted() -> None:
|
||||||
|
"""Test that supported providers are sorted alphabetically."""
|
||||||
|
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("langchain_openai")
|
@pytest.mark.requires("langchain_openai")
|
||||||
@mock.patch.dict(
|
@mock.patch.dict(
|
||||||
os.environ,
|
os.environ,
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def test_infer_model_and_provider_errors() -> None:
|
|||||||
)
|
)
|
||||||
def test_supported_providers_package_names(provider: str) -> None:
|
def test_supported_providers_package_names(provider: str) -> None:
|
||||||
"""Test that all supported providers have valid package names."""
|
"""Test that all supported providers have valid package names."""
|
||||||
package = _SUPPORTED_PROVIDERS[provider]
|
package = _SUPPORTED_PROVIDERS[provider][0]
|
||||||
assert "-" not in package
|
assert "-" not in package
|
||||||
assert package.startswith("langchain_")
|
assert package.startswith("langchain_")
|
||||||
assert package.islower()
|
assert package.islower()
|
||||||
|
|||||||
Reference in New Issue
Block a user