mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 06:13:36 +00:00
Switched VertexAI models from preview (#13657)
Replace this entire comment with: - **Description:** VertexAI models are now GA, moved away from using preview ones from the SDK - **Issue:** #13606 --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
a47f1da884
commit
85bb3a418c
@ -132,16 +132,14 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
"""Validate that the python package exists in environment."""
|
"""Validate that the python package exists in environment."""
|
||||||
cls._try_init_vertexai(values)
|
cls._try_init_vertexai(values)
|
||||||
try:
|
try:
|
||||||
if is_codey_model(values["model_name"]):
|
from vertexai.language_models import ChatModel, CodeChatModel
|
||||||
from vertexai.preview.language_models import CodeChatModel
|
|
||||||
|
|
||||||
values["client"] = CodeChatModel.from_pretrained(values["model_name"])
|
|
||||||
else:
|
|
||||||
from vertexai.preview.language_models import ChatModel
|
|
||||||
|
|
||||||
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
|
if is_codey_model(values["model_name"]):
|
||||||
|
model_cls = CodeChatModel
|
||||||
|
else:
|
||||||
|
model_cls = ChatModel
|
||||||
|
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
|
@ -17,7 +17,7 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
|||||||
"""Validates that the python package exists in environment."""
|
"""Validates that the python package exists in environment."""
|
||||||
cls._try_init_vertexai(values)
|
cls._try_init_vertexai(values)
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.language_models import TextEmbeddingModel
|
from vertexai.language_models import TextEmbeddingModel
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
|
||||||
|
@ -195,19 +195,19 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
if self.is_codey_model:
|
params = {
|
||||||
return {
|
"temperature": self.temperature,
|
||||||
"temperature": self.temperature,
|
"max_output_tokens": self.max_output_tokens,
|
||||||
"max_output_tokens": self.max_output_tokens,
|
"candidate_count": self.n,
|
||||||
}
|
}
|
||||||
else:
|
if not self.is_codey_model:
|
||||||
return {
|
params.update(
|
||||||
"temperature": self.temperature,
|
{
|
||||||
"max_output_tokens": self.max_output_tokens,
|
"top_k": self.top_k,
|
||||||
"top_k": self.top_k,
|
"top_p": self.top_p,
|
||||||
"top_p": self.top_p,
|
}
|
||||||
"candidate_count": self.n,
|
)
|
||||||
}
|
return params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _try_init_vertexai(cls, values: Dict) -> None:
|
def _try_init_vertexai(cls, values: Dict) -> None:
|
||||||
@ -250,24 +250,33 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
tuned_model_name = values.get("tuned_model_name")
|
tuned_model_name = values.get("tuned_model_name")
|
||||||
model_name = values["model_name"]
|
model_name = values["model_name"]
|
||||||
try:
|
try:
|
||||||
if not is_codey_model(model_name):
|
from vertexai.language_models import (
|
||||||
from vertexai.preview.language_models import TextGenerationModel
|
CodeGenerationModel,
|
||||||
|
TextGenerationModel,
|
||||||
|
)
|
||||||
|
from vertexai.preview.language_models import (
|
||||||
|
CodeGenerationModel as PreviewCodeGenerationModel,
|
||||||
|
)
|
||||||
|
from vertexai.preview.language_models import (
|
||||||
|
TextGenerationModel as PreviewTextGenerationModel,
|
||||||
|
)
|
||||||
|
|
||||||
if tuned_model_name:
|
if is_codey_model(model_name):
|
||||||
values["client"] = TextGenerationModel.get_tuned_model(
|
model_cls = CodeGenerationModel
|
||||||
tuned_model_name
|
preview_model_cls = PreviewCodeGenerationModel
|
||||||
)
|
|
||||||
else:
|
|
||||||
values["client"] = TextGenerationModel.from_pretrained(model_name)
|
|
||||||
else:
|
else:
|
||||||
from vertexai.preview.language_models import CodeGenerationModel
|
model_cls = TextGenerationModel
|
||||||
|
preview_model_cls = PreviewTextGenerationModel
|
||||||
|
|
||||||
|
if tuned_model_name:
|
||||||
|
values["client"] = model_cls.get_tuned_model(tuned_model_name)
|
||||||
|
values["client_preview"] = preview_model_cls.get_tuned_model(
|
||||||
|
tuned_model_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
values["client"] = model_cls.from_pretrained(model_name)
|
||||||
|
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
|
||||||
|
|
||||||
if tuned_model_name:
|
|
||||||
values["client"] = CodeGenerationModel.get_tuned_model(
|
|
||||||
tuned_model_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
values["client"] = CodeGenerationModel.from_pretrained(model_name)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
|
|
||||||
@ -287,12 +296,9 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
The integer number of tokens in the text.
|
The integer number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = self.client.count_tokens([text])
|
result = self.client_preview.count_tokens([text])
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise NotImplementedError(
|
raise_vertex_import_error()
|
||||||
"Your google-cloud-aiplatform version didn't implement count_tokens."
|
|
||||||
"Please, install it with pip install google-cloud-aiplatform>=1.35.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result.total_tokens
|
return result.total_tokens
|
||||||
|
|
||||||
@ -319,12 +325,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
res = completion_with_retry(
|
res = completion_with_retry(
|
||||||
self, prompt, run_manager=run_manager, **params
|
self, prompt, run_manager=run_manager, **params
|
||||||
)
|
)
|
||||||
if self.is_codey_model:
|
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||||
generations.append([_response_to_generation(res)])
|
|
||||||
else:
|
|
||||||
generations.append(
|
|
||||||
[_response_to_generation(r) for r in res.candidates]
|
|
||||||
)
|
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
|
@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
|||||||
from google.auth.credentials import Credentials
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
|
||||||
def raise_vertex_import_error(minimum_expected_version: str = "1.35.0") -> None:
|
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
|
||||||
"""Raise ImportError related to Vertex SDK being not available.
|
"""Raise ImportError related to Vertex SDK being not available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -16,7 +16,7 @@ def raise_vertex_import_error(minimum_expected_version: str = "1.35.0") -> None:
|
|||||||
ImportError: an ImportError that mentions a required version of the SDK.
|
ImportError: an ImportError that mentions a required version of the SDK.
|
||||||
"""
|
"""
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import VertexAI. Please, install it with "
|
"Please, install or upgrade the google-cloud-aiplatform library: "
|
||||||
f"pip install google-cloud-aiplatform>={minimum_expected_version}"
|
f"pip install google-cloud-aiplatform>={minimum_expected_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
"""Test Vertex AI API wrapper.
|
"""Test Vertex AI API wrapper.
|
||||||
In order to run this test, you need to install VertexAI SDK (that is is the private
|
In order to run this test, you need to install VertexAI SDK:
|
||||||
preview) and be whitelisted to list the models themselves:
|
pip install google-cloud-aiplatform>=1.36.0
|
||||||
In order to run this test, you need to install VertexAI SDK
|
|
||||||
pip install google-cloud-aiplatform>=1.35.0
|
|
||||||
|
|
||||||
Your end-user credentials would be used to make the calls (make sure you've run
|
Your end-user credentials would be used to make the calls (make sure you've run
|
||||||
`gcloud auth login` first).
|
`gcloud auth login` first).
|
||||||
@ -150,10 +148,10 @@ async def test_model_garden_agenerate(
|
|||||||
assert len(output.generations) == 2
|
assert len(output.generations) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_vertex_call_trigger_count_tokens() -> None:
|
def test_vertex_call_count_tokens() -> None:
|
||||||
llm = VertexAI()
|
llm = VertexAI()
|
||||||
output = llm.get_num_tokens("Hi")
|
output = llm.get_num_tokens("How are you?")
|
||||||
assert output == 2
|
assert output == 4
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("google.cloud.aiplatform")
|
@pytest.mark.requires("google.cloud.aiplatform")
|
||||||
|
Loading…
Reference in New Issue
Block a user