mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
x
This commit is contained in:
@@ -6,20 +6,20 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from langchain.automaton.runnables import (
|
||||
create_llm_program,
|
||||
)
|
||||
from langchain.automaton.tool_utils import generate_tool_info
|
||||
from langchain.automaton.typedefs import (
|
||||
AdHocMessage,
|
||||
Agent,
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
MessageLike,
|
||||
MessageLog,
|
||||
PrimingMessage,
|
||||
)
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
@@ -118,7 +118,7 @@ class ThinkActPromptGenerator(PromptValue):
|
||||
prompt value can do additional priming of the LLM, used in to_string method right now.
|
||||
"""
|
||||
|
||||
message_log: MessageLog
|
||||
messages: Sequence[MessageLike]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -126,10 +126,12 @@ class ThinkActPromptGenerator(PromptValue):
|
||||
def to_string(self) -> str:
|
||||
"""The string variant of the prompt."""
|
||||
finalized = []
|
||||
messages = self.message_log.messages
|
||||
messages = self.messages
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, PrimingMessage):
|
||||
component = message.content
|
||||
if isinstance(message, AdHocMessage):
|
||||
if message.type != "prime":
|
||||
raise AssertionError()
|
||||
component = message.data
|
||||
finalized.append(component)
|
||||
continue
|
||||
|
||||
@@ -152,7 +154,7 @@ class ThinkActPromptGenerator(PromptValue):
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
messages = []
|
||||
for message in self.message_log.messages:
|
||||
for message in self.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
@@ -162,9 +164,9 @@ class ThinkActPromptGenerator(PromptValue):
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def from_message_log(cls, message_log: MessageLog):
|
||||
def from_message_log(cls, messages: Sequence[MessageLike]):
|
||||
"""Create a ThinkActPromptGenerator from a MessageLog, allowing positional arguments."""
|
||||
return cls(message_log=message_log)
|
||||
return cls(messages=messages)
|
||||
|
||||
|
||||
class ThinkActAgent(Agent):
|
||||
@@ -186,19 +188,23 @@ class ThinkActAgent(Agent):
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
def run(
|
||||
self,
|
||||
messages: Sequence[MessageLike],
|
||||
*,
|
||||
max_iterations: int = 100,
|
||||
) -> Iterator[MessageLike]:
|
||||
"""Run the agent."""
|
||||
if not message_log:
|
||||
raise AssertionError(f"Expected at least one message in message_log")
|
||||
|
||||
for _ in range(self.max_iterations):
|
||||
last_message = message_log[-1]
|
||||
|
||||
if isinstance(last_message, AgentFinish):
|
||||
all_messages = list(messages)
|
||||
for _ in range(max_iterations):
|
||||
if all_messages and isinstance(all_messages[-1], AgentFinish):
|
||||
break
|
||||
|
||||
# Prime the LLM to start with "Thought: " after an observation
|
||||
if isinstance(last_message, (FunctionResult, HumanMessage)):
|
||||
message_log.add_messages([PrimingMessage(content="Thought:")])
|
||||
if all_messages and isinstance(
|
||||
all_messages[-1], (FunctionResult, HumanMessage)
|
||||
):
|
||||
all_messages.append(AdHocMessage(type="prime", data="Thought:"))
|
||||
|
||||
message_log.add_messages(self.think_act.invoke(message_log))
|
||||
new_messages = self.think_act.invoke(all_messages)
|
||||
yield from new_messages
|
||||
all_messages.extend(new_messages)
|
||||
|
||||
@@ -45,10 +45,11 @@ class FunctionResult(InternalMessage):
|
||||
return f"FunctionResult(name={self.name}, result={self.result}, error={self.error})"
|
||||
|
||||
|
||||
class PrimingMessage(InternalMessage):
|
||||
class AdHocMessage(InternalMessage):
|
||||
"""A message that is used to prime the language model."""
|
||||
|
||||
content: str
|
||||
type: str
|
||||
data: Any # Make sure this is serializable
|
||||
|
||||
|
||||
class AgentFinish(InternalMessage):
|
||||
|
||||
Reference in New Issue
Block a user