mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
community[patch]: update Gradient embeddings (#14846)
- **Description:** Going forward, we have a own API `pip install gradientai`. Therefore gradually removing the self-build packages in llamaindex, haystack and langchain. - **Issue:** None. - **Dependencies:** `pip install gradientai` - **Tag maintainer:** @michaelfeil
This commit is contained in:
parent
6cc3c2452c
commit
7b96de3d5d
@ -1,15 +1,9 @@
|
|||||||
import asyncio
|
from typing import Any, Dict, List, Optional
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
__all__ = ["GradientEmbeddings"]
|
__all__ = ["GradientEmbeddings"]
|
||||||
|
|
||||||
@ -49,6 +43,9 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
gradient_api_url: str = "https://api.gradient.ai/api"
|
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||||
"""Endpoint URL to use."""
|
"""Endpoint URL to use."""
|
||||||
|
|
||||||
|
query_prompt_for_retrieval: Optional[str] = None
|
||||||
|
"""Query pre-prompt"""
|
||||||
|
|
||||||
client: Any = None #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
"""Gradient client."""
|
"""Gradient client."""
|
||||||
|
|
||||||
@ -72,21 +69,24 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
values["gradient_api_url"] = get_from_dict_or_env(
|
values["gradient_api_url"] = get_from_dict_or_env(
|
||||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
import gradientai
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
|
||||||
|
)
|
||||||
|
|
||||||
values["client"] = TinyAsyncGradientEmbeddingClient(
|
if parse(gradientai.__version__) < parse("1.4.0"):
|
||||||
|
raise ImportError(
|
||||||
|
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
|
||||||
|
)
|
||||||
|
|
||||||
|
gradient = gradientai.Gradient(
|
||||||
access_token=values["gradient_access_token"],
|
access_token=values["gradient_access_token"],
|
||||||
workspace_id=values["gradient_workspace_id"],
|
workspace_id=values["gradient_workspace_id"],
|
||||||
host=values["gradient_api_url"],
|
host=values["gradient_api_url"],
|
||||||
)
|
)
|
||||||
try:
|
values["client"] = gradient.get_embeddings_model(slug=values["model"])
|
||||||
import gradientai # noqa
|
|
||||||
except ImportError:
|
|
||||||
logging.warning(
|
|
||||||
"DeprecationWarning: `GradientEmbeddings` will use "
|
|
||||||
"`pip install gradientai` in future releases of langchain."
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -99,11 +99,11 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = self.client.embed(
|
inputs = [{"input": text} for text in texts]
|
||||||
model=self.model,
|
|
||||||
texts=texts,
|
result = self.client.embed(inputs=inputs).embeddings
|
||||||
)
|
|
||||||
return embeddings
|
return [e.embedding for e in result]
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Async call out to Gradient's embedding endpoint.
|
"""Async call out to Gradient's embedding endpoint.
|
||||||
@ -114,11 +114,11 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = await self.client.aembed(
|
inputs = [{"input": text} for text in texts]
|
||||||
model=self.model,
|
|
||||||
texts=texts,
|
result = (await self.client.aembed(inputs=inputs)).embeddings
|
||||||
)
|
|
||||||
return embeddings
|
return [e.embedding for e in result]
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to Gradient's embedding endpoint.
|
"""Call out to Gradient's embedding endpoint.
|
||||||
@ -129,7 +129,12 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
return self.embed_documents([text])[0]
|
query = (
|
||||||
|
f"{self.query_prompt_for_retrieval} {text}"
|
||||||
|
if self.query_prompt_for_retrieval
|
||||||
|
else text
|
||||||
|
)
|
||||||
|
return self.embed_documents([query])[0]
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
"""Async call out to Gradient's embedding endpoint.
|
"""Async call out to Gradient's embedding endpoint.
|
||||||
@ -140,240 +145,22 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
embeddings = await self.aembed_documents([text])
|
query = (
|
||||||
|
f"{self.query_prompt_for_retrieval} {text}"
|
||||||
|
if self.query_prompt_for_retrieval
|
||||||
|
else text
|
||||||
|
)
|
||||||
|
embeddings = await self.aembed_documents([query])
|
||||||
return embeddings[0]
|
return embeddings[0]
|
||||||
|
|
||||||
|
|
||||||
class TinyAsyncGradientEmbeddingClient: #: :meta private:
|
class TinyAsyncGradientEmbeddingClient: #: :meta private:
|
||||||
"""A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API,
|
"""Deprecated, TinyAsyncGradientEmbeddingClient was removed.
|
||||||
direct use discouraged.
|
|
||||||
|
|
||||||
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
|
|
||||||
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
|
|
||||||
or alternatively provide them as keywords to the constructor of this class.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
|
|
||||||
mini_client = TinyAsyncGradientEmbeddingClient(
|
|
||||||
workspace_id="12345614fc0_workspace",
|
|
||||||
access_token="gradientai-access_token",
|
|
||||||
)
|
|
||||||
embeds = mini_client.embed(
|
|
||||||
model="bge-large",
|
|
||||||
text=["doc1", "doc2"]
|
|
||||||
)
|
|
||||||
# or
|
|
||||||
embeds = await mini_client.aembed(
|
|
||||||
model="bge-large",
|
|
||||||
text=["doc1", "doc2"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
This class is just for backwards compatibility with older versions
|
||||||
|
of langchain_community.
|
||||||
|
It might be entirely removed in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
self,
|
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")
|
||||||
access_token: Optional[str] = None,
|
|
||||||
workspace_id: Optional[str] = None,
|
|
||||||
host: str = "https://api.gradient.ai/api",
|
|
||||||
aiosession: Optional[aiohttp.ClientSession] = None,
|
|
||||||
) -> None:
|
|
||||||
self.access_token = access_token or os.environ.get(
|
|
||||||
"GRADIENT_ACCESS_TOKEN", None
|
|
||||||
)
|
|
||||||
self.workspace_id = workspace_id or os.environ.get(
|
|
||||||
"GRADIENT_WORKSPACE_ID", None
|
|
||||||
)
|
|
||||||
self.host = host
|
|
||||||
self.aiosession = aiosession
|
|
||||||
|
|
||||||
if self.access_token is None or len(self.access_token) < 10:
|
|
||||||
raise ValueError(
|
|
||||||
"env variable `GRADIENT_ACCESS_TOKEN` or "
|
|
||||||
" param `access_token` must be set "
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.workspace_id is None or len(self.workspace_id) < 3:
|
|
||||||
raise ValueError(
|
|
||||||
"env variable `GRADIENT_WORKSPACE_ID` or "
|
|
||||||
" param `workspace_id` must be set"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.host is None or len(self.host) < 3:
|
|
||||||
raise ValueError(" param `host` must be set to a valid url")
|
|
||||||
self._batch_size = 128
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _permute(
|
|
||||||
texts: List[str], sorter: Callable = len
|
|
||||||
) -> Tuple[List[str], Callable]:
|
|
||||||
"""Sort texts in ascending order, and
|
|
||||||
delivers a lambda expr, which can sort a same length list
|
|
||||||
https://github.com/UKPLab/sentence-transformers/blob/
|
|
||||||
c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (List[str]): _description_
|
|
||||||
sorter (Callable, optional): _description_. Defaults to len.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[List[str], Callable]: _description_
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```
|
|
||||||
texts = ["one","three","four"]
|
|
||||||
perm_texts, undo = self._permute(texts)
|
|
||||||
texts == undo(perm_texts)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(texts) == 1:
|
|
||||||
# special case query
|
|
||||||
return texts, lambda t: t
|
|
||||||
length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
|
|
||||||
texts_sorted = [texts[idx] for idx in length_sorted_idx]
|
|
||||||
|
|
||||||
return texts_sorted, lambda unsorted_embeddings: [ # noqa E731
|
|
||||||
unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _batch(self, texts: List[str]) -> List[List[str]]:
|
|
||||||
"""
|
|
||||||
splits Lists of text parts into batches of size max `self._batch_size`
|
|
||||||
When encoding vector database,
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (List[str]): List of sentences
|
|
||||||
self._batch_size (int, optional): max batch size of one request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[str]]: Batches of List of sentences
|
|
||||||
"""
|
|
||||||
if len(texts) == 1:
|
|
||||||
# special case query
|
|
||||||
return [texts]
|
|
||||||
batches = []
|
|
||||||
for start_index in range(0, len(texts), self._batch_size):
|
|
||||||
batches.append(texts[start_index : start_index + self._batch_size])
|
|
||||||
return batches
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
|
|
||||||
if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
|
|
||||||
# special case query
|
|
||||||
return batch_of_texts[0]
|
|
||||||
texts = []
|
|
||||||
for sublist in batch_of_texts:
|
|
||||||
texts.extend(sublist)
|
|
||||||
return texts
|
|
||||||
|
|
||||||
def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
|
|
||||||
"""Build the kwargs for the Post request, used by sync
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (str): _description_
|
|
||||||
texts (List[str]): _description_
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Collection[str]]: _description_
|
|
||||||
"""
|
|
||||||
return dict(
|
|
||||||
url=f"{self.host}/embeddings/{model}",
|
|
||||||
headers={
|
|
||||||
"authorization": f"Bearer {self.access_token}",
|
|
||||||
"x-gradient-workspace-id": f"{self.workspace_id}",
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
},
|
|
||||||
json=dict(
|
|
||||||
inputs=[{"input": i} for i in texts],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sync_request_embed(
|
|
||||||
self, model: str, batch_texts: List[str]
|
|
||||||
) -> List[List[float]]:
|
|
||||||
response = requests.post(
|
|
||||||
**self._kwargs_post_request(model=model, texts=batch_texts)
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"Gradient returned an unexpected response with status "
|
|
||||||
f"{response.status_code}: {response.text}"
|
|
||||||
)
|
|
||||||
return [e["embedding"] for e in response.json()["embeddings"]]
|
|
||||||
|
|
||||||
def embed(self, model: str, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""call the embedding of model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (str): to embedding model
|
|
||||||
texts (List[str]): List of sentences to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[float]]: List of vectors for each sentence
|
|
||||||
"""
|
|
||||||
perm_texts, unpermute_func = self._permute(texts)
|
|
||||||
perm_texts_batched = self._batch(perm_texts)
|
|
||||||
|
|
||||||
# Request
|
|
||||||
map_args = (
|
|
||||||
self._sync_request_embed,
|
|
||||||
[model] * len(perm_texts_batched),
|
|
||||||
perm_texts_batched,
|
|
||||||
)
|
|
||||||
if len(perm_texts_batched) == 1:
|
|
||||||
embeddings_batch_perm = list(map(*map_args))
|
|
||||||
else:
|
|
||||||
with ThreadPoolExecutor(32) as p:
|
|
||||||
embeddings_batch_perm = list(p.map(*map_args))
|
|
||||||
|
|
||||||
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
|
||||||
embeddings = unpermute_func(embeddings_perm)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
async def _async_request(
|
|
||||||
self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
|
|
||||||
) -> List[List[float]]:
|
|
||||||
async with session.post(**kwargs) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"Gradient returned an unexpected response with status "
|
|
||||||
f"{response.status}: {response.text}"
|
|
||||||
)
|
|
||||||
embedding = (await response.json())["embeddings"]
|
|
||||||
return [e["embedding"] for e in embedding]
|
|
||||||
|
|
||||||
async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""call the embedding of model, async method
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (str): to embedding model
|
|
||||||
texts (List[str]): List of sentences to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[float]]: List of vectors for each sentence
|
|
||||||
"""
|
|
||||||
perm_texts, unpermute_func = self._permute(texts)
|
|
||||||
perm_texts_batched = self._batch(perm_texts)
|
|
||||||
|
|
||||||
# Request
|
|
||||||
if self.aiosession is None:
|
|
||||||
self.aiosession = aiohttp.ClientSession(
|
|
||||||
trust_env=True, connector=aiohttp.TCPConnector(limit=32)
|
|
||||||
)
|
|
||||||
async with self.aiosession as session:
|
|
||||||
embeddings_batch_perm = await asyncio.gather(
|
|
||||||
*[
|
|
||||||
self._async_request(
|
|
||||||
session=session,
|
|
||||||
**self._kwargs_post_request(model=model, texts=t),
|
|
||||||
)
|
|
||||||
for t in perm_texts_batched
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
|
||||||
embeddings = unpermute_func(embeddings_perm)
|
|
||||||
return embeddings
|
|
||||||
|
37
libs/community/poetry.lock
generated
37
libs/community/poetry.lock
generated
@ -1,5 +1,17 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aenum"
|
||||||
|
version = "3.1.15"
|
||||||
|
description = "Advanced Enumerations (compatible with Python's stdlib Enum), NamedTuples, and NamedConstants"
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "aenum-3.1.15-py2-none-any.whl", hash = "sha256:27b1710b9d084de6e2e695dab78fe9f269de924b51ae2850170ee7e1ca6288a5"},
|
||||||
|
{file = "aenum-3.1.15-py3-none-any.whl", hash = "sha256:e0dfaeea4c2bd362144b87377e2c61d91958c5ed0b4daf89cb6f45ae23af6288"},
|
||||||
|
{file = "aenum-3.1.15.tar.gz", hash = "sha256:8cbd76cd18c4f870ff39b24284d3ea028fbe8731a58df3aa581e434c575b9559"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiodns"
|
name = "aiodns"
|
||||||
version = "3.1.1"
|
version = "3.1.1"
|
||||||
@ -2352,6 +2364,23 @@ test = ["aiofiles", "aiohttp (>=3.7.1,<3.9.0)", "botocore (>=1.21,<2)", "mock (=
|
|||||||
test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.0.2)"]
|
test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.0.2)"]
|
||||||
websockets = ["websockets (>=10,<11)", "websockets (>=9,<10)"]
|
websockets = ["websockets (>=10,<11)", "websockets (>=9,<10)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gradientai"
|
||||||
|
version = "1.4.0"
|
||||||
|
description = "Gradient AI API"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.8.1,<4.0.0"
|
||||||
|
files = [
|
||||||
|
{file = "gradientai-1.4.0-py3-none-any.whl", hash = "sha256:58b74151e4bee534d438509303bcca3a9b84d17dafff31c206353489b54fcbfa"},
|
||||||
|
{file = "gradientai-1.4.0.tar.gz", hash = "sha256:98b9e0894530c6b7c675a113010dca7f7f7c399e02c46c0fb5532bf9fc1609f4"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
aenum = ">=3.1.11"
|
||||||
|
pydantic = ">=1.10.5,<2.0.0"
|
||||||
|
python-dateutil = ">=2.8.2"
|
||||||
|
urllib3 = ">=1.25.3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphql-core"
|
name = "graphql-core"
|
||||||
version = "3.2.3"
|
version = "3.2.3"
|
||||||
@ -3023,7 +3052,6 @@ files = [
|
|||||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
||||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
|
|
||||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
||||||
@ -3421,7 +3449,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -5529,6 +5557,7 @@ files = [
|
|||||||
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
|
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
|
||||||
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
|
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
|
||||||
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
|
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
|
||||||
|
{file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"},
|
||||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
|
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
|
||||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
|
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
|
||||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
|
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
|
||||||
@ -8480,9 +8509,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
cli = ["typer"]
|
cli = ["typer"]
|
||||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "e3bacf389a13d283c4dd29e3a673e1863826b4e98785c666fefc10cf714c2f6f"
|
content-hash = "ab4b1efe33110b575d2fb65bd5ecb90e92d1bd83dd5eac87080e4d07268df72f"
|
||||||
|
@ -28,6 +28,7 @@ openai = {version = "<2", optional = true}
|
|||||||
arxiv = {version = "^1.4", optional = true}
|
arxiv = {version = "^1.4", optional = true}
|
||||||
pypdf = {version = "^3.4.0", optional = true}
|
pypdf = {version = "^3.4.0", optional = true}
|
||||||
aleph-alpha-client = {version="^2.15.0", optional = true}
|
aleph-alpha-client = {version="^2.15.0", optional = true}
|
||||||
|
gradientai = {version="^1.4.0", optional = true}
|
||||||
pgvector = {version = "^0.1.6", optional = true}
|
pgvector = {version = "^0.1.6", optional = true}
|
||||||
atlassian-python-api = {version = "^3.36.0", optional=true}
|
atlassian-python-api = {version = "^3.36.0", optional=true}
|
||||||
html2text = {version="^2020.1.16", optional=true}
|
html2text = {version="^2020.1.16", optional=true}
|
||||||
@ -203,6 +204,7 @@ extended_testing = [
|
|||||||
"telethon",
|
"telethon",
|
||||||
"psychicapi",
|
"psychicapi",
|
||||||
"gql",
|
"gql",
|
||||||
|
"gradientai",
|
||||||
"requests-toolbelt",
|
"requests-toolbelt",
|
||||||
"html2text",
|
"html2text",
|
||||||
"numexpr",
|
"numexpr",
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from typing import Dict
|
import sys
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
|
||||||
|
|
||||||
from langchain_community.embeddings import GradientEmbeddings
|
from langchain_community.embeddings import GradientEmbeddings
|
||||||
|
|
||||||
@ -11,38 +12,22 @@ _GRADIENT_WORKSPACE_ID = "valid_workspace_12345"
|
|||||||
_GRADIENT_BASE_URL = "https://api.gradient.ai/api"
|
_GRADIENT_BASE_URL = "https://api.gradient.ai/api"
|
||||||
_DOCUMENTS = [
|
_DOCUMENTS = [
|
||||||
"pizza",
|
"pizza",
|
||||||
"another pizza",
|
"another long pizza",
|
||||||
"a document",
|
"a document",
|
||||||
"another pizza",
|
"another long pizza",
|
||||||
"super long document with many tokens",
|
"super long document with many tokens",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class MockResponse:
|
class GradientEmbeddingsModel(MagicMock):
|
||||||
def __init__(self, json_data: Dict, status_code: int):
|
"""MockGradientModel."""
|
||||||
self.json_data = json_data
|
|
||||||
self.status_code = status_code
|
|
||||||
|
|
||||||
def json(self) -> Dict:
|
def embed(self, inputs: List[Dict[str, str]]) -> Any:
|
||||||
return self.json_data
|
"""Just duplicate the query m times."""
|
||||||
|
output = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
def mocked_requests_post(
|
|
||||||
url: str,
|
|
||||||
headers: dict,
|
|
||||||
json: dict,
|
|
||||||
) -> MockResponse:
|
|
||||||
assert url.startswith(_GRADIENT_BASE_URL)
|
|
||||||
assert _MODEL_ID in url
|
|
||||||
assert json
|
|
||||||
assert headers
|
|
||||||
|
|
||||||
assert headers.get("authorization") == f"Bearer {_GRADIENT_SECRET}"
|
|
||||||
assert headers.get("x-gradient-workspace-id") == f"{_GRADIENT_WORKSPACE_ID}"
|
|
||||||
|
|
||||||
assert "inputs" in json and "input" in json["inputs"][0]
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for inp in json["inputs"]:
|
for i, inp in enumerate(inputs):
|
||||||
# verify correct ordering
|
# verify correct ordering
|
||||||
inp = inp["input"]
|
inp = inp["input"]
|
||||||
if "pizza" in inp:
|
if "pizza" in inp:
|
||||||
@ -53,19 +38,43 @@ def mocked_requests_post(
|
|||||||
v = [0.0, 0.0, -1.0]
|
v = [0.0, 0.0, -1.0]
|
||||||
if len(inp) > 10:
|
if len(inp) > 10:
|
||||||
v[2] += 0.1
|
v[2] += 0.1
|
||||||
embeddings.append({"embedding": v})
|
output_inner = MagicMock()
|
||||||
|
output_inner.embedding = v
|
||||||
|
embeddings.append(output_inner)
|
||||||
|
|
||||||
return MockResponse(
|
output.embeddings = embeddings
|
||||||
json_data={"embeddings": embeddings},
|
return output
|
||||||
status_code=200,
|
|
||||||
)
|
async def aembed(self, *args) -> Any:
|
||||||
|
return self.embed(*args)
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_llm_sync(
|
class MockGradient(MagicMock):
|
||||||
mocker: MockerFixture,
|
"""Mock Gradient package."""
|
||||||
) -> None:
|
|
||||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
|
||||||
|
|
||||||
|
def __init__(self, access_token: str, workspace_id, host):
|
||||||
|
assert access_token == _GRADIENT_SECRET
|
||||||
|
assert workspace_id == _GRADIENT_WORKSPACE_ID
|
||||||
|
assert host == _GRADIENT_BASE_URL
|
||||||
|
|
||||||
|
def get_embeddings_model(self, slug: str) -> GradientEmbeddingsModel:
|
||||||
|
assert slug == _MODEL_ID
|
||||||
|
return GradientEmbeddingsModel()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Mock Gradient close."""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class MockGradientaiPackage(MagicMock):
|
||||||
|
"""Mock Gradientai package."""
|
||||||
|
|
||||||
|
Gradient = MockGradient
|
||||||
|
__version__ = "1.4.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_llm_sync() -> None:
|
||||||
|
with patch.dict(sys.modules, {"gradientai": MockGradientaiPackage()}):
|
||||||
embedder = GradientEmbeddings(
|
embedder = GradientEmbeddings(
|
||||||
gradient_api_url=_GRADIENT_BASE_URL,
|
gradient_api_url=_GRADIENT_BASE_URL,
|
||||||
gradient_access_token=_GRADIENT_SECRET,
|
gradient_access_token=_GRADIENT_SECRET,
|
||||||
@ -89,39 +98,7 @@ def test_gradient_llm_sync(
|
|||||||
assert response == want
|
assert response == want
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_llm_large_batch_size(
|
def test_gradient_wrong_setup() -> None:
|
||||||
mocker: MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
|
||||||
|
|
||||||
embedder = GradientEmbeddings(
|
|
||||||
gradient_api_url=_GRADIENT_BASE_URL,
|
|
||||||
gradient_access_token=_GRADIENT_SECRET,
|
|
||||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
|
||||||
model=_MODEL_ID,
|
|
||||||
)
|
|
||||||
assert embedder.gradient_access_token == _GRADIENT_SECRET
|
|
||||||
assert embedder.gradient_api_url == _GRADIENT_BASE_URL
|
|
||||||
assert embedder.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
|
||||||
assert embedder.model == _MODEL_ID
|
|
||||||
|
|
||||||
response = embedder.embed_documents(_DOCUMENTS * 1024)
|
|
||||||
want = [
|
|
||||||
[1.0, 0.0, 0.0], # pizza
|
|
||||||
[1.0, 0.0, 0.1], # pizza + long
|
|
||||||
[0.0, 0.9, 0.0], # doc
|
|
||||||
[1.0, 0.0, 0.1], # pizza + long
|
|
||||||
[0.0, 0.9, 0.1], # doc + long
|
|
||||||
] * 1024
|
|
||||||
|
|
||||||
assert response == want
|
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_wrong_setup(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
GradientEmbeddings(
|
GradientEmbeddings(
|
||||||
gradient_api_url=_GRADIENT_BASE_URL,
|
gradient_api_url=_GRADIENT_BASE_URL,
|
||||||
@ -130,6 +107,8 @@ def test_gradient_wrong_setup(
|
|||||||
model=_MODEL_ID,
|
model=_MODEL_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_wrong_setup2() -> None:
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
GradientEmbeddings(
|
GradientEmbeddings(
|
||||||
gradient_api_url=_GRADIENT_BASE_URL,
|
gradient_api_url=_GRADIENT_BASE_URL,
|
||||||
@ -138,6 +117,8 @@ def test_gradient_wrong_setup(
|
|||||||
model=_MODEL_ID,
|
model=_MODEL_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_wrong_setup3() -> None:
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
GradientEmbeddings(
|
GradientEmbeddings(
|
||||||
gradient_api_url="-", # empty
|
gradient_api_url="-", # empty
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
from langchain_community.embeddings.gradient_ai import (
|
from langchain_community.embeddings.gradient_ai import GradientEmbeddings
|
||||||
GradientEmbeddings,
|
|
||||||
TinyAsyncGradientEmbeddingClient,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["GradientEmbeddings", "TinyAsyncGradientEmbeddingClient"]
|
__all__ = ["GradientEmbeddings"]
|
||||||
|
Loading…
Reference in New Issue
Block a user