Harrison/conversational retrieval agent (#8639)

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harrison Chase
2023-08-02 18:05:15 -07:00
committed by GitHub
parent 71f98db2fe
commit 43dffe39fb
6 changed files with 764 additions and 1 deletions

View File

@@ -3,6 +3,12 @@ from langchain.agents.agent_toolkits.amadeus.toolkit import AmadeusToolkit
from langchain.agents.agent_toolkits.azure_cognitive_services import (
AzureCognitiveServicesToolkit,
)
from langchain.agents.agent_toolkits.conversational_retrieval.openai_functions import (
create_conversational_retrieval_agent,
)
from langchain.agents.agent_toolkits.conversational_retrieval.tool import (
create_retriever_tool,
)
from langchain.agents.agent_toolkits.csv.base import create_csv_agent
from langchain.agents.agent_toolkits.file_management.toolkit import (
FileManagementToolkit,
@@ -59,16 +65,18 @@ __all__ = [
"ZapierToolkit",
"create_csv_agent",
"create_json_agent",
"create_multion_agent",
"create_openapi_agent",
"create_pandas_dataframe_agent",
"create_pbi_agent",
"create_pbi_chat_agent",
"create_python_agent",
"create_multion_agent",
"create_retriever_tool",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_sql_agent",
"create_vectorstore_agent",
"create_vectorstore_router_agent",
"create_xorbits_agent",
"create_conversational_retrieval_agent",
]

View File

@@ -0,0 +1,87 @@
from typing import Any, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.chat_models.openai import ChatOpenAI
from langchain.memory.token_buffer import ConversationTokenBufferMemory
from langchain.prompts.chat import MessagesPlaceholder
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.memory import BaseMemory
from langchain.schema.messages import SystemMessage
from langchain.tools.base import BaseTool
def _get_default_system_message() -> SystemMessage:
return SystemMessage(
content=(
"Do your best to answer the questions. "
"Feel free to use any tools available to look up "
"relevant information, only if necessary"
)
)
def create_conversational_retrieval_agent(
llm: BaseLanguageModel,
tools: List[BaseTool],
remember_intermediate_steps: bool = True,
memory_key: str = "chat_history",
system_message: Optional[SystemMessage] = None,
verbose: bool = False,
max_token_limit: int = 2000,
**kwargs: Any
) -> AgentExecutor:
"""A convenience method for creating a conversational retrieval agent.
Args:
llm: The language model to use, should be ChatOpenAI
tools: A list of tools the agent has access to
remember_intermediate_steps: Whether the agent should remember intermediate
steps or not. Intermediate steps refer to prior action/observation
pairs from previous questions. The benefit of remembering these is if
there is relevant information in there, the agent can use it to answer
follow up questions. The downside is it will take up more tokens.
memory_key: The name of the memory key in the prompt.
system_message: The system message to use. By default, a basic one will
be used.
verbose: Whether or not the final AgentExecutor should be verbose or not,
defaults to False.
max_token_limit: The max number of tokens to keep around in memory.
Defaults to 2000.
Returns:
An agent executor initialized appropriately
"""
if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
if remember_intermediate_steps:
memory: BaseMemory = AgentTokenBufferMemory(
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
)
else:
memory = ConversationTokenBufferMemory(
memory_key=memory_key,
return_messages=True,
output_key="output",
llm=llm,
max_token_limit=max_token_limit,
)
_system_message = system_message or _get_default_system_message()
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=_system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
return AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=verbose,
return_intermediate_steps=remember_intermediate_steps,
**kwargs
)

View File

@@ -0,0 +1,22 @@
from langchain.schema import BaseRetriever
from langchain.tools import Tool
def create_retriever_tool(
retriever: BaseRetriever, name: str, description: str
) -> Tool:
"""Create a tool to do retrieval of documents.
Args:
retriever: The retriever to use for the retrieval
name: The name for the tool. This will be passed to the language model,
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
Returns:
Tool class to pass to an agent
"""
return Tool(
name=name, description=description, func=retriever.get_relevant_documents
)

View File

@@ -0,0 +1,63 @@
"""Memory used to save agent output AND intermediate steps."""
from typing import Any, Dict, List
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, get_buffer_string
class AgentTokenBufferMemory(BaseChatMemory):
"""Memory used to save agent output AND intermediate steps."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
memory_key: str = "history"
max_token_limit: int = 12000
"""The max number of tokens to keep in the buffer.
Once the buffer exceeds this many tokens, the oldest messages will be pruned."""
return_messages: bool = True
output_key = "output"
intermediate_steps_key = "intermediate_steps"
@property
def buffer(self) -> List[BaseMessage]:
"""String buffer of memory."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
if self.return_messages:
final_buffer: Any = self.buffer
else:
final_buffer = get_buffer_string(
self.buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""Save context from this conversation to buffer. Pruned."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str)
steps = _format_intermediate_steps(outputs[self.intermediate_steps_key])
for msg in steps:
self.chat_memory.add_message(msg)
self.chat_memory.add_ai_message(output_str)
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
while curr_buffer_length > self.max_token_limit:
buffer.pop(0)
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)