Stream refac and vertex streaming (#10470)

---------

Co-authored-by: Terry Cruz Melo <tcruz@vozy.co>
Co-authored-by: Terry Cruz Melo <33166112+TerryCM@users.noreply.github.com>
This commit is contained in:
Bagatur 2023-09-20 11:49:16 -07:00 committed by GitHub
parent f421af8b80
commit 0749a642f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 168 additions and 120 deletions

View File

@ -4,7 +4,11 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import (
BaseChatModel,
_agenerate_from_stream,
_generate_from_stream,
)
from langchain.llms.anthropic import _AnthropicCommon from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -162,22 +166,22 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming: if self.streaming:
completion = "" stream_iter = self._stream(
for chunk in self._stream(messages, stop, run_manager, **kwargs): messages, stop=stop, run_manager=run_manager, **kwargs
completion += chunk.text
else:
prompt = self._convert_messages_to_prompt(
messages,
) )
params: Dict[str, Any] = { return _generate_from_stream(stream_iter)
"prompt": prompt, prompt = self._convert_messages_to_prompt(
**self._default_params, messages,
**kwargs, )
} params: Dict[str, Any] = {
if stop: "prompt": prompt,
params["stop_sequences"] = stop **self._default_params,
response = self.client.completions.create(**params) **kwargs,
completion = response.completion }
if stop:
params["stop_sequences"] = stop
response = self.client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion) message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])
@ -189,22 +193,22 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming: if self.streaming:
completion = "" stream_iter = self._astream(
async for chunk in self._astream(messages, stop, run_manager, **kwargs): messages, stop=stop, run_manager=run_manager, **kwargs
completion += chunk.text
else:
prompt = self._convert_messages_to_prompt(
messages,
) )
params: Dict[str, Any] = { return await _agenerate_from_stream(stream_iter)
"prompt": prompt, prompt = self._convert_messages_to_prompt(
**self._default_params, messages,
**kwargs, )
} params: Dict[str, Any] = {
if stop: "prompt": prompt,
params["stop_sequences"] = stop **self._default_params,
response = await self.async_client.completions.create(**params) **kwargs,
completion = response.completion }
if stop:
params["stop_sequences"] = stop
response = await self.async_client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion) message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -49,6 +49,30 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
async def _agenerate_from_stream(
stream: AsyncIterator[ChatGenerationChunk],
) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
async for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"""Base class for Chat models.""" """Base class for Chat models."""

View File

@ -27,7 +27,11 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import (
BaseChatModel,
_agenerate_from_stream,
_generate_from_stream,
)
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
@ -319,16 +323,10 @@ class JinaChat(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming: if self.streaming:
generation: Optional[ChatGenerationChunk] = None stream_iter = self._stream(
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs messages=messages, stop=stop, run_manager=run_manager, **kwargs
): )
if generation is None: return _generate_from_stream(stream_iter)
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
@ -384,16 +382,10 @@ class JinaChat(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.streaming: if self.streaming:
generation: Optional[ChatGenerationChunk] = None stream_iter = self._astream(
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs messages=messages, stop=stop, run_manager=run_manager, **kwargs
): )
if generation is None: return await _agenerate_from_stream(stream_iter)
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}

View File

@ -21,6 +21,7 @@ from langchain.adapters.openai import convert_dict_to_message, convert_message_t
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.chat_models.base import _generate_from_stream
from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import ChatGeneration, ChatResult from langchain.schema import ChatGeneration, ChatResult
@ -224,16 +225,10 @@ class ChatKonko(ChatOpenAI):
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
generation: Optional[ChatGenerationChunk] = None stream_iter = self._stream(
for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs )
): return _generate_from_stream(stream_iter)
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}

View File

@ -19,7 +19,11 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import (
BaseChatModel,
_agenerate_from_stream,
_generate_from_stream,
)
from langchain.llms.base import create_base_retry_decorator from langchain.llms.base import create_base_retry_decorator
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import ( from langchain.schema import (
@ -320,16 +324,10 @@ class ChatLiteLLM(BaseChatModel):
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
generation: Optional[ChatGenerationChunk] = None stream_iter = self._stream(
for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs )
): return _generate_from_stream(stream_iter)
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
@ -421,16 +419,10 @@ class ChatLiteLLM(BaseChatModel):
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
generation: Optional[ChatGenerationChunk] = None stream_iter = self._astream(
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs messages=messages, stop=stop, run_manager=run_manager, **kwargs
): )
if generation is None: return await _agenerate_from_stream(stream_iter)
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}

