mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
community[patch]: changed default for VertexAIEmbeddings (#14614)
Replace this entire comment with: - **Description:** @kurtisvg has raised a point that it's a good idea to have a fixed version for embeddings (since otherwise a user might run a query with one version vs a vectorstore where another version was used). In order to avoid breaking changes, I'd suggest to give users a warning, and make a `model_name` a required argument in 1.5 months.
This commit is contained in:
parent
138bc49759
commit
b99274c9d8
@ -29,21 +29,28 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
|||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validates that the python package exists in environment."""
|
"""Validates that the python package exists in environment."""
|
||||||
cls._try_init_vertexai(values)
|
cls._try_init_vertexai(values)
|
||||||
|
if values["model_name"] == "textembedding-gecko-default":
|
||||||
|
logger.warning(
|
||||||
|
"Model_name will become a required arg for VertexAIEmbeddings "
|
||||||
|
"starting from Feb-01-2024. Currently the default is set to "
|
||||||
|
"textembedding-gecko@001"
|
||||||
|
)
|
||||||
|
values["model_name"] = "textembedding-gecko@001"
|
||||||
try:
|
try:
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
from vertexai.language_models import TextEmbeddingModel
|
||||||
|
|
||||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
|
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
# the default value would be removed after Feb-01-2024
|
||||||
|
model_name: str = "textembedding-gecko-default",
|
||||||
project: Optional[str] = None,
|
project: Optional[str] = None,
|
||||||
location: str = "us-central1",
|
location: str = "us-central1",
|
||||||
request_parallelism: int = 5,
|
request_parallelism: int = 5,
|
||||||
max_retries: int = 6,
|
max_retries: int = 6,
|
||||||
model_name: str = "textembedding-gecko",
|
|
||||||
credentials: Optional[Any] = None,
|
credentials: Optional[Any] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
|
@ -5,6 +5,8 @@ pip install google-cloud-aiplatform>=1.35.0
|
|||||||
Your end-user credentials would be used to make the calls (make sure you've run
|
Your end-user credentials would be used to make the calls (make sure you've run
|
||||||
`gcloud auth login` first).
|
`gcloud auth login` first).
|
||||||
"""
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain_community.embeddings import VertexAIEmbeddings
|
from langchain_community.embeddings import VertexAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +17,7 @@ def test_embedding_documents() -> None:
|
|||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
assert len(output[0]) == 768
|
assert len(output[0]) == 768
|
||||||
assert model.model_name == model.client._model_id
|
assert model.model_name == model.client._model_id
|
||||||
|
assert model.model_name == "textembedding-gecko@001"
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_query() -> None:
|
def test_embedding_query() -> None:
|
||||||
@ -50,3 +53,15 @@ def test_paginated_texts() -> None:
|
|||||||
assert len(output) == 8
|
assert len(output) == 8
|
||||||
assert len(output[0]) == 768
|
assert len(output[0]) == 768
|
||||||
assert model.model_name == model.client._model_id
|
assert model.model_name == model.client._model_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_warning(caplog: pytest.LogCaptureFixture) -> None:
|
||||||
|
_ = VertexAIEmbeddings()
|
||||||
|
assert len(caplog.records) == 1
|
||||||
|
record = caplog.records[0]
|
||||||
|
assert record.levelname == "WARNING"
|
||||||
|
expected_message = (
|
||||||
|
"Model_name will become a required arg for VertexAIEmbeddings starting from "
|
||||||
|
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
|
||||||
|
)
|
||||||
|
assert record.message == expected_message
|
||||||
|
Loading…
Reference in New Issue
Block a user