This commit is contained in:
Bagatur
2023-10-13 15:38:35 -07:00
parent 26b66a59fa
commit 7dda1bf45a
5 changed files with 35 additions and 35 deletions

View File

@@ -40,12 +40,12 @@ def _create_retry_decorator(embeddings: DashScopeEmbeddings) -> Callable[[Any],
)
def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
def __embed_with_retry(**kwargs: Any) -> Any:
resp = embeddings.client.call(**kwargs)
if resp.status_code == 200:
return resp.output["embeddings"]
@@ -61,7 +61,7 @@ def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any:
response=resp,
)
return _embed_with_retry(**kwargs)
return __embed_with_retry(**kwargs)
class DashScopeEmbeddings(BaseModel, Embeddings):
@@ -135,7 +135,7 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(
embeddings = _embed_with_retry(
self, input=texts, text_type="document", model=self.model
)
embedding_list = [item["embedding"] for item in embeddings]
@@ -150,7 +150,7 @@ class DashScopeEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
embedding = embed_with_retry(
embedding = _embed_with_retry(
self, input=text, text_type="query", model=self.model
)[0]["embedding"]
return embedding

View File

@@ -40,17 +40,17 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)
def embed_with_retry(
def _embed_with_retry(
embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
def __embed_with_retry(*args: Any, **kwargs: Any) -> Any:
return embeddings.client.generate_embeddings(*args, **kwargs)
return _embed_with_retry(*args, **kwargs)
return __embed_with_retry(*args, **kwargs)
class GooglePalmEmbeddings(BaseModel, Embeddings):
@@ -83,5 +83,5 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
embedding = embed_with_retry(self, self.model_name, text)
embedding = _embed_with_retry(self, self.model_name, text)
return embedding["embedding"]

View File

@@ -94,27 +94,27 @@ def _check_response(response: dict) -> dict:
return response
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
def __embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response)
return _embed_with_retry(**kwargs)
return __embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
async def _async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
async def __async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
return _check_response(response)
return await _async_embed_with_retry(**kwargs)
return await __async_embed_with_retry(**kwargs)
class LocalAIEmbeddings(BaseModel, Embeddings):
@@ -265,13 +265,13 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
return _embed_with_retry(
self,
input=[text],
**self._invocation_params,
)["data"][
0
]["embedding"]
)[
"data"
][0]["embedding"]
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
@@ -281,7 +281,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
await async_embed_with_retry(
await _async_embed_with_retry(
self,
input=[text],
**self._invocation_params,

View File

@@ -34,15 +34,15 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)
def embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
def __embed_with_retry(*args: Any, **kwargs: Any) -> Any:
return embeddings.embed(*args, **kwargs)
return _embed_with_retry(*args, **kwargs)
return __embed_with_retry(*args, **kwargs)
class MiniMaxEmbeddings(BaseModel, Embeddings):
@@ -144,7 +144,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
embeddings = _embed_with_retry(self, texts=texts, embed_type=self.embed_type_db)
return embeddings
def embed_query(self, text: str) -> List[float]:
@@ -156,7 +156,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embeddings = embed_with_retry(
embeddings = _embed_with_retry(
self, texts=[text], embed_type=self.embed_type_query
)
return embeddings[0]

View File

@@ -95,27 +95,27 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
return response
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
def _embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
def __embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response, skip_empty=embeddings.skip_empty)
return _embed_with_retry(**kwargs)
return __embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
async def _async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
async def __async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
return _check_response(response, skip_empty=embeddings.skip_empty)
return await _async_embed_with_retry(**kwargs)
return await __async_embed_with_retry(**kwargs)
class OpenAIEmbeddings(BaseModel, Embeddings):
@@ -371,7 +371,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_iter = range(0, len(tokens), _chunk_size)
for i in _iter:
response = embed_with_retry(
response = _embed_with_retry(
self,
input=tokens[i : i + _chunk_size],
**self._invocation_params,
@@ -389,7 +389,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in range(len(texts)):
_result = results[i]
if len(_result) == 0:
average = embed_with_retry(
average = _embed_with_retry(
self,
input="",
**self._invocation_params,
@@ -443,7 +443,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
batched_embeddings: List[List[float]] = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(tokens), _chunk_size):
response = await async_embed_with_retry(
response = await _async_embed_with_retry(
self,
input=tokens[i : i + _chunk_size],
**self._invocation_params,
@@ -460,7 +460,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
_result = results[i]
if len(_result) == 0:
average = (
await async_embed_with_retry(
await _async_embed_with_retry(
self,
input="",
**self._invocation_params,