mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
Clean up OpenAI Embeddings to fix method name and comments (#2687)
**Problem:** OpenAI Embeddings has a few minor issues: method name and comment for _completion_with_retry seems to be a copypasta error and a few comments around usage of embedding_ctx_length seem to be incorrect. **Solution:** Clean up issues. --------- Co-authored-by: Vijay Rajaram <vrajaram3@gatech.edu>
This commit is contained in:
parent
ad3c5dd186
commit
28bef6f87d
@ -43,14 +43,14 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any
|
|||||||
|
|
||||||
|
|
||||||
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the embedding call."""
|
||||||
retry_decorator = _create_retry_decorator(embeddings)
|
retry_decorator = _create_retry_decorator(embeddings)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||||
return embeddings.client.create(**kwargs)
|
return embeddings.client.create(**kwargs)
|
||||||
|
|
||||||
return _completion_with_retry(**kwargs)
|
return _embed_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||||
@ -231,10 +231,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint."""
|
"""Call out to OpenAI's embedding endpoint."""
|
||||||
# replace newlines, which can negatively affect performance.
|
# handle large input text
|
||||||
if self.embedding_ctx_length > 0:
|
if self.embedding_ctx_length > 0:
|
||||||
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
||||||
else:
|
else:
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return embed_with_retry(self, input=[text], engine=engine)["data"][0][
|
return embed_with_retry(self, input=[text], engine=engine)["data"][0][
|
||||||
"embedding"
|
"embedding"
|
||||||
@ -253,7 +254,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
# handle large batches of texts
|
# handle batches of large input text
|
||||||
if self.embedding_ctx_length > 0:
|
if self.embedding_ctx_length > 0:
|
||||||
return self._get_len_safe_embeddings(texts, engine=self.document_model_name)
|
return self._get_len_safe_embeddings(texts, engine=self.document_model_name)
|
||||||
else:
|
else:
|
||||||
@ -275,7 +276,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
text: The text to embed.
|
text: The text to embed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embedding for the text.
|
||||||
"""
|
"""
|
||||||
embedding = self._embedding_func(text, engine=self.query_model_name)
|
embedding = self._embedding_func(text, engine=self.query_model_name)
|
||||||
return embedding
|
return embedding
|
||||||
|
Loading…
Reference in New Issue
Block a user