mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 03:28:40 +00:00
Stream refac and vertex streaming (#10470)
--------- Co-authored-by: Terry Cruz Melo <tcruz@vozy.co> Co-authored-by: Terry Cruz Melo <33166112+TerryCM@users.noreply.github.com>
This commit is contained in:
parent
f421af8b80
commit
0749a642f5
@ -4,7 +4,11 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.chat_models.base import (
|
||||
BaseChatModel,
|
||||
_agenerate_from_stream,
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.llms.anthropic import _AnthropicCommon
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@ -162,22 +166,22 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = self.client.completions.create(**params)
|
||||
completion = response.completion
|
||||
return _generate_from_stream(stream_iter)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = self.client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@ -189,22 +193,22 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = await self.async_client.completions.create(**params)
|
||||
completion = response.completion
|
||||
return await _agenerate_from_stream(stream_iter)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = await self.async_client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
|
@ -49,6 +49,30 @@ def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in stream:
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
async def _agenerate_from_stream(
|
||||
stream: AsyncIterator[ChatGenerationChunk],
|
||||
) -> ChatResult:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in stream:
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
"""Base class for Chat models."""
|
||||
|
||||
|
@ -27,7 +27,11 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.chat_models.base import (
|
||||
BaseChatModel,
|
||||
_agenerate_from_stream,
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
@ -319,16 +323,10 @@ class JinaChat(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
@ -384,16 +382,10 @@ class JinaChat(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
)
|
||||
return await _agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
|
@ -21,6 +21,7 @@ from langchain.adapters.openai import convert_dict_to_message, convert_message_t
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import _generate_from_stream
|
||||
from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
@ -224,16 +225,10 @@ class ChatKonko(ChatOpenAI):
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
|
@ -19,7 +19,11 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.chat_models.base import (
|
||||
BaseChatModel,
|
||||
_agenerate_from_stream,
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import (
|
||||
@ -320,16 +324,10 @@ class ChatLiteLLM(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
@ -421,16 +419,10 @@ class ChatLiteLLM(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
)
|
||||
return await _agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
|
@ -22,7 +22,11 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.chat_models.base import (
|
||||
BaseChatModel,
|
||||
_agenerate_from_stream,
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
@ -330,17 +334,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
@ -411,16 +408,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await _agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Wrapper around Google VertexAI chat-based models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema import (
|
||||
@ -12,14 +14,21 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utilities.vertexai import raise_vertex_import_error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vertexai.language_models import ChatMessage, InputOutputTextPair
|
||||
from vertexai.language_models import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
CodeChatSession,
|
||||
InputOutputTextPair,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -91,10 +100,23 @@ def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
|
||||
return example_pairs
|
||||
|
||||
|
||||
def _get_question(messages: List[BaseMessage]) -> HumanMessage:
|
||||
"""Get the human message at the end of a list of input messages to a chat model."""
|
||||
if not messages:
|
||||
raise ValueError("You should provide at least one message to start the chat!")
|
||||
question = messages[-1]
|
||||
if not isinstance(question, HumanMessage):
|
||||
raise ValueError(
|
||||
f"Last message in the list should be from human, got {question.type}."
|
||||
)
|
||||
return question
|
||||
|
||||
|
||||
class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
"""`Vertex AI` Chat large language models API."""
|
||||
|
||||
model_name: str = "chat-bison"
|
||||
streaming: bool = False
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -118,6 +140,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate next turn in the conversation.
|
||||
@ -127,6 +150,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
does not support context.
|
||||
stop: The list of stop words (optional).
|
||||
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||
stream: Whether to use the streaming endpoint.
|
||||
|
||||
Returns:
|
||||
The ChatResult that contains outputs generated by the model.
|
||||
@ -134,27 +158,53 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
Raises:
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
if not messages:
|
||||
raise ValueError(
|
||||
"You should provide at least one message to start the chat!"
|
||||
)
|
||||
question = messages[-1]
|
||||
if not isinstance(question, HumanMessage):
|
||||
raise ValueError(
|
||||
f"Last message in the list should be from human, got {question.type}."
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
context = history.context if history.context else None
|
||||
params = {**self._default_params, **kwargs}
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
if not self.is_codey_model:
|
||||
chat = self.client.start_chat(
|
||||
context=context, message_history=history.history, **params
|
||||
)
|
||||
else:
|
||||
chat = self.client.start_chat(message_history=history.history, **params)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
response = chat.send_message(question.content)
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = {**self._default_params, **kwargs}
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
for response in responses:
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(text)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
|
||||
def _start_chat(
|
||||
self, history: _ChatHistory, params: dict
|
||||
) -> Union[ChatSession, CodeChatSession]:
|
||||
if not self.is_codey_model:
|
||||
return self.client.start_chat(
|
||||
context=history.context, message_history=history.history, **params
|
||||
)
|
||||
else:
|
||||
return self.client.start_chat(message_history=history.history, **params)
|
||||
|
Loading…
Reference in New Issue
Block a user