diff --git a/libs/partners/ollama/tests/integration_tests/test_embeddings.py b/libs/partners/ollama/tests/integration_tests/test_embeddings.py index e0310bf3b0f..26722362e6b 100644 --- a/libs/partners/ollama/tests/integration_tests/test_embeddings.py +++ b/libs/partners/ollama/tests/integration_tests/test_embeddings.py @@ -1,20 +1,17 @@ """Test Ollama embeddings.""" +from typing import Type + +from langchain_standard_tests.integration_tests import EmbeddingsIntegrationTests + from langchain_ollama.embeddings import OllamaEmbeddings -def test_langchain_ollama_embedding_documents() -> None: - """Test cohere embeddings.""" - documents = ["foo bar"] - embedding = OllamaEmbeddings(model="llama3") - output = embedding.embed_documents(documents) - assert len(output) == 1 - assert len(output[0]) > 0 +class TestOllamaEmbeddings(EmbeddingsIntegrationTests): + @property + def embeddings_class(self) -> Type[OllamaEmbeddings]: + return OllamaEmbeddings - -def test_langchain_ollama_embedding_query() -> None: - """Test cohere embeddings.""" - document = "foo bar" - embedding = OllamaEmbeddings(model="llama3") - output = embedding.embed_query(document) - assert len(output) > 0 + @property + def embedding_model_params(self) -> dict: + return {"model": "llama3:latest"} diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py b/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py index f304bff2381..ef9b6b05acf 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/__init__.py @@ -9,6 +9,7 @@ modules = [ "cache", "chat_models", "vectorstores", + "embeddings", ] for module in modules: @@ -19,7 +20,11 @@ for module in modules: from langchain_standard_tests.integration_tests.chat_models import ( ChatModelIntegrationTests, ) +from langchain_standard_tests.integration_tests.embeddings import ( + EmbeddingsIntegrationTests, +) __all__ = [ "ChatModelIntegrationTests", + "EmbeddingsIntegrationTests", ] diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/embeddings.py b/libs/standard-tests/langchain_standard_tests/integration_tests/embeddings.py new file mode 100644 index 00000000000..1de2aece780 --- /dev/null +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/embeddings.py @@ -0,0 +1,49 @@ +from typing import List + +from langchain_core.embeddings import Embeddings + +from langchain_standard_tests.unit_tests.embeddings import EmbeddingsTests + + +class EmbeddingsIntegrationTests(EmbeddingsTests): + def test_embed_query(self, model: Embeddings) -> None: + embedding_1 = model.embed_query("foo") + + assert isinstance(embedding_1, List) + assert isinstance(embedding_1[0], float) + + embedding_2 = model.embed_query("bar") + + assert len(embedding_1) > 0 + assert len(embedding_1) == len(embedding_2) + + def test_embed_documents(self, model: Embeddings) -> None: + documents = ["foo", "bar", "baz"] + embeddings = model.embed_documents(documents) + + assert len(embeddings) == len(documents) + assert all(isinstance(embedding, List) for embedding in embeddings) + assert all(isinstance(embedding[0], float) for embedding in embeddings) + assert len(embeddings[0]) > 0 + assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings) + + async def test_aembed_query(self, model: Embeddings) -> None: + embedding_1 = await model.aembed_query("foo") + + assert isinstance(embedding_1, List) + assert isinstance(embedding_1[0], float) + + embedding_2 = await model.aembed_query("bar") + + assert len(embedding_1) > 0 + assert len(embedding_1) == len(embedding_2) + + async def test_aembed_documents(self, model: Embeddings) -> None: + documents = ["foo", "bar", "baz"] + embeddings = await model.aembed_documents(documents) + + assert len(embeddings) == len(documents) + assert all(isinstance(embedding, List) for embedding in embeddings) + assert all(isinstance(embedding[0], float) for embedding in embeddings) + assert len(embeddings[0]) > 0 + assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings) diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py b/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py index 418330b5dc1..14715260ea5 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/__init__.py @@ -6,6 +6,7 @@ import pytest # https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting modules = [ "chat_models", + "embeddings", ] for module in modules: @@ -13,4 +14,4 @@ for module in modules: from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests -__all__ = ["ChatModelUnitTests"] +__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests"] diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/embeddings.py b/libs/standard-tests/langchain_standard_tests/unit_tests/embeddings.py new file mode 100644 index 00000000000..28e8c66bbd6 --- /dev/null +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/embeddings.py @@ -0,0 +1,28 @@ +from abc import abstractmethod +from typing import Type + +import pytest +from langchain_core.embeddings import Embeddings + +from langchain_standard_tests.base import BaseStandardTests + + +class EmbeddingsTests(BaseStandardTests): + @property + @abstractmethod + def embeddings_class(self) -> Type[Embeddings]: + ... + + @property + def embedding_model_params(self) -> dict: + return {} + + @pytest.fixture + def model(self) -> Embeddings: + return self.embeddings_class(**self.embedding_model_params) + + +class EmbeddingsUnitTests(EmbeddingsTests): + def test_init(self) -> None: + model = self.embeddings_class(**self.embedding_model_params) + assert model is not None