mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 12:25:37 +00:00
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
351 lines
11 KiB
Python
351 lines
11 KiB
Python
"""
|
|
TextEmbed: Embedding Inference Server
|
|
|
|
TextEmbed provides a high-throughput, low-latency solution for serving embeddings.
|
|
It supports various sentence-transformer models.
|
|
Now, it includes the ability to deploy image embedding models.
|
|
TextEmbed offers flexibility and scalability for diverse applications.
|
|
|
|
TextEmbed is maintained by Keval Dekivadiya and is licensed under the Apache-2.0 license.
|
|
""" # noqa: E501
|
|
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import aiohttp
|
|
import numpy as np
|
|
import requests
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.utils import from_env, secret_from_env
|
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
|
from typing_extensions import Self
|
|
|
|
__all__ = ["TextEmbedEmbeddings"]
|
|
|
|
|
|
class TextEmbedEmbeddings(BaseModel, Embeddings):
|
|
"""
|
|
A class to handle embedding requests to the TextEmbed API.
|
|
|
|
Attributes:
|
|
model : The TextEmbed model ID to use for embeddings.
|
|
api_url : The base URL for the TextEmbed API.
|
|
api_key : The API key for authenticating with the TextEmbed API.
|
|
client : The TextEmbed client instance.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.embeddings import TextEmbedEmbeddings
|
|
|
|
embeddings = TextEmbedEmbeddings(
|
|
model="sentence-transformers/clip-ViT-B-32",
|
|
api_url="http://localhost:8000/v1",
|
|
api_key="<API_KEY>"
|
|
)
|
|
|
|
For more information: https://github.com/kevaldekivadiya2415/textembed/blob/main/docs/setup.md
|
|
""" # noqa: E501
|
|
|
|
model: str
|
|
"""Underlying TextEmbed model id."""
|
|
|
|
api_url: str = Field(
|
|
default_factory=from_env(
|
|
"TEXTEMBED_API_URL", default="http://localhost:8000/v1"
|
|
)
|
|
)
|
|
"""Endpoint URL to use."""
|
|
|
|
api_key: SecretStr = Field(default_factory=secret_from_env("TEXTEMBED_API_KEY"))
|
|
"""API Key for authentication"""
|
|
|
|
client: Any = None
|
|
"""TextEmbed client."""
|
|
|
|
model_config = ConfigDict(
|
|
extra="forbid",
|
|
)
|
|
|
|
@model_validator(mode="after")
|
|
def validate_environment(self) -> Self:
|
|
"""Validate that api key and URL exist in the environment."""
|
|
self.client = AsyncOpenAITextEmbedEmbeddingClient(
|
|
host=self.api_url, api_key=self.api_key.get_secret_value()
|
|
)
|
|
return self
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Call out to TextEmbed's embedding endpoint.
|
|
|
|
Args:
|
|
texts (List[str]): The list of texts to embed.
|
|
|
|
Returns:
|
|
List[List[float]]: List of embeddings, one for each text.
|
|
"""
|
|
embeddings = self.client.embed(
|
|
model=self.model,
|
|
texts=texts,
|
|
)
|
|
return embeddings
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Async call out to TextEmbed's embedding endpoint.
|
|
|
|
Args:
|
|
texts (List[str]): The list of texts to embed.
|
|
|
|
Returns:
|
|
List[List[float]]: List of embeddings, one for each text.
|
|
"""
|
|
embeddings = await self.client.aembed(
|
|
model=self.model,
|
|
texts=texts,
|
|
)
|
|
return embeddings
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Call out to TextEmbed's embedding endpoint for a single query.
|
|
|
|
Args:
|
|
text (str): The text to embed.
|
|
|
|
Returns:
|
|
List[float]: Embeddings for the text.
|
|
"""
|
|
return self.embed_documents([text])[0]
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Async call out to TextEmbed's embedding endpoint for a single query.
|
|
|
|
Args:
|
|
text (str): The text to embed.
|
|
|
|
Returns:
|
|
List[float]: Embeddings for the text.
|
|
"""
|
|
embeddings = await self.aembed_documents([text])
|
|
return embeddings[0]
|
|
|
|
|
|
class AsyncOpenAITextEmbedEmbeddingClient:
|
|
"""
|
|
A client to handle synchronous and asynchronous requests to the TextEmbed API.
|
|
|
|
Attributes:
|
|
host (str): The base URL for the TextEmbed API.
|
|
api_key (str): The API key for authenticating with the TextEmbed API.
|
|
aiosession (Optional[aiohttp.ClientSession]): The aiohttp session for async requests.
|
|
_batch_size (int): Maximum batch size for a single request.
|
|
""" # noqa: E501
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = "http://localhost:8000/v1",
|
|
api_key: Union[str, None] = None,
|
|
aiosession: Optional[aiohttp.ClientSession] = None,
|
|
) -> None:
|
|
self.host = host
|
|
self.api_key = api_key
|
|
self.aiosession = aiosession
|
|
|
|
if self.host is None or len(self.host) < 3:
|
|
raise ValueError("Parameter `host` must be set to a valid URL")
|
|
self._batch_size = 256
|
|
|
|
@staticmethod
|
|
def _permute(
|
|
texts: List[str], sorter: Callable = len
|
|
) -> Tuple[List[str], Callable]:
|
|
"""
|
|
Sorts texts in ascending order and provides a function to restore the original order.
|
|
|
|
Args:
|
|
texts (List[str]): List of texts to sort.
|
|
sorter (Callable, optional): Sorting function, defaults to length.
|
|
|
|
Returns:
|
|
Tuple[List[str], Callable]: Sorted texts and a function to restore original order.
|
|
""" # noqa: E501
|
|
if len(texts) == 1:
|
|
return texts, lambda t: t
|
|
length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
|
|
texts_sorted = [texts[idx] for idx in length_sorted_idx]
|
|
|
|
return texts_sorted, lambda unsorted_embeddings: [
|
|
unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
|
|
]
|
|
|
|
def _batch(self, texts: List[str]) -> List[List[str]]:
|
|
"""
|
|
Splits a list of texts into batches of size max `self._batch_size`.
|
|
|
|
Args:
|
|
texts (List[str]): List of texts to split.
|
|
|
|
Returns:
|
|
List[List[str]]: List of batches of texts.
|
|
"""
|
|
if len(texts) == 1:
|
|
return [texts]
|
|
batches = []
|
|
for start_index in range(0, len(texts), self._batch_size):
|
|
batches.append(texts[start_index : start_index + self._batch_size])
|
|
return batches
|
|
|
|
@staticmethod
|
|
def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
|
|
"""
|
|
Merges batches of texts into a single list.
|
|
|
|
Args:
|
|
batch_of_texts (List[List[Any]]): List of batches of texts.
|
|
|
|
Returns:
|
|
List[Any]: Merged list of texts.
|
|
"""
|
|
if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
|
|
return batch_of_texts[0]
|
|
texts = []
|
|
for sublist in batch_of_texts:
|
|
texts.extend(sublist)
|
|
return texts
|
|
|
|
def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
|
|
"""
|
|
Builds the kwargs for the POST request, used by sync method.
|
|
|
|
Args:
|
|
model (str): The model to use for embedding.
|
|
texts (List[str]): List of texts to embed.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Dictionary of POST request parameters.
|
|
"""
|
|
return dict(
|
|
url=f"{self.host}/embedding",
|
|
headers={
|
|
"accept": "application/json",
|
|
"content-type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
},
|
|
json=dict(
|
|
input=texts,
|
|
model=model,
|
|
),
|
|
)
|
|
|
|
def _sync_request_embed(
|
|
self, model: str, batch_texts: List[str]
|
|
) -> List[List[float]]:
|
|
"""
|
|
Sends a synchronous request to the embedding endpoint.
|
|
|
|
Args:
|
|
model (str): The model to use for embedding.
|
|
batch_texts (List[str]): Batch of texts to embed.
|
|
|
|
Returns:
|
|
List[List[float]]: List of embeddings for the batch.
|
|
|
|
Raises:
|
|
Exception: If the response status is not 200.
|
|
"""
|
|
response = requests.post(
|
|
**self._kwargs_post_request(model=model, texts=batch_texts)
|
|
)
|
|
if response.status_code != 200:
|
|
raise Exception(
|
|
f"TextEmbed responded with an unexpected status message "
|
|
f"{response.status_code}: {response.text}"
|
|
)
|
|
return [e["embedding"] for e in response.json()["data"]]
|
|
|
|
def embed(self, model: str, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Embeds a list of texts synchronously.
|
|
|
|
Args:
|
|
model (str): The model to use for embedding.
|
|
texts (List[str]): List of texts to embed.
|
|
|
|
Returns:
|
|
List[List[float]]: List of embeddings for the texts.
|
|
"""
|
|
perm_texts, unpermute_func = self._permute(texts)
|
|
perm_texts_batched = self._batch(perm_texts)
|
|
|
|
# Request
|
|
map_args = (
|
|
self._sync_request_embed,
|
|
[model] * len(perm_texts_batched),
|
|
perm_texts_batched,
|
|
)
|
|
if len(perm_texts_batched) == 1:
|
|
embeddings_batch_perm = list(map(*map_args))
|
|
else:
|
|
with ThreadPoolExecutor(32) as p:
|
|
embeddings_batch_perm = list(p.map(*map_args))
|
|
|
|
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
|
embeddings = unpermute_func(embeddings_perm)
|
|
return embeddings
|
|
|
|
async def _async_request(
|
|
self, session: aiohttp.ClientSession, **kwargs: Dict[str, Any]
|
|
) -> List[List[float]]:
|
|
"""
|
|
Sends an asynchronous request to the embedding endpoint.
|
|
|
|
Args:
|
|
session (aiohttp.ClientSession): The aiohttp session for the request.
|
|
kwargs (Dict[str, Any]): Dictionary of POST request parameters.
|
|
|
|
Returns:
|
|
List[List[float]]: List of embeddings for the request.
|
|
|
|
Raises:
|
|
Exception: If the response status is not 200.
|
|
"""
|
|
async with session.post(**kwargs) as response: # type: ignore
|
|
if response.status != 200:
|
|
raise Exception(
|
|
f"TextEmbed responded with an unexpected status message "
|
|
f"{response.status}: {response.text}"
|
|
)
|
|
embedding = (await response.json())["data"]
|
|
return [e["embedding"] for e in embedding]
|
|
|
|
async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Embeds a list of texts asynchronously.
|
|
|
|
Args:
|
|
model (str): The model to use for embedding.
|
|
texts (List[str]): List of texts to embed.
|
|
|
|
Returns:
|
|
List[List[float]]: List of embeddings for the texts.
|
|
"""
|
|
perm_texts, unpermute_func = self._permute(texts)
|
|
perm_texts_batched = self._batch(perm_texts)
|
|
|
|
async with aiohttp.ClientSession(
|
|
connector=aiohttp.TCPConnector(limit=32)
|
|
) as session:
|
|
embeddings_batch_perm = await asyncio.gather(
|
|
*[
|
|
self._async_request(
|
|
session=session,
|
|
**self._kwargs_post_request(model=model, texts=t),
|
|
)
|
|
for t in perm_texts_batched
|
|
]
|
|
)
|
|
|
|
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
|
embeddings = unpermute_func(embeddings_perm)
|
|
return embeddings
|