add embeddings integration tests (#25508)

This commit is contained in:
Isaac Francisco 2024-08-16 13:20:37 -07:00 committed by GitHub
parent a06818a654
commit a2e90a5a43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 95 additions and 15 deletions

View File

@ -1,20 +1,17 @@
"""Test Ollama embeddings.""" """Test Ollama embeddings."""
from typing import Type
from langchain_standard_tests.integration_tests import EmbeddingsIntegrationTests
from langchain_ollama.embeddings import OllamaEmbeddings from langchain_ollama.embeddings import OllamaEmbeddings
def test_langchain_ollama_embedding_documents() -> None: class TestOllamaEmbeddings(EmbeddingsIntegrationTests):
"""Test cohere embeddings.""" @property
documents = ["foo bar"] def embeddings_class(self) -> Type[OllamaEmbeddings]:
embedding = OllamaEmbeddings(model="llama3") return OllamaEmbeddings
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) > 0
@property
def test_langchain_ollama_embedding_query() -> None: def embedding_model_params(self) -> dict:
"""Test cohere embeddings.""" return {"model": "llama3:latest"}
document = "foo bar"
embedding = OllamaEmbeddings(model="llama3")
output = embedding.embed_query(document)
assert len(output) > 0

View File

@ -9,6 +9,7 @@ modules = [
"cache", "cache",
"chat_models", "chat_models",
"vectorstores", "vectorstores",
"embeddings",
] ]
for module in modules: for module in modules:
@ -19,7 +20,11 @@ for module in modules:
from langchain_standard_tests.integration_tests.chat_models import ( from langchain_standard_tests.integration_tests.chat_models import (
ChatModelIntegrationTests, ChatModelIntegrationTests,
) )
from langchain_standard_tests.integration_tests.embeddings import (
EmbeddingsIntegrationTests,
)
__all__ = [ __all__ = [
"ChatModelIntegrationTests", "ChatModelIntegrationTests",
"EmbeddingsIntegrationTests",
] ]

View File

@ -0,0 +1,49 @@
from typing import List
from langchain_core.embeddings import Embeddings
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsTests
class EmbeddingsIntegrationTests(EmbeddingsTests):
def test_embed_query(self, model: Embeddings) -> None:
embedding_1 = model.embed_query("foo")
assert isinstance(embedding_1, List)
assert isinstance(embedding_1[0], float)
embedding_2 = model.embed_query("bar")
assert len(embedding_1) > 0
assert len(embedding_1) == len(embedding_2)
def test_embed_documents(self, model: Embeddings) -> None:
documents = ["foo", "bar", "baz"]
embeddings = model.embed_documents(documents)
assert len(embeddings) == len(documents)
assert all(isinstance(embedding, List) for embedding in embeddings)
assert all(isinstance(embedding[0], float) for embedding in embeddings)
assert len(embeddings[0]) > 0
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
async def test_aembed_query(self, model: Embeddings) -> None:
embedding_1 = await model.aembed_query("foo")
assert isinstance(embedding_1, List)
assert isinstance(embedding_1[0], float)
embedding_2 = await model.aembed_query("bar")
assert len(embedding_1) > 0
assert len(embedding_1) == len(embedding_2)
async def test_aembed_documents(self, model: Embeddings) -> None:
documents = ["foo", "bar", "baz"]
embeddings = await model.aembed_documents(documents)
assert len(embeddings) == len(documents)
assert all(isinstance(embedding, List) for embedding in embeddings)
assert all(isinstance(embedding[0], float) for embedding in embeddings)
assert len(embeddings[0]) > 0
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)

View File

@ -6,6 +6,7 @@ import pytest
# https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting # https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting
modules = [ modules = [
"chat_models", "chat_models",
"embeddings",
] ]
for module in modules: for module in modules:
@ -13,4 +14,4 @@ for module in modules:
from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests
__all__ = ["ChatModelUnitTests"] __all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests"]

View File

@ -0,0 +1,28 @@
from abc import abstractmethod
from typing import Type
import pytest
from langchain_core.embeddings import Embeddings
from langchain_standard_tests.base import BaseStandardTests
class EmbeddingsTests(BaseStandardTests):
@property
@abstractmethod
def embeddings_class(self) -> Type[Embeddings]:
...
@property
def embedding_model_params(self) -> dict:
return {}
@pytest.fixture
def model(self) -> Embeddings:
return self.embeddings_class(**self.embedding_model_params)
class EmbeddingsUnitTests(EmbeddingsTests):
def test_init(self) -> None:
model = self.embeddings_class(**self.embedding_model_params)
assert model is not None