mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 15:03:21 +00:00
community: Fix OVHcloud 401 Unauthorized on embedding. (#23260)
They are now rejecting with code 401 calls from users with expired or invalid tokens (while before they were being considered anonymous). Thus, the authorization header has to be removed when there is no token. Related to: #23178 --------- Signed-off-by: Joffref <mariusjoffre@gmail.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@@ -11,14 +11,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class OVHCloudEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
Usage:
|
||||
OVH_AI_ENDPOINTS_ACCESS_TOKEN="your-token" python3 langchain_embedding.py
|
||||
NB: Make sure you are using a valid token.
|
||||
In the contrary, document indexing will be long due to rate-limiting.
|
||||
OVHcloud AI Endpoints Embeddings.
|
||||
"""
|
||||
|
||||
""" OVHcloud AI Endpoints Access Token"""
|
||||
access_token: Optional[str] = None
|
||||
access_token: str = ""
|
||||
|
||||
""" OVHcloud AI Endpoints model name for embeddings generation"""
|
||||
model_name: str = ""
|
||||
@@ -33,10 +30,8 @@ class OVHCloudEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
if self.access_token is None:
|
||||
logger.warning(
|
||||
"No access token provided indexing will be slow due to rate limiting."
|
||||
)
|
||||
if self.access_token == "":
|
||||
raise ValueError("Access token is required for OVHCloud embeddings.")
|
||||
if self.model_name == "":
|
||||
raise ValueError("Model name is required for OVHCloud embeddings.")
|
||||
if self.region == "":
|
||||
@@ -72,7 +67,9 @@ class OVHCloudEmbeddings(BaseModel, Embeddings):
|
||||
else:
|
||||
"""Rate limit reset time has passed, retry immediately"""
|
||||
continue
|
||||
|
||||
if response.status_code == 401:
|
||||
""" Unauthorized, retry with new token """
|
||||
raise ValueError("Unauthorized, retry with new token")
|
||||
""" Handle other non-200 status codes """
|
||||
raise ValueError(
|
||||
"Request failed with status code: {status_code}, {text}".format(
|
||||
|
@@ -1,8 +0,0 @@
|
||||
from langchain_community.embeddings.ovhcloud import OVHCloudEmbeddings
|
||||
|
||||
|
||||
def test_ovhcloud_embed_documents() -> None:
|
||||
llm = OVHCloudEmbeddings(model_name="multilingual-e5-base")
|
||||
docs = ["Hello", "World"]
|
||||
output = llm.embed_documents(docs)
|
||||
assert len(output) == len(docs)
|
@@ -4,22 +4,28 @@ from langchain_community.embeddings.ovhcloud import OVHCloudEmbeddings
|
||||
|
||||
|
||||
def test_ovhcloud_correct_instantiation() -> None:
|
||||
llm = OVHCloudEmbeddings(model_name="multilingual-e5-base")
|
||||
llm = OVHCloudEmbeddings(model_name="multilingual-e5-base", access_token="token")
|
||||
assert isinstance(llm, OVHCloudEmbeddings)
|
||||
llm = OVHCloudEmbeddings(
|
||||
model_name="multilingual-e5-base", region="kepler", access_token="token"
|
||||
)
|
||||
assert isinstance(llm, OVHCloudEmbeddings)
|
||||
|
||||
|
||||
def test_ovhcloud_empty_model_name_should_raise_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OVHCloudEmbeddings(model_name="")
|
||||
OVHCloudEmbeddings(model_name="", region="kepler", access_token="token")
|
||||
|
||||
|
||||
def test_ovhcloud_empty_region_should_raise_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OVHCloudEmbeddings(model_name="multilingual-e5-base", region="")
|
||||
OVHCloudEmbeddings(
|
||||
model_name="multilingual-e5-base", region="", access_token="token"
|
||||
)
|
||||
|
||||
|
||||
def test_ovhcloud_empty_access_token_should_not_raise_error() -> None:
|
||||
llm = OVHCloudEmbeddings(
|
||||
model_name="multilingual-e5-base", region="kepler", access_token=""
|
||||
)
|
||||
assert isinstance(llm, OVHCloudEmbeddings)
|
||||
def test_ovhcloud_empty_access_token_should_raise_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OVHCloudEmbeddings(
|
||||
model_name="multilingual-e5-base", region="kepler", access_token=""
|
||||
)
|
||||
|
Reference in New Issue
Block a user