This commit is contained in:
Eugene Yurtsev
2023-09-01 11:23:02 -04:00
parent e4cf87ab34
commit bf12905e7f
3 changed files with 51 additions and 18 deletions

View File

@@ -3,22 +3,23 @@ from __future__ import annotations
import ast
import re
from typing import Sequence, Union, Optional
from typing import Sequence, Union, Optional, List
from langchain.automaton.prompt_generators import MessageLogPromptValue
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.typedefs import (
MessageLog,
AgentFinish, MessageLike, FunctionCall,
AgentFinish,
MessageLike,
FunctionCall,
FunctionResult,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import SystemMessage, BaseMessage
from langchain.tools import BaseTool
class ActionParser:
"""A utility class to encode and decode action blocks."""
class ActionEncoder:
def __init__(self) -> None:
"""Initialize the ActionParser."""
self.pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
@@ -40,6 +41,29 @@ class ActionParser:
else:
return None
def encode_as_str(self, function_call: FunctionCall) -> str:
"""Encode the action."""
if function_call.name == "Final Answer":
return f"<action>{{'action': 'Final Answer', 'action_input': '{function_call.arguments}'}}</action>"
return f"<action>{{'action': '{function_call.name}', 'action_input': {function_call.arguments}}}</action>"
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
"""Generate a prompt from a log of message like objects."""
messages = []
for message in log.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
SystemMessage(
content=f"Observation: {message.result}",
)
)
else:
pass
return messages
class ChatAgent:
"""An agent for chat models."""
@@ -52,11 +76,12 @@ class ChatAgent:
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
action_encoder = ActionEncoder()
self.llm_program = create_llm_program(
llm,
prompt_generator=MessageLogPromptValue.from_message_log,
tools=tools,
parser=OpenAIFunctionsParser(),
parser=action_encoder.decode,
)
self.max_iterations = max_iterations

View File

@@ -3,10 +3,8 @@ from __future__ import annotations
import json
from typing import Sequence, List
from langchain.automaton.prompt_generators import MessageLogPromptValue
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.typedefs import (
MessageLike,
MessageLog,
AgentFinish,
FunctionCall,
@@ -43,13 +41,19 @@ class OpenAIFunctionsParser(BaseGenerationOutputParser):
)
def message_adapter(message: MessageLike) -> List[BaseMessage]:
if isinstance(message, BaseMessage):
return [message]
elif isinstance(message, FunctionResult):
return [FunctionMessage(name=message.name, content=json.dumps(message.result))]
else:
return []
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
"""Generate a prompt from a log of message like objects."""
messages = []
for message in log.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
FunctionMessage(name=message.name, content=json.dumps(message.result))
)
else:
pass
return messages
class OpenAIAgent:
@@ -63,7 +67,7 @@ class OpenAIAgent:
"""Initialize the chat automaton."""
self.llm_program = create_llm_program(
llm,
prompt_generator=MessageLogPromptValue.from_message_log,
prompt_generator=prompt_generator,
tools=tools,
parser=OpenAIFunctionsParser(),
)

View File

@@ -44,7 +44,9 @@ def create_tool_invoker(
def create_llm_program(
llm: BaseLanguageModel,
prompt_generator: Callable[[MessageLog], PromptValue],
prompt_generator: Callable[
[MessageLog], Union[str, PromptValue, Sequence[BaseMessage]]
],
*,
tools: Optional[Sequence[BaseTool]] = None,
stop: Optional[Sequence[str]] = None,
@@ -91,6 +93,8 @@ def create_llm_program(
)
messages.append(parsed_result)
if not messages:
raise AssertionError(f"Expected at least one message")
last_message = messages[-1]
if tool_invoker and isinstance(last_message, FunctionCall):