mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
community[patch]: add VoyageEmbeddings truncation (#17638)
This commit is contained in:
parent
d7c26c89b2
commit
a058c8812d
@ -86,6 +86,15 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding. Must have tqdm installed if set
|
||||
to True."""
|
||||
truncation: Optional[bool] = None
|
||||
"""Whether to truncate the input texts to fit within the context length.
|
||||
|
||||
If True, over-length input texts will be truncated to fit within the context
|
||||
length, before vectorized by the embedding model. If False, an error will be
|
||||
raised if any given text exceeds the context length. If not specified
|
||||
(defaults to None), we will truncate the input text before sending it to the
|
||||
embedding model if it slightly exceeds the context window length. If it
|
||||
significantly exceeds the context window length, an error will be raised."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -104,12 +113,14 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
||||
self, input: List[str], input_type: Optional[str] = None
|
||||
) -> Dict:
|
||||
api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
|
||||
params = {
|
||||
params: Dict = {
|
||||
"url": self.voyage_api_base,
|
||||
"headers": {"Authorization": f"Bearer {api_key}"},
|
||||
"json": {"model": self.model, "input": input, "input_type": input_type},
|
||||
"timeout": self.request_timeout,
|
||||
}
|
||||
if self.truncation is not None:
|
||||
params["json"]["truncation"] = self.truncation
|
||||
return params
|
||||
|
||||
def _get_embeddings(
|
||||
|
Loading…
Reference in New Issue
Block a user