mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 15:38:06 +00:00
141 lines
4.9 KiB
Python
141 lines
4.9 KiB
Python
"""Wrapper around Google VertexAI chat-based models."""
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import root_validator
|
|
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain.chat_models.base import BaseChatModel
|
|
from langchain.llms.vertexai import _VertexAICommon
|
|
from langchain.schema import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
ChatGeneration,
|
|
ChatResult,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain.utilities.vertexai import raise_vertex_import_error
|
|
|
|
|
|
@dataclass
|
|
class _MessagePair:
|
|
"""InputOutputTextPair represents a pair of input and output texts."""
|
|
|
|
question: HumanMessage
|
|
answer: AIMessage
|
|
|
|
|
|
@dataclass
|
|
class _ChatHistory:
|
|
"""InputOutputTextPair represents a pair of input and output texts."""
|
|
|
|
history: List[_MessagePair] = field(default_factory=list)
|
|
system_message: Optional[SystemMessage] = None
|
|
|
|
|
|
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
|
"""Parse a sequence of messages into history.
|
|
|
|
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
|
|
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
|
|
AIMessage, ...).
|
|
|
|
Args:
|
|
history: The list of messages to re-create the history of the chat.
|
|
Returns:
|
|
A parsed chat history.
|
|
Raises:
|
|
ValueError: If a sequence of message is odd, or a human message is not followed
|
|
by a message from AI (e.g., Human, Human, AI or AI, AI, Human).
|
|
"""
|
|
if not history:
|
|
return _ChatHistory()
|
|
first_message = history[0]
|
|
system_message = first_message if isinstance(first_message, SystemMessage) else None
|
|
chat_history = _ChatHistory(system_message=system_message)
|
|
messages_left = history[1:] if system_message else history
|
|
if len(messages_left) % 2 != 0:
|
|
raise ValueError(
|
|
f"Amount of messages in history should be even, got {len(messages_left)}!"
|
|
)
|
|
for question, answer in zip(messages_left[::2], messages_left[1::2]):
|
|
if not isinstance(question, HumanMessage) or not isinstance(answer, AIMessage):
|
|
raise ValueError(
|
|
"A human message should follow a bot one, "
|
|
f"got {question.type}, {answer.type}."
|
|
)
|
|
chat_history.history.append(_MessagePair(question=question, answer=answer))
|
|
return chat_history
|
|
|
|
|
|
class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
"""Wrapper around Vertex AI large language models."""
|
|
|
|
model_name: str = "chat-bison"
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that the python package exists in environment."""
|
|
cls._try_init_vertexai(values)
|
|
try:
|
|
from vertexai.preview.language_models import ChatModel
|
|
except ImportError:
|
|
raise_vertex_import_error()
|
|
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
|
return values
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Generate next turn in the conversation.
|
|
|
|
Args:
|
|
messages: The history of the conversation as a list of messages.
|
|
stop: The list of stop words (optional).
|
|
run_manager: The Callbackmanager for LLM run, it's not used at the moment.
|
|
|
|
Returns:
|
|
The ChatResult that contains outputs generated by the model.
|
|
|
|
Raises:
|
|
ValueError: if the last message in the list is not from human.
|
|
"""
|
|
if not messages:
|
|
raise ValueError(
|
|
"You should provide at least one message to start the chat!"
|
|
)
|
|
question = messages[-1]
|
|
if not isinstance(question, HumanMessage):
|
|
raise ValueError(
|
|
f"Last message in the list should be from human, got {question.type}."
|
|
)
|
|
|
|
history = _parse_chat_history(messages[:-1])
|
|
context = history.system_message.content if history.system_message else None
|
|
params = {**self._default_params, **kwargs}
|
|
chat = self.client.start_chat(context=context, **params)
|
|
for pair in history.history:
|
|
chat._history.append((pair.question.content, pair.answer.content))
|
|
response = chat.send_message(question.content, **self._default_params)
|
|
text = self._enforce_stop_words(response.text, stop)
|
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
raise NotImplementedError(
|
|
"""Vertex AI doesn't support async requests at the moment."""
|
|
)
|