feat: FastEmbed embedding provider (#13109)

## Description:
This PR intends to add
[Qdrant/FastEmbed](https://qdrant.github.io/fastembed/) as a local
embeddings provider, associated tests and documentation.

**Documentation preview:**
https://langchain-git-fork-anush008-master-langchain.vercel.app/docs/integrations/text_embedding/fastembed

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Anush 2023-11-11 00:21:52 +05:30 committed by GitHub
parent b0e8cbe0b3
commit 52f34de9b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 341 additions and 0 deletions

View File

@ -0,0 +1,154 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Qdrant FastEmbed\n",
"\n",
"[FastEmbed](https://qdrant.github.io/fastembed/) is a lightweight, fast, Python library built for embedding generation. \n",
"\n",
"- Quantized model weights\n",
"- ONNX Runtime, no PyTorch dependency\n",
"- CPU-first design\n",
"- Data-parallelism for encoding of large datasets."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2a773d8d",
"metadata": {},
"source": [
"## Dependencies\n",
"\n",
"To use FastEmbed with LangChain, install the `fastembed` Python package."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91ea14ce-831d-409a-a88f-30353acdabd1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%pip install fastembed"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "426f1156",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3f5dc9d7-65e3-4b5b-9086-3327d016cfe0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.embeddings.fastembed import FastEmbedEmbeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiating FastEmbed\n",
" \n",
"### Parameters\n",
"- `model_name: str` (default: \"BAAI/bge-small-en-v1.5\")\n",
" > Name of the FastEmbedding model to use. You can find the list of supported models [here](https://qdrant.github.io/fastembed/examples/Supported_Models/).\n",
"\n",
"- `max_length: int` (default: 512)\n",
" > The maximum number of tokens. Unknown behavior for values > 512.\n",
"\n",
"- `cache_dir: Optional[str]`\n",
" > The path to the cache directory. Defaults to `local_cache` in the parent directory.\n",
"\n",
"- `threads: Optional[int]`\n",
" > The number of threads a single onnxruntime session can use. Defaults to None.\n",
"\n",
"- `doc_embed_type: Literal[\"default\", \"passage\"]` (default: \"default\")\n",
" > \"default\": Uses FastEmbed's default embedding method.\n",
" \n",
" > \"passage\": Prefixes the text with \"passage\" before embedding."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fb585dd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"embeddings = FastEmbedEmbeddings()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"### Generating document embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"document_embeddings = embeddings.embed_documents([\"This is a document\", \"This is some other document\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generating query embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query_embeddings = embeddings.embed_query(\"This is a query\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -32,6 +32,7 @@ from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
from langchain.embeddings.embaas import EmbaasEmbeddings from langchain.embeddings.embaas import EmbaasEmbeddings
from langchain.embeddings.ernie import ErnieEmbeddings from langchain.embeddings.ernie import ErnieEmbeddings
from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
from langchain.embeddings.fastembed import FastEmbedEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.embeddings.gpt4all import GPT4AllEmbeddings from langchain.embeddings.gpt4all import GPT4AllEmbeddings
from langchain.embeddings.gradient_ai import GradientEmbeddings from langchain.embeddings.gradient_ai import GradientEmbeddings
@ -77,6 +78,7 @@ __all__ = [
"ClarifaiEmbeddings", "ClarifaiEmbeddings",
"CohereEmbeddings", "CohereEmbeddings",
"ElasticsearchEmbeddings", "ElasticsearchEmbeddings",
"FastEmbedEmbeddings",
"HuggingFaceEmbeddings", "HuggingFaceEmbeddings",
"HuggingFaceInferenceAPIEmbeddings", "HuggingFaceInferenceAPIEmbeddings",
"GradientEmbeddings", "GradientEmbeddings",

View File

@ -0,0 +1,108 @@
from typing import Any, Dict, List, Literal, Optional
import numpy as np
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema.embeddings import Embeddings
class FastEmbedEmbeddings(BaseModel, Embeddings):
"""Qdrant FastEmbedding models.
FastEmbed is a lightweight, fast, Python library built for embedding generation.
See more documentation at:
* https://github.com/qdrant/fastembed/
* https://qdrant.github.io/fastembed/
To use this class, you must install the `fastembed` Python package.
`pip install fastembed`
Example:
from langchain.embeddings import FastEmbedEmbeddings
fastembed = FastEmbedEmbeddings()
"""
model_name: str = "BAAI/bge-small-en-v1.5"
"""Name of the FastEmbedding model to use
Defaults to "BAAI/bge-small-en-v1.5"
Find the list of supported models at
https://qdrant.github.io/fastembed/examples/Supported_Models/
"""
max_length: int = 512
"""The maximum number of tokens. Defaults to 512.
Unknown behavior for values > 512.
"""
cache_dir: Optional[str]
"""The path to the cache directory.
Defaults to `local_cache` in the parent directory
"""
threads: Optional[int]
"""The number of threads single onnxruntime session can use.
Defaults to None
"""
doc_embed_type: Literal["default", "passage"] = "default"
"""Type of embedding to use for documents
"default": Uses FastEmbed's default embedding method
"passage": Prefixes the text with "passage" before embedding.
"""
_model: Any # : :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that FastEmbed has been installed."""
try:
from fastembed.embedding import FlagEmbedding
model_name = values.get("model_name")
max_length = values.get("max_length")
cache_dir = values.get("cache_dir")
threads = values.get("threads")
values["_model"] = FlagEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
)
except ImportError as ie:
raise ImportError(
"Could not import 'fastembed' Python package. "
"Please install it with `pip install fastembed`."
) from ie
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for documents using FastEmbed.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings: List[np.ndarray]
if self.doc_embed_type == "passage":
embeddings = self._model.passage_embed(texts)
else:
embeddings = self._model.embed(texts)
return [e.tolist() for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Generate query embeddings using FastEmbed.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
query_embeddings: np.ndarray = next(self._model.query_embed(text))
return query_embeddings.tolist()

View File

@ -0,0 +1,76 @@
"""Test FastEmbed embeddings."""
import pytest
from langchain.embeddings.fastembed import FastEmbedEmbeddings
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
@pytest.mark.parametrize("threads", [0, 10])
def test_fastembed_embedding_documents(
model_name: str, max_length: int, doc_embed_type: str, threads: int
) -> None:
"""Test fastembed embeddings for documents."""
documents = ["foo bar", "bar foo"]
embedding = FastEmbedEmbeddings(
model_name=model_name,
max_length=max_length,
doc_embed_type=doc_embed_type,
threads=threads,
)
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
def test_fastembed_embedding_query(model_name: str, max_length: int) -> None:
"""Test fastembed embeddings for query."""
document = "foo bar"
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length)
output = embedding.embed_query(document)
assert len(output) == 384
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
@pytest.mark.parametrize("doc_embed_type", ["default", "passage"])
@pytest.mark.parametrize("threads", [0, 10])
async def test_fastembed_async_embedding_documents(
model_name: str, max_length: int, doc_embed_type: str, threads: int
) -> None:
"""Test fastembed embeddings for documents."""
documents = ["foo bar", "bar foo"]
embedding = FastEmbedEmbeddings(
model_name=model_name,
max_length=max_length,
doc_embed_type=doc_embed_type,
threads=threads,
)
output = await embedding.aembed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name", ["sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5"]
)
@pytest.mark.parametrize("max_length", [50, 512])
async def test_fastembed_async_embedding_query(
model_name: str, max_length: int
) -> None:
"""Test fastembed embeddings for query."""
document = "foo bar"
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length)
output = await embedding.aembed_query(document)
assert len(output) == 384

View File

@ -7,6 +7,7 @@ EXPECTED_ALL = [
"ClarifaiEmbeddings", "ClarifaiEmbeddings",
"CohereEmbeddings", "CohereEmbeddings",
"ElasticsearchEmbeddings", "ElasticsearchEmbeddings",
"FastEmbedEmbeddings",
"HuggingFaceEmbeddings", "HuggingFaceEmbeddings",
"HuggingFaceInferenceAPIEmbeddings", "HuggingFaceInferenceAPIEmbeddings",
"GradientEmbeddings", "GradientEmbeddings",