mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 14:43:08 +00:00
IMPROVEMENT: add input_type to VoyageEmbeddings (#13488)
- **Description:** add input_type to VoyageEmbeddings
This commit is contained in:
parent
ea6e017b85
commit
41a433fa33
@ -101,17 +101,21 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _invocation_params(self, input: List[str]) -> Dict:
|
def _invocation_params(
|
||||||
|
self, input: List[str], input_type: Optional[str] = None
|
||||||
|
) -> Dict:
|
||||||
api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
|
api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
|
||||||
params = {
|
params = {
|
||||||
"url": self.voyage_api_base,
|
"url": self.voyage_api_base,
|
||||||
"headers": {"Authorization": f"Bearer {api_key}"},
|
"headers": {"Authorization": f"Bearer {api_key}"},
|
||||||
"json": {"model": self.model, "input": input},
|
"json": {"model": self.model, "input": input, "input_type": input_type},
|
||||||
"timeout": self.request_timeout,
|
"timeout": self.request_timeout,
|
||||||
}
|
}
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _get_embeddings(self, texts: List[str], batch_size: int) -> List[List[float]]:
|
def _get_embeddings(
|
||||||
|
self, texts: List[str], batch_size: int, input_type: Optional[str] = None
|
||||||
|
) -> List[List[float]]:
|
||||||
embeddings: List[List[float]] = []
|
embeddings: List[List[float]] = []
|
||||||
|
|
||||||
if self.show_progress_bar:
|
if self.show_progress_bar:
|
||||||
@ -127,9 +131,18 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
|||||||
else:
|
else:
|
||||||
_iter = range(0, len(texts), batch_size)
|
_iter = range(0, len(texts), batch_size)
|
||||||
|
|
||||||
|
if input_type and input_type not in ["query", "document"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"input_type {input_type} is invalid. Options: None, 'query', "
|
||||||
|
"'document'."
|
||||||
|
)
|
||||||
|
|
||||||
for i in _iter:
|
for i in _iter:
|
||||||
response = embed_with_retry(
|
response = embed_with_retry(
|
||||||
self, **self._invocation_params(input=texts[i : i + batch_size])
|
self,
|
||||||
|
**self._invocation_params(
|
||||||
|
input=texts[i : i + batch_size], input_type=input_type
|
||||||
|
),
|
||||||
)
|
)
|
||||||
embeddings.extend(r["embedding"] for r in response["data"])
|
embeddings.extend(r["embedding"] for r in response["data"])
|
||||||
|
|
||||||
@ -144,7 +157,9 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
return self._get_embeddings(texts, batch_size=self.batch_size)
|
return self._get_embeddings(
|
||||||
|
texts, batch_size=self.batch_size, input_type="document"
|
||||||
|
)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to Voyage Embedding endpoint for embedding query text.
|
"""Call out to Voyage Embedding endpoint for embedding query text.
|
||||||
@ -155,4 +170,6 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
"""
|
"""
|
||||||
return self.embed_documents([text])[0]
|
return self._get_embeddings(
|
||||||
|
[text], batch_size=self.batch_size, input_type="query"
|
||||||
|
)[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user