community[patch]: Voyage AI updates default model and batch size (#17655)

- **Description:** update the default model and batch size in
VoyageEmbeddings
    - **Issue:** N/A
    - **Dependencies:** N/A
    - **Twitter handle:** N/A

---------

Co-authored-by: fodizoltan <zoltan@conway.expert>
This commit is contained in:
Yujie Qian 2024-03-01 10:22:24 -08:00 committed by GitHub
parent ae471a7dcb
commit cbb65741a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 15 deletions

View File

@ -27,7 +27,7 @@
"id": "137cfde9-b88c-409a-9394-a9e31a6bf30d",
"metadata": {},
"source": [
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key."
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key. Please refer to the documentation for further details on the available models: https://docs.voyageai.com/embeddings/"
]
},
{
@ -37,7 +37,9 @@
"metadata": {},
"outputs": [],
"source": [
"embeddings = VoyageEmbeddings(voyage_api_key=\"[ Your Voyage API key ]\")"
"embeddings = VoyageEmbeddings(\n",
" voyage_api_key=\"[ Your Voyage API key ]\", model=\"voyage-2\"\n",
")"
]
},
{

View File

@ -69,15 +69,15 @@ class VoyageEmbeddings(BaseModel, Embeddings):
from langchain_community.embeddings import VoyageEmbeddings
voyage = VoyageEmbeddings(voyage_api_key="your-api-key")
voyage = VoyageEmbeddings(voyage_api_key="your-api-key", model="voyage-2")
text = "This is a test query."
query_result = voyage.embed_query(text)
"""
model: str = "voyage-01"
model: str
voyage_api_base: str = "https://api.voyageai.com/v1/embeddings"
voyage_api_key: Optional[SecretStr] = None
batch_size: int = 8
batch_size: int
"""Maximum number of texts to embed in each API request."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
@ -86,15 +86,12 @@ class VoyageEmbeddings(BaseModel, Embeddings):
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding. Must have tqdm installed if set
to True."""
truncation: Optional[bool] = None
truncation: bool = True
"""Whether to truncate the input texts to fit within the context length.
If True, over-length input texts will be truncated to fit within the context
length, before vectorized by the embedding model. If False, an error will be
raised if any given text exceeds the context length. If not specified
(defaults to None), we will truncate the input text before sending it to the
embedding model if it slightly exceeds the context window length. If it
significantly exceeds the context window length, an error will be raised."""
raised if any given text exceeds the context length."""
class Config:
"""Configuration for this pydantic object."""
@ -107,6 +104,22 @@ class VoyageEmbeddings(BaseModel, Embeddings):
values["voyage_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY")
)
if "model" not in values:
values["model"] = "voyage-01"
logger.warning(
"model will become a required arg for VoyageAIEmbeddings, "
"we recommend to specify it when using this class. "
"Currently the default is set to voyage-01."
)
if "batch_size" not in values:
values["batch_size"] = (
72
if "model" in values and (values["model"] in ["voyage-2", "voyage-02"])
else 7
)
return values
def _invocation_params(
@ -116,11 +129,14 @@ class VoyageEmbeddings(BaseModel, Embeddings):
params: Dict = {
"url": self.voyage_api_base,
"headers": {"Authorization": f"Bearer {api_key}"},
"json": {"model": self.model, "input": input, "input_type": input_type},
"json": {
"model": self.model,
"input": input,
"input_type": input_type,
"truncation": self.truncation,
},
"timeout": self.request_timeout,
}
if self.truncation is not None:
params["json"]["truncation"] = self.truncation
return params
def _get_embeddings(
@ -186,7 +202,9 @@ class VoyageEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
return self._get_embeddings([text], input_type="query")[0]
return self._get_embeddings(
[text], batch_size=self.batch_size, input_type="query"
)[0]
def embed_general_texts(
self, texts: List[str], *, input_type: Optional[str] = None

View File

@ -2,7 +2,7 @@
from langchain_community.embeddings.voyageai import VoyageEmbeddings
# Please set VOYAGE_API_KEY in the environment variables
MODEL = "voyage-01"
MODEL = "voyage-2"
def test_voyagi_embedding_documents() -> None:
@ -14,10 +14,22 @@ def test_voyagi_embedding_documents() -> None:
assert len(output[0]) == 1024
def test_voyagi_with_default_model() -> None:
"""Test voyage embeddings."""
embedding = VoyageEmbeddings()
assert embedding.model == "voyage-01"
assert embedding.batch_size == 7
documents = [f"foo bar {i}" for i in range(72)]
output = embedding.embed_documents(documents)
assert len(output) == 72
assert len(output[0]) == 1024
def test_voyage_embedding_documents_multiple() -> None:
"""Test voyage embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = VoyageEmbeddings(model=MODEL, batch_size=2)
assert embedding.model == MODEL
output = embedding.embed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 1024