View File

@ -22,7 +22,11 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import (
BaseChatModel,
_agenerate_from_stream,
_generate_from_stream,
)
from langchain.llms.base import create_base_retry_decorator from langchain.llms.base import create_base_retry_decorator
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import ChatGeneration, ChatResult from langchain.schema import ChatGeneration, ChatResult
@ -330,17 +334,10 @@ class ChatOpenAI(BaseChatModel):
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
generation: Optional[ChatGenerationChunk] = None stream_iter = self._stream(
for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs )
): return _generate_from_stream(stream_iter)
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
response = self.completion_with_retry( response = self.completion_with_retry(
@ -411,16 +408,10 @@ class ChatOpenAI(BaseChatModel):
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
generation: Optional[ChatGenerationChunk] = None stream_iter = self._astream(
async for chunk in self._astream( messages, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs )
): return await _agenerate_from_stream(stream_iter)
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}

View File

@ -1,9 +1,11 @@
"""Wrapper around Google VertexAI chat-based models.""" """Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel, _generate_from_stream
from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.pydantic_v1 import root_validator from langchain.pydantic_v1 import root_validator
from langchain.schema import ( from langchain.schema import (
@ -12,14 +14,21 @@ from langchain.schema import (
) )
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain.schema.output import ChatGenerationChunk
from langchain.utilities.vertexai import raise_vertex_import_error from langchain.utilities.vertexai import raise_vertex_import_error
if TYPE_CHECKING: if TYPE_CHECKING:
from vertexai.language_models import ChatMessage, InputOutputTextPair from vertexai.language_models import (
ChatMessage,
ChatSession,
CodeChatSession,
InputOutputTextPair,
)
@dataclass @dataclass
@ -91,10 +100,23 @@ def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
return example_pairs return example_pairs
def _get_question(messages: List[BaseMessage]) -> HumanMessage:
"""Get the human message at the end of a list of input messages to a chat model."""
if not messages:
raise ValueError("You should provide at least one message to start the chat!")
question = messages[-1]
if not isinstance(question, HumanMessage):
raise ValueError(
f"Last message in the list should be from human, got {question.type}."
)
return question
class ChatVertexAI(_VertexAICommon, BaseChatModel): class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API.""" """`Vertex AI` Chat large language models API."""
model_name: str = "chat-bison" model_name: str = "chat-bison"
streaming: bool = False
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
@ -118,6 +140,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Generate next turn in the conversation. """Generate next turn in the conversation.
@ -127,6 +150,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
does not support context. does not support context.
stop: The list of stop words (optional). stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment. run_manager: The CallbackManager for LLM run, it's not used at the moment.
stream: Whether to use the streaming endpoint.
Returns: Returns:
The ChatResult that contains outputs generated by the model. The ChatResult that contains outputs generated by the model.
@ -134,27 +158,53 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
Raises: Raises:
ValueError: if the last message in the list is not from human. ValueError: if the last message in the list is not from human.
""" """
if not messages: should_stream = stream if stream is not None else self.streaming
raise ValueError( if should_stream:
"You should provide at least one message to start the chat!" stream_iter = self._stream(
) messages, stop=stop, run_manager=run_manager, **kwargs
question = messages[-1]
if not isinstance(question, HumanMessage):
raise ValueError(
f"Last message in the list should be from human, got {question.type}."
) )
return _generate_from_stream(stream_iter)
question = _get_question(messages)
history = _parse_chat_history(messages[:-1]) history = _parse_chat_history(messages[:-1])
context = history.context if history.context else None
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
examples = kwargs.get("examples", None) examples = kwargs.get("examples", None)
if examples: if examples:
params["examples"] = _parse_examples(examples) params["examples"] = _parse_examples(examples)
if not self.is_codey_model:
chat = self.client.start_chat( chat = self._start_chat(history, params)
context=context, message_history=history.history, **params
)
else:
chat = self.client.start_chat(message_history=history.history, **params)
response = chat.send_message(question.content) response = chat.send_message(question.content)
text = self._enforce_stop_words(response.text, stop) text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = {**self._default_params, **kwargs}
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
text = self._enforce_stop_words(response.text, stop)
if run_manager:
run_manager.on_llm_new_token(text)
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
def _start_chat(
self, history: _ChatHistory, params: dict
) -> Union[ChatSession, CodeChatSession]:
if not self.is_codey_model:
return self.client.start_chat(
context=history.context, message_history=history.history, **params
)
else:
return self.client.start_chat(message_history=history.history, **params)