mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
small fixes to Vertex (#10934)
Fixed tests, updated the required version of the SDK and a few minor changes after the recent improvement (https://github.com/langchain-ai/langchain/pull/10910)
This commit is contained in:
parent
4e58b78102
commit
9d4b710a48
@ -138,7 +138,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
|
||||
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
||||
except ImportError:
|
||||
raise_vertex_import_error(minimum_expected_version="1.29.0")
|
||||
raise_vertex_import_error()
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
@ -173,15 +173,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = {**self._default_params, **kwargs}
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
response = chat.send_message(question.content)
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@ -209,15 +210,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = {**self._default_params, **kwargs}
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
response = await chat.send_message_async(question.content)
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -228,7 +230,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = {**self._default_params, **kwargs}
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
@ -236,10 +238,9 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
chat = self._start_chat(history, params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
for response in responses:
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(text)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
run_manager.on_llm_new_token(response.text)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||
|
||||
def _start_chat(
|
||||
self, history: _ChatHistory, params: dict
|
||||
|
@ -18,7 +18,6 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema import (
|
||||
Generation,
|
||||
@ -151,13 +150,6 @@ class _VertexAIBase(BaseModel):
|
||||
model_name: Optional[str] = None
|
||||
"Underlying model name."
|
||||
|
||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||
if stop is None and self.stop is not None:
|
||||
stop = self.stop
|
||||
if stop:
|
||||
return enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
||||
if cls.task_executor is None:
|
||||
@ -220,6 +212,14 @@ class _VertexAICommon(_VertexAIBase):
|
||||
init_vertexai(**params)
|
||||
return None
|
||||
|
||||
def _prepare_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
stop_sequences = stop or self.stop
|
||||
return {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
|
||||
|
||||
class VertexAI(_VertexAICommon, BaseLLM):
|
||||
"""Google Vertex AI large language models."""
|
||||
@ -228,7 +228,6 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
"The name of the Vertex AI large language model."
|
||||
tuned_model_name: Optional[str] = None
|
||||
"The name of a tuned model. If provided, model_name is ignored."
|
||||
streaming: bool = False
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -267,10 +266,8 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
stop_sequences = stop or self.stop
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
if should_stream:
|
||||
@ -294,8 +291,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
stop_sequences = stop or self.stop
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
res = await acompletion_with_retry(
|
||||
@ -311,8 +307,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
stop_sequences = stop or self.stop
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
for stream_resp in stream_completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
):
|
||||
|
@ -5,7 +5,7 @@ if TYPE_CHECKING:
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.26.1") -> None:
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None:
|
||||
"""Raise ImportError related to Vertex SDK being not available.
|
||||
|
||||
Args:
|
||||
|
@ -7,6 +7,7 @@ pip install google-cloud-aiplatform>=1.25.0
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
"""
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -27,7 +28,7 @@ def test_vertexai_single_call(model_name: str) -> None:
|
||||
response = model([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert model._llm_type == "chat-vertexai"
|
||||
assert model._llm_type == "vertexai"
|
||||
assert model.model_name == model.client._model_id
|
||||
|
||||
|
||||
@ -127,7 +128,8 @@ def test_vertexai_single_call_failes_no_message() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_vertexai_args_passed() -> None:
|
||||
@pytest.mark.parametrize("stop", [None, "stop1"])
|
||||
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||
response_text = "Goodbye"
|
||||
user_prompt = "Hello"
|
||||
prompt_params = {
|
||||
@ -149,12 +151,19 @@ def test_vertexai_args_passed() -> None:
|
||||
|
||||
model = ChatVertexAI(**prompt_params)
|
||||
message = HumanMessage(content=user_prompt)
|
||||
if stop:
|
||||
response = model([message], stop=[stop])
|
||||
else:
|
||||
response = model([message])
|
||||
|
||||
assert response.content == response_text
|
||||
mock_send_message.assert_called_once_with(user_prompt)
|
||||
expected_stop_sequence = [stop] if stop else None
|
||||
start_chat.assert_called_once_with(
|
||||
context=None, message_history=[], **prompt_params
|
||||
context=None,
|
||||
message_history=[],
|
||||
**prompt_params,
|
||||
stop_sequences=expected_stop_sequence
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user