openai[patch]: Allow disablling safe_len_embeddings(OpenAIEmbeddings) (#19743)

OpenAI API compatible server may not support `safe_len_embedding`, 

use `disable_safe_len_embeddings=True` to disable it.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
YISH 2024-04-26 00:45:52 +08:00 committed by GitHub
parent 5b83130855
commit ed26149a29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -129,6 +129,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
check_embedding_ctx_length: bool = True
"""Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length."""
class Config:
"""Configuration for this pydantic object."""
@ -511,6 +514,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
response = self.client.create(
input=text,
**self._invocation_params,
)
if not isinstance(response, dict):
response = response.dict()
embeddings.extend(r["embedding"] for r in response["data"])
return embeddings
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)
@ -529,6 +544,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
response = await self.async_client.create(
input=text,
**self._invocation_params,
)
if not isinstance(response, dict):
response = response.dict()
embeddings.extend(r["embedding"] for r in response["data"])
return embeddings
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment)