mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
community[patch]: update Gradient embeddings (#14846)
- **Description:** Going forward, we have a own API `pip install gradientai`. Therefore gradually removing the self-build packages in llamaindex, haystack and langchain. - **Issue:** None. - **Dependencies:** `pip install gradientai` - **Tag maintainer:** @michaelfeil
This commit is contained in:
parent
6cc3c2452c
commit
7b96de3d5d
@ -1,15 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from packaging.version import parse
|
||||
|
||||
__all__ = ["GradientEmbeddings"]
|
||||
|
||||
@ -49,6 +43,9 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||
"""Endpoint URL to use."""
|
||||
|
||||
query_prompt_for_retrieval: Optional[str] = None
|
||||
"""Query pre-prompt"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
"""Gradient client."""
|
||||
|
||||
@ -72,21 +69,24 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
values["gradient_api_url"] = get_from_dict_or_env(
|
||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||
)
|
||||
try:
|
||||
import gradientai
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
|
||||
)
|
||||
|
||||
values["client"] = TinyAsyncGradientEmbeddingClient(
|
||||
if parse(gradientai.__version__) < parse("1.4.0"):
|
||||
raise ImportError(
|
||||
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
|
||||
)
|
||||
|
||||
gradient = gradientai.Gradient(
|
||||
access_token=values["gradient_access_token"],
|
||||
workspace_id=values["gradient_workspace_id"],
|
||||
host=values["gradient_api_url"],
|
||||
)
|
||||
try:
|
||||
import gradientai # noqa
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"DeprecationWarning: `GradientEmbeddings` will use "
|
||||
"`pip install gradientai` in future releases of langchain."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
values["client"] = gradient.get_embeddings_model(slug=values["model"])
|
||||
|
||||
return values
|
||||
|
||||
@ -99,11 +99,11 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = self.client.embed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
)
|
||||
return embeddings
|
||||
inputs = [{"input": text} for text in texts]
|
||||
|
||||
result = self.client.embed(inputs=inputs).embeddings
|
||||
|
||||
return [e.embedding for e in result]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Async call out to Gradient's embedding endpoint.
|
||||
@ -114,11 +114,11 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = await self.client.aembed(
|
||||
model=self.model,
|
||||
texts=texts,
|
||||
)
|
||||
return embeddings
|
||||
inputs = [{"input": text} for text in texts]
|
||||
|
||||
result = (await self.client.aembed(inputs=inputs)).embeddings
|
||||
|
||||
return [e.embedding for e in result]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Gradient's embedding endpoint.
|
||||
@ -129,7 +129,12 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
query = (
|
||||
f"{self.query_prompt_for_retrieval} {text}"
|
||||
if self.query_prompt_for_retrieval
|
||||
else text
|
||||
)
|
||||
return self.embed_documents([query])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Async call out to Gradient's embedding endpoint.
|
||||
@ -140,240 +145,22 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
query = (
|
||||
f"{self.query_prompt_for_retrieval} {text}"
|
||||
if self.query_prompt_for_retrieval
|
||||
else text
|
||||
)
|
||||
embeddings = await self.aembed_documents([query])
|
||||
return embeddings[0]
|
||||
|
||||
|
||||
class TinyAsyncGradientEmbeddingClient: #: :meta private:
|
||||
"""A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API,
|
||||
direct use discouraged.
|
||||
|
||||
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
|
||||
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
|
||||
or alternatively provide them as keywords to the constructor of this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
mini_client = TinyAsyncGradientEmbeddingClient(
|
||||
workspace_id="12345614fc0_workspace",
|
||||
access_token="gradientai-access_token",
|
||||
)
|
||||
embeds = mini_client.embed(
|
||||
model="bge-large",
|
||||
text=["doc1", "doc2"]
|
||||
)
|
||||
# or
|
||||
embeds = await mini_client.aembed(
|
||||
model="bge-large",
|
||||
text=["doc1", "doc2"]
|
||||
)
|
||||
"""Deprecated, TinyAsyncGradientEmbeddingClient was removed.
|
||||
|
||||
This class is just for backwards compatibility with older versions
|
||||
of langchain_community.
|
||||
It might be entirely removed in the future.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
host: str = "https://api.gradient.ai/api",
|
||||
aiosession: Optional[aiohttp.ClientSession] = None,
|
||||
) -> None:
|
||||
self.access_token = access_token or os.environ.get(
|
||||
"GRADIENT_ACCESS_TOKEN", None
|
||||
)
|
||||
self.workspace_id = workspace_id or os.environ.get(
|
||||
"GRADIENT_WORKSPACE_ID", None
|
||||
)
|
||||
self.host = host
|
||||
self.aiosession = aiosession
|
||||
|
||||
if self.access_token is None or len(self.access_token) < 10:
|
||||
raise ValueError(
|
||||
"env variable `GRADIENT_ACCESS_TOKEN` or "
|
||||
" param `access_token` must be set "
|
||||
)
|
||||
|
||||
if self.workspace_id is None or len(self.workspace_id) < 3:
|
||||
raise ValueError(
|
||||
"env variable `GRADIENT_WORKSPACE_ID` or "
|
||||
" param `workspace_id` must be set"
|
||||
)
|
||||
|
||||
if self.host is None or len(self.host) < 3:
|
||||
raise ValueError(" param `host` must be set to a valid url")
|
||||
self._batch_size = 128
|
||||
|
||||
@staticmethod
|
||||
def _permute(
|
||||
texts: List[str], sorter: Callable = len
|
||||
) -> Tuple[List[str], Callable]:
|
||||
"""Sort texts in ascending order, and
|
||||
delivers a lambda expr, which can sort a same length list
|
||||
https://github.com/UKPLab/sentence-transformers/blob/
|
||||
c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156
|
||||
|
||||
Args:
|
||||
texts (List[str]): _description_
|
||||
sorter (Callable, optional): _description_. Defaults to len.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Callable]: _description_
|
||||
|
||||
Example:
|
||||
```
|
||||
texts = ["one","three","four"]
|
||||
perm_texts, undo = self._permute(texts)
|
||||
texts == undo(perm_texts)
|
||||
```
|
||||
"""
|
||||
|
||||
if len(texts) == 1:
|
||||
# special case query
|
||||
return texts, lambda t: t
|
||||
length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
|
||||
texts_sorted = [texts[idx] for idx in length_sorted_idx]
|
||||
|
||||
return texts_sorted, lambda unsorted_embeddings: [ # noqa E731
|
||||
unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
|
||||
]
|
||||
|
||||
def _batch(self, texts: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
splits Lists of text parts into batches of size max `self._batch_size`
|
||||
When encoding vector database,
|
||||
|
||||
Args:
|
||||
texts (List[str]): List of sentences
|
||||
self._batch_size (int, optional): max batch size of one request.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: Batches of List of sentences
|
||||
"""
|
||||
if len(texts) == 1:
|
||||
# special case query
|
||||
return [texts]
|
||||
batches = []
|
||||
for start_index in range(0, len(texts), self._batch_size):
|
||||
batches.append(texts[start_index : start_index + self._batch_size])
|
||||
return batches
|
||||
|
||||
@staticmethod
|
||||
def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
|
||||
if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
|
||||
# special case query
|
||||
return batch_of_texts[0]
|
||||
texts = []
|
||||
for sublist in batch_of_texts:
|
||||
texts.extend(sublist)
|
||||
return texts
|
||||
|
||||
def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
|
||||
"""Build the kwargs for the Post request, used by sync
|
||||
|
||||
Args:
|
||||
model (str): _description_
|
||||
texts (List[str]): _description_
|
||||
|
||||
Returns:
|
||||
Dict[str, Collection[str]]: _description_
|
||||
"""
|
||||
return dict(
|
||||
url=f"{self.host}/embeddings/{model}",
|
||||
headers={
|
||||
"authorization": f"Bearer {self.access_token}",
|
||||
"x-gradient-workspace-id": f"{self.workspace_id}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
json=dict(
|
||||
inputs=[{"input": i} for i in texts],
|
||||
),
|
||||
)
|
||||
|
||||
def _sync_request_embed(
|
||||
self, model: str, batch_texts: List[str]
|
||||
) -> List[List[float]]:
|
||||
response = requests.post(
|
||||
**self._kwargs_post_request(model=model, texts=batch_texts)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
return [e["embedding"] for e in response.json()["embeddings"]]
|
||||
|
||||
def embed(self, model: str, texts: List[str]) -> List[List[float]]:
|
||||
"""call the embedding of model
|
||||
|
||||
Args:
|
||||
model (str): to embedding model
|
||||
texts (List[str]): List of sentences to embed.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: List of vectors for each sentence
|
||||
"""
|
||||
perm_texts, unpermute_func = self._permute(texts)
|
||||
perm_texts_batched = self._batch(perm_texts)
|
||||
|
||||
# Request
|
||||
map_args = (
|
||||
self._sync_request_embed,
|
||||
[model] * len(perm_texts_batched),
|
||||
perm_texts_batched,
|
||||
)
|
||||
if len(perm_texts_batched) == 1:
|
||||
embeddings_batch_perm = list(map(*map_args))
|
||||
else:
|
||||
with ThreadPoolExecutor(32) as p:
|
||||
embeddings_batch_perm = list(p.map(*map_args))
|
||||
|
||||
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
||||
embeddings = unpermute_func(embeddings_perm)
|
||||
return embeddings
|
||||
|
||||
async def _async_request(
|
||||
self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
|
||||
) -> List[List[float]]:
|
||||
async with session.post(**kwargs) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
embedding = (await response.json())["embeddings"]
|
||||
return [e["embedding"] for e in embedding]
|
||||
|
||||
async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
|
||||
"""call the embedding of model, async method
|
||||
|
||||
Args:
|
||||
model (str): to embedding model
|
||||
texts (List[str]): List of sentences to embed.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: List of vectors for each sentence
|
||||
"""
|
||||
perm_texts, unpermute_func = self._permute(texts)
|
||||
perm_texts_batched = self._batch(perm_texts)
|
||||
|
||||
# Request
|
||||
if self.aiosession is None:
|
||||
self.aiosession = aiohttp.ClientSession(
|
||||
trust_env=True, connector=aiohttp.TCPConnector(limit=32)
|
||||
)
|
||||
async with self.aiosession as session:
|
||||
embeddings_batch_perm = await asyncio.gather(
|
||||
*[
|
||||
self._async_request(
|
||||
session=session,
|
||||
**self._kwargs_post_request(model=model, texts=t),
|
||||
)
|
||||
for t in perm_texts_batched
|
||||
]
|
||||
)
|
||||
|
||||
embeddings_perm = self._unbatch(embeddings_batch_perm)
|
||||
embeddings = unpermute_func(embeddings_perm)
|
||||
return embeddings
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")
|
||||
|
37
libs/community/poetry.lock
generated
37
libs/community/poetry.lock
generated
@ -1,5 +1,17 @@
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
version = "3.1.15"
|
||||
description = "Advanced Enumerations (compatible with Python's stdlib Enum), NamedTuples, and NamedConstants"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "aenum-3.1.15-py2-none-any.whl", hash = "sha256:27b1710b9d084de6e2e695dab78fe9f269de924b51ae2850170ee7e1ca6288a5"},
|
||||
{file = "aenum-3.1.15-py3-none-any.whl", hash = "sha256:e0dfaeea4c2bd362144b87377e2c61d91958c5ed0b4daf89cb6f45ae23af6288"},
|
||||
{file = "aenum-3.1.15.tar.gz", hash = "sha256:8cbd76cd18c4f870ff39b24284d3ea028fbe8731a58df3aa581e434c575b9559"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiodns"
|
||||
version = "3.1.1"
|
||||
@ -2352,6 +2364,23 @@ test = ["aiofiles", "aiohttp (>=3.7.1,<3.9.0)", "botocore (>=1.21,<2)", "mock (=
|
||||
test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.0.2)"]
|
||||
websockets = ["websockets (>=10,<11)", "websockets (>=9,<10)"]
|
||||
|
||||
[[package]]
|
||||
name = "gradientai"
|
||||
version = "1.4.0"
|
||||
description = "Gradient AI API"
|
||||
optional = true
|
||||
python-versions = ">=3.8.1,<4.0.0"
|
||||
files = [
|
||||
{file = "gradientai-1.4.0-py3-none-any.whl", hash = "sha256:58b74151e4bee534d438509303bcca3a9b84d17dafff31c206353489b54fcbfa"},
|
||||
{file = "gradientai-1.4.0.tar.gz", hash = "sha256:98b9e0894530c6b7c675a113010dca7f7f7c399e02c46c0fb5532bf9fc1609f4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aenum = ">=3.1.11"
|
||||
pydantic = ">=1.10.5,<2.0.0"
|
||||
python-dateutil = ">=2.8.2"
|
||||
urllib3 = ">=1.25.3"
|
||||
|
||||
[[package]]
|
||||
name = "graphql-core"
|
||||
version = "3.2.3"
|
||||
@ -3023,7 +3052,6 @@ files = [
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
||||
@ -3421,7 +3449,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -5529,6 +5557,7 @@ files = [
|
||||
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
|
||||
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
|
||||
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
|
||||
@ -8480,9 +8509,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
|
||||
[extras]
|
||||
cli = ["typer"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "e3bacf389a13d283c4dd29e3a673e1863826b4e98785c666fefc10cf714c2f6f"
|
||||
content-hash = "ab4b1efe33110b575d2fb65bd5ecb90e92d1bd83dd5eac87080e4d07268df72f"
|
||||
|
@ -28,6 +28,7 @@ openai = {version = "<2", optional = true}
|
||||
arxiv = {version = "^1.4", optional = true}
|
||||
pypdf = {version = "^3.4.0", optional = true}
|
||||
aleph-alpha-client = {version="^2.15.0", optional = true}
|
||||
gradientai = {version="^1.4.0", optional = true}
|
||||
pgvector = {version = "^0.1.6", optional = true}
|
||||
atlassian-python-api = {version = "^3.36.0", optional=true}
|
||||
html2text = {version="^2020.1.16", optional=true}
|
||||
@ -203,6 +204,7 @@ extended_testing = [
|
||||
"telethon",
|
||||
"psychicapi",
|
||||
"gql",
|
||||
"gradientai",
|
||||
"requests-toolbelt",
|
||||
"html2text",
|
||||
"numexpr",
|
||||
|
@ -1,7 +1,8 @@
|
||||
from typing import Dict
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain_community.embeddings import GradientEmbeddings
|
||||
|
||||
@ -11,117 +12,93 @@ _GRADIENT_WORKSPACE_ID = "valid_workspace_12345"
|
||||
_GRADIENT_BASE_URL = "https://api.gradient.ai/api"
|
||||
_DOCUMENTS = [
|
||||
"pizza",
|
||||
"another pizza",
|
||||
"another long pizza",
|
||||
"a document",
|
||||
"another pizza",
|
||||
"another long pizza",
|
||||
"super long document with many tokens",
|
||||
]
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data: Dict, status_code: int):
|
||||
self.json_data = json_data
|
||||
self.status_code = status_code
|
||||
class GradientEmbeddingsModel(MagicMock):
|
||||
"""MockGradientModel."""
|
||||
|
||||
def json(self) -> Dict:
|
||||
return self.json_data
|
||||
def embed(self, inputs: List[Dict[str, str]]) -> Any:
|
||||
"""Just duplicate the query m times."""
|
||||
output = MagicMock()
|
||||
|
||||
embeddings = []
|
||||
for i, inp in enumerate(inputs):
|
||||
# verify correct ordering
|
||||
inp = inp["input"]
|
||||
if "pizza" in inp:
|
||||
v = [1.0, 0.0, 0.0]
|
||||
elif "document" in inp:
|
||||
v = [0.0, 0.9, 0.0]
|
||||
else:
|
||||
v = [0.0, 0.0, -1.0]
|
||||
if len(inp) > 10:
|
||||
v[2] += 0.1
|
||||
output_inner = MagicMock()
|
||||
output_inner.embedding = v
|
||||
embeddings.append(output_inner)
|
||||
|
||||
output.embeddings = embeddings
|
||||
return output
|
||||
|
||||
async def aembed(self, *args) -> Any:
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
def mocked_requests_post(
|
||||
url: str,
|
||||
headers: dict,
|
||||
json: dict,
|
||||
) -> MockResponse:
|
||||
assert url.startswith(_GRADIENT_BASE_URL)
|
||||
assert _MODEL_ID in url
|
||||
assert json
|
||||
assert headers
|
||||
class MockGradient(MagicMock):
|
||||
"""Mock Gradient package."""
|
||||
|
||||
assert headers.get("authorization") == f"Bearer {_GRADIENT_SECRET}"
|
||||
assert headers.get("x-gradient-workspace-id") == f"{_GRADIENT_WORKSPACE_ID}"
|
||||
def __init__(self, access_token: str, workspace_id, host):
|
||||
assert access_token == _GRADIENT_SECRET
|
||||
assert workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
assert host == _GRADIENT_BASE_URL
|
||||
|
||||
assert "inputs" in json and "input" in json["inputs"][0]
|
||||
embeddings = []
|
||||
for inp in json["inputs"]:
|
||||
# verify correct ordering
|
||||
inp = inp["input"]
|
||||
if "pizza" in inp:
|
||||
v = [1.0, 0.0, 0.0]
|
||||
elif "document" in inp:
|
||||
v = [0.0, 0.9, 0.0]
|
||||
else:
|
||||
v = [0.0, 0.0, -1.0]
|
||||
if len(inp) > 10:
|
||||
v[2] += 0.1
|
||||
embeddings.append({"embedding": v})
|
||||
def get_embeddings_model(self, slug: str) -> GradientEmbeddingsModel:
|
||||
assert slug == _MODEL_ID
|
||||
return GradientEmbeddingsModel()
|
||||
|
||||
return MockResponse(
|
||||
json_data={"embeddings": embeddings},
|
||||
status_code=200,
|
||||
)
|
||||
def close(self) -> None:
|
||||
"""Mock Gradient close."""
|
||||
return
|
||||
|
||||
|
||||
def test_gradient_llm_sync(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||
class MockGradientaiPackage(MagicMock):
|
||||
"""Mock Gradientai package."""
|
||||
|
||||
embedder = GradientEmbeddings(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
gradient_access_token=_GRADIENT_SECRET,
|
||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||
model=_MODEL_ID,
|
||||
)
|
||||
assert embedder.gradient_access_token == _GRADIENT_SECRET
|
||||
assert embedder.gradient_api_url == _GRADIENT_BASE_URL
|
||||
assert embedder.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
assert embedder.model == _MODEL_ID
|
||||
|
||||
response = embedder.embed_documents(_DOCUMENTS)
|
||||
want = [
|
||||
[1.0, 0.0, 0.0], # pizza
|
||||
[1.0, 0.0, 0.1], # pizza + long
|
||||
[0.0, 0.9, 0.0], # doc
|
||||
[1.0, 0.0, 0.1], # pizza + long
|
||||
[0.0, 0.9, 0.1], # doc + long
|
||||
]
|
||||
|
||||
assert response == want
|
||||
Gradient = MockGradient
|
||||
__version__ = "1.4.0"
|
||||
|
||||
|
||||
def test_gradient_llm_large_batch_size(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||
def test_gradient_llm_sync() -> None:
|
||||
with patch.dict(sys.modules, {"gradientai": MockGradientaiPackage()}):
|
||||
embedder = GradientEmbeddings(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
gradient_access_token=_GRADIENT_SECRET,
|
||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||
model=_MODEL_ID,
|
||||
)
|
||||
assert embedder.gradient_access_token == _GRADIENT_SECRET
|
||||
assert embedder.gradient_api_url == _GRADIENT_BASE_URL
|
||||
assert embedder.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
assert embedder.model == _MODEL_ID
|
||||
|
||||
embedder = GradientEmbeddings(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
gradient_access_token=_GRADIENT_SECRET,
|
||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||
model=_MODEL_ID,
|
||||
)
|
||||
assert embedder.gradient_access_token == _GRADIENT_SECRET
|
||||
assert embedder.gradient_api_url == _GRADIENT_BASE_URL
|
||||
assert embedder.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
assert embedder.model == _MODEL_ID
|
||||
response = embedder.embed_documents(_DOCUMENTS)
|
||||
want = [
|
||||
[1.0, 0.0, 0.0], # pizza
|
||||
[1.0, 0.0, 0.1], # pizza + long
|
||||
[0.0, 0.9, 0.0], # doc
|
||||
[1.0, 0.0, 0.1], # pizza + long
|
||||
[0.0, 0.9, 0.1], # doc + long
|
||||
]
|
||||
|
||||
response = embedder.embed_documents(_DOCUMENTS * 1024)
|
||||
want = [
|
||||
[1.0, 0.0, 0.0], # pizza
|
||||
[1.0, 0.0, 0.1], # pizza + long
|
||||
[0.0, 0.9, 0.0], # doc
|
||||
[1.0, 0.0, 0.1], # pizza + long
|
||||
[0.0, 0.9, 0.1], # doc + long
|
||||
] * 1024
|
||||
|
||||
assert response == want
|
||||
assert response == want
|
||||
|
||||
|
||||
def test_gradient_wrong_setup(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||
|
||||
def test_gradient_wrong_setup() -> None:
|
||||
with pytest.raises(Exception):
|
||||
GradientEmbeddings(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
@ -130,6 +107,8 @@ def test_gradient_wrong_setup(
|
||||
model=_MODEL_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_gradient_wrong_setup2() -> None:
|
||||
with pytest.raises(Exception):
|
||||
GradientEmbeddings(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
@ -138,6 +117,8 @@ def test_gradient_wrong_setup(
|
||||
model=_MODEL_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_gradient_wrong_setup3() -> None:
|
||||
with pytest.raises(Exception):
|
||||
GradientEmbeddings(
|
||||
gradient_api_url="-", # empty
|
||||
|
@ -1,6 +1,3 @@
|
||||
from langchain_community.embeddings.gradient_ai import (
|
||||
GradientEmbeddings,
|
||||
TinyAsyncGradientEmbeddingClient,
|
||||
)
|
||||
from langchain_community.embeddings.gradient_ai import GradientEmbeddings
|
||||
|
||||
__all__ = ["GradientEmbeddings", "TinyAsyncGradientEmbeddingClient"]
|
||||
__all__ = ["GradientEmbeddings"]
|
||||
|
Loading…
Reference in New Issue
Block a user