mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
community: ✨ Use new OVHcloud batch embedding (#26209)
- **Description:** change to do the batch embedding server side and not client side - **Twitter handle:** @wildagsx --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
a54f390090
commit
4b8cd7a09a
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
@ -41,17 +42,55 @@ class OVHCloudEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List[float]: Embeddings for the text.
|
List[float]: Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
return self._send_request_to_ai_endpoints("text/plain", text, "text2vec")
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed a list of documents.
|
||||||
|
Args:
|
||||||
|
texts (List[str]): The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: List of embeddings, one for each input text.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._send_request_to_ai_endpoints(
|
||||||
|
"application/json", json.dumps(texts), "batch_text2vec"
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Embed a single query text.
|
||||||
|
Args:
|
||||||
|
text (str): The text to embed.
|
||||||
|
Returns:
|
||||||
|
List[float]: Embeddings for the text.
|
||||||
|
"""
|
||||||
|
return self._generate_embedding(text)
|
||||||
|
|
||||||
|
def _send_request_to_ai_endpoints(
|
||||||
|
self, contentType: str, payload: str, route: str
|
||||||
|
) -> Any:
|
||||||
|
"""Send a HTTPS request to OVHcloud AI Endpoints
|
||||||
|
Args:
|
||||||
|
contentType (str): The content type of the request, application/json or text/plain.
|
||||||
|
payload (str): The payload of the request.
|
||||||
|
route (str): The route of the request, batch_text2vec or text2vec.
|
||||||
|
""" # noqa: E501
|
||||||
headers = {
|
headers = {
|
||||||
"content-type": "text/plain",
|
"content-type": contentType,
|
||||||
"Authorization": f"Bearer {self.access_token}",
|
"Authorization": f"Bearer {self.access_token}",
|
||||||
}
|
}
|
||||||
|
|
||||||
session = requests.session()
|
session = requests.session()
|
||||||
while True:
|
while True:
|
||||||
response = session.post(
|
response = session.post(
|
||||||
f"https://{self.model_name}.endpoints.{self.region}.ai.cloud.ovh.net/api/text2vec",
|
(
|
||||||
|
f"https://{self.model_name}.endpoints.{self.region}"
|
||||||
|
f".ai.cloud.ovh.net/api/{route}"
|
||||||
|
),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=text,
|
data=payload,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
if response.status_code == 429:
|
if response.status_code == 429:
|
||||||
@ -74,22 +113,3 @@ class OVHCloudEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Create a retry decorator for PremAIEmbeddings.
|
|
||||||
Args:
|
|
||||||
texts (List[str]): The list of texts to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[float]]: List of embeddings, one for each input text.
|
|
||||||
"""
|
|
||||||
return [self._generate_embedding(text) for text in texts]
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""Embed a single query text.
|
|
||||||
Args:
|
|
||||||
text (str): The text to embed.
|
|
||||||
Returns:
|
|
||||||
List[float]: Embeddings for the text.
|
|
||||||
"""
|
|
||||||
return self._generate_embedding(text)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user