This commit is contained in:
Eugene Yurtsev
2023-09-07 15:54:08 -04:00
parent 0396570c82
commit ba61d64c04
2 changed files with 30 additions and 23 deletions

View File

@@ -6,20 +6,20 @@ from __future__ import annotations
import ast
import re
from typing import List, Optional, Sequence, Union
from typing import Iterator, List, Optional, Sequence, Union
from langchain.automaton.runnables import (
create_llm_program,
)
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.typedefs import (
AdHocMessage,
Agent,
AgentFinish,
FunctionCall,
FunctionResult,
MessageLike,
MessageLog,
PrimingMessage,
)
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import (
@@ -118,7 +118,7 @@ class ThinkActPromptGenerator(PromptValue):
prompt value can do additional priming of the LLM, used in to_string method right now.
"""
message_log: MessageLog
messages: Sequence[MessageLike]
class Config:
arbitrary_types_allowed = True
@@ -126,10 +126,12 @@ class ThinkActPromptGenerator(PromptValue):
def to_string(self) -> str:
"""The string variant of the prompt."""
finalized = []
messages = self.message_log.messages
messages = self.messages
for idx, message in enumerate(messages):
if isinstance(message, PrimingMessage):
component = message.content
if isinstance(message, AdHocMessage):
if message.type != "prime":
raise AssertionError()
component = message.data
finalized.append(component)
continue
@@ -152,7 +154,7 @@ class ThinkActPromptGenerator(PromptValue):
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
messages = []
for message in self.message_log.messages:
for message in self.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
@@ -162,9 +164,9 @@ class ThinkActPromptGenerator(PromptValue):
return messages
@classmethod
def from_message_log(cls, message_log: MessageLog):
def from_message_log(cls, messages: Sequence[MessageLike]):
"""Create a ThinkActPromptGenerator from a MessageLog, allowing positional arguments."""
return cls(message_log=message_log)
return cls(messages=messages)
class ThinkActAgent(Agent):
@@ -186,19 +188,23 @@ class ThinkActAgent(Agent):
)
self.max_iterations = max_iterations
def run(self, message_log: MessageLog) -> None:
def run(
self,
messages: Sequence[MessageLike],
*,
max_iterations: int = 100,
) -> Iterator[MessageLike]:
"""Run the agent."""
if not message_log:
raise AssertionError(f"Expected at least one message in message_log")
for _ in range(self.max_iterations):
last_message = message_log[-1]
if isinstance(last_message, AgentFinish):
all_messages = list(messages)
for _ in range(max_iterations):
if all_messages and isinstance(all_messages[-1], AgentFinish):
break
# Prime the LLM to start with "Thought: " after an observation
if isinstance(last_message, (FunctionResult, HumanMessage)):
message_log.add_messages([PrimingMessage(content="Thought:")])
if all_messages and isinstance(
all_messages[-1], (FunctionResult, HumanMessage)
):
all_messages.append(AdHocMessage(type="prime", data="Thought:"))
message_log.add_messages(self.think_act.invoke(message_log))
new_messages = self.think_act.invoke(all_messages)
yield from new_messages
all_messages.extend(new_messages)

View File

@@ -45,10 +45,11 @@ class FunctionResult(InternalMessage):
return f"FunctionResult(name={self.name}, result={self.result}, error={self.error})"
class PrimingMessage(InternalMessage):
class AdHocMessage(InternalMessage):
"""A message that is used to prime the language model."""
content: str
type: str
data: Any # Make sure this is serializable
class AgentFinish(InternalMessage):