langchain/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py
Eugene Yurtsev 56dde3ade3
feat(langchain): v1 scaffolding (#32166)
This PR adds scaffolding for langchain 1.0 entry package.

Most contents have been removed. 

Currently remaining entrypoints for:

* chat models
* embedding models
* memory -> trimming messages, filtering messages and counting tokens
[we may remove this]
* prompts -> we may remove some prompts
* storage: primarily to support cache backed embeddings, may remove the
kv store
* tools -> report tool primitives

Things to be added:

* Selected agent implementations
* Selected workflows
* Common primitives: messages, Document
* Primitives for type hinting: BaseChatModel, BaseEmbeddings
* Selected retrievers
* Selected text splitters

Things to be removed:

* Globals needs to be removed (needs an update in langchain core)


Todos: 

* TBD indexing api (requires sqlalchemy which we don't want as a
dependency)
* Be explicit about public/private interfaces (e.g., likely rename
chat_models.base.py to something more internal)
* Remove dockerfiles
* Update module doc-strings and README.md
2025-07-24 09:47:48 -04:00

112 lines
3.5 KiB
Python

"""Test embeddings base module."""
import pytest
from langchain.embeddings.base import (
_SUPPORTED_PROVIDERS,
_infer_model_and_provider,
_parse_model_string,
)
def test_parse_model_string() -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
)
def test_parse_model_string_errors() -> None:
"""Test error cases for model string parsing."""
with pytest.raises(ValueError, match="Model name must be"):
_parse_model_string("just-a-model-name")
with pytest.raises(ValueError, match="Invalid model format "):
_parse_model_string("")
with pytest.raises(ValueError, match="is not supported"):
_parse_model_string(":model-name")
with pytest.raises(ValueError, match="Model name cannot be empty"):
_parse_model_string("openai:")
with pytest.raises(
ValueError,
match="Provider 'invalid-provider' is not supported",
):
_parse_model_string("invalid-provider:model-name")
for provider in _SUPPORTED_PROVIDERS:
with pytest.raises(ValueError, match=f"{provider}"):
_parse_model_string("invalid-provider:model-name")
def test_infer_model_and_provider() -> None:
"""Test model and provider inference from different input formats."""
assert _infer_model_and_provider("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _infer_model_and_provider(
model="text-embedding-3-small",
provider="openai",
) == ("openai", "text-embedding-3-small")
assert _infer_model_and_provider(
model="ft:text-embedding-3-small",
provider="openai",
) == ("openai", "ft:text-embedding-3-small")
assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == (
"openai",
"ft:text-embedding-3-small",
)
def test_infer_model_and_provider_errors() -> None:
"""Test error cases for model and provider inference."""
# Test missing provider
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("text-embedding-3-small")
# Test empty model
with pytest.raises(ValueError, match="Model name cannot be empty"):
_infer_model_and_provider("")
# Test empty provider with model
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("model", provider="")
# Test invalid provider
with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc:
_infer_model_and_provider("model", provider="invalid")
# Test provider list is in error
for provider in _SUPPORTED_PROVIDERS:
assert provider in str(exc.value)
@pytest.mark.parametrize(
"provider",
sorted(_SUPPORTED_PROVIDERS.keys()),
)
def test_supported_providers_package_names(provider: str) -> None:
"""Test that all supported providers have valid package names."""
package = _SUPPORTED_PROVIDERS[provider]
assert "-" not in package
assert package.startswith("langchain_")
assert package.islower()
def test_is_sorted() -> None:
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())