IMPROVEMENT: add input_type to VoyageEmbeddings (#13488)

- **Description:** add input_type to VoyageEmbeddings
This commit is contained in:
Yujie Qian 2023-11-16 16:35:36 -08:00 committed by GitHub
parent ea6e017b85
commit 41a433fa33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -101,17 +101,21 @@ class VoyageEmbeddings(BaseModel, Embeddings):
) )
return values return values
def _invocation_params(self, input: List[str]) -> Dict: def _invocation_params(
self, input: List[str], input_type: Optional[str] = None
) -> Dict:
api_key = cast(SecretStr, self.voyage_api_key).get_secret_value() api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
params = { params = {
"url": self.voyage_api_base, "url": self.voyage_api_base,
"headers": {"Authorization": f"Bearer {api_key}"}, "headers": {"Authorization": f"Bearer {api_key}"},
"json": {"model": self.model, "input": input}, "json": {"model": self.model, "input": input, "input_type": input_type},
"timeout": self.request_timeout, "timeout": self.request_timeout,
} }
return params return params
def _get_embeddings(self, texts: List[str], batch_size: int) -> List[List[float]]: def _get_embeddings(
self, texts: List[str], batch_size: int, input_type: Optional[str] = None
) -> List[List[float]]:
embeddings: List[List[float]] = [] embeddings: List[List[float]] = []
if self.show_progress_bar: if self.show_progress_bar:
@ -127,9 +131,18 @@ class VoyageEmbeddings(BaseModel, Embeddings):
else: else:
_iter = range(0, len(texts), batch_size) _iter = range(0, len(texts), batch_size)
if input_type and input_type not in ["query", "document"]:
raise ValueError(
f"input_type {input_type} is invalid. Options: None, 'query', "
"'document'."
)
for i in _iter: for i in _iter:
response = embed_with_retry( response = embed_with_retry(
self, **self._invocation_params(input=texts[i : i + batch_size]) self,
**self._invocation_params(
input=texts[i : i + batch_size], input_type=input_type
),
) )
embeddings.extend(r["embedding"] for r in response["data"]) embeddings.extend(r["embedding"] for r in response["data"])
@ -144,7 +157,9 @@ class VoyageEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
return self._get_embeddings(texts, batch_size=self.batch_size) return self._get_embeddings(
texts, batch_size=self.batch_size, input_type="document"
)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to Voyage Embedding endpoint for embedding query text. """Call out to Voyage Embedding endpoint for embedding query text.
@ -155,4 +170,6 @@ class VoyageEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embedding for the text. Embedding for the text.
""" """
return self.embed_documents([text])[0] return self._get_embeddings(
[text], batch_size=self.batch_size, input_type="query"
)[0]