From 0749a642f5b9aa4cb83bf0d9edd1b0d61c4abefc Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 20 Sep 2023 11:49:16 -0700 Subject: [PATCH] Stream refac and vertex streaming (#10470) --------- Co-authored-by: Terry Cruz Melo Co-authored-by: Terry Cruz Melo <33166112+TerryCM@users.noreply.github.com> --- .../langchain/chat_models/anthropic.py | 66 +++++++------- libs/langchain/langchain/chat_models/base.py | 24 ++++++ .../langchain/chat_models/jinachat.py | 30 +++---- libs/langchain/langchain/chat_models/konko.py | 15 ++-- .../langchain/chat_models/litellm.py | 32 +++---- .../langchain/langchain/chat_models/openai.py | 35 +++----- .../langchain/chat_models/vertexai.py | 86 +++++++++++++++---- 7 files changed, 168 insertions(+), 120 deletions(-) diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py index a5c04dde630..07634f6dcb6 100644 --- a/libs/langchain/langchain/chat_models/anthropic.py +++ b/libs/langchain/langchain/chat_models/anthropic.py @@ -4,7 +4,11 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream, +) from langchain.llms.anthropic import _AnthropicCommon from langchain.schema.messages import ( AIMessage, @@ -162,22 +166,22 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): **kwargs: Any, ) -> ChatResult: if self.streaming: - completion = "" - for chunk in self._stream(messages, stop, run_manager, **kwargs): - completion += chunk.text - else: - prompt = self._convert_messages_to_prompt( - messages, + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs ) - params: Dict[str, Any] = { - "prompt": prompt, - **self._default_params, - **kwargs, - } - if stop: - params["stop_sequences"] = stop - response = self.client.completions.create(**params) - completion = response.completion + return _generate_from_stream(stream_iter) + prompt = self._convert_messages_to_prompt( + messages, + ) + params: Dict[str, Any] = { + "prompt": prompt, + **self._default_params, + **kwargs, + } + if stop: + params["stop_sequences"] = stop + response = self.client.completions.create(**params) + completion = response.completion message = AIMessage(content=completion) return ChatResult(generations=[ChatGeneration(message=message)]) @@ -189,22 +193,22 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): **kwargs: Any, ) -> ChatResult: if self.streaming: - completion = "" - async for chunk in self._astream(messages, stop, run_manager, **kwargs): - completion += chunk.text - else: - prompt = self._convert_messages_to_prompt( - messages, + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs ) - params: Dict[str, Any] = { - "prompt": prompt, - **self._default_params, - **kwargs, - } - if stop: - params["stop_sequences"] = stop - response = await self.async_client.completions.create(**params) - completion = response.completion + return await _agenerate_from_stream(stream_iter) + prompt = self._convert_messages_to_prompt( + messages, + ) + params: Dict[str, Any] = { + "prompt": prompt, + **self._default_params, + **kwargs, + } + if stop: + params["stop_sequences"] = stop + response = await self.async_client.completions.create(**params) + completion = response.completion message = AIMessage(content=completion) return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 61bf020823f..094418c94b6 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -49,6 +49,30 @@ def _get_verbosity() -> bool: return langchain.verbose +def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult: + generation: Optional[ChatGenerationChunk] = None + for chunk in stream: + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + + +async def _agenerate_from_stream( + stream: AsyncIterator[ChatGenerationChunk], +) -> ChatResult: + generation: Optional[ChatGenerationChunk] = None + async for chunk in stream: + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation]) + + class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): """Base class for Chat models.""" diff --git a/libs/langchain/langchain/chat_models/jinachat.py b/libs/langchain/langchain/chat_models/jinachat.py index 7c1e6f90341..d0e80bd288c 100644 --- a/libs/langchain/langchain/chat_models/jinachat.py +++ b/libs/langchain/langchain/chat_models/jinachat.py @@ -27,7 +27,11 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream, +) from langchain.pydantic_v1 import Field, root_validator from langchain.schema import ( AIMessage, @@ -319,16 +323,10 @@ class JinaChat(BaseChatModel): **kwargs: Any, ) -> ChatResult: if self.streaming: - generation: Optional[ChatGenerationChunk] = None - for chunk in self._stream( + stream_iter = self._stream( messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) + ) + return _generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} @@ -384,16 +382,10 @@ class JinaChat(BaseChatModel): **kwargs: Any, ) -> ChatResult: if self.streaming: - generation: Optional[ChatGenerationChunk] = None - async for chunk in self._astream( + stream_iter = self._astream( messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) + ) + return await _agenerate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} diff --git a/libs/langchain/langchain/chat_models/konko.py b/libs/langchain/langchain/chat_models/konko.py index e27ee42057d..9735fb2ce5a 100644 --- a/libs/langchain/langchain/chat_models/konko.py +++ b/libs/langchain/langchain/chat_models/konko.py @@ -21,6 +21,7 @@ from langchain.adapters.openai import convert_dict_to_message, convert_message_t from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) +from langchain.chat_models.base import _generate_from_stream from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk from langchain.pydantic_v1 import Field, root_validator from langchain.schema import ChatGeneration, ChatResult @@ -224,16 +225,10 @@ class ChatKonko(ChatOpenAI): ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: - generation: Optional[ChatGenerationChunk] = None - for chunk in self._stream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} diff --git a/libs/langchain/langchain/chat_models/litellm.py b/libs/langchain/langchain/chat_models/litellm.py index f9ecf67073a..f568cf2760c 100644 --- a/libs/langchain/langchain/chat_models/litellm.py +++ b/libs/langchain/langchain/chat_models/litellm.py @@ -19,7 +19,11 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream, +) from langchain.llms.base import create_base_retry_decorator from langchain.pydantic_v1 import Field, root_validator from langchain.schema import ( @@ -320,16 +324,10 @@ class ChatLiteLLM(BaseChatModel): ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: - generation: Optional[ChatGenerationChunk] = None - for chunk in self._stream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} @@ -421,16 +419,10 @@ class ChatLiteLLM(BaseChatModel): ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: - generation: Optional[ChatGenerationChunk] = None - async for chunk in self._astream( + stream_iter = self._astream( messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) + ) + return await _agenerate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index 47f29eaf2ae..664201b008d 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -22,7 +22,11 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream, +) from langchain.llms.base import create_base_retry_decorator from langchain.pydantic_v1 import Field, root_validator from langchain.schema import ChatGeneration, ChatResult @@ -330,17 +334,10 @@ class ChatOpenAI(BaseChatModel): ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: - generation: Optional[ChatGenerationChunk] = None - for chunk in self._stream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) - + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} response = self.completion_with_retry( @@ -411,16 +408,10 @@ class ChatOpenAI(BaseChatModel): ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: - generation: Optional[ChatGenerationChunk] = None - async for chunk in self._astream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await _agenerate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} diff --git a/libs/langchain/langchain/chat_models/vertexai.py b/libs/langchain/langchain/chat_models/vertexai.py index 79a80b6b4c4..0b0909f74d4 100644 --- a/libs/langchain/langchain/chat_models/vertexai.py +++ b/libs/langchain/langchain/chat_models/vertexai.py @@ -1,9 +1,11 @@ """Wrapper around Google VertexAI chat-based models.""" +from __future__ import annotations + from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.base import BaseChatModel, _generate_from_stream from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.pydantic_v1 import root_validator from langchain.schema import ( @@ -12,14 +14,21 @@ from langchain.schema import ( ) from langchain.schema.messages import ( AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ) +from langchain.schema.output import ChatGenerationChunk from langchain.utilities.vertexai import raise_vertex_import_error if TYPE_CHECKING: - from vertexai.language_models import ChatMessage, InputOutputTextPair + from vertexai.language_models import ( + ChatMessage, + ChatSession, + CodeChatSession, + InputOutputTextPair, + ) @dataclass @@ -91,10 +100,23 @@ def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]: return example_pairs +def _get_question(messages: List[BaseMessage]) -> HumanMessage: + """Get the human message at the end of a list of input messages to a chat model.""" + if not messages: + raise ValueError("You should provide at least one message to start the chat!") + question = messages[-1] + if not isinstance(question, HumanMessage): + raise ValueError( + f"Last message in the list should be from human, got {question.type}." + ) + return question + + class ChatVertexAI(_VertexAICommon, BaseChatModel): """`Vertex AI` Chat large language models API.""" model_name: str = "chat-bison" + streaming: bool = False @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -118,6 +140,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: """Generate next turn in the conversation. @@ -127,6 +150,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): does not support context. stop: The list of stop words (optional). run_manager: The CallbackManager for LLM run, it's not used at the moment. + stream: Whether to use the streaming endpoint. Returns: The ChatResult that contains outputs generated by the model. @@ -134,27 +158,53 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): Raises: ValueError: if the last message in the list is not from human. """ - if not messages: - raise ValueError( - "You should provide at least one message to start the chat!" - ) - question = messages[-1] - if not isinstance(question, HumanMessage): - raise ValueError( - f"Last message in the list should be from human, got {question.type}." + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs ) + return _generate_from_stream(stream_iter) + + question = _get_question(messages) history = _parse_chat_history(messages[:-1]) - context = history.context if history.context else None params = {**self._default_params, **kwargs} examples = kwargs.get("examples", None) if examples: params["examples"] = _parse_examples(examples) - if not self.is_codey_model: - chat = self.client.start_chat( - context=context, message_history=history.history, **params - ) - else: - chat = self.client.start_chat(message_history=history.history, **params) + + chat = self._start_chat(history, params) response = chat.send_message(question.content) text = self._enforce_stop_words(response.text, stop) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + params = {**self._default_params, **kwargs} + examples = kwargs.get("examples", None) + if examples: + params["examples"] = _parse_examples(examples) + + chat = self._start_chat(history, params) + responses = chat.send_message_streaming(question.content, **params) + for response in responses: + text = self._enforce_stop_words(response.text, stop) + if run_manager: + run_manager.on_llm_new_token(text) + yield ChatGenerationChunk(message=AIMessageChunk(content=text)) + + def _start_chat( + self, history: _ChatHistory, params: dict + ) -> Union[ChatSession, CodeChatSession]: + if not self.is_codey_model: + return self.client.start_chat( + context=history.context, message_history=history.history, **params + ) + else: + return self.client.start_chat(message_history=history.history, **params)