diff --git a/docs/docs/integrations/text_embedding/textembed.ipynb b/docs/docs/integrations/text_embedding/textembed.ipynb new file mode 100644 index 00000000000..af5eea42e2f --- /dev/null +++ b/docs/docs/integrations/text_embedding/textembed.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TextEmbed - Embedding Inference Server\n", + "\n", + "TextEmbed is a high-throughput, low-latency REST API designed for serving vector embeddings. It supports a wide range of sentence-transformer models and frameworks, making it suitable for various applications in natural language processing.\n", + "\n", + "## Features\n", + "\n", + "- **High Throughput & Low Latency:** Designed to handle a large number of requests efficiently.\n", + "- **Flexible Model Support:** Works with various sentence-transformer models.\n", + "- **Scalable:** Easily integrates into larger systems and scales with demand.\n", + "- **Batch Processing:** Supports batch processing for better and faster inference.\n", + "- **OpenAI Compatible REST API Endpoint:** Provides an OpenAI compatible REST API endpoint.\n", + "- **Single Line Command Deployment:** Deploy multiple models via a single command for efficient deployment.\n", + "- **Support for Embedding Formats:** Supports binary, float16, and float32 embeddings formats for faster retrieval.\n", + "\n", + "## Getting Started\n", + "\n", + "### Prerequisites\n", + "\n", + "Ensure you have Python 3.10 or higher installed. You will also need to install the required dependencies." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation via PyPI\n", + "\n", + "1. **Install the required dependencies:**\n", + "\n", + " ```bash\n", + " pip install -U textembed\n", + " ```\n", + "\n", + "2. **Start the TextEmbed server with your desired models:**\n", + "\n", + " ```bash\n", + " python -m textembed.server --models sentence-transformers/all-MiniLM-L12-v2 --workers 4 --api-key TextEmbed \n", + " ```\n", + "\n", + "For more information, please read the [documentation](https://github.com/kevaldekivadiya2415/textembed/blob/main/docs/setup.md)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.embeddings import TextEmbedEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = TextEmbedEmbeddings(\n", + " model=\"sentence-transformers/all-MiniLM-L12-v2\",\n", + " api_url=\"http://0.0.0.0:8000/v1\",\n", + " api_key=\"TextEmbed\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Embed your documents" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# Define a list of documents\n", + "documents = [\n", + " \"Data science involves extracting insights from data.\",\n", + " \"Artificial intelligence is transforming various industries.\",\n", + " \"Cloud computing provides scalable computing resources over the internet.\",\n", + " \"Big data analytics helps in understanding large datasets.\",\n", + " \"India has a diverse cultural heritage.\",\n", + "]\n", + "\n", + "# Define a query\n", + "query = \"What is the cultural heritage of India?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# Embed all documents\n", + "document_embeddings = embeddings.embed_documents(documents)\n", + "\n", + "# Embed the query\n", + "query_embedding = embeddings.embed_query(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Data science involves extracting insights from data.': 0.05121298956322118,\n", + " 'Artificial intelligence is transforming various industries.': -0.0060612142358469345,\n", + " 'Cloud computing provides scalable computing resources over the internet.': -0.04877402795301714,\n", + " 'Big data analytics helps in understanding large datasets.': 0.016582168576929422,\n", + " 'India has a diverse cultural heritage.': 0.7408992963028144}" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Compute Similarity\n", + "import numpy as np\n", + "\n", + "scores = np.array(document_embeddings) @ np.array(query_embedding).T\n", + "dict(zip(documents, scores))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "check10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 0b4c9c35639..845d437e1ec 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -213,6 +213,9 @@ if TYPE_CHECKING: from langchain_community.embeddings.tensorflow_hub import ( TensorflowHubEmbeddings, ) + from langchain_community.embeddings.textembed import ( + TextEmbedEmbeddings, + ) from langchain_community.embeddings.titan_takeoff import ( TitanTakeoffEmbed, ) @@ -308,6 +311,7 @@ __all__ = [ "SpacyEmbeddings", "SparkLLMTextEmbeddings", "TensorflowHubEmbeddings", + "TextEmbedEmbeddings", "TitanTakeoffEmbed", "VertexAIEmbeddings", "VolcanoEmbeddings", @@ -392,6 +396,7 @@ _module_lookup = { "VolcanoEmbeddings": "langchain_community.embeddings.volcengine", "VoyageEmbeddings": "langchain_community.embeddings.voyageai", "XinferenceEmbeddings": "langchain_community.embeddings.xinference", + "TextEmbedEmbeddings": "langchain_community.embeddings.textembed", "TitanTakeoffEmbed": "langchain_community.embeddings.titan_takeoff", "PremAIEmbeddings": "langchain_community.embeddings.premai", "YandexGPTEmbeddings": "langchain_community.embeddings.yandex", diff --git a/libs/community/langchain_community/embeddings/textembed.py b/libs/community/langchain_community/embeddings/textembed.py new file mode 100644 index 00000000000..f83d172b91a --- /dev/null +++ b/libs/community/langchain_community/embeddings/textembed.py @@ -0,0 +1,356 @@ +""" +TextEmbed: Embedding Inference Server + +TextEmbed provides a high-throughput, low-latency solution for serving embeddings. +It supports various sentence-transformer models. +Now, it includes the ability to deploy image embedding models. +TextEmbed offers flexibility and scalability for diverse applications. + +TextEmbed is maintained by Keval Dekivadiya and is licensed under the Apache-2.0 license. +""" # noqa: E501 + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.utils import get_from_dict_or_env + +__all__ = ["TextEmbedEmbeddings"] + + +class TextEmbedEmbeddings(BaseModel, Embeddings): + """ + A class to handle embedding requests to the TextEmbed API. + + Attributes: + model : The TextEmbed model ID to use for embeddings. + api_url : The base URL for the TextEmbed API. + api_key : The API key for authenticating with the TextEmbed API. + client : The TextEmbed client instance. + + Example: + .. code-block:: python + + from langchain_community.embeddings import TextEmbedEmbeddings + + embeddings = TextEmbedEmbeddings( + model="sentence-transformers/clip-ViT-B-32", + api_url="http://localhost:8000/v1", + api_key="" + ) + + For more information: https://github.com/kevaldekivadiya2415/textembed/blob/main/docs/setup.md + """ # noqa: E501 + + model: str + """Underlying TextEmbed model id.""" + + api_url: str = "http://localhost:8000/v1" + """Endpoint URL to use.""" + + api_key: str = "None" + """API Key for authentication""" + + client: Any = None + """TextEmbed client.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator(pre=False, skip_on_failure=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and URL exist in the environment. + + Args: + values (Dict): Dictionary of values to validate. + + Returns: + Dict: Validated values. + """ + values["api_url"] = get_from_dict_or_env(values, "api_url", "API_URL") + values["api_key"] = get_from_dict_or_env(values, "api_key", "API_KEY") + + values["client"] = AsyncOpenAITextEmbedEmbeddingClient( + host=values["api_url"], api_key=values["api_key"] + ) + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to TextEmbed's embedding endpoint. + + Args: + texts (List[str]): The list of texts to embed. + + Returns: + List[List[float]]: List of embeddings, one for each text. + """ + embeddings = self.client.embed( + model=self.model, + texts=texts, + ) + return embeddings + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Async call out to TextEmbed's embedding endpoint. + + Args: + texts (List[str]): The list of texts to embed. + + Returns: + List[List[float]]: List of embeddings, one for each text. + """ + embeddings = await self.client.aembed( + model=self.model, + texts=texts, + ) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Call out to TextEmbed's embedding endpoint for a single query. + + Args: + text (str): The text to embed. + + Returns: + List[float]: Embeddings for the text. + """ + return self.embed_documents([text])[0] + + async def aembed_query(self, text: str) -> List[float]: + """Async call out to TextEmbed's embedding endpoint for a single query. + + Args: + text (str): The text to embed. + + Returns: + List[float]: Embeddings for the text. + """ + embeddings = await self.aembed_documents([text]) + return embeddings[0] + + +class AsyncOpenAITextEmbedEmbeddingClient: + """ + A client to handle synchronous and asynchronous requests to the TextEmbed API. + + Attributes: + host (str): The base URL for the TextEmbed API. + api_key (str): The API key for authenticating with the TextEmbed API. + aiosession (Optional[aiohttp.ClientSession]): The aiohttp session for async requests. + _batch_size (int): Maximum batch size for a single request. + """ # noqa: E501 + + def __init__( + self, + host: str = "http://localhost:8000/v1", + api_key: Union[str, None] = None, + aiosession: Optional[aiohttp.ClientSession] = None, + ) -> None: + self.host = host + self.api_key = api_key + self.aiosession = aiosession + + if self.host is None or len(self.host) < 3: + raise ValueError("Parameter `host` must be set to a valid URL") + self._batch_size = 256 + + @staticmethod + def _permute( + texts: List[str], sorter: Callable = len + ) -> Tuple[List[str], Callable]: + """ + Sorts texts in ascending order and provides a function to restore the original order. + + Args: + texts (List[str]): List of texts to sort. + sorter (Callable, optional): Sorting function, defaults to length. + + Returns: + Tuple[List[str], Callable]: Sorted texts and a function to restore original order. + """ # noqa: E501 + if len(texts) == 1: + return texts, lambda t: t + length_sorted_idx = np.argsort([-sorter(sen) for sen in texts]) + texts_sorted = [texts[idx] for idx in length_sorted_idx] + + return texts_sorted, lambda unsorted_embeddings: [ + unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx) + ] + + def _batch(self, texts: List[str]) -> List[List[str]]: + """ + Splits a list of texts into batches of size max `self._batch_size`. + + Args: + texts (List[str]): List of texts to split. + + Returns: + List[List[str]]: List of batches of texts. + """ + if len(texts) == 1: + return [texts] + batches = [] + for start_index in range(0, len(texts), self._batch_size): + batches.append(texts[start_index : start_index + self._batch_size]) + return batches + + @staticmethod + def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]: + """ + Merges batches of texts into a single list. + + Args: + batch_of_texts (List[List[Any]]): List of batches of texts. + + Returns: + List[Any]: Merged list of texts. + """ + if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1: + return batch_of_texts[0] + texts = [] + for sublist in batch_of_texts: + texts.extend(sublist) + return texts + + def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]: + """ + Builds the kwargs for the POST request, used by sync method. + + Args: + model (str): The model to use for embedding. + texts (List[str]): List of texts to embed. + + Returns: + Dict[str, Any]: Dictionary of POST request parameters. + """ + return dict( + url=f"{self.host}/embedding", + headers={ + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"Bearer {self.api_key}", + }, + json=dict( + input=texts, + model=model, + ), + ) + + def _sync_request_embed( + self, model: str, batch_texts: List[str] + ) -> List[List[float]]: + """ + Sends a synchronous request to the embedding endpoint. + + Args: + model (str): The model to use for embedding. + batch_texts (List[str]): Batch of texts to embed. + + Returns: + List[List[float]]: List of embeddings for the batch. + + Raises: + Exception: If the response status is not 200. + """ + response = requests.post( + **self._kwargs_post_request(model=model, texts=batch_texts) + ) + if response.status_code != 200: + raise Exception( + f"TextEmbed responded with an unexpected status message " + f"{response.status_code}: {response.text}" + ) + return [e["embedding"] for e in response.json()["data"]] + + def embed(self, model: str, texts: List[str]) -> List[List[float]]: + """ + Embeds a list of texts synchronously. + + Args: + model (str): The model to use for embedding. + texts (List[str]): List of texts to embed. + + Returns: + List[List[float]]: List of embeddings for the texts. + """ + perm_texts, unpermute_func = self._permute(texts) + perm_texts_batched = self._batch(perm_texts) + + # Request + map_args = ( + self._sync_request_embed, + [model] * len(perm_texts_batched), + perm_texts_batched, + ) + if len(perm_texts_batched) == 1: + embeddings_batch_perm = list(map(*map_args)) + else: + with ThreadPoolExecutor(32) as p: + embeddings_batch_perm = list(p.map(*map_args)) + + embeddings_perm = self._unbatch(embeddings_batch_perm) + embeddings = unpermute_func(embeddings_perm) + return embeddings + + async def _async_request( + self, session: aiohttp.ClientSession, **kwargs: Dict[str, Any] + ) -> List[List[float]]: + """ + Sends an asynchronous request to the embedding endpoint. + + Args: + session (aiohttp.ClientSession): The aiohttp session for the request. + kwargs (Dict[str, Any]): Dictionary of POST request parameters. + + Returns: + List[List[float]]: List of embeddings for the request. + + Raises: + Exception: If the response status is not 200. + """ + async with session.post(**kwargs) as response: # type: ignore + if response.status != 200: + raise Exception( + f"TextEmbed responded with an unexpected status message " + f"{response.status}: {response.text}" + ) + embedding = (await response.json())["data"] + return [e["embedding"] for e in embedding] + + async def aembed(self, model: str, texts: List[str]) -> List[List[float]]: + """ + Embeds a list of texts asynchronously. + + Args: + model (str): The model to use for embedding. + texts (List[str]): List of texts to embed. + + Returns: + List[List[float]]: List of embeddings for the texts. + """ + perm_texts, unpermute_func = self._permute(texts) + perm_texts_batched = self._batch(perm_texts) + + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=32) + ) as session: + embeddings_batch_perm = await asyncio.gather( + *[ + self._async_request( + session=session, + **self._kwargs_post_request(model=model, texts=t), + ) + for t in perm_texts_batched + ] + ) + + embeddings_perm = self._unbatch(embeddings_batch_perm) + embeddings = unpermute_func(embeddings_perm) + return embeddings diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 7cbc1f6a331..6298b43464a 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -80,6 +80,7 @@ EXPECTED_ALL = [ "SolarEmbeddings", "AscendEmbeddings", "ZhipuAIEmbeddings", + "TextEmbedEmbeddings", ]