mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
openai[patch]: Allow disablling safe_len_embeddings(OpenAIEmbeddings) (#19743)
OpenAI API compatible server may not support `safe_len_embedding`, use `disable_safe_len_embeddings=True` to disable it. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5b83130855
commit
ed26149a29
@ -129,6 +129,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
check_embedding_ctx_length: bool = True
|
||||
"""Whether to check the token length of inputs and automatically split inputs
|
||||
longer than embedding_ctx_length."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -511,6 +514,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: List[List[float]] = []
|
||||
for text in texts:
|
||||
response = self.client.create(
|
||||
input=text,
|
||||
**self._invocation_params,
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
embeddings.extend(r["embedding"] for r in response["data"])
|
||||
return embeddings
|
||||
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
@ -529,6 +544,18 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
if not self.check_embedding_ctx_length:
|
||||
embeddings: List[List[float]] = []
|
||||
for text in texts:
|
||||
response = await self.async_client.create(
|
||||
input=text,
|
||||
**self._invocation_params,
|
||||
)
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
embeddings.extend(r["embedding"] for r in response["data"])
|
||||
return embeddings
|
||||
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
|
Loading…
Reference in New Issue
Block a user