community[patch]: Voyage AI updates default model and batch size (#17655)

- **Description:** update the default model and batch size in
VoyageEmbeddings
    - **Issue:** N/A
    - **Dependencies:** N/A
    - **Twitter handle:** N/A

---------

Co-authored-by: fodizoltan <zoltan@conway.expert>
This commit is contained in:
Yujie Qian 2024-03-01 10:22:24 -08:00 committed by GitHub
parent ae471a7dcb
commit cbb65741a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 15 deletions

View File

@ -27,7 +27,7 @@
"id": "137cfde9-b88c-409a-9394-a9e31a6bf30d", "id": "137cfde9-b88c-409a-9394-a9e31a6bf30d",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key." "Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key. Please refer to the documentation for further details on the available models: https://docs.voyageai.com/embeddings/"
] ]
}, },
{ {
@ -37,7 +37,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"embeddings = VoyageEmbeddings(voyage_api_key=\"[ Your Voyage API key ]\")" "embeddings = VoyageEmbeddings(\n",
" voyage_api_key=\"[ Your Voyage API key ]\", model=\"voyage-2\"\n",
")"
] ]
}, },
{ {

View File

@ -69,15 +69,15 @@ class VoyageEmbeddings(BaseModel, Embeddings):
from langchain_community.embeddings import VoyageEmbeddings from langchain_community.embeddings import VoyageEmbeddings
voyage = VoyageEmbeddings(voyage_api_key="your-api-key") voyage = VoyageEmbeddings(voyage_api_key="your-api-key", model="voyage-2")
text = "This is a test query." text = "This is a test query."
query_result = voyage.embed_query(text) query_result = voyage.embed_query(text)
""" """
model: str = "voyage-01" model: str
voyage_api_base: str = "https://api.voyageai.com/v1/embeddings" voyage_api_base: str = "https://api.voyageai.com/v1/embeddings"
voyage_api_key: Optional[SecretStr] = None voyage_api_key: Optional[SecretStr] = None
batch_size: int = 8 batch_size: int
"""Maximum number of texts to embed in each API request.""" """Maximum number of texts to embed in each API request."""
max_retries: int = 6 max_retries: int = 6
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
@ -86,15 +86,12 @@ class VoyageEmbeddings(BaseModel, Embeddings):
show_progress_bar: bool = False show_progress_bar: bool = False
"""Whether to show a progress bar when embedding. Must have tqdm installed if set """Whether to show a progress bar when embedding. Must have tqdm installed if set
to True.""" to True."""
truncation: Optional[bool] = None truncation: bool = True
"""Whether to truncate the input texts to fit within the context length. """Whether to truncate the input texts to fit within the context length.
If True, over-length input texts will be truncated to fit within the context If True, over-length input texts will be truncated to fit within the context
length, before vectorized by the embedding model. If False, an error will be length, before vectorized by the embedding model. If False, an error will be
raised if any given text exceeds the context length. If not specified raised if any given text exceeds the context length."""
(defaults to None), we will truncate the input text before sending it to the
embedding model if it slightly exceeds the context window length. If it
significantly exceeds the context window length, an error will be raised."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -107,6 +104,22 @@ class VoyageEmbeddings(BaseModel, Embeddings):
values["voyage_api_key"] = convert_to_secret_str( values["voyage_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY") get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY")
) )
if "model" not in values:
values["model"] = "voyage-01"
logger.warning(
"model will become a required arg for VoyageAIEmbeddings, "
"we recommend to specify it when using this class. "
"Currently the default is set to voyage-01."
)
if "batch_size" not in values:
values["batch_size"] = (
72
if "model" in values and (values["model"] in ["voyage-2", "voyage-02"])
else 7
)
return values return values
def _invocation_params( def _invocation_params(
@ -116,11 +129,14 @@ class VoyageEmbeddings(BaseModel, Embeddings):
params: Dict = { params: Dict = {
"url": self.voyage_api_base, "url": self.voyage_api_base,
"headers": {"Authorization": f"Bearer {api_key}"}, "headers": {"Authorization": f"Bearer {api_key}"},
"json": {"model": self.model, "input": input, "input_type": input_type}, "json": {
"model": self.model,
"input": input,
"input_type": input_type,
"truncation": self.truncation,
},
"timeout": self.request_timeout, "timeout": self.request_timeout,
} }
if self.truncation is not None:
params["json"]["truncation"] = self.truncation
return params return params
def _get_embeddings( def _get_embeddings(
@ -186,7 +202,9 @@ class VoyageEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embedding for the text. Embedding for the text.
""" """
return self._get_embeddings([text], input_type="query")[0] return self._get_embeddings(
[text], batch_size=self.batch_size, input_type="query"
)[0]
def embed_general_texts( def embed_general_texts(
self, texts: List[str], *, input_type: Optional[str] = None self, texts: List[str], *, input_type: Optional[str] = None

View File

@ -2,7 +2,7 @@
from langchain_community.embeddings.voyageai import VoyageEmbeddings from langchain_community.embeddings.voyageai import VoyageEmbeddings
# Please set VOYAGE_API_KEY in the environment variables # Please set VOYAGE_API_KEY in the environment variables
MODEL = "voyage-01" MODEL = "voyage-2"
def test_voyagi_embedding_documents() -> None: def test_voyagi_embedding_documents() -> None:
@ -14,10 +14,22 @@ def test_voyagi_embedding_documents() -> None:
assert len(output[0]) == 1024 assert len(output[0]) == 1024
def test_voyagi_with_default_model() -> None:
"""Test voyage embeddings."""
embedding = VoyageEmbeddings()
assert embedding.model == "voyage-01"
assert embedding.batch_size == 7
documents = [f"foo bar {i}" for i in range(72)]
output = embedding.embed_documents(documents)
assert len(output) == 72
assert len(output[0]) == 1024
def test_voyage_embedding_documents_multiple() -> None: def test_voyage_embedding_documents_multiple() -> None:
"""Test voyage embeddings.""" """Test voyage embeddings."""
documents = ["foo bar", "bar foo", "foo"] documents = ["foo bar", "bar foo", "foo"]
embedding = VoyageEmbeddings(model=MODEL, batch_size=2) embedding = VoyageEmbeddings(model=MODEL, batch_size=2)
assert embedding.model == MODEL
output = embedding.embed_documents(documents) output = embedding.embed_documents(documents)
assert len(output) == 3 assert len(output) == 3
assert len(output[0]) == 1024 assert len(output[0]) == 1024