langchain[patch]: fix agent_token_buffer_memory not working with openai tools (#20708)

- **Description:** fix a bug in the agent_token_buffer_memory
- **Issue:** agent_token_buffer_memory was not working with openai tools
- **Dependencies:** None
- **Twitter handle:** @pokidyshef
This commit is contained in:
Nikita Pokidyshev 2024-04-25 01:51:58 +03:00 committed by GitHub
parent 6353991498
commit 9e983c9500
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,8 +4,9 @@ from typing import Any, Dict, List
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain.agents.format_scratchpad.openai_functions import ( from langchain.agents.format_scratchpad import (
format_to_openai_function_messages, format_to_openai_function_messages,
format_to_tool_messages,
) )
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
@ -23,6 +24,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
return_messages: bool = True return_messages: bool = True
output_key: str = "output" output_key: str = "output"
intermediate_steps_key: str = "intermediate_steps" intermediate_steps_key: str = "intermediate_steps"
format_as_tools: bool = False
@property @property
def buffer(self) -> List[BaseMessage]: def buffer(self) -> List[BaseMessage]:
@ -53,7 +55,12 @@ class AgentTokenBufferMemory(BaseChatMemory):
"""Save context from this conversation to buffer. Pruned.""" """Save context from this conversation to buffer. Pruned."""
input_str, output_str = self._get_input_output(inputs, outputs) input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str) self.chat_memory.add_user_message(input_str)
steps = format_to_openai_function_messages(outputs[self.intermediate_steps_key]) format_to_messages = (
format_to_tool_messages
if self.format_as_tools
else format_to_openai_function_messages
)
steps = format_to_messages(outputs[self.intermediate_steps_key])
for msg in steps: for msg in steps:
self.chat_memory.add_message(msg) self.chat_memory.add_message(msg)
self.chat_memory.add_ai_message(output_str) self.chat_memory.add_ai_message(output_str)