mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
# Implemented appending arbitrary messages to the base chat message history, the in-memory and cosmos ones. <!-- Thank you for contributing to LangChain! Your PR will appear in our next release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. --> As discussed this is the alternative way instead of #4480, with a add_message method added that takes a BaseMessage as input, so that the user can control what is in the base message like kwargs. <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting <!-- If you're adding a new integration, include an integration test and an example notebook showing its use! --> ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @hwchase17 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
402 lines
11 KiB
Python
402 lines
11 KiB
Python
"""Common schema objects."""
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Generic,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
Sequence,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
|
|
def get_buffer_string(
|
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
|
) -> str:
|
|
"""Get buffer string of messages."""
|
|
string_messages = []
|
|
for m in messages:
|
|
if isinstance(m, HumanMessage):
|
|
role = human_prefix
|
|
elif isinstance(m, AIMessage):
|
|
role = ai_prefix
|
|
elif isinstance(m, SystemMessage):
|
|
role = "System"
|
|
elif isinstance(m, ChatMessage):
|
|
role = m.role
|
|
else:
|
|
raise ValueError(f"Got unsupported message type: {m}")
|
|
string_messages.append(f"{role}: {m.content}")
|
|
return "\n".join(string_messages)
|
|
|
|
|
|
class AgentAction(NamedTuple):
|
|
"""Agent's action to take."""
|
|
|
|
tool: str
|
|
tool_input: Union[str, dict]
|
|
log: str
|
|
|
|
|
|
class AgentFinish(NamedTuple):
|
|
"""Agent's return value."""
|
|
|
|
return_values: dict
|
|
log: str
|
|
|
|
|
|
class Generation(BaseModel):
|
|
"""Output of a single generation."""
|
|
|
|
text: str
|
|
"""Generated text output."""
|
|
|
|
generation_info: Optional[Dict[str, Any]] = None
|
|
"""Raw generation info response from the provider"""
|
|
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
|
# TODO: add log probs
|
|
|
|
|
|
class BaseMessage(BaseModel):
|
|
"""Message object."""
|
|
|
|
content: str
|
|
additional_kwargs: dict = Field(default_factory=dict)
|
|
|
|
@property
|
|
@abstractmethod
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
|
|
|
|
class HumanMessage(BaseMessage):
|
|
"""Type of message that is spoken by the human."""
|
|
|
|
example: bool = False
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "human"
|
|
|
|
|
|
class AIMessage(BaseMessage):
|
|
"""Type of message that is spoken by the AI."""
|
|
|
|
example: bool = False
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "ai"
|
|
|
|
|
|
class SystemMessage(BaseMessage):
|
|
"""Type of message that is a system message."""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "system"
|
|
|
|
|
|
class ChatMessage(BaseMessage):
|
|
"""Type of message with arbitrary speaker."""
|
|
|
|
role: str
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "chat"
|
|
|
|
|
|
def _message_to_dict(message: BaseMessage) -> dict:
|
|
return {"type": message.type, "data": message.dict()}
|
|
|
|
|
|
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
|
return [_message_to_dict(m) for m in messages]
|
|
|
|
|
|
def _message_from_dict(message: dict) -> BaseMessage:
|
|
_type = message["type"]
|
|
if _type == "human":
|
|
return HumanMessage(**message["data"])
|
|
elif _type == "ai":
|
|
return AIMessage(**message["data"])
|
|
elif _type == "system":
|
|
return SystemMessage(**message["data"])
|
|
elif _type == "chat":
|
|
return ChatMessage(**message["data"])
|
|
else:
|
|
raise ValueError(f"Got unexpected type: {_type}")
|
|
|
|
|
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
|
return [_message_from_dict(m) for m in messages]
|
|
|
|
|
|
class ChatGeneration(Generation):
|
|
"""Output of a single generation."""
|
|
|
|
text = ""
|
|
message: BaseMessage
|
|
|
|
@root_validator
|
|
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
values["text"] = values["message"].content
|
|
return values
|
|
|
|
|
|
class ChatResult(BaseModel):
|
|
"""Class that contains all relevant information for a Chat Result."""
|
|
|
|
generations: List[ChatGeneration]
|
|
"""List of the things generated."""
|
|
llm_output: Optional[dict] = None
|
|
"""For arbitrary LLM provider specific output."""
|
|
|
|
|
|
class LLMResult(BaseModel):
|
|
"""Class that contains all relevant information for an LLM Result."""
|
|
|
|
generations: List[List[Generation]]
|
|
"""List of the things generated. This is List[List[]] because
|
|
each input could have multiple generations."""
|
|
llm_output: Optional[dict] = None
|
|
"""For arbitrary LLM provider specific output."""
|
|
|
|
|
|
class PromptValue(BaseModel, ABC):
|
|
@abstractmethod
|
|
def to_string(self) -> str:
|
|
"""Return prompt as string."""
|
|
|
|
@abstractmethod
|
|
def to_messages(self) -> List[BaseMessage]:
|
|
"""Return prompt as messages."""
|
|
|
|
|
|
class BaseMemory(BaseModel, ABC):
|
|
"""Base interface for memory in chains."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
@abstractmethod
|
|
def memory_variables(self) -> List[str]:
|
|
"""Input keys this memory class will load dynamically."""
|
|
|
|
@abstractmethod
|
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Return key-value pairs given the text input to the chain.
|
|
|
|
If None, return all memories
|
|
"""
|
|
|
|
@abstractmethod
|
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
"""Save the context of this model run to memory."""
|
|
|
|
@abstractmethod
|
|
def clear(self) -> None:
|
|
"""Clear memory contents."""
|
|
|
|
|
|
class BaseChatMessageHistory(ABC):
|
|
"""Base interface for chat message history
|
|
See `ChatMessageHistory` for default implementation.
|
|
"""
|
|
|
|
"""
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
class FileChatMessageHistory(BaseChatMessageHistory):
|
|
storage_path: str
|
|
session_id: str
|
|
|
|
@property
|
|
def messages(self):
|
|
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
|
messages = json.loads(f.read())
|
|
return messages_from_dict(messages)
|
|
|
|
def add_message(self, message: BaseMessage) -> None:
|
|
messages = self.messages.append(_message_to_dict(message))
|
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
|
json.dump(f, messages)
|
|
|
|
def clear(self):
|
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
|
f.write("[]")
|
|
"""
|
|
|
|
messages: List[BaseMessage]
|
|
|
|
def add_user_message(self, message: str) -> None:
|
|
"""Add a user message to the store"""
|
|
self.add_message(HumanMessage(content=message))
|
|
|
|
def add_ai_message(self, message: str) -> None:
|
|
"""Add an AI message to the store"""
|
|
self.add_message(AIMessage(content=message))
|
|
|
|
def add_message(self, message: BaseMessage) -> None:
|
|
"""Add a self-created message to the store"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def clear(self) -> None:
|
|
"""Remove all messages from the store"""
|
|
|
|
|
|
class Document(BaseModel):
|
|
"""Interface for interacting with a document."""
|
|
|
|
page_content: str
|
|
metadata: dict = Field(default_factory=dict)
|
|
|
|
|
|
class BaseRetriever(ABC):
|
|
@abstractmethod
|
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
"""Get documents relevant for a query.
|
|
|
|
Args:
|
|
query: string to find relevant documents for
|
|
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
|
"""Get documents relevant for a query.
|
|
|
|
Args:
|
|
query: string to find relevant documents for
|
|
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
|
|
|
|
# For backwards compatibility
|
|
|
|
|
|
Memory = BaseMemory
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
|
"""Class to parse the output of an LLM call.
|
|
|
|
Output parsers help structure language model responses.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def parse(self, text: str) -> T:
|
|
"""Parse the output of an LLM call.
|
|
|
|
A method which takes in a string (assumed output of a language model )
|
|
and parses it into some structure.
|
|
|
|
Args:
|
|
text: output of language model
|
|
|
|
Returns:
|
|
structured output
|
|
"""
|
|
|
|
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
|
"""Optional method to parse the output of an LLM call with a prompt.
|
|
|
|
The prompt is largely provided in the event the OutputParser wants
|
|
to retry or fix the output in some way, and needs information from
|
|
the prompt to do so.
|
|
|
|
Args:
|
|
completion: output of language model
|
|
prompt: prompt value
|
|
|
|
Returns:
|
|
structured output
|
|
"""
|
|
return self.parse(completion)
|
|
|
|
def get_format_instructions(self) -> str:
|
|
"""Instructions on how the LLM output should be formatted."""
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def _type(self) -> str:
|
|
"""Return the type key."""
|
|
raise NotImplementedError(
|
|
f"_type property is not implemented in class {self.__class__.__name__}."
|
|
" This is required for serialization."
|
|
)
|
|
|
|
def dict(self, **kwargs: Any) -> Dict:
|
|
"""Return dictionary representation of output parser."""
|
|
output_parser_dict = super().dict()
|
|
output_parser_dict["_type"] = self._type
|
|
return output_parser_dict
|
|
|
|
|
|
class OutputParserException(ValueError):
|
|
"""Exception that output parsers should raise to signify a parsing error.
|
|
|
|
This exists to differentiate parsing errors from other code or execution errors
|
|
that also may arise inside the output parser. OutputParserExceptions will be
|
|
available to catch and handle in ways to fix the parsing error, while other
|
|
errors will be raised.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
error: Any,
|
|
observation: str | None = None,
|
|
llm_output: str | None = None,
|
|
send_to_llm: bool = False,
|
|
):
|
|
super(OutputParserException, self).__init__(error)
|
|
if send_to_llm:
|
|
if observation is None or llm_output is None:
|
|
raise ValueError(
|
|
"Arguments 'observation' & 'llm_output'"
|
|
" are required if 'send_to_llm' is True"
|
|
)
|
|
self.observation = observation
|
|
self.llm_output = llm_output
|
|
self.send_to_llm = send_to_llm
|
|
|
|
|
|
class BaseDocumentTransformer(ABC):
|
|
"""Base interface for transforming documents."""
|
|
|
|
@abstractmethod
|
|
def transform_documents(
|
|
self, documents: Sequence[Document], **kwargs: Any
|
|
) -> Sequence[Document]:
|
|
"""Transform a list of documents."""
|
|
|
|
@abstractmethod
|
|
async def atransform_documents(
|
|
self, documents: Sequence[Document], **kwargs: Any
|
|
) -> Sequence[Document]:
|
|
"""Asynchronously transform a list of documents."""
|