mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 06:13:36 +00:00
Init embeddings (#28370)
This commit is contained in:
parent
ffe7bd4832
commit
585da22752
@ -14,6 +14,7 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain._api import create_importer
|
from langchain._api import create_importer
|
||||||
|
from langchain.embeddings.base import init_embeddings
|
||||||
from langchain.embeddings.cache import CacheBackedEmbeddings
|
from langchain.embeddings.cache import CacheBackedEmbeddings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -221,4 +222,5 @@ __all__ = [
|
|||||||
"VertexAIEmbeddings",
|
"VertexAIEmbeddings",
|
||||||
"VoyageEmbeddings",
|
"VoyageEmbeddings",
|
||||||
"XinferenceEmbeddings",
|
"XinferenceEmbeddings",
|
||||||
|
"init_embeddings",
|
||||||
]
|
]
|
||||||
|
@ -1,4 +1,224 @@
|
|||||||
from langchain_core.embeddings import Embeddings
|
import functools
|
||||||
|
from importlib import util
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
# This is for backwards compatibility
|
from langchain_core._api import beta
|
||||||
__all__ = ["Embeddings"]
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
|
_SUPPORTED_PROVIDERS = {
|
||||||
|
"azure_openai": "langchain_openai",
|
||||||
|
"bedrock": "langchain_aws",
|
||||||
|
"cohere": "langchain_cohere",
|
||||||
|
"google_vertexai": "langchain_google_vertexai",
|
||||||
|
"huggingface": "langchain_huggingface",
|
||||||
|
"mistralai": "langchain_mistralai",
|
||||||
|
"openai": "langchain_openai",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_provider_list() -> str:
|
||||||
|
"""Get formatted list of providers and their packages."""
|
||||||
|
return "\n".join(
|
||||||
|
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model_string(model_name: str) -> Tuple[str, str]:
|
||||||
|
"""Parse a model string into provider and model name components.
|
||||||
|
|
||||||
|
The model string should be in the format 'provider:model-name', where provider
|
||||||
|
is one of the supported providers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: A model string in the format 'provider:model-name'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (provider, model_name)
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
_parse_model_string("openai:text-embedding-3-small")
|
||||||
|
# Returns: ("openai", "text-embedding-3-small")
|
||||||
|
|
||||||
|
_parse_model_string("bedrock:amazon.titan-embed-text-v1")
|
||||||
|
# Returns: ("bedrock", "amazon.titan-embed-text-v1")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model string is not in the correct format or
|
||||||
|
the provider is unsupported
|
||||||
|
"""
|
||||||
|
if ":" not in model_name:
|
||||||
|
providers = _SUPPORTED_PROVIDERS
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid model format '{model_name}'.\n"
|
||||||
|
f"Model name must be in format 'provider:model-name'\n"
|
||||||
|
f"Example valid model strings:\n"
|
||||||
|
f" - openai:text-embedding-3-small\n"
|
||||||
|
f" - bedrock:amazon.titan-embed-text-v1\n"
|
||||||
|
f" - cohere:embed-english-v3.0\n"
|
||||||
|
f"Supported providers: {providers}"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider, model = model_name.split(":", 1)
|
||||||
|
provider = provider.lower().strip()
|
||||||
|
model = model.strip()
|
||||||
|
|
||||||
|
if provider not in _SUPPORTED_PROVIDERS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider '{provider}' is not supported.\n"
|
||||||
|
f"Supported providers and their required packages:\n"
|
||||||
|
f"{_get_provider_list()}"
|
||||||
|
)
|
||||||
|
if not model:
|
||||||
|
raise ValueError("Model name cannot be empty")
|
||||||
|
return provider, model
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_model_and_provider(
|
||||||
|
model: str, *, provider: Optional[str] = None
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
if not model.strip():
|
||||||
|
raise ValueError("Model name cannot be empty")
|
||||||
|
if provider is None and ":" in model:
|
||||||
|
provider, model_name = _parse_model_string(model)
|
||||||
|
else:
|
||||||
|
provider = provider
|
||||||
|
model_name = model
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
providers = _SUPPORTED_PROVIDERS
|
||||||
|
raise ValueError(
|
||||||
|
"Must specify either:\n"
|
||||||
|
"1. A model string in format 'provider:model-name'\n"
|
||||||
|
" Example: 'openai:text-embedding-3-small'\n"
|
||||||
|
"2. Or explicitly set provider from: "
|
||||||
|
f"{providers}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider not in _SUPPORTED_PROVIDERS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider '{provider}' is not supported.\n"
|
||||||
|
f"Supported providers and their required packages:\n"
|
||||||
|
f"{_get_provider_list()}"
|
||||||
|
)
|
||||||
|
return provider, model_name
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
||||||
|
def _check_pkg(pkg: str) -> None:
|
||||||
|
"""Check if a package is installed."""
|
||||||
|
if not util.find_spec(pkg):
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import {pkg} python package. "
|
||||||
|
f"Please install it with `pip install {pkg}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@beta()
|
||||||
|
def init_embeddings(
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[Embeddings, Runnable[Any, List[float]]]:
|
||||||
|
"""Initialize an embeddings model from a model name and optional provider.
|
||||||
|
|
||||||
|
**Note:** Must have the integration package corresponding to the model provider
|
||||||
|
installed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Name of the model to use. Can be either:
|
||||||
|
- A model string like "openai:text-embedding-3-small"
|
||||||
|
- Just the model name if provider is specified
|
||||||
|
provider: Optional explicit provider name. If not specified,
|
||||||
|
will attempt to parse from the model string. Supported providers
|
||||||
|
and their required packages:
|
||||||
|
|
||||||
|
{_get_provider_list()}
|
||||||
|
|
||||||
|
**kwargs: Additional model-specific parameters passed to the embedding model.
|
||||||
|
These vary by provider, see the provider-specific documentation for details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An Embeddings instance that can generate embeddings for text.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model provider is not supported or cannot be determined
|
||||||
|
ImportError: If the required provider package is not installed
|
||||||
|
|
||||||
|
.. dropdown:: Example Usage
|
||||||
|
:open:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Using a model string
|
||||||
|
model = init_embeddings("openai:text-embedding-3-small")
|
||||||
|
model.embed_query("Hello, world!")
|
||||||
|
|
||||||
|
# Using explicit provider
|
||||||
|
model = init_embeddings(
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
provider="openai"
|
||||||
|
)
|
||||||
|
model.embed_documents(["Hello, world!", "Goodbye, world!"])
|
||||||
|
|
||||||
|
# With additional parameters
|
||||||
|
model = init_embeddings(
|
||||||
|
"openai:text-embedding-3-small",
|
||||||
|
api_key="sk-..."
|
||||||
|
)
|
||||||
|
|
||||||
|
.. versionadded:: 0.3.9
|
||||||
|
"""
|
||||||
|
if not model:
|
||||||
|
providers = _SUPPORTED_PROVIDERS.keys()
|
||||||
|
raise ValueError(
|
||||||
|
"Must specify model name. "
|
||||||
|
f"Supported providers are: {', '.join(providers)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
||||||
|
pkg = _SUPPORTED_PROVIDERS[provider]
|
||||||
|
_check_pkg(pkg)
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
return OpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
|
elif provider == "azure_openai":
|
||||||
|
from langchain_openai import AzureOpenAIEmbeddings
|
||||||
|
|
||||||
|
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
|
elif provider == "google_vertexai":
|
||||||
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
|
|
||||||
|
return VertexAIEmbeddings(model=model_name, **kwargs)
|
||||||
|
elif provider == "bedrock":
|
||||||
|
from langchain_aws import BedrockEmbeddings
|
||||||
|
|
||||||
|
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
||||||
|
elif provider == "cohere":
|
||||||
|
from langchain_cohere import CohereEmbeddings
|
||||||
|
|
||||||
|
return CohereEmbeddings(model=model_name, **kwargs)
|
||||||
|
elif provider == "mistralai":
|
||||||
|
from langchain_mistralai import MistralAIEmbeddings
|
||||||
|
|
||||||
|
return MistralAIEmbeddings(model=model_name, **kwargs)
|
||||||
|
elif provider == "huggingface":
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider '{provider}' is not supported.\n"
|
||||||
|
f"Supported providers and their required packages:\n"
|
||||||
|
f"{_get_provider_list()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"init_embeddings",
|
||||||
|
"Embeddings", # This one is for backwards compatibility
|
||||||
|
]
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
"""Test embeddings base module."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider, model",
|
||||||
|
[
|
||||||
|
("openai", "text-embedding-3-large"),
|
||||||
|
("google_vertexai", "text-embedding-gecko@003"),
|
||||||
|
("bedrock", "amazon.titan-embed-text-v1"),
|
||||||
|
("cohere", "embed-english-v2.0"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_init_embedding_model(provider: str, model: str) -> None:
|
||||||
|
package = _SUPPORTED_PROVIDERS[provider]
|
||||||
|
try:
|
||||||
|
importlib.import_module(package)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip(f"Package {package} is not installed")
|
||||||
|
|
||||||
|
model_colon = init_embeddings(f"{provider}:{model}")
|
||||||
|
assert isinstance(model_colon, Embeddings)
|
||||||
|
|
||||||
|
model_explicit = init_embeddings(
|
||||||
|
model=model,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
assert isinstance(model_explicit, Embeddings)
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
|
||||||
|
embedding_colon = await model_colon.aembed_query(text)
|
||||||
|
assert isinstance(embedding_colon, list)
|
||||||
|
assert all(isinstance(x, float) for x in embedding_colon)
|
||||||
|
|
||||||
|
embedding_explicit = await model_explicit.aembed_query(text)
|
||||||
|
assert isinstance(embedding_explicit, list)
|
||||||
|
assert all(isinstance(x, float) for x in embedding_explicit)
|
111
libs/langchain/tests/unit_tests/embeddings/test_base.py
Normal file
111
libs/langchain/tests/unit_tests/embeddings/test_base.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
"""Test embeddings base module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.embeddings.base import (
|
||||||
|
_SUPPORTED_PROVIDERS,
|
||||||
|
_infer_model_and_provider,
|
||||||
|
_parse_model_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_string() -> None:
|
||||||
|
"""Test parsing model strings into provider and model components."""
|
||||||
|
assert _parse_model_string("openai:text-embedding-3-small") == (
|
||||||
|
"openai",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
)
|
||||||
|
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
|
||||||
|
"bedrock",
|
||||||
|
"amazon.titan-embed-text-v1",
|
||||||
|
)
|
||||||
|
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
|
||||||
|
"huggingface",
|
||||||
|
"BAAI/bge-base-en:v1.5",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_string_errors() -> None:
|
||||||
|
"""Test error cases for model string parsing."""
|
||||||
|
with pytest.raises(ValueError, match="Model name must be"):
|
||||||
|
_parse_model_string("just-a-model-name")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid model format "):
|
||||||
|
_parse_model_string("")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="is not supported"):
|
||||||
|
_parse_model_string(":model-name")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Model name cannot be empty"):
|
||||||
|
_parse_model_string("openai:")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Provider 'invalid-provider' is not supported"
|
||||||
|
):
|
||||||
|
_parse_model_string("invalid-provider:model-name")
|
||||||
|
|
||||||
|
for provider in _SUPPORTED_PROVIDERS:
|
||||||
|
with pytest.raises(ValueError, match=f"{provider}"):
|
||||||
|
_parse_model_string("invalid-provider:model-name")
|
||||||
|
|
||||||
|
|
||||||
|
def test_infer_model_and_provider() -> None:
|
||||||
|
"""Test model and provider inference from different input formats."""
|
||||||
|
assert _infer_model_and_provider("openai:text-embedding-3-small") == (
|
||||||
|
"openai",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _infer_model_and_provider(
|
||||||
|
model="text-embedding-3-small", provider="openai"
|
||||||
|
) == ("openai", "text-embedding-3-small")
|
||||||
|
|
||||||
|
assert _infer_model_and_provider(
|
||||||
|
model="ft:text-embedding-3-small", provider="openai"
|
||||||
|
) == ("openai", "ft:text-embedding-3-small")
|
||||||
|
|
||||||
|
assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == (
|
||||||
|
"openai",
|
||||||
|
"ft:text-embedding-3-small",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_infer_model_and_provider_errors() -> None:
|
||||||
|
"""Test error cases for model and provider inference."""
|
||||||
|
# Test missing provider
|
||||||
|
with pytest.raises(ValueError, match="Must specify either"):
|
||||||
|
_infer_model_and_provider("text-embedding-3-small")
|
||||||
|
|
||||||
|
# Test empty model
|
||||||
|
with pytest.raises(ValueError, match="Model name cannot be empty"):
|
||||||
|
_infer_model_and_provider("")
|
||||||
|
|
||||||
|
# Test empty provider with model
|
||||||
|
with pytest.raises(ValueError, match="Must specify either"):
|
||||||
|
_infer_model_and_provider("model", provider="")
|
||||||
|
|
||||||
|
# Test invalid provider
|
||||||
|
with pytest.raises(ValueError, match="is not supported"):
|
||||||
|
_infer_model_and_provider("model", provider="invalid")
|
||||||
|
|
||||||
|
# Test provider list is in error
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
_infer_model_and_provider("model", provider="invalid")
|
||||||
|
for provider in _SUPPORTED_PROVIDERS:
|
||||||
|
assert provider in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider",
|
||||||
|
sorted(_SUPPORTED_PROVIDERS.keys()),
|
||||||
|
)
|
||||||
|
def test_supported_providers_package_names(provider: str) -> None:
|
||||||
|
"""Test that all supported providers have valid package names."""
|
||||||
|
package = _SUPPORTED_PROVIDERS[provider]
|
||||||
|
assert "-" not in package
|
||||||
|
assert package.startswith("langchain_")
|
||||||
|
assert package.islower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_sorted() -> None:
|
||||||
|
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
|
@ -55,6 +55,7 @@ EXPECTED_ALL = [
|
|||||||
"JohnSnowLabsEmbeddings",
|
"JohnSnowLabsEmbeddings",
|
||||||
"VoyageEmbeddings",
|
"VoyageEmbeddings",
|
||||||
"BookendEmbeddings",
|
"BookendEmbeddings",
|
||||||
|
"init_embeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user