mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
add embeddings integration tests (#25508)
This commit is contained in:
parent
a06818a654
commit
a2e90a5a43
@ -1,20 +1,17 @@
|
|||||||
"""Test Ollama embeddings."""
|
"""Test Ollama embeddings."""
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain_standard_tests.integration_tests import EmbeddingsIntegrationTests
|
||||||
|
|
||||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||||
|
|
||||||
|
|
||||||
def test_langchain_ollama_embedding_documents() -> None:
|
class TestOllamaEmbeddings(EmbeddingsIntegrationTests):
|
||||||
"""Test cohere embeddings."""
|
@property
|
||||||
documents = ["foo bar"]
|
def embeddings_class(self) -> Type[OllamaEmbeddings]:
|
||||||
embedding = OllamaEmbeddings(model="llama3")
|
return OllamaEmbeddings
|
||||||
output = embedding.embed_documents(documents)
|
|
||||||
assert len(output) == 1
|
|
||||||
assert len(output[0]) > 0
|
|
||||||
|
|
||||||
|
@property
|
||||||
def test_langchain_ollama_embedding_query() -> None:
|
def embedding_model_params(self) -> dict:
|
||||||
"""Test cohere embeddings."""
|
return {"model": "llama3:latest"}
|
||||||
document = "foo bar"
|
|
||||||
embedding = OllamaEmbeddings(model="llama3")
|
|
||||||
output = embedding.embed_query(document)
|
|
||||||
assert len(output) > 0
|
|
||||||
|
@ -9,6 +9,7 @@ modules = [
|
|||||||
"cache",
|
"cache",
|
||||||
"chat_models",
|
"chat_models",
|
||||||
"vectorstores",
|
"vectorstores",
|
||||||
|
"embeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
@ -19,7 +20,11 @@ for module in modules:
|
|||||||
from langchain_standard_tests.integration_tests.chat_models import (
|
from langchain_standard_tests.integration_tests.chat_models import (
|
||||||
ChatModelIntegrationTests,
|
ChatModelIntegrationTests,
|
||||||
)
|
)
|
||||||
|
from langchain_standard_tests.integration_tests.embeddings import (
|
||||||
|
EmbeddingsIntegrationTests,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatModelIntegrationTests",
|
"ChatModelIntegrationTests",
|
||||||
|
"EmbeddingsIntegrationTests",
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,49 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsTests
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsIntegrationTests(EmbeddingsTests):
|
||||||
|
def test_embed_query(self, model: Embeddings) -> None:
|
||||||
|
embedding_1 = model.embed_query("foo")
|
||||||
|
|
||||||
|
assert isinstance(embedding_1, List)
|
||||||
|
assert isinstance(embedding_1[0], float)
|
||||||
|
|
||||||
|
embedding_2 = model.embed_query("bar")
|
||||||
|
|
||||||
|
assert len(embedding_1) > 0
|
||||||
|
assert len(embedding_1) == len(embedding_2)
|
||||||
|
|
||||||
|
def test_embed_documents(self, model: Embeddings) -> None:
|
||||||
|
documents = ["foo", "bar", "baz"]
|
||||||
|
embeddings = model.embed_documents(documents)
|
||||||
|
|
||||||
|
assert len(embeddings) == len(documents)
|
||||||
|
assert all(isinstance(embedding, List) for embedding in embeddings)
|
||||||
|
assert all(isinstance(embedding[0], float) for embedding in embeddings)
|
||||||
|
assert len(embeddings[0]) > 0
|
||||||
|
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
|
||||||
|
|
||||||
|
async def test_aembed_query(self, model: Embeddings) -> None:
|
||||||
|
embedding_1 = await model.aembed_query("foo")
|
||||||
|
|
||||||
|
assert isinstance(embedding_1, List)
|
||||||
|
assert isinstance(embedding_1[0], float)
|
||||||
|
|
||||||
|
embedding_2 = await model.aembed_query("bar")
|
||||||
|
|
||||||
|
assert len(embedding_1) > 0
|
||||||
|
assert len(embedding_1) == len(embedding_2)
|
||||||
|
|
||||||
|
async def test_aembed_documents(self, model: Embeddings) -> None:
|
||||||
|
documents = ["foo", "bar", "baz"]
|
||||||
|
embeddings = await model.aembed_documents(documents)
|
||||||
|
|
||||||
|
assert len(embeddings) == len(documents)
|
||||||
|
assert all(isinstance(embedding, List) for embedding in embeddings)
|
||||||
|
assert all(isinstance(embedding[0], float) for embedding in embeddings)
|
||||||
|
assert len(embeddings[0]) > 0
|
||||||
|
assert all(len(embedding) == len(embeddings[0]) for embedding in embeddings)
|
@ -6,6 +6,7 @@ import pytest
|
|||||||
# https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting
|
# https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting
|
||||||
modules = [
|
modules = [
|
||||||
"chat_models",
|
"chat_models",
|
||||||
|
"embeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
@ -13,4 +14,4 @@ for module in modules:
|
|||||||
|
|
||||||
from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests
|
from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests
|
||||||
|
|
||||||
__all__ = ["ChatModelUnitTests"]
|
__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests"]
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
from langchain_standard_tests.base import BaseStandardTests
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsTests(BaseStandardTests):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def embeddings_class(self) -> Type[Embeddings]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_model_params(self) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model(self) -> Embeddings:
|
||||||
|
return self.embeddings_class(**self.embedding_model_params)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsUnitTests(EmbeddingsTests):
|
||||||
|
def test_init(self) -> None:
|
||||||
|
model = self.embeddings_class(**self.embedding_model_params)
|
||||||
|
assert model is not None
|
Loading…
Reference in New Issue
Block a user