mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 11:09:07 +00:00
They are now rejecting with code 401 calls from users with expired or invalid tokens (while before they were being considered anonymous). Thus, the authorization header has to be removed when there is no token. Related to: #23178 --------- Signed-off-by: Joffref <mariusjoffre@gmail.com>
99 lines
3.4 KiB
Python
99 lines
3.4 KiB
Python
import logging
|
|
import time
|
|
from typing import Any, List
|
|
|
|
import requests
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.pydantic_v1 import BaseModel, Extra
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OVHCloudEmbeddings(BaseModel, Embeddings):
|
|
"""
|
|
OVHcloud AI Endpoints Embeddings.
|
|
"""
|
|
|
|
""" OVHcloud AI Endpoints Access Token"""
|
|
access_token: str = ""
|
|
|
|
""" OVHcloud AI Endpoints model name for embeddings generation"""
|
|
model_name: str = ""
|
|
|
|
""" OVHcloud AI Endpoints region"""
|
|
region: str = "kepler"
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
def __init__(self, **kwargs: Any):
|
|
super().__init__(**kwargs)
|
|
if self.access_token == "":
|
|
raise ValueError("Access token is required for OVHCloud embeddings.")
|
|
if self.model_name == "":
|
|
raise ValueError("Model name is required for OVHCloud embeddings.")
|
|
if self.region == "":
|
|
raise ValueError("Region is required for OVHCloud embeddings.")
|
|
|
|
def _generate_embedding(self, text: str) -> List[float]:
|
|
"""Generate embeddings from OVHCLOUD AIE.
|
|
Args:
|
|
text (str): The text to embed.
|
|
Returns:
|
|
List[float]: Embeddings for the text.
|
|
"""
|
|
headers = {
|
|
"content-type": "text/plain",
|
|
"Authorization": f"Bearer {self.access_token}",
|
|
}
|
|
|
|
session = requests.session()
|
|
while True:
|
|
response = session.post(
|
|
f"https://{self.model_name}.endpoints.{self.region}.ai.cloud.ovh.net/api/text2vec",
|
|
headers=headers,
|
|
data=text,
|
|
)
|
|
if response.status_code != 200:
|
|
if response.status_code == 429:
|
|
"""Rate limit exceeded, wait for reset"""
|
|
reset_time = int(response.headers.get("RateLimit-Reset", 0))
|
|
logger.info("Rate limit exceeded. Waiting %d seconds.", reset_time)
|
|
if reset_time > 0:
|
|
time.sleep(reset_time)
|
|
continue
|
|
else:
|
|
"""Rate limit reset time has passed, retry immediately"""
|
|
continue
|
|
if response.status_code == 401:
|
|
""" Unauthorized, retry with new token """
|
|
raise ValueError("Unauthorized, retry with new token")
|
|
""" Handle other non-200 status codes """
|
|
raise ValueError(
|
|
"Request failed with status code: {status_code}, {text}".format(
|
|
status_code=response.status_code, text=response.text
|
|
)
|
|
)
|
|
return response.json()
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Create a retry decorator for PremAIEmbeddings.
|
|
Args:
|
|
texts (List[str]): The list of texts to embed.
|
|
|
|
Returns:
|
|
List[List[float]]: List of embeddings, one for each input text.
|
|
"""
|
|
return [self._generate_embedding(text) for text in texts]
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Embed a single query text.
|
|
Args:
|
|
text (str): The text to embed.
|
|
Returns:
|
|
List[float]: Embeddings for the text.
|
|
"""
|
|
return self._generate_embedding(text)
|