diff --git a/libs/langchain/langchain/chat_models/vertexai.py b/libs/langchain/langchain/chat_models/vertexai.py index 91da917ea6f..42475bfe96e 100644 --- a/libs/langchain/langchain/chat_models/vertexai.py +++ b/libs/langchain/langchain/chat_models/vertexai.py @@ -12,10 +12,7 @@ from langchain.callbacks.manager import ( from langchain.chat_models.base import BaseChatModel, _generate_from_stream from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.pydantic_v1 import root_validator -from langchain.schema import ( - ChatGeneration, - ChatResult, -) +from langchain.schema import ChatGeneration, ChatResult from langchain.schema.messages import ( AIMessage, AIMessageChunk, @@ -177,16 +174,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): question = _get_question(messages) history = _parse_chat_history(messages[:-1]) - params = self._prepare_params(stop=stop, **kwargs) + params = self._prepare_params(stop=stop, stream=False, **kwargs) examples = kwargs.get("examples", None) if examples: params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, params) - response = chat.send_message(question.content) - return ChatResult( - generations=[ChatGeneration(message=AIMessage(content=response.text))] - ) + msg_params = {} + if "candidate_count" in params: + msg_params["candidate_count"] = params.pop("candidate_count") + + chat = self._start_chat(history, **params) + response = chat.send_message(question.content, **msg_params) + generations = [ + ChatGeneration(message=AIMessage(content=r.text)) + for r in response.candidates + ] + return ChatResult(generations=generations) async def _agenerate( self, @@ -219,11 +222,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): if examples: params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, params) - response = await chat.send_message_async(question.content) - return ChatResult( - generations=[ChatGeneration(message=AIMessage(content=response.text))] - ) + msg_params = {} + if "candidate_count" in params: + msg_params["candidate_count"] = params.pop("candidate_count") + chat = self._start_chat(history, **params) + response = await chat.send_message_async(question.content, **msg_params) + generations = [ + ChatGeneration(message=AIMessage(content=r.text)) + for r in response.candidates + ] + return ChatResult(generations=generations) def _stream( self, @@ -239,7 +247,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): if examples: params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, params) + chat = self._start_chat(history, **params) responses = chat.send_message_streaming(question.content, **params) for response in responses: if run_manager: @@ -247,11 +255,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): yield ChatGenerationChunk(message=AIMessageChunk(content=response.text)) def _start_chat( - self, history: _ChatHistory, params: dict + self, history: _ChatHistory, **kwargs: Any ) -> Union[ChatSession, CodeChatSession]: if not self.is_codey_model: return self.client.start_chat( - context=history.context, message_history=history.history, **params + context=history.context, message_history=history.history, **kwargs ) else: - return self.client.start_chat(message_history=history.history, **params) + return self.client.start_chat(message_history=history.history, **kwargs) diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index c8625699c97..a2dc147edff 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -175,7 +175,10 @@ class _VertexAICommon(_VertexAIBase): "The default custom credentials (google.auth.credentials.Credentials) to use " "when making API calls. If not provided, credentials will be ascertained from " "the environment." + n: int = 1 + """How many completions to generate for each prompt.""" streaming: bool = False + """Whether to stream the results or not.""" @property def _llm_type(self) -> str: @@ -203,6 +206,7 @@ class _VertexAICommon(_VertexAIBase): "max_output_tokens": self.max_output_tokens, "top_k": self.top_k, "top_p": self.top_p, + "candidate_count": self.n, } @classmethod @@ -215,10 +219,16 @@ class _VertexAICommon(_VertexAIBase): def _prepare_params( self, stop: Optional[List[str]] = None, + stream: bool = False, **kwargs: Any, ) -> dict: stop_sequences = stop or self.stop - return {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + params_mapping = {"n": "candidate_count"} + params = {params_mapping.get(k, k): v for k, v in kwargs.items()} + params = {**self._default_params, "stop_sequences": stop_sequences, **params} + if stream or self.streaming: + params.pop("candidate_count") + return params class VertexAI(_VertexAICommon, BaseLLM): @@ -260,6 +270,9 @@ class VertexAI(_VertexAICommon, BaseLLM): values["client"] = CodeGenerationModel.from_pretrained(model_name) except ImportError: raise_vertex_import_error() + + if values["streaming"] and values["n"] > 1: + raise ValueError("Only one candidate can be generated with streaming!") return values def _generate( @@ -271,7 +284,7 @@ class VertexAI(_VertexAICommon, BaseLLM): **kwargs: Any, ) -> LLMResult: should_stream = stream if stream is not None else self.streaming - params = self._prepare_params(stop=stop, **kwargs) + params = self._prepare_params(stop=stop, stream=should_stream, **kwargs) generations = [] for prompt in prompts: if should_stream: @@ -285,7 +298,7 @@ class VertexAI(_VertexAICommon, BaseLLM): res = completion_with_retry( self, prompt, run_manager=run_manager, **params ) - generations.append([_response_to_generation(res)]) + generations.append([_response_to_generation(r) for r in res.candidates]) return LLMResult(generations=generations) async def _agenerate( @@ -301,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM): res = await acompletion_with_retry( self, prompt, run_manager=run_manager, **params ) - generations.append([_response_to_generation(res)]) + generations.append([_response_to_generation(r) for r in res.candidates]) return LLMResult(generations=generations) def _stream( @@ -311,7 +324,7 @@ class VertexAI(_VertexAICommon, BaseLLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: - params = self._prepare_params(stop=stop, **kwargs) + params = self._prepare_params(stop=stop, stream=True, **kwargs) for stream_resp in stream_completion_with_retry( self, prompt, run_manager=run_manager, **params ): diff --git a/libs/langchain/langchain/utilities/vertexai.py b/libs/langchain/langchain/utilities/vertexai.py index 244292db43e..0df556307a6 100644 --- a/libs/langchain/langchain/utilities/vertexai.py +++ b/libs/langchain/langchain/utilities/vertexai.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from google.auth.credentials import Credentials -def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None: +def raise_vertex_import_error(minimum_expected_version: str = "1.35.0") -> None: """Raise ImportError related to Vertex SDK being not available. Args: diff --git a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py index 7a50f70a5cb..17a7c0ac723 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py @@ -20,7 +20,10 @@ from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage @pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"]) def test_vertexai_instantiation(model_name: str) -> None: - model = ChatVertexAI(model_name=model_name) + if model_name: + model = ChatVertexAI(model_name=model_name) + else: + model = ChatVertexAI() assert model._llm_type == "vertexai" assert model.model_name == model.client._model_id @@ -38,6 +41,15 @@ def test_vertexai_single_call(model_name: str) -> None: assert isinstance(response.content, str) +def test_candidates() -> None: + model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2) + message = HumanMessage(content="Hello") + response = model.generate(messages=[[message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 1 + assert len(response.generations[0]) == 2 + + @pytest.mark.scheduled @pytest.mark.asyncio async def test_vertexai_agenerate() -> None: @@ -153,7 +165,8 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None: with patch( "vertexai.language_models._language_models.ChatModel.start_chat" ) as start_chat: - mock_response = Mock(text=response_text) + mock_response = MagicMock() + mock_response.candidates = [Mock(text=response_text)] mock_chat = MagicMock() start_chat.return_value = mock_chat mock_send_message = MagicMock(return_value=mock_response) @@ -167,7 +180,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None: response = model([message]) assert response.content == response_text - mock_send_message.assert_called_once_with(user_prompt) + mock_send_message.assert_called_once_with(user_prompt, candidate_count=1) expected_stop_sequence = [stop] if stop else None start_chat.assert_called_once_with( context=None, diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/libs/langchain/tests/integration_tests/llms/test_vertexai.py index 561e10a002c..2bee6af181c 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/llms/test_vertexai.py @@ -29,29 +29,31 @@ def test_vertex_call() -> None: @pytest.mark.scheduled def test_vertex_generate() -> None: - llm = VertexAI(temperate=0) - output = llm.generate(["Please say foo:"]) + llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001") + output = llm.generate(["Say foo:"]) assert isinstance(output, LLMResult) + assert len(output.generations) == 1 + assert len(output.generations[0]) == 2 @pytest.mark.scheduled @pytest.mark.asyncio async def test_vertex_agenerate() -> None: - llm = VertexAI(temperate=0) + llm = VertexAI(temperature=0) output = await llm.agenerate(["Please say foo:"]) assert isinstance(output, LLMResult) @pytest.mark.scheduled def test_vertex_stream() -> None: - llm = VertexAI(temperate=0) + llm = VertexAI(temperature=0) outputs = list(llm.stream("Please say foo:")) assert isinstance(outputs[0], str) @pytest.mark.asyncio async def test_vertex_consistency() -> None: - llm = VertexAI(temperate=0) + llm = VertexAI(temperature=0) output = llm.generate(["Please say foo:"]) streaming_output = llm.generate(["Please say foo:"], stream=True) async_output = await llm.agenerate(["Please say foo:"])