mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 07:50:39 +00:00
small fixes to Vertex (#10934)
Fixed tests, updated the required version of the SDK and a few minor changes after the recent improvement (https://github.com/langchain-ai/langchain/pull/10910)
This commit is contained in:
parent
4e58b78102
commit
9d4b710a48
@ -138,7 +138,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
|
|
||||||
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error(minimum_expected_version="1.29.0")
|
raise_vertex_import_error()
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -173,15 +173,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
|
|
||||||
question = _get_question(messages)
|
question = _get_question(messages)
|
||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
params = {**self._default_params, **kwargs}
|
params = self._prepare_params(stop=stop, **kwargs)
|
||||||
examples = kwargs.get("examples", None)
|
examples = kwargs.get("examples", None)
|
||||||
if examples:
|
if examples:
|
||||||
params["examples"] = _parse_examples(examples)
|
params["examples"] = _parse_examples(examples)
|
||||||
|
|
||||||
chat = self._start_chat(history, params)
|
chat = self._start_chat(history, params)
|
||||||
response = chat.send_message(question.content)
|
response = chat.send_message(question.content)
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
return ChatResult(
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
||||||
|
)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
@ -209,15 +210,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
logger.warning("ChatVertexAI does not currently support async streaming.")
|
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||||
question = _get_question(messages)
|
question = _get_question(messages)
|
||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
params = {**self._default_params, **kwargs}
|
params = self._prepare_params(stop=stop, **kwargs)
|
||||||
examples = kwargs.get("examples", None)
|
examples = kwargs.get("examples", None)
|
||||||
if examples:
|
if examples:
|
||||||
params["examples"] = _parse_examples(examples)
|
params["examples"] = _parse_examples(examples)
|
||||||
|
|
||||||
chat = self._start_chat(history, params)
|
chat = self._start_chat(history, params)
|
||||||
response = await chat.send_message_async(question.content)
|
response = await chat.send_message_async(question.content)
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
return ChatResult(
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
||||||
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -228,7 +230,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
question = _get_question(messages)
|
question = _get_question(messages)
|
||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
params = {**self._default_params, **kwargs}
|
params = self._prepare_params(stop=stop, **kwargs)
|
||||||
examples = kwargs.get("examples", None)
|
examples = kwargs.get("examples", None)
|
||||||
if examples:
|
if examples:
|
||||||
params["examples"] = _parse_examples(examples)
|
params["examples"] = _parse_examples(examples)
|
||||||
@ -236,10 +238,9 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
chat = self._start_chat(history, params)
|
chat = self._start_chat(history, params)
|
||||||
responses = chat.send_message_streaming(question.content, **params)
|
responses = chat.send_message_streaming(question.content, **params)
|
||||||
for response in responses:
|
for response in responses:
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(text)
|
run_manager.on_llm_new_token(response.text)
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||||
|
|
||||||
def _start_chat(
|
def _start_chat(
|
||||||
self, history: _ChatHistory, params: dict
|
self, history: _ChatHistory, params: dict
|
||||||
|
@ -18,7 +18,6 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
|
||||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
Generation,
|
Generation,
|
||||||
@ -151,13 +150,6 @@ class _VertexAIBase(BaseModel):
|
|||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None
|
||||||
"Underlying model name."
|
"Underlying model name."
|
||||||
|
|
||||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
|
||||||
if stop is None and self.stop is not None:
|
|
||||||
stop = self.stop
|
|
||||||
if stop:
|
|
||||||
return enforce_stop_tokens(text, stop)
|
|
||||||
return text
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
||||||
if cls.task_executor is None:
|
if cls.task_executor is None:
|
||||||
@ -220,6 +212,14 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
init_vertexai(**params)
|
init_vertexai(**params)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _prepare_params(
|
||||||
|
self,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict:
|
||||||
|
stop_sequences = stop or self.stop
|
||||||
|
return {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||||
|
|
||||||
|
|
||||||
class VertexAI(_VertexAICommon, BaseLLM):
|
class VertexAI(_VertexAICommon, BaseLLM):
|
||||||
"""Google Vertex AI large language models."""
|
"""Google Vertex AI large language models."""
|
||||||
@ -228,7 +228,6 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
"The name of the Vertex AI large language model."
|
"The name of the Vertex AI large language model."
|
||||||
tuned_model_name: Optional[str] = None
|
tuned_model_name: Optional[str] = None
|
||||||
"The name of a tuned model. If provided, model_name is ignored."
|
"The name of a tuned model. If provided, model_name is ignored."
|
||||||
streaming: bool = False
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -267,10 +266,8 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
stop_sequences = stop or self.stop
|
|
||||||
should_stream = stream if stream is not None else self.streaming
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
params = self._prepare_params(stop=stop, **kwargs)
|
||||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
|
||||||
generations = []
|
generations = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
if should_stream:
|
if should_stream:
|
||||||
@ -294,8 +291,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
stop_sequences = stop or self.stop
|
params = self._prepare_params(stop=stop, **kwargs)
|
||||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
|
||||||
generations = []
|
generations = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
res = await acompletion_with_retry(
|
res = await acompletion_with_retry(
|
||||||
@ -311,8 +307,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
stop_sequences = stop or self.stop
|
params = self._prepare_params(stop=stop, **kwargs)
|
||||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
|
||||||
for stream_resp in stream_completion_with_retry(
|
for stream_resp in stream_completion_with_retry(
|
||||||
self, prompt, run_manager=run_manager, **params
|
self, prompt, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
|
@ -5,7 +5,7 @@ if TYPE_CHECKING:
|
|||||||
from google.auth.credentials import Credentials
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
|
||||||
def raise_vertex_import_error(minimum_expected_version: str = "1.26.1") -> None:
|
def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None:
|
||||||
"""Raise ImportError related to Vertex SDK being not available.
|
"""Raise ImportError related to Vertex SDK being not available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -7,6 +7,7 @@ pip install google-cloud-aiplatform>=1.25.0
|
|||||||
Your end-user credentials would be used to make the calls (make sure you've run
|
Your end-user credentials would be used to make the calls (make sure you've run
|
||||||
`gcloud auth login` first).
|
`gcloud auth login` first).
|
||||||
"""
|
"""
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -27,7 +28,7 @@ def test_vertexai_single_call(model_name: str) -> None:
|
|||||||
response = model([message])
|
response = model([message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
assert model._llm_type == "chat-vertexai"
|
assert model._llm_type == "vertexai"
|
||||||
assert model.model_name == model.client._model_id
|
assert model.model_name == model.client._model_id
|
||||||
|
|
||||||
|
|
||||||
@ -127,7 +128,8 @@ def test_vertexai_single_call_failes_no_message() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_vertexai_args_passed() -> None:
|
@pytest.mark.parametrize("stop", [None, "stop1"])
|
||||||
|
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||||
response_text = "Goodbye"
|
response_text = "Goodbye"
|
||||||
user_prompt = "Hello"
|
user_prompt = "Hello"
|
||||||
prompt_params = {
|
prompt_params = {
|
||||||
@ -149,12 +151,19 @@ def test_vertexai_args_passed() -> None:
|
|||||||
|
|
||||||
model = ChatVertexAI(**prompt_params)
|
model = ChatVertexAI(**prompt_params)
|
||||||
message = HumanMessage(content=user_prompt)
|
message = HumanMessage(content=user_prompt)
|
||||||
|
if stop:
|
||||||
|
response = model([message], stop=[stop])
|
||||||
|
else:
|
||||||
response = model([message])
|
response = model([message])
|
||||||
|
|
||||||
assert response.content == response_text
|
assert response.content == response_text
|
||||||
mock_send_message.assert_called_once_with(user_prompt)
|
mock_send_message.assert_called_once_with(user_prompt)
|
||||||
|
expected_stop_sequence = [stop] if stop else None
|
||||||
start_chat.assert_called_once_with(
|
start_chat.assert_called_once_with(
|
||||||
context=None, message_history=[], **prompt_params
|
context=None,
|
||||||
|
message_history=[],
|
||||||
|
**prompt_params,
|
||||||
|
stop_sequences=expected_stop_sequence
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user