mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
Harrison/conversational retrieval agent (#8639)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -3,6 +3,12 @@ from langchain.agents.agent_toolkits.amadeus.toolkit import AmadeusToolkit
|
||||
from langchain.agents.agent_toolkits.azure_cognitive_services import (
|
||||
AzureCognitiveServicesToolkit,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.conversational_retrieval.openai_functions import (
|
||||
create_conversational_retrieval_agent,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.conversational_retrieval.tool import (
|
||||
create_retriever_tool,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.csv.base import create_csv_agent
|
||||
from langchain.agents.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
@@ -59,16 +65,18 @@ __all__ = [
|
||||
"ZapierToolkit",
|
||||
"create_csv_agent",
|
||||
"create_json_agent",
|
||||
"create_multion_agent",
|
||||
"create_openapi_agent",
|
||||
"create_pandas_dataframe_agent",
|
||||
"create_pbi_agent",
|
||||
"create_pbi_chat_agent",
|
||||
"create_python_agent",
|
||||
"create_multion_agent",
|
||||
"create_retriever_tool",
|
||||
"create_spark_dataframe_agent",
|
||||
"create_spark_sql_agent",
|
||||
"create_sql_agent",
|
||||
"create_vectorstore_agent",
|
||||
"create_vectorstore_router_agent",
|
||||
"create_xorbits_agent",
|
||||
"create_conversational_retrieval_agent",
|
||||
]
|
||||
|
@@ -0,0 +1,87 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
||||
AgentTokenBufferMemory,
|
||||
)
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain.prompts.chat import MessagesPlaceholder
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.memory import BaseMemory
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
def _get_default_system_message() -> SystemMessage:
|
||||
return SystemMessage(
|
||||
content=(
|
||||
"Do your best to answer the questions. "
|
||||
"Feel free to use any tools available to look up "
|
||||
"relevant information, only if necessary"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_conversational_retrieval_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tools: List[BaseTool],
|
||||
remember_intermediate_steps: bool = True,
|
||||
memory_key: str = "chat_history",
|
||||
system_message: Optional[SystemMessage] = None,
|
||||
verbose: bool = False,
|
||||
max_token_limit: int = 2000,
|
||||
**kwargs: Any
|
||||
) -> AgentExecutor:
|
||||
"""A convenience method for creating a conversational retrieval agent.
|
||||
|
||||
Args:
|
||||
llm: The language model to use, should be ChatOpenAI
|
||||
tools: A list of tools the agent has access to
|
||||
remember_intermediate_steps: Whether the agent should remember intermediate
|
||||
steps or not. Intermediate steps refer to prior action/observation
|
||||
pairs from previous questions. The benefit of remembering these is if
|
||||
there is relevant information in there, the agent can use it to answer
|
||||
follow up questions. The downside is it will take up more tokens.
|
||||
memory_key: The name of the memory key in the prompt.
|
||||
system_message: The system message to use. By default, a basic one will
|
||||
be used.
|
||||
verbose: Whether or not the final AgentExecutor should be verbose or not,
|
||||
defaults to False.
|
||||
max_token_limit: The max number of tokens to keep around in memory.
|
||||
Defaults to 2000.
|
||||
|
||||
Returns:
|
||||
An agent executor initialized appropriately
|
||||
"""
|
||||
|
||||
if not isinstance(llm, ChatOpenAI):
|
||||
raise ValueError("Only supported with ChatOpenAI models.")
|
||||
if remember_intermediate_steps:
|
||||
memory: BaseMemory = AgentTokenBufferMemory(
|
||||
memory_key=memory_key, llm=llm, max_token_limit=max_token_limit
|
||||
)
|
||||
else:
|
||||
memory = ConversationTokenBufferMemory(
|
||||
memory_key=memory_key,
|
||||
return_messages=True,
|
||||
output_key="output",
|
||||
llm=llm,
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
|
||||
_system_message = system_message or _get_default_system_message()
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(
|
||||
system_message=_system_message,
|
||||
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
|
||||
)
|
||||
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
||||
return AgentExecutor(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=remember_intermediate_steps,
|
||||
**kwargs
|
||||
)
|
@@ -0,0 +1,22 @@
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.tools import Tool
|
||||
|
||||
|
||||
def create_retriever_tool(
|
||||
retriever: BaseRetriever, name: str, description: str
|
||||
) -> Tool:
|
||||
"""Create a tool to do retrieval of documents.
|
||||
|
||||
Args:
|
||||
retriever: The retriever to use for the retrieval
|
||||
name: The name for the tool. This will be passed to the language model,
|
||||
so should be unique and somewhat descriptive.
|
||||
description: The description for the tool. This will be passed to the language
|
||||
model, so should be descriptive.
|
||||
|
||||
Returns:
|
||||
Tool class to pass to an agent
|
||||
"""
|
||||
return Tool(
|
||||
name=name, description=description, func=retriever.get_relevant_documents
|
||||
)
|
@@ -0,0 +1,63 @@
|
||||
"""Memory used to save agent output AND intermediate steps."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class AgentTokenBufferMemory(BaseChatMemory):
|
||||
"""Memory used to save agent output AND intermediate steps."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "history"
|
||||
max_token_limit: int = 12000
|
||||
"""The max number of tokens to keep in the buffer.
|
||||
Once the buffer exceeds this many tokens, the oldest messages will be pruned."""
|
||||
return_messages: bool = True
|
||||
output_key = "output"
|
||||
intermediate_steps_key = "intermediate_steps"
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
final_buffer: Any = self.buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
self.buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
||||
"""Save context from this conversation to buffer. Pruned."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_user_message(input_str)
|
||||
steps = _format_intermediate_steps(outputs[self.intermediate_steps_key])
|
||||
for msg in steps:
|
||||
self.chat_memory.add_message(msg)
|
||||
self.chat_memory.add_ai_message(output_str)
|
||||
# Prune buffer if it exceeds max token limit
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
buffer.pop(0)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
Reference in New Issue
Block a user