langchain/libs/community/tests/integration_tests/embeddings/test_fastembed.py
Anush 51b15448cc
community: Fix FastEmbedEmbeddings (#24462)
## Description

This PR:
- Fixes the validation error in `FastEmbedEmbeddings`.
- Adds support for `batch_size`, `parallel` params.
- Removes support for very old FastEmbed versions.
- Updates the FastEmbed doc with the new params.

Associated Issues:
- Resolves #24039
- Resolves #https://github.com/qdrant/fastembed/issues/296
2024-07-30 12:42:46 -04:00

83 lines
2.9 KiB
Python

"""Test FastEmbed embeddings."""
import pytest
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
@pytest.mark.parametrize("threads", [0, 10])
@pytest.mark.parametrize("batch_size", [1, 10])
def test_fastembed_embedding_documents(
model_name: str, max_length: int, doc_embed_type: str, threads: int, batch_size: int
) -> None:
"""Test fastembed embeddings for documents."""
documents = ["foo bar", "bar foo"]
embedding = FastEmbedEmbeddings( # type: ignore[call-arg]
model_name=model_name,
max_length=max_length,
doc_embed_type=doc_embed_type, # type: ignore[arg-type]
threads=threads,
batch_size=batch_size,
)
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("batch_size", [1, 10])
def test_fastembed_embedding_query(
model_name: str, max_length: int, batch_size: int
) -> None:
"""Test fastembed embeddings for query."""
document = "foo bar"
embedding = FastEmbedEmbeddings(
model_name=model_name, max_length=max_length, batch_size=batch_size
) # type: ignore[call-arg]
output = embedding.embed_query(document)
assert len(output) == 384
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
@pytest.mark.parametrize("threads", [0, 10])
async def test_fastembed_async_embedding_documents(
model_name: str, max_length: int, doc_embed_type: str, threads: int
) -> None:
"""Test fastembed embeddings for documents."""
documents = ["foo bar", "bar foo"]
embedding = FastEmbedEmbeddings( # type: ignore[call-arg]
model_name=model_name,
max_length=max_length,
doc_embed_type=doc_embed_type, # type: ignore[arg-type]
threads=threads,
)
output = await embedding.aembed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
async def test_fastembed_async_embedding_query(
model_name: str, max_length: int
) -> None:
"""Test fastembed embeddings for query."""
document = "foo bar"
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
output = await embedding.aembed_query(document)
assert len(output) == 384