mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 21:35:08 +00:00
This PR adds scaffolding for langchain 1.0 entry package. Most contents have been removed. Currently remaining entrypoints for: * chat models * embedding models * memory -> trimming messages, filtering messages and counting tokens [we may remove this] * prompts -> we may remove some prompts * storage: primarily to support cache backed embeddings, may remove the kv store * tools -> report tool primitives Things to be added: * Selected agent implementations * Selected workflows * Common primitives: messages, Document * Primitives for type hinting: BaseChatModel, BaseEmbeddings * Selected retrievers * Selected text splitters Things to be removed: * Globals needs to be removed (needs an update in langchain core) Todos: * TBD indexing api (requires sqlalchemy which we don't want as a dependency) * Be explicit about public/private interfaces (e.g., likely rename chat_models.base.py to something more internal) * Remove dockerfiles * Update module doc-strings and README.md
112 lines
3.5 KiB
Python
112 lines
3.5 KiB
Python
"""Test embeddings base module."""
|
|
|
|
import pytest
|
|
|
|
from langchain.embeddings.base import (
|
|
_SUPPORTED_PROVIDERS,
|
|
_infer_model_and_provider,
|
|
_parse_model_string,
|
|
)
|
|
|
|
|
|
def test_parse_model_string() -> None:
|
|
"""Test parsing model strings into provider and model components."""
|
|
assert _parse_model_string("openai:text-embedding-3-small") == (
|
|
"openai",
|
|
"text-embedding-3-small",
|
|
)
|
|
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
|
|
"bedrock",
|
|
"amazon.titan-embed-text-v1",
|
|
)
|
|
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
|
|
"huggingface",
|
|
"BAAI/bge-base-en:v1.5",
|
|
)
|
|
|
|
|
|
def test_parse_model_string_errors() -> None:
|
|
"""Test error cases for model string parsing."""
|
|
with pytest.raises(ValueError, match="Model name must be"):
|
|
_parse_model_string("just-a-model-name")
|
|
|
|
with pytest.raises(ValueError, match="Invalid model format "):
|
|
_parse_model_string("")
|
|
|
|
with pytest.raises(ValueError, match="is not supported"):
|
|
_parse_model_string(":model-name")
|
|
|
|
with pytest.raises(ValueError, match="Model name cannot be empty"):
|
|
_parse_model_string("openai:")
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Provider 'invalid-provider' is not supported",
|
|
):
|
|
_parse_model_string("invalid-provider:model-name")
|
|
|
|
for provider in _SUPPORTED_PROVIDERS:
|
|
with pytest.raises(ValueError, match=f"{provider}"):
|
|
_parse_model_string("invalid-provider:model-name")
|
|
|
|
|
|
def test_infer_model_and_provider() -> None:
|
|
"""Test model and provider inference from different input formats."""
|
|
assert _infer_model_and_provider("openai:text-embedding-3-small") == (
|
|
"openai",
|
|
"text-embedding-3-small",
|
|
)
|
|
|
|
assert _infer_model_and_provider(
|
|
model="text-embedding-3-small",
|
|
provider="openai",
|
|
) == ("openai", "text-embedding-3-small")
|
|
|
|
assert _infer_model_and_provider(
|
|
model="ft:text-embedding-3-small",
|
|
provider="openai",
|
|
) == ("openai", "ft:text-embedding-3-small")
|
|
|
|
assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == (
|
|
"openai",
|
|
"ft:text-embedding-3-small",
|
|
)
|
|
|
|
|
|
def test_infer_model_and_provider_errors() -> None:
|
|
"""Test error cases for model and provider inference."""
|
|
# Test missing provider
|
|
with pytest.raises(ValueError, match="Must specify either"):
|
|
_infer_model_and_provider("text-embedding-3-small")
|
|
|
|
# Test empty model
|
|
with pytest.raises(ValueError, match="Model name cannot be empty"):
|
|
_infer_model_and_provider("")
|
|
|
|
# Test empty provider with model
|
|
with pytest.raises(ValueError, match="Must specify either"):
|
|
_infer_model_and_provider("model", provider="")
|
|
|
|
# Test invalid provider
|
|
with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc:
|
|
_infer_model_and_provider("model", provider="invalid")
|
|
# Test provider list is in error
|
|
for provider in _SUPPORTED_PROVIDERS:
|
|
assert provider in str(exc.value)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"provider",
|
|
sorted(_SUPPORTED_PROVIDERS.keys()),
|
|
)
|
|
def test_supported_providers_package_names(provider: str) -> None:
|
|
"""Test that all supported providers have valid package names."""
|
|
package = _SUPPORTED_PROVIDERS[provider]
|
|
assert "-" not in package
|
|
assert package.startswith("langchain_")
|
|
assert package.islower()
|
|
|
|
|
|
def test_is_sorted() -> None:
|
|
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
|