mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
x
This commit is contained in:
@@ -3,22 +3,23 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import Sequence, Union, Optional
|
||||
from typing import Sequence, Union, Optional, List
|
||||
|
||||
from langchain.automaton.prompt_generators import MessageLogPromptValue
|
||||
from langchain.automaton.runnables import create_llm_program
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLog,
|
||||
AgentFinish, MessageLike, FunctionCall,
|
||||
AgentFinish,
|
||||
MessageLike,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import SystemMessage, BaseMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ActionParser:
|
||||
"""A utility class to encode and decode action blocks."""
|
||||
|
||||
class ActionEncoder:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the ActionParser."""
|
||||
self.pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
|
||||
@@ -40,6 +41,29 @@ class ActionParser:
|
||||
else:
|
||||
return None
|
||||
|
||||
def encode_as_str(self, function_call: FunctionCall) -> str:
|
||||
"""Encode the action."""
|
||||
if function_call.name == "Final Answer":
|
||||
return f"<action>{{'action': 'Final Answer', 'action_input': '{function_call.arguments}'}}</action>"
|
||||
return f"<action>{{'action': '{function_call.name}', 'action_input': {function_call.arguments}}}</action>"
|
||||
|
||||
|
||||
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
|
||||
"""Generate a prompt from a log of message like objects."""
|
||||
messages = []
|
||||
for message in log.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
messages.append(
|
||||
SystemMessage(
|
||||
content=f"Observation: {message.result}",
|
||||
)
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return messages
|
||||
|
||||
|
||||
class ChatAgent:
|
||||
"""An agent for chat models."""
|
||||
@@ -52,11 +76,12 @@ class ChatAgent:
|
||||
max_iterations: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
action_encoder = ActionEncoder()
|
||||
self.llm_program = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=MessageLogPromptValue.from_message_log,
|
||||
tools=tools,
|
||||
parser=OpenAIFunctionsParser(),
|
||||
parser=action_encoder.decode,
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
|
||||
@@ -3,10 +3,8 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Sequence, List
|
||||
|
||||
from langchain.automaton.prompt_generators import MessageLogPromptValue
|
||||
from langchain.automaton.runnables import create_llm_program
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLike,
|
||||
MessageLog,
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
@@ -43,13 +41,19 @@ class OpenAIFunctionsParser(BaseGenerationOutputParser):
|
||||
)
|
||||
|
||||
|
||||
def message_adapter(message: MessageLike) -> List[BaseMessage]:
|
||||
if isinstance(message, BaseMessage):
|
||||
return [message]
|
||||
elif isinstance(message, FunctionResult):
|
||||
return [FunctionMessage(name=message.name, content=json.dumps(message.result))]
|
||||
else:
|
||||
return []
|
||||
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
|
||||
"""Generate a prompt from a log of message like objects."""
|
||||
messages = []
|
||||
for message in log.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
messages.append(
|
||||
FunctionMessage(name=message.name, content=json.dumps(message.result))
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return messages
|
||||
|
||||
|
||||
class OpenAIAgent:
|
||||
@@ -63,7 +67,7 @@ class OpenAIAgent:
|
||||
"""Initialize the chat automaton."""
|
||||
self.llm_program = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=MessageLogPromptValue.from_message_log,
|
||||
prompt_generator=prompt_generator,
|
||||
tools=tools,
|
||||
parser=OpenAIFunctionsParser(),
|
||||
)
|
||||
|
||||
@@ -44,7 +44,9 @@ def create_tool_invoker(
|
||||
|
||||
def create_llm_program(
|
||||
llm: BaseLanguageModel,
|
||||
prompt_generator: Callable[[MessageLog], PromptValue],
|
||||
prompt_generator: Callable[
|
||||
[MessageLog], Union[str, PromptValue, Sequence[BaseMessage]]
|
||||
],
|
||||
*,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
@@ -91,6 +93,8 @@ def create_llm_program(
|
||||
)
|
||||
messages.append(parsed_result)
|
||||
|
||||
if not messages:
|
||||
raise AssertionError(f"Expected at least one message")
|
||||
last_message = messages[-1]
|
||||
|
||||
if tool_invoker and isinstance(last_message, FunctionCall):
|
||||
|
||||
Reference in New Issue
Block a user