mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
langchain[minor]: Adding infinity
embedding integration. (#13928)
This adds integation to https://github.com/michaelfeil/infinity. Users requested it in https://github.com/michaelfeil/infinity/issues/36 @saatvikshah Follows my implementation of gradient.ai. Feedback 1: Well done - I love your CI / repo / poetry setup - I adapted a lot in https://github.com/michaelfeil/infinity. Feedback 2: Not so good: The openai integration contains to much reverse engineering - in general projects such as michaelfeil/infinity and huggingface/text-embeddings-inference are compatible to the `pip install openai` package. Reverse engineering like this one is really hindering the use for me:8e88ba16a8/libs/langchain/langchain/embeddings/openai.py (L347)
8e88ba16a8/libs/langchain/langchain/embeddings/openai.py (L351)
- it is about preventing 3rd party providers to use the same url + uses interfaces of openai, that are not publically documented.
This commit is contained in:
parent
10a6e7cbb6
commit
686162670e
11
docs/docs/integrations/providers/infinity.mdx
Normal file
11
docs/docs/integrations/providers/infinity.mdx
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Infinity
|
||||||
|
|
||||||
|
>[Infinity](https://github.com/michaelfeil/infinity) allows the creation of text embeddings.
|
||||||
|
|
||||||
|
## Text Embedding Model
|
||||||
|
|
||||||
|
There exists an infinity Embedding model, which you can access with
|
||||||
|
```python
|
||||||
|
from langchain.embeddings import InfinityEmbeddings
|
||||||
|
```
|
||||||
|
For a more detailed walkthrough of this, see [this notebook](/docs/integrations/text_embedding/infinity)
|
191
docs/docs/integrations/text_embedding/infinity.ipynb
Normal file
191
docs/docs/integrations/text_embedding/infinity.ipynb
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Infinity\n",
|
||||||
|
"\n",
|
||||||
|
"`Infinity` allows to create `Embeddings` using a MIT-licensed Embedding Server. \n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use Langchain with Embeddings with the [Infinity Github Project](https://github.com/michaelfeil/infinity).\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Imports"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings import InfinityEmbeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Optional: Make sure to start the Infinity instance\n",
|
||||||
|
"\n",
|
||||||
|
"To install infinity use the following command. For further details check out the [Docs on Github](https://github.com/michaelfeil/infinity).\n",
|
||||||
|
"```bash\n",
|
||||||
|
"pip install infinity_emb[all]\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Requirement already satisfied: infinity_emb[cli] in /home/michi/langchain/.venv/lib/python3.10/site-packages (0.0.8)\n",
|
||||||
|
"\u001b[33mWARNING: infinity-emb 0.0.8 does not provide the extra 'cli'\u001b[0m\u001b[33m\n",
|
||||||
|
"\u001b[0mRequirement already satisfied: numpy>=1.20.0 in /home/michi/langchain/.venv/lib/python3.10/site-packages (from infinity_emb[cli]) (1.24.4)\n",
|
||||||
|
"\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\n",
|
||||||
|
"\u001b[0m"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Install the infinity package\n",
|
||||||
|
"!pip install infinity_emb[cli,torch]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Start up the server - best to be done from a separate terminal, not inside Jupyter Notebook\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"model=sentence-transformers/all-MiniLM-L6-v2\n",
|
||||||
|
"port=7797\n",
|
||||||
|
"infinity_emb --port $port --model-name-or-path $model\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"or alternativley just use docker:\n",
|
||||||
|
"```bash\n",
|
||||||
|
"model=sentence-transformers/all-MiniLM-L6-v2\n",
|
||||||
|
"port=7797\n",
|
||||||
|
"docker run -it --gpus all -p $port:$port michaelf34/infinity:latest --model-name-or-path $model --port $port\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Embed your documents using your Infinity instance "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"documents = [\n",
|
||||||
|
" \"Baguette is a dish.\",\n",
|
||||||
|
" \"Paris is the capital of France.\",\n",
|
||||||
|
" \"numpy is a lib for linear algebra\",\n",
|
||||||
|
" \"You escaped what I've escaped - You'd be in Paris getting fucked up too\",\n",
|
||||||
|
"]\n",
|
||||||
|
"query = \"Where is Paris?\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"embeddings created successful\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"#\n",
|
||||||
|
"infinity_api_url = \"http://localhost:7797/v1\"\n",
|
||||||
|
"# model is currently not validated.\n",
|
||||||
|
"embeddings = InfinityEmbeddings(\n",
|
||||||
|
" model=\"sentence-transformers/all-MiniLM-L6-v2\", infinity_api_url=infinity_api_url\n",
|
||||||
|
")\n",
|
||||||
|
"try:\n",
|
||||||
|
" documents_embedded = embeddings.embed_documents(documents)\n",
|
||||||
|
" query_result = embeddings.embed_query(query)\n",
|
||||||
|
" print(\"embeddings created successful\")\n",
|
||||||
|
"except Exception as ex:\n",
|
||||||
|
" print(\n",
|
||||||
|
" \"Make sure the infinity instance is running. Verify by clicking on \"\n",
|
||||||
|
" f\"{infinity_api_url.replace('v1','docs')} Exception: {ex}. \"\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'Baguette is a dish.': 0.31344215908661155,\n",
|
||||||
|
" 'Paris is the capital of France.': 0.8148670296896388,\n",
|
||||||
|
" 'numpy is a lib for linear algebra': 0.004429399861302009,\n",
|
||||||
|
" \"You escaped what I've escaped - You'd be in Paris getting fucked up too\": 0.5088476180154582}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# (demo) compute similarity\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"scores = np.array(documents_embedded) @ np.array(query_result).T\n",
|
||||||
|
"dict(zip(documents, scores))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -43,6 +43,7 @@ from langchain.embeddings.huggingface import (
|
|||||||
HuggingFaceInstructEmbeddings,
|
HuggingFaceInstructEmbeddings,
|
||||||
)
|
)
|
||||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||||
|
from langchain.embeddings.infinity import InfinityEmbeddings
|
||||||
from langchain.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings
|
from langchain.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings
|
||||||
from langchain.embeddings.jina import JinaEmbeddings
|
from langchain.embeddings.jina import JinaEmbeddings
|
||||||
from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||||
@ -81,6 +82,7 @@ __all__ = [
|
|||||||
"FastEmbedEmbeddings",
|
"FastEmbedEmbeddings",
|
||||||
"HuggingFaceEmbeddings",
|
"HuggingFaceEmbeddings",
|
||||||
"HuggingFaceInferenceAPIEmbeddings",
|
"HuggingFaceInferenceAPIEmbeddings",
|
||||||
|
"InfinityEmbeddings",
|
||||||
"GradientEmbeddings",
|
"GradientEmbeddings",
|
||||||
"JinaEmbeddings",
|
"JinaEmbeddings",
|
||||||
"LlamaCppEmbeddings",
|
"LlamaCppEmbeddings",
|
||||||
|
323
libs/langchain/langchain/embeddings/infinity.py
Normal file
323
libs/langchain/langchain/embeddings/infinity.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
"""written under MIT Licence, Michael Feil 2023."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
__all__ = ["InfinityEmbeddings"]
|
||||||
|
|
||||||
|
|
||||||
|
class InfinityEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""Embedding models for self-hosted https://github.com/michaelfeil/infinity
|
||||||
|
This should also work for text-embeddings-inference and other
|
||||||
|
self-hosted openai-compatible servers.
|
||||||
|
|
||||||
|
Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import InfinityEmbeddings
|
||||||
|
InfinityEmbeddings(
|
||||||
|
model="BAAI/bge-small",
|
||||||
|
infinity_api_url="http://localhost:7797/v1",
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"Underlying Infinity model id."
|
||||||
|
|
||||||
|
infinity_api_url: str = "http://localhost:7797/v1"
|
||||||
|
"""Endpoint URL to use."""
|
||||||
|
|
||||||
|
client: Any = None #: :meta private:
|
||||||
|
"""Infinity client."""
|
||||||
|
|
||||||
|
# LLM call kwargs
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator(allow_reuse=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
|
||||||
|
values["infinity_api_url"] = get_from_dict_or_env(
|
||||||
|
values, "infinity_api_url", "INFINITY_API_URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
values["client"] = TinyAsyncOpenAIInfinityEmbeddingClient(
|
||||||
|
host=values["infinity_api_url"],
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Call out to Infinity's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
embeddings = self.client.embed(
|
||||||
|
model=self.model,
|
||||||
|
texts=texts,
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Async call out to Infinity's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
embeddings = await self.client.aembed(
|
||||||
|
model=self.model,
|
||||||
|
texts=texts,
|
||||||
|
)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to Infinity's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Async call out to Infinity's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
embeddings = await self.aembed_documents([text])
|
||||||
|
return embeddings[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TinyAsyncOpenAIInfinityEmbeddingClient: #: :meta private:
|
||||||
|
"""A helper tool to embed Infinity. Not part of Langchain's stable API,
|
||||||
|
direct use discouraged.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
||||||
|
mini_client = TinyAsyncInfinityEmbeddingClient(
|
||||||
|
)
|
||||||
|
embeds = mini_client.embed(
|
||||||
|
model="BAAI/bge-small",
|
||||||
|
text=["doc1", "doc2"]
|
||||||
|
)
|
||||||
|
# or
|
||||||
|
embeds = await mini_client.aembed(
|
||||||
|
model="BAAI/bge-small",
|
||||||
|
text=["doc1", "doc2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "http://localhost:7797/v1",
|
||||||
|
aiosession: Optional[aiohttp.ClientSession] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host = host
|
||||||
|
self.aiosession = aiosession
|
||||||
|
|
||||||
|
if self.host is None or len(self.host) < 3:
|
||||||
|
raise ValueError(" param `host` must be set to a valid url")
|
||||||
|
self._batch_size = 128
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _permute(
|
||||||
|
texts: List[str], sorter: Callable = len
|
||||||
|
) -> Tuple[List[str], Callable]:
|
||||||
|
"""Sort texts in ascending order, and
|
||||||
|
delivers a lambda expr, which can sort a same length list
|
||||||
|
https://github.com/UKPLab/sentence-transformers/blob/
|
||||||
|
c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): _description_
|
||||||
|
sorter (Callable, optional): _description_. Defaults to len.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[str], Callable]: _description_
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
texts = ["one","three","four"]
|
||||||
|
perm_texts, undo = self._permute(texts)
|
||||||
|
texts == undo(perm_texts)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(texts) == 1:
|
||||||
|
# special case query
|
||||||
|
return texts, lambda t: t
|
||||||
|
length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
|
||||||
|
texts_sorted = [texts[idx] for idx in length_sorted_idx]
|
||||||
|
|
||||||
|
return texts_sorted, lambda unsorted_embeddings: [ # noqa E731
|
||||||
|
unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _batch(self, texts: List[str]) -> List[List[str]]:
|
||||||
|
"""
|
||||||
|
splits Lists of text parts into batches of size max `self._batch_size`
|
||||||
|
When encoding vector database,
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): List of sentences
|
||||||
|
self._batch_size (int, optional): max batch size of one request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[str]]: Batches of List of sentences
|
||||||
|
"""
|
||||||
|
if len(texts) == 1:
|
||||||
|
# special case query
|
||||||
|
return [texts]
|
||||||
|
batches = []
|
||||||
|
for start_index in range(0, len(texts), self._batch_size):
|
||||||
|
batches.append(texts[start_index : start_index + self._batch_size])
|
||||||
|
return batches
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
|
||||||
|
if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
|
||||||
|
# special case query
|
||||||
|
return batch_of_texts[0]
|
||||||
|
texts = []
|
||||||
|
for sublist in batch_of_texts:
|
||||||
|
texts.extend(sublist)
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
|
||||||
|
"""Build the kwargs for the Post request, used by sync
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): _description_
|
||||||
|
texts (List[str]): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Collection[str]]: _description_
|
||||||
|
"""
|
||||||
|
return dict(
|
||||||
|
url=f"{self.host}/embeddings",
|
||||||
|
headers={
|
||||||
|
# "accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
},
|
||||||
|
json=dict(
|
||||||
|
input=texts,
|
||||||
|
model=model,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_request_embed(
|
||||||
|
self, model: str, batch_texts: List[str]
|
||||||
|
) -> List[List[float]]:
|
||||||
|
response = requests.post(
|
||||||
|
**self._kwargs_post_request(model=model, texts=batch_texts)
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Infinity returned an unexpected response with status "
|
||||||
|
f"{response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
return [e["embedding"] for e in response.json()["data"]]
|
||||||
|
|
||||||
|
def embed(self, model: str, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""call the embedding of model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): to embedding model
|
||||||
|
texts (List[str]): List of sentences to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: List of vectors for each sentence
|
||||||
|
"""
|
||||||
|
perm_texts, unpermute_func = self._permute(texts)
|
||||||
|
perm_texts_batched = self._batch(perm_texts)
|
||||||
|
|
||||||
|
# Request
|
||||||
|
map_args = (
|
||||||
|
self._sync_request_embed,
|
||||||
|
[model] * len(perm_texts_batched),
|
||||||
|
perm_texts_batched,
|
||||||
|
)
|
||||||
|
if len(perm_texts_batched) == 1:
|
||||||
|
embeddings_batch_perm = list(map(*map_args))
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(32) as p:
|
||||||
|
embeddings_batch_perm = list(p.map(*map_args))
|
||||||
|
|
||||||
|
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
||||||
|
embeddings = unpermute_func(embeddings_perm)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
async def _async_request(
|
||||||
|
self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
|
||||||
|
) -> List[List[float]]:
|
||||||
|
async with session.post(**kwargs) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Infinity returned an unexpected response with status "
|
||||||
|
f"{response.status}: {response.text}"
|
||||||
|
)
|
||||||
|
embedding = (await response.json())["embeddings"]
|
||||||
|
return [e["embedding"] for e in embedding]
|
||||||
|
|
||||||
|
async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""call the embedding of model, async method
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): to embedding model
|
||||||
|
texts (List[str]): List of sentences to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: List of vectors for each sentence
|
||||||
|
"""
|
||||||
|
perm_texts, unpermute_func = self._permute(texts)
|
||||||
|
perm_texts_batched = self._batch(perm_texts)
|
||||||
|
|
||||||
|
# Request
|
||||||
|
if self.aiosession is None:
|
||||||
|
self.aiosession = aiohttp.ClientSession(
|
||||||
|
trust_env=True, connector=aiohttp.TCPConnector(limit=32)
|
||||||
|
)
|
||||||
|
async with self.aiosession as session:
|
||||||
|
embeddings_batch_perm = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self._async_request(
|
||||||
|
session=session,
|
||||||
|
**self._kwargs_post_request(model=model, texts=t),
|
||||||
|
)
|
||||||
|
for t in perm_texts_batched
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
||||||
|
embeddings = unpermute_func(embeddings_perm)
|
||||||
|
return embeddings
|
@ -10,6 +10,7 @@ EXPECTED_ALL = [
|
|||||||
"FastEmbedEmbeddings",
|
"FastEmbedEmbeddings",
|
||||||
"HuggingFaceEmbeddings",
|
"HuggingFaceEmbeddings",
|
||||||
"HuggingFaceInferenceAPIEmbeddings",
|
"HuggingFaceInferenceAPIEmbeddings",
|
||||||
|
"InfinityEmbeddings",
|
||||||
"GradientEmbeddings",
|
"GradientEmbeddings",
|
||||||
"JinaEmbeddings",
|
"JinaEmbeddings",
|
||||||
"LlamaCppEmbeddings",
|
"LlamaCppEmbeddings",
|
||||||
|
101
libs/langchain/tests/unit_tests/embeddings/test_infinity.py
Normal file
101
libs/langchain/tests/unit_tests/embeddings/test_infinity.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from langchain.embeddings import InfinityEmbeddings
|
||||||
|
|
||||||
|
_MODEL_ID = "BAAI/bge-small"
|
||||||
|
_INFINITY_BASE_URL = "https://localhost/api"
|
||||||
|
_DOCUMENTS = [
|
||||||
|
"pizza",
|
||||||
|
"another pizza",
|
||||||
|
"a document",
|
||||||
|
"another pizza",
|
||||||
|
"super long document with many tokens",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, json_data: Dict, status_code: int):
|
||||||
|
self.json_data = json_data
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
def json(self) -> Dict:
|
||||||
|
return self.json_data
|
||||||
|
|
||||||
|
|
||||||
|
def mocked_requests_post(
|
||||||
|
url: str,
|
||||||
|
headers: dict,
|
||||||
|
json: dict,
|
||||||
|
) -> MockResponse:
|
||||||
|
assert url.startswith(_INFINITY_BASE_URL)
|
||||||
|
assert "model" in json and _MODEL_ID in json["model"]
|
||||||
|
assert json
|
||||||
|
assert headers
|
||||||
|
|
||||||
|
assert "input" in json and isinstance(json["input"], list)
|
||||||
|
embeddings = []
|
||||||
|
for inp in json["input"]:
|
||||||
|
# verify correct ordering
|
||||||
|
if "pizza" in inp:
|
||||||
|
v = [1.0, 0.0, 0.0]
|
||||||
|
elif "document" in inp:
|
||||||
|
v = [0.0, 0.9, 0.0]
|
||||||
|
else:
|
||||||
|
v = [0.0, 0.0, -1.0]
|
||||||
|
if len(inp) > 10:
|
||||||
|
v[2] += 0.1
|
||||||
|
embeddings.append({"embedding": v})
|
||||||
|
|
||||||
|
return MockResponse(
|
||||||
|
json_data={"data": embeddings},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_infinity_emb_sync(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||||
|
|
||||||
|
embedder = InfinityEmbeddings(model=_MODEL_ID, infinity_api_url=_INFINITY_BASE_URL)
|
||||||
|
|
||||||
|
assert embedder.infinity_api_url == _INFINITY_BASE_URL
|
||||||
|
assert embedder.model == _MODEL_ID
|
||||||
|
|
||||||
|
response = embedder.embed_documents(_DOCUMENTS)
|
||||||
|
want = [
|
||||||
|
[1.0, 0.0, 0.0], # pizza
|
||||||
|
[1.0, 0.0, 0.1], # pizza + long
|
||||||
|
[0.0, 0.9, 0.0], # doc
|
||||||
|
[1.0, 0.0, 0.1], # pizza + long
|
||||||
|
[0.0, 0.9, 0.1], # doc + long
|
||||||
|
]
|
||||||
|
|
||||||
|
assert response == want
|
||||||
|
|
||||||
|
|
||||||
|
def test_infinity_large_batch_size(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||||
|
|
||||||
|
embedder = InfinityEmbeddings(
|
||||||
|
infinity_api_url=_INFINITY_BASE_URL,
|
||||||
|
model=_MODEL_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedder.infinity_api_url == _INFINITY_BASE_URL
|
||||||
|
assert embedder.model == _MODEL_ID
|
||||||
|
|
||||||
|
response = embedder.embed_documents(_DOCUMENTS * 1024)
|
||||||
|
want = [
|
||||||
|
[1.0, 0.0, 0.0], # pizza
|
||||||
|
[1.0, 0.0, 0.1], # pizza + long
|
||||||
|
[0.0, 0.9, 0.0], # doc
|
||||||
|
[1.0, 0.0, 0.1], # pizza + long
|
||||||
|
[0.0, 0.9, 0.1], # doc + long
|
||||||
|
] * 1024
|
||||||
|
|
||||||
|
assert response == want
|
Loading…
Reference in New Issue
Block a